Spaces:
Build error
Build error
| import streamlit as st | |
| import pandas as pd | |
| import matplotlib.pyplot as plt | |
| import plotly.express as px | |
| from dotenv import load_dotenv | |
| from langchain.agents.agent_types import AgentType | |
| from langchain_experimental.agents.agent_toolkits import create_pandas_dataframe_agent | |
| from langchain_openai import ChatOpenAI | |
| import os | |
| import seaborn as sns | |
| import plotly.graph_objects as go | |
| import json | |
| import pdfkit | |
| import io | |
| import base64 | |
| from matplotlib.backends.backend_agg import FigureCanvasAgg | |
| import html | |
| import re | |
| from openai import OpenAI | |
| from io import StringIO | |
| load_dotenv() | |
| # --- Configuration --- | |
| OPENAI_API_KEY=os.getenv("OPENAI_API_KEY") or st.secrets.get("OPENAI_API_KEY") | |
| client = OpenAI(api_key=OPENAI_API_KEY) | |
| csv_path = "asig_sales_31012025.csv" | |
| if not os.path.exists(csv_path): | |
| print(f"Error: CSV file '{csv_path}' not found.") | |
| exit(1) | |
| def get_csv_sample(csv_path, sample_size=5): | |
| """Reads a CSV file and returns column info, a sample, and the DataFrame.""" | |
| df = pd.read_csv(csv_path) | |
| sample_df = df.sample(n=min(sample_size, len(df)), random_state=42) | |
| return df.dtypes.to_string(), sample_df.to_string(index=False), df | |
| column_info, sample_str, _ = get_csv_sample(csv_path) | |
| # @observe() | |
| def chat(response_text): | |
| return json.loads(response_text) # Directly parse the JSON | |
| def generate_code(question, column_info, sample_str, csv_path, model_name="gpt-4o"): | |
| """Asks OpenAI to generate Pandas code for a given question.""" | |
| prompt = f"""You are a highly skilled Python data analyst with expert-level proficiency in Pandas. Your task is to write **concise, correct, and efficient** Pandas code to answer a specific question about data contained within a CSV file. The code you generate must be self-contained, directly executable, and produce the correct numerical output or DataFrame structure. | |
| **CSV File Information:** | |
| * **Path:** '{csv_path}' | |
| * **Column Information:** (This tells you the names and data types of the columns) | |
| ``` | |
| {column_info} | |
| ``` | |
| * **Sample Data:** (This gives you a glimpse of the data's structure. Note the European date format DD/MM/YYYY) | |
| ``` | |
| {sample_str} | |
| ``` | |
| **Strict Requirements (Follow these EXACTLY):** | |
| 0. **Multi-part Questions:** | |
| * If the user asks a multi-part question, **reformat it** to process each part correctly while maintaining the original meaning. **Do not change the intent** of the question. | |
| * **For multi-part questions**, the code should reflect how each part of the question is handled. You must ensure that each part is processed and combined correctly at the end. | |
| * **Print a statement** explaining how you processed the multi-part question, e.g., `print("Question was split into parts for processing.")`. | |
| 1. **Load Data and Parse Dates:** Your code *MUST* begin with the following line to load the data, correctly parsing *ALL* potential date columns: | |
| ```python | |
| import pandas as pd | |
| df = pd.read_csv('{csv_path}', parse_dates=['HIST_DATE', 'DATA_SEM_OFERTA', 'DATA_STARE_CERERE', 'DATA_IN_OFERTA', 'CTR_DATA_START', 'CTR_DATA_STATUS'], dayfirst=True) | |
| ``` | |
| Do *NOT* modify this line. The `parse_dates` argument is *critical* for correct date handling, and `dayfirst=True` is absolutely required because dates are in European DD/MM/YYYY format. | |
| 2. **Imports:** Do *NOT* import any libraries other than pandas (which is already imported as `pd`). Do *NOT* use `numpy` or `datetime` directly, unless it is used within the context of parsing in read_csv. Pandas is sufficient for all tasks. | |
| 3. **Output:** | |
| * Store your final answer in a variable named `result`. | |
| * Print the `result` variable using `print(result)`. | |
| * Do *NOT* use `display()`. | |
| * The output must be a Pandas DataFrame, Series, or a single value, as appropriate for the question. If it's a DataFrame or Series, ensure the index is reset where appropriate (e.g., after a `groupby()` followed by `.size()`). | |
| 4. **Conciseness and Style:** | |
| * Write the *most concise* and efficient Pandas code possible. | |
| * Use method chaining (e.g., `df.groupby(...).sum().sort_values().head()`) whenever possible and appropriate. | |
| * Avoid unnecessary intermediate variables unless they *significantly* improve readability. | |
| * Use clear and understandable variable names for filtered dataframes, (for example: df_2010, df_filtered etc) | |
| * If calculating a percentage or distribution, combine operations efficiently, ideally in a single chained expression. | |
| 5. **Correctness:** Your code *MUST* be syntactically correct Python and *MUST* produce the correct answer to the question. Double-check your logic, especially when grouping and aggregating. Pay close attention to the wording of the question. | |
| 6. **Date and Time Conditions (Implicit Filtering):** | |
| * **Any question that refers to dates, time periods, months, years, or uses phrases like "issued in," "policies from," "between [dates]," etc., *MUST* filter the data using the `DATA_SEM_OFERTA` column.** This is the *implied* date column for policy issuance. Do *NOT* ask the user which column to use; assume `DATA_SEM_OFERTA`. | |
| * When filtering dates, use combined boolean conditions for efficiency, e.g., `df[(df['DATA_SEM_OFERTA'].dt.year == 2010) & (df['DATA_SEM_OFERTA'].dt.month == 12)]` rather than separate filtering steps. | |
| 7. **Column Names:** Use the *exact* column names provided in the "CSV Column Information." Pay close attention to capitalization, spaces, and any special characters. | |
| 8. **No Explanations:** Output *ONLY* the Python code. Do *NOT* include any comments, explanations, surrounding text, or markdown formatting (like ```python). Just the code. | |
| 9. **Aggregation (VERY IMPORTANT):** When the question asks for: | |
| * "top N" or "first N" | |
| * "most frequent" | |
| * "highest/lowest" (after grouping) | |
| * "average/sum/count per [group]" | |
| * **Calculate Percentage**: When percentage is asked, compute the correct percentage value | |
| You *MUST* perform a `groupby()` operation *BEFORE* sorting or selecting the top N values. The correct order is: | |
| 1. Filter the DataFrame (if needed, using boolean indexing). | |
| 2. Group by the appropriate column(s) using `.groupby()`. | |
| 3. Apply an aggregation function (e.g., `.sum()`, `.mean()`, `.size()`, `.count()`, `.median()`). | |
| 4. *Then*, sort (if needed) using `.sort_values()` and/or select the top N (if needed) using `.nlargest()` or `.head()`. | |
| 10. **Error Handling:** Assume the CSV file exists and is correctly formatted. You do *not* need to write any explicit error handling code. | |
| 11. **Clarity:** Use clear and meaningful variable names if you create intermediate dataframes, but prioritize conciseness. | |
| **Column Usage Guidance:** | |
| 13. primele means .nlargest and ultimele means .nsmallest | |
| * Use `CTR_STATUS` when a concise or coded representation of the contract status is needed (e.g., for technical filtering or matching with system data). | |
| * Use `CTR_DESCRIERE_STATUS` when a human-readable description is required (e.g., for distributions, summaries, or grouping by status type, such as "Activ", "Reziliat"). Default to `CTR_DESCRIERE_STATUS` for questions involving totals, distributions, or descriptive analysis unless the question specifies a coded status. | |
| * Use `COD_SUCURSALA` for numerical branch identification (e.g., filtering or joining with other datasets); use `DENUMIRE_SUCURSALA` for human-readable branch names (e.g., grouping or summarizing by branch name). | |
| * Use `COD_AGENTIE` for numerical agency identification; use `DENUMIRE_AGENTIE` for human-readable agency names, preferring the latter for summaries or rankings. | |
| * Use `DATA_SEM_OFERTA` as the implied date column for policy issuance or time-based filtering (e.g., "issued in", "per month"), unless the question specifies another date column. | |
| * Use `PBA_BAZA`, `PBA_ASIG_SUPLIM`, `PBA_TOTAL_SEMNARE_CERERE`, and `PBA_TOTAL_EMITERE_CERERE` for financial aggregations (e.g., sum, mean) based on the specific PBA type mentioned in the question. | |
| **Question:** | |
| {question} | |
| """ | |
| response = client.chat.completions.create(model=model_name, | |
| temperature=0, # Keep temperature at 0 for consistent, deterministic code | |
| messages=[ | |
| {"role": "system", "content": "You are a helpful assistant that generates Python code."}, | |
| {"role": "user", "content": prompt} | |
| ]) | |
| code_to_execute = response.choices[0].message.content.strip() | |
| code_to_execute = code_to_execute.replace("```python", "").replace("```", "").strip() | |
| return code_to_execute | |
| def execute_code(generated_code, csv_path): | |
| """Executes the generated Pandas code and captures the output.""" | |
| local_vars = {"pd": pd, "__file__": csv_path} | |
| exec(generated_code, {}, local_vars) | |
| return local_vars.get("result") | |
| def fig_to_base64(fig): | |
| buf = io.BytesIO() | |
| fig.savefig(buf, format="png", bbox_inches="tight") | |
| buf.seek(0) | |
| img_str = base64.b64encode(buf.getvalue()).decode("utf-8") | |
| buf.close() | |
| return img_str | |
| def plotly_to_base64(fig): | |
| img_bytes = fig.to_image(format="png", scale=2) | |
| img_str = base64.b64encode(img_bytes).decode("utf-8") | |
| return img_str | |
| def generate_plots(metadata, categories, values): | |
| # Filter numeric values and categories | |
| numeric_values = [v for v in values if isinstance(v, (int, float))] | |
| numeric_categories = [c for c, v in zip(categories, values) if isinstance(v, (int, float))] | |
| if not numeric_values: | |
| st.warning("No numeric data to plot for this query.") | |
| return [] | |
| sorted_categories, sorted_values = zip(*sorted(zip(numeric_categories, numeric_values), key=lambda x: x[1], reverse=True)) | |
| plots = [] | |
| if all(isinstance(c, str) for c in categories) and all(isinstance(v, (int, float)) for v in values): | |
| sorted_categories, sorted_values = zip(*sorted(zip(categories, values), key=lambda x: x[1], reverse=True)) | |
| # Bar Plot (Main plot for string categories and numeric values) | |
| fig_bar = px.bar(x=sorted_values, y=sorted_categories, orientation="h", | |
| labels={"x": "Value", "y": "Category"}, | |
| title=f"{metadata['query']} (Bar Chart)", | |
| color=sorted_values, color_continuous_scale="blues") | |
| fig_bar.update_layout(yaxis=dict(categoryorder="total ascending")) | |
| st.plotly_chart(fig_bar) | |
| plots.append(("Bar Chart (Plotly)", plotly_to_base64(fig_bar))) | |
| # Numeric plots (only if there are numeric values) | |
| if any(isinstance(v, (int, float)) for v in values): | |
| numeric_values = [v for v in values if isinstance(v, (int, float))] | |
| numeric_categories = [c for c, v in zip(categories, values) if isinstance(v, (int, float))] | |
| if numeric_values: | |
| sorted_categories, sorted_values = zip(*sorted(zip(numeric_categories, numeric_values), key=lambda x: x[1], reverse=True)) | |
| # Bar Plot (Plotly) | |
| fig1 = px.bar(x=sorted_categories, y=sorted_values, labels={"x": "Category", "y": metadata.get("unit", "Value")}, | |
| title=f"{metadata['query']} (Plotly Bar)", color=sorted_values, color_continuous_scale="blues") | |
| st.plotly_chart(fig1) | |
| plots.append(("Bar Plot (Plotly)", plotly_to_base64(fig1))) | |
| # Pie Chart | |
| fig2, ax2 = plt.subplots(figsize=(10, 8)) | |
| cmap = plt.get_cmap("tab20c") | |
| colors = [cmap(i) for i in range(len(sorted_categories))] | |
| wedges, texts = ax2.pie(sorted_values, labels=None, autopct=None, startangle=140, colors=colors, wedgeprops=dict(width=0.4)) | |
| legend_labels = [f"{cat} ({val / sum(sorted_values):.1%})" for cat, val in zip(sorted_categories, sorted_values)] | |
| ax2.legend(wedges, legend_labels, title="Categories", loc="center left", bbox_to_anchor=(1, 0, 0.5, 1), fontsize=10) | |
| ax2.axis("equal") | |
| ax2.set_title(f"{metadata['query']} (Pie)", fontsize=16) | |
| st.pyplot(fig2) | |
| plots.append(("Pie Chart", fig_to_base64(fig2))) | |
| plt.close(fig2) | |
| # Histogram | |
| fig3, ax3 = plt.subplots(figsize=(10, 6)) | |
| ax3.hist(sorted_values, bins=10, color="skyblue", edgecolor="black") | |
| ax3.set_title(f"Distribution of {metadata['query']} (Histogram)", fontsize=16) | |
| st.pyplot(fig3) | |
| plots.append(("Histogram", fig_to_base64(fig3))) | |
| plt.close(fig3) | |
| # Heatmap | |
| fig4, ax4 = plt.subplots(figsize=(10, 6)) | |
| data_matrix = pd.DataFrame({metadata.get("unit", "Value"): sorted_values}, index=sorted_categories) | |
| sns.heatmap(data_matrix, annot=True, cmap="Blues", ax=ax4, fmt=".1f") | |
| ax4.set_title(f"{metadata['query']} (Heatmap)", fontsize=16) | |
| st.pyplot(fig4) | |
| plots.append(("Heatmap", fig_to_base64(fig4))) | |
| plt.close(fig4) | |
| # Scatter Plot | |
| fig5 = px.scatter(x=sorted_categories, y=sorted_values, title=f"{metadata['query']} (Scatter Plot)", | |
| labels={"x": "Category", "y": metadata.get("unit", "Value")}) | |
| st.plotly_chart(fig5) | |
| plots.append(("Scatter Plot (Plotly)", plotly_to_base64(fig5))) | |
| # Line Plot | |
| fig6 = px.line(x=sorted_categories, y=sorted_values, title=f"{metadata['query']} (Line Plot)", | |
| labels={"x": "Category", "y": metadata.get("unit", "Value")}) | |
| st.plotly_chart(fig6) | |
| plots.append(("Line Plot (Plotly)", plotly_to_base64(fig6))) | |
| # Box Plot | |
| fig7, ax7 = plt.subplots(figsize=(10, 6)) | |
| ax7.boxplot(sorted_values, vert=False, tick_labels=["Data"], patch_artist=True) | |
| ax7.set_title(f"{metadata['query']} (Box Plot)", fontsize=16) | |
| st.pyplot(fig7) | |
| plots.append(("Box Plot", fig_to_base64(fig7))) | |
| plt.close(fig7) | |
| # Violin Plot | |
| fig8, ax8 = plt.subplots(figsize=(10, 6)) | |
| ax8.violinplot(sorted_values, vert=False, showmeans=True, showextrema=True) | |
| ax8.set_title(f"{metadata['query']} (Violin Plot)", fontsize=16) | |
| st.pyplot(fig8) | |
| plots.append(("Violin Plot", fig_to_base64(fig8))) | |
| plt.close(fig8) | |
| # Area Chart | |
| fig9 = px.area(x=sorted_categories, y=sorted_values, title=f"{metadata['query']} (Area Chart)", labels={"x": "Category", "y": metadata.get("unit", "Value")}) | |
| st.plotly_chart(fig9) | |
| plots.append(("Area Chart (Plotly)", plotly_to_base64(fig9))) | |
| # Radar Chart | |
| fig10 = go.Figure(data=go.Scatterpolar(r=sorted_values, theta=sorted_categories, fill='toself', name=metadata['query'])) | |
| fig10.update_layout(polar=dict(radialaxis=dict(visible=True)), showlegend=True, title=f"{metadata['query']} (Radar Chart)") | |
| st.plotly_chart(fig10) | |
| plots.append(("Radar Chart (Plotly)", plotly_to_base64(fig10))) | |
| else: | |
| st.warning("No numeric data to plot for this query.") | |
| return plots | |
| def sanitize_filename(filename): | |
| return re.sub(r'[^a-zA-Z0-9]', '_', filename) | |
| def generate_pdf(query, response_text, chat_response, plots): | |
| query = html.unescape(query) | |
| response_text = html.unescape(response_text) | |
| escaped_query = html.escape(query) | |
| escaped_response_text = html.escape(response_text) | |
| html_content = f""" | |
| <!DOCTYPE html> | |
| <html lang="ro"> | |
| <head> | |
| <title>Data Analysis Report</title> | |
| <meta charset="UTF-8"> | |
| <style> | |
| body {{ font-family: Arial, sans-serif; margin: 20px; background-color: #f9f9f9; color: #333; }} | |
| h1 {{ color: #1f77b4; text-align: center; }} | |
| h3 {{ color: #2c3e50; border-bottom: 2px solid #ddd; padding-bottom: 5px; }} | |
| h4 {{ color: #2980b9; }} | |
| p {{ line-height: 1.6; background-color: #fff; padding: 10px; border-radius: 5px; box-shadow: 0 1px 3px rgba(0,0,0,0.1); }} | |
| pre {{ background-color: #ecf0f1; padding: 10px; border-radius: 5px; font-size: 12px; }} | |
| table {{ border-collapse: collapse; width: 100%; margin: 10px 0; page-break-inside: avoid; }} | |
| th, td {{ border: 1px solid #bdc3c7; padding: 10px; text-align: left; }} | |
| th {{ background-color: #3498db; color: white; }} | |
| td {{ background-color: #fff; }} | |
| img {{ max-width: 100%; height: auto; margin: 10px 0; page-break-inside: avoid; }} | |
| .section {{ margin-bottom: 20px; }} | |
| .no-break {{ page-break-inside: avoid; }} | |
| .powered-by {{ text-align: center; margin-top: 20px; font-size: 10px; color: #777; }} | |
| .logo {{ height: 100px; }} | |
| </style> | |
| </head> | |
| <body> | |
| <h1>Data Analysis Agent Interface</h1> | |
| <div class="section no-break"><h3>Query</h3><p>{escaped_query}</p></div> | |
| <div class="section no-break"><h3>Response</h3><p>{escaped_response_text}</p></div> | |
| <div class="section no-break"> | |
| <h3>Raw Structured Response</h3> | |
| <h4>Metadata</h4><pre>{json.dumps(chat_response["metadata"], indent=2, ensure_ascii=False)}</pre> | |
| <h4>Data</h4>{pd.DataFrame(chat_response["data"]).to_html(index=False, classes="no-break", escape=False)} | |
| </div> | |
| <div class="section"><h3>Plots</h3>{"".join([f'<div class="no-break"><h4>{name}</h4><img src="data:image/png;base64,{base64}"/></div>' for name, base64 in plots])}</div> | |
| <div class="powered-by">Powered by <img src="data:image/png;base64,{get_zega_logo_base64()}" class="logo"></div> | |
| </body></html> | |
| """ | |
| html_file = "temp.html" | |
| sanitized_query = sanitize_filename(query) | |
| os.makedirs("./exported_pdfs", exist_ok=True) | |
| pdf_file = f"./exported_pdfs/{sanitized_query}.pdf" | |
| try: | |
| with open(html_file, "w", encoding="utf-8") as f: | |
| f.write(html_content) | |
| options = {'encoding': "UTF-8", 'custom-header': [('Content-Type', 'text/html; charset=UTF-8')], 'no-outline': None} | |
| pdfkit.from_file(html_file, pdf_file, options=options) | |
| os.remove(html_file) | |
| except Exception as e: | |
| raise | |
| return pdf_file | |
| def get_zega_logo_base64(): | |
| try: | |
| with open("zega_logo.png", "rb") as image_file: | |
| encoded_string = base64.b64encode(image_file.read()).decode("utf-8") | |
| return encoded_string | |
| except Exception as e: | |
| raise | |
| # Streamlit Interface | |
| st.title("Data Analysis Agent Interface") | |
| st.sidebar.markdown( | |
| f""" | |
| <div style="text-align: center;"> | |
| Powered by <img src="data:image/png;base64,{get_zega_logo_base64()}" style="height: 100px;"> | |
| </div> | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |
| st.sidebar.header("Sample Questions") | |
| sample_questions = [ | |
| "Da-mi top cinci sucursale cu vânzări în perioada 01.03.2024-01.04.2024.", | |
| "Da-mi vânzările defalcate pe produse pentru top cinci sucursale cu vânzări în perioada 01.03.2024-01.04.2024.", | |
| "Da-mi vânzările defalcate pe pachete pentru top cinci sucursale cu vânzări în perioada 01.03.2024-01.04.2024.", | |
| ] | |
| selected_question = st.sidebar.selectbox("Select a sample question:", sample_questions) | |
| user_query = st.text_area("Please write one question at a time.", value=selected_question, height=100) | |
| def process_query(): | |
| try: | |
| generated_code = generate_code(user_query, column_info, sample_str, csv_path) | |
| result = execute_code(generated_code, csv_path) | |
| if isinstance(result, pd.DataFrame): | |
| chat_response = { | |
| "metadata": {"query": user_query, "unit": "", "plot_types": []}, | |
| "data": result.to_dict(orient='records'), | |
| "csv_data": result.to_dict(orient='records'), | |
| } | |
| elif isinstance(result, pd.Series): | |
| result = result.reset_index() | |
| chat_response = { | |
| "metadata": {"query": user_query, "unit": "", "plot_types": []}, | |
| "data": result.to_dict(orient='records'), | |
| "csv_data": result.to_dict(orient='records'), | |
| } | |
| elif isinstance(result, list): | |
| if all(isinstance(item, (int, float)) for item in result): | |
| chat_response = { | |
| "metadata": {"query": user_query, "unit": "", "plot_types": []}, | |
| "data": [{"category": str(i), "value": v} for i, v in enumerate(result)], | |
| "csv_data": [{"category": str(i), "value": v} for i, v in enumerate(result)], | |
| } | |
| elif all(isinstance(item, dict) for item in result): | |
| chat_response = { | |
| "metadata": {"query": user_query, "unit": "", "plot_types": []}, | |
| "data": result, | |
| "csv_data": result, | |
| } | |
| else: | |
| st.warning("Result is a list with mixed data types. Please inspect.") | |
| return | |
| else: | |
| chat_response = { | |
| "metadata": {"query": user_query, "unit": "", "plot_types": []}, | |
| "data": [{"category": "Result", "value": result}], | |
| "csv_data": [{"category": "Result", "value": result}], | |
| } | |
| st.markdown(f"<h3 style='color: #2e86de;'>Question:</h3>", unsafe_allow_html=True) | |
| st.markdown(f"<p style='color: #2e86de;'>{user_query}</p>", unsafe_allow_html=True) | |
| st.write("-" * 200) | |
| # Initially hide the code. | |
| with st.expander("Show the code"): | |
| st.code(generated_code, language="python") | |
| st.write("-" * 200) | |
| st.markdown("### Data:") | |
| st.dataframe(pd.DataFrame(chat_response["data"])) | |
| metadata = chat_response["metadata"] | |
| data = chat_response["data"] | |
| if data and isinstance(data, list) and isinstance(data[0], dict): | |
| if len(data[0]) == 1: | |
| categories = [item[list(item.keys())[0]] for item in data] | |
| values = categories | |
| else: | |
| categories = list(data[0].keys()) | |
| if len(categories) == 1: | |
| values = [item[categories[0]] for item in data] | |
| categories = values | |
| else: | |
| prioritized_columns = ["DENUMIRE_SUCURSALA", "NUMAR_CERERE", "size", "HIST_DATE", "COD_SUCURSALA", "COD_AGENTIE", | |
| "DENUMIRE_AGENTIE", "PRODUS", "DATA_SEM_OFERTA", "DATA_STARE_CERERE", "STATUS_CERERE", | |
| "DESCRIERE_STARE_CERERE", "DATA_IN_OFERTA", "PBA_BAZA", "PBA_ASIG_SUM", | |
| "PBA_TOTAL_SEMNARE_CERERE", "PBA_CTR_ASOC", "PBA_TOTAL_EMITERE_CERERE", "FRECVENTA_PLATA"] | |
| for col in prioritized_columns: | |
| if all(col in item for item in data): | |
| categories = [str(item[col]) for item in data] | |
| if col != "NUMAR_CERERE" and col != "size": | |
| if all("NUMAR_CERERE" in item for item in data): | |
| values = [item.get("NUMAR_CERERE", 0) for item in data] | |
| elif all("size" in item for item in data): | |
| values = [item.get("size", 0) for item in data] | |
| else: | |
| numeric_col = next((c for c in data[0] if isinstance(data[0][c], (int, float))), None) | |
| if numeric_col: | |
| values = [item.get(numeric_col, 0) for item in data] | |
| else: | |
| values = [str(list(item.values())[1]) for item in data] | |
| break | |
| else: | |
| values = [str(list(item.values())[1]) for item in data] | |
| elif isinstance(data, list) and all(isinstance(item, (int, float)) for item in data): | |
| categories = list(range(len(data))) | |
| values = data | |
| elif isinstance(data, (int, float, str)): | |
| categories = ["Result"] | |
| values = [data] | |
| else: | |
| categories = [] | |
| values = [] | |
| st.warning("Unexpected data format. Check the query and data.") | |
| plots = generate_plots(metadata, categories, values) | |
| st.session_state["query"] = user_query | |
| st.session_state["response_text"] = result | |
| st.session_state["chat_response"] = chat_response | |
| st.session_state["plots"] = plots | |
| st.session_state["generated_code"] = generated_code # Store the generated code | |
| except Exception as e: | |
| st.error(f"An error occurred: {e}") | |
| if st.button("Submit"): | |
| with st.spinner("Processing query..."): | |
| try: | |
| process_query() | |
| except Exception as e: | |
| st.error(f"An error occurred: {e}") | |
| if "chat_response" in st.session_state: | |
| if st.button("Download PDF"): | |
| with st.spinner("Generating PDF..."): | |
| try: | |
| pdf_file = generate_pdf( | |
| st.session_state["query"], | |
| st.session_state["response_text"], | |
| st.session_state["chat_response"], | |
| st.session_state["plots"] | |
| ) | |
| with open(pdf_file, "rb") as f: | |
| pdf_data = f.read() | |
| sanitized_query = sanitize_filename(st.session_state["query"]) | |
| st.download_button( | |
| label="Click Here to Download PDF", | |
| data=pdf_data, | |
| file_name=f"{sanitized_query}.pdf", | |
| mime="application/pdf", | |
| ) | |
| except Exception as e: | |
| st.error(f"PDF generation failed: {e}") | |