Spaces:
Build error
Build error
File size: 25,960 Bytes
1d55012 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 |
import streamlit as st
import pandas as pd
import matplotlib.pyplot as plt
import plotly.express as px
from dotenv import load_dotenv
from langchain.agents.agent_types import AgentType
from langchain_experimental.agents.agent_toolkits import create_pandas_dataframe_agent
from langchain_openai import ChatOpenAI
import os
import seaborn as sns
import plotly.graph_objects as go
import json
import pdfkit
import io
import base64
from matplotlib.backends.backend_agg import FigureCanvasAgg
import html
import re
from openai import OpenAI
from io import StringIO
load_dotenv()
# --- Configuration ---
OPENAI_API_KEY=os.getenv("OPENAI_API_KEY") or st.secrets.get("OPENAI_API_KEY")
client = OpenAI(api_key=OPENAI_API_KEY)
csv_path = "asig_sales_31012025.csv"
if not os.path.exists(csv_path):
print(f"Error: CSV file '{csv_path}' not found.")
exit(1)
def get_csv_sample(csv_path, sample_size=5):
"""Reads a CSV file and returns column info, a sample, and the DataFrame."""
df = pd.read_csv(csv_path)
sample_df = df.sample(n=min(sample_size, len(df)), random_state=42)
return df.dtypes.to_string(), sample_df.to_string(index=False), df
column_info, sample_str, _ = get_csv_sample(csv_path)
# @observe()
def chat(response_text):
return json.loads(response_text) # Directly parse the JSON
def generate_code(question, column_info, sample_str, csv_path, model_name="gpt-4o"):
"""Asks OpenAI to generate Pandas code for a given question."""
prompt = f"""You are a highly skilled Python data analyst with expert-level proficiency in Pandas. Your task is to write **concise, correct, and efficient** Pandas code to answer a specific question about data contained within a CSV file. The code you generate must be self-contained, directly executable, and produce the correct numerical output or DataFrame structure.
**CSV File Information:**
* **Path:** '{csv_path}'
* **Column Information:** (This tells you the names and data types of the columns)
```
{column_info}
```
* **Sample Data:** (This gives you a glimpse of the data's structure. Note the European date format DD/MM/YYYY)
```
{sample_str}
```
**Strict Requirements (Follow these EXACTLY):**
0. **Multi-part Questions:**
* If the user asks a multi-part question, **reformat it** to process each part correctly while maintaining the original meaning. **Do not change the intent** of the question.
* **For multi-part questions**, the code should reflect how each part of the question is handled. You must ensure that each part is processed and combined correctly at the end.
* **Print a statement** explaining how you processed the multi-part question, e.g., `print("Question was split into parts for processing.")`.
1. **Load Data and Parse Dates:** Your code *MUST* begin with the following line to load the data, correctly parsing *ALL* potential date columns:
```python
import pandas as pd
df = pd.read_csv('{csv_path}', parse_dates=['HIST_DATE', 'DATA_SEM_OFERTA', 'DATA_STARE_CERERE', 'DATA_IN_OFERTA', 'CTR_DATA_START', 'CTR_DATA_STATUS'], dayfirst=True)
```
Do *NOT* modify this line. The `parse_dates` argument is *critical* for correct date handling, and `dayfirst=True` is absolutely required because dates are in European DD/MM/YYYY format.
2. **Imports:** Do *NOT* import any libraries other than pandas (which is already imported as `pd`). Do *NOT* use `numpy` or `datetime` directly, unless it is used within the context of parsing in read_csv. Pandas is sufficient for all tasks.
3. **Output:**
* Store your final answer in a variable named `result`.
* Print the `result` variable using `print(result)`.
* Do *NOT* use `display()`.
* The output must be a Pandas DataFrame, Series, or a single value, as appropriate for the question. If it's a DataFrame or Series, ensure the index is reset where appropriate (e.g., after a `groupby()` followed by `.size()`).
4. **Conciseness and Style:**
* Write the *most concise* and efficient Pandas code possible.
* Use method chaining (e.g., `df.groupby(...).sum().sort_values().head()`) whenever possible and appropriate.
* Avoid unnecessary intermediate variables unless they *significantly* improve readability.
* Use clear and understandable variable names for filtered dataframes, (for example: df_2010, df_filtered etc)
* If calculating a percentage or distribution, combine operations efficiently, ideally in a single chained expression.
5. **Correctness:** Your code *MUST* be syntactically correct Python and *MUST* produce the correct answer to the question. Double-check your logic, especially when grouping and aggregating. Pay close attention to the wording of the question.
6. **Date and Time Conditions (Implicit Filtering):**
* **Any question that refers to dates, time periods, months, years, or uses phrases like "issued in," "policies from," "between [dates]," etc., *MUST* filter the data using the `DATA_SEM_OFERTA` column.** This is the *implied* date column for policy issuance. Do *NOT* ask the user which column to use; assume `DATA_SEM_OFERTA`.
* When filtering dates, use combined boolean conditions for efficiency, e.g., `df[(df['DATA_SEM_OFERTA'].dt.year == 2010) & (df['DATA_SEM_OFERTA'].dt.month == 12)]` rather than separate filtering steps.
7. **Column Names:** Use the *exact* column names provided in the "CSV Column Information." Pay close attention to capitalization, spaces, and any special characters.
8. **No Explanations:** Output *ONLY* the Python code. Do *NOT* include any comments, explanations, surrounding text, or markdown formatting (like ```python). Just the code.
9. **Aggregation (VERY IMPORTANT):** When the question asks for:
* "top N" or "first N"
* "most frequent"
* "highest/lowest" (after grouping)
* "average/sum/count per [group]"
* **Calculate Percentage**: When percentage is asked, compute the correct percentage value
You *MUST* perform a `groupby()` operation *BEFORE* sorting or selecting the top N values. The correct order is:
1. Filter the DataFrame (if needed, using boolean indexing).
2. Group by the appropriate column(s) using `.groupby()`.
3. Apply an aggregation function (e.g., `.sum()`, `.mean()`, `.size()`, `.count()`, `.median()`).
4. *Then*, sort (if needed) using `.sort_values()` and/or select the top N (if needed) using `.nlargest()` or `.head()`.
10. **Error Handling:** Assume the CSV file exists and is correctly formatted. You do *not* need to write any explicit error handling code.
11. **Clarity:** Use clear and meaningful variable names if you create intermediate dataframes, but prioritize conciseness.
**Column Usage Guidance:**
13. primele means .nlargest and ultimele means .nsmallest
* Use `CTR_STATUS` when a concise or coded representation of the contract status is needed (e.g., for technical filtering or matching with system data).
* Use `CTR_DESCRIERE_STATUS` when a human-readable description is required (e.g., for distributions, summaries, or grouping by status type, such as "Activ", "Reziliat"). Default to `CTR_DESCRIERE_STATUS` for questions involving totals, distributions, or descriptive analysis unless the question specifies a coded status.
* Use `COD_SUCURSALA` for numerical branch identification (e.g., filtering or joining with other datasets); use `DENUMIRE_SUCURSALA` for human-readable branch names (e.g., grouping or summarizing by branch name).
* Use `COD_AGENTIE` for numerical agency identification; use `DENUMIRE_AGENTIE` for human-readable agency names, preferring the latter for summaries or rankings.
* Use `DATA_SEM_OFERTA` as the implied date column for policy issuance or time-based filtering (e.g., "issued in", "per month"), unless the question specifies another date column.
* Use `PBA_BAZA`, `PBA_ASIG_SUPLIM`, `PBA_TOTAL_SEMNARE_CERERE`, and `PBA_TOTAL_EMITERE_CERERE` for financial aggregations (e.g., sum, mean) based on the specific PBA type mentioned in the question.
**Question:**
{question}
"""
response = client.chat.completions.create(model=model_name,
temperature=0, # Keep temperature at 0 for consistent, deterministic code
messages=[
{"role": "system", "content": "You are a helpful assistant that generates Python code."},
{"role": "user", "content": prompt}
])
code_to_execute = response.choices[0].message.content.strip()
code_to_execute = code_to_execute.replace("```python", "").replace("```", "").strip()
return code_to_execute
def execute_code(generated_code, csv_path):
"""Executes the generated Pandas code and captures the output."""
local_vars = {"pd": pd, "__file__": csv_path}
exec(generated_code, {}, local_vars)
return local_vars.get("result")
def fig_to_base64(fig):
buf = io.BytesIO()
fig.savefig(buf, format="png", bbox_inches="tight")
buf.seek(0)
img_str = base64.b64encode(buf.getvalue()).decode("utf-8")
buf.close()
return img_str
def plotly_to_base64(fig):
img_bytes = fig.to_image(format="png", scale=2)
img_str = base64.b64encode(img_bytes).decode("utf-8")
return img_str
def generate_plots(metadata, categories, values):
# Filter numeric values and categories
numeric_values = [v for v in values if isinstance(v, (int, float))]
numeric_categories = [c for c, v in zip(categories, values) if isinstance(v, (int, float))]
if not numeric_values:
st.warning("No numeric data to plot for this query.")
return []
sorted_categories, sorted_values = zip(*sorted(zip(numeric_categories, numeric_values), key=lambda x: x[1], reverse=True))
plots = []
if all(isinstance(c, str) for c in categories) and all(isinstance(v, (int, float)) for v in values):
sorted_categories, sorted_values = zip(*sorted(zip(categories, values), key=lambda x: x[1], reverse=True))
# Bar Plot (Main plot for string categories and numeric values)
fig_bar = px.bar(x=sorted_values, y=sorted_categories, orientation="h",
labels={"x": "Value", "y": "Category"},
title=f"{metadata['query']} (Bar Chart)",
color=sorted_values, color_continuous_scale="blues")
fig_bar.update_layout(yaxis=dict(categoryorder="total ascending"))
st.plotly_chart(fig_bar)
plots.append(("Bar Chart (Plotly)", plotly_to_base64(fig_bar)))
# Numeric plots (only if there are numeric values)
if any(isinstance(v, (int, float)) for v in values):
numeric_values = [v for v in values if isinstance(v, (int, float))]
numeric_categories = [c for c, v in zip(categories, values) if isinstance(v, (int, float))]
if numeric_values:
sorted_categories, sorted_values = zip(*sorted(zip(numeric_categories, numeric_values), key=lambda x: x[1], reverse=True))
# Bar Plot (Plotly)
fig1 = px.bar(x=sorted_categories, y=sorted_values, labels={"x": "Category", "y": metadata.get("unit", "Value")},
title=f"{metadata['query']} (Plotly Bar)", color=sorted_values, color_continuous_scale="blues")
st.plotly_chart(fig1)
plots.append(("Bar Plot (Plotly)", plotly_to_base64(fig1)))
# Pie Chart
fig2, ax2 = plt.subplots(figsize=(10, 8))
cmap = plt.get_cmap("tab20c")
colors = [cmap(i) for i in range(len(sorted_categories))]
wedges, texts = ax2.pie(sorted_values, labels=None, autopct=None, startangle=140, colors=colors, wedgeprops=dict(width=0.4))
legend_labels = [f"{cat} ({val / sum(sorted_values):.1%})" for cat, val in zip(sorted_categories, sorted_values)]
ax2.legend(wedges, legend_labels, title="Categories", loc="center left", bbox_to_anchor=(1, 0, 0.5, 1), fontsize=10)
ax2.axis("equal")
ax2.set_title(f"{metadata['query']} (Pie)", fontsize=16)
st.pyplot(fig2)
plots.append(("Pie Chart", fig_to_base64(fig2)))
plt.close(fig2)
# Histogram
fig3, ax3 = plt.subplots(figsize=(10, 6))
ax3.hist(sorted_values, bins=10, color="skyblue", edgecolor="black")
ax3.set_title(f"Distribution of {metadata['query']} (Histogram)", fontsize=16)
st.pyplot(fig3)
plots.append(("Histogram", fig_to_base64(fig3)))
plt.close(fig3)
# Heatmap
fig4, ax4 = plt.subplots(figsize=(10, 6))
data_matrix = pd.DataFrame({metadata.get("unit", "Value"): sorted_values}, index=sorted_categories)
sns.heatmap(data_matrix, annot=True, cmap="Blues", ax=ax4, fmt=".1f")
ax4.set_title(f"{metadata['query']} (Heatmap)", fontsize=16)
st.pyplot(fig4)
plots.append(("Heatmap", fig_to_base64(fig4)))
plt.close(fig4)
# Scatter Plot
fig5 = px.scatter(x=sorted_categories, y=sorted_values, title=f"{metadata['query']} (Scatter Plot)",
labels={"x": "Category", "y": metadata.get("unit", "Value")})
st.plotly_chart(fig5)
plots.append(("Scatter Plot (Plotly)", plotly_to_base64(fig5)))
# Line Plot
fig6 = px.line(x=sorted_categories, y=sorted_values, title=f"{metadata['query']} (Line Plot)",
labels={"x": "Category", "y": metadata.get("unit", "Value")})
st.plotly_chart(fig6)
plots.append(("Line Plot (Plotly)", plotly_to_base64(fig6)))
# Box Plot
fig7, ax7 = plt.subplots(figsize=(10, 6))
ax7.boxplot(sorted_values, vert=False, tick_labels=["Data"], patch_artist=True)
ax7.set_title(f"{metadata['query']} (Box Plot)", fontsize=16)
st.pyplot(fig7)
plots.append(("Box Plot", fig_to_base64(fig7)))
plt.close(fig7)
# Violin Plot
fig8, ax8 = plt.subplots(figsize=(10, 6))
ax8.violinplot(sorted_values, vert=False, showmeans=True, showextrema=True)
ax8.set_title(f"{metadata['query']} (Violin Plot)", fontsize=16)
st.pyplot(fig8)
plots.append(("Violin Plot", fig_to_base64(fig8)))
plt.close(fig8)
# Area Chart
fig9 = px.area(x=sorted_categories, y=sorted_values, title=f"{metadata['query']} (Area Chart)", labels={"x": "Category", "y": metadata.get("unit", "Value")})
st.plotly_chart(fig9)
plots.append(("Area Chart (Plotly)", plotly_to_base64(fig9)))
# Radar Chart
fig10 = go.Figure(data=go.Scatterpolar(r=sorted_values, theta=sorted_categories, fill='toself', name=metadata['query']))
fig10.update_layout(polar=dict(radialaxis=dict(visible=True)), showlegend=True, title=f"{metadata['query']} (Radar Chart)")
st.plotly_chart(fig10)
plots.append(("Radar Chart (Plotly)", plotly_to_base64(fig10)))
else:
st.warning("No numeric data to plot for this query.")
return plots
def sanitize_filename(filename):
return re.sub(r'[^a-zA-Z0-9]', '_', filename)
def generate_pdf(query, response_text, chat_response, plots):
query = html.unescape(query)
response_text = html.unescape(response_text)
escaped_query = html.escape(query)
escaped_response_text = html.escape(response_text)
html_content = f"""
<!DOCTYPE html>
<html lang="ro">
<head>
<title>Data Analysis Report</title>
<meta charset="UTF-8">
<style>
body {{ font-family: Arial, sans-serif; margin: 20px; background-color: #f9f9f9; color: #333; }}
h1 {{ color: #1f77b4; text-align: center; }}
h3 {{ color: #2c3e50; border-bottom: 2px solid #ddd; padding-bottom: 5px; }}
h4 {{ color: #2980b9; }}
p {{ line-height: 1.6; background-color: #fff; padding: 10px; border-radius: 5px; box-shadow: 0 1px 3px rgba(0,0,0,0.1); }}
pre {{ background-color: #ecf0f1; padding: 10px; border-radius: 5px; font-size: 12px; }}
table {{ border-collapse: collapse; width: 100%; margin: 10px 0; page-break-inside: avoid; }}
th, td {{ border: 1px solid #bdc3c7; padding: 10px; text-align: left; }}
th {{ background-color: #3498db; color: white; }}
td {{ background-color: #fff; }}
img {{ max-width: 100%; height: auto; margin: 10px 0; page-break-inside: avoid; }}
.section {{ margin-bottom: 20px; }}
.no-break {{ page-break-inside: avoid; }}
.powered-by {{ text-align: center; margin-top: 20px; font-size: 10px; color: #777; }}
.logo {{ height: 100px; }}
</style>
</head>
<body>
<h1>Data Analysis Agent Interface</h1>
<div class="section no-break"><h3>Query</h3><p>{escaped_query}</p></div>
<div class="section no-break"><h3>Response</h3><p>{escaped_response_text}</p></div>
<div class="section no-break">
<h3>Raw Structured Response</h3>
<h4>Metadata</h4><pre>{json.dumps(chat_response["metadata"], indent=2, ensure_ascii=False)}</pre>
<h4>Data</h4>{pd.DataFrame(chat_response["data"]).to_html(index=False, classes="no-break", escape=False)}
</div>
<div class="section"><h3>Plots</h3>{"".join([f'<div class="no-break"><h4>{name}</h4><img src="data:image/png;base64,{base64}"/></div>' for name, base64 in plots])}</div>
<div class="powered-by">Powered by <img src="data:image/png;base64,{get_zega_logo_base64()}" class="logo"></div>
</body></html>
"""
html_file = "temp.html"
sanitized_query = sanitize_filename(query)
os.makedirs("./exported_pdfs", exist_ok=True)
pdf_file = f"./exported_pdfs/{sanitized_query}.pdf"
try:
with open(html_file, "w", encoding="utf-8") as f:
f.write(html_content)
options = {'encoding': "UTF-8", 'custom-header': [('Content-Type', 'text/html; charset=UTF-8')], 'no-outline': None}
pdfkit.from_file(html_file, pdf_file, options=options)
os.remove(html_file)
except Exception as e:
raise
return pdf_file
def get_zega_logo_base64():
try:
with open("zega_logo.png", "rb") as image_file:
encoded_string = base64.b64encode(image_file.read()).decode("utf-8")
return encoded_string
except Exception as e:
raise
# Streamlit Interface
st.title("Data Analysis Agent Interface")
st.sidebar.markdown(
f"""
<div style="text-align: center;">
Powered by <img src="data:image/png;base64,{get_zega_logo_base64()}" style="height: 100px;">
</div>
""",
unsafe_allow_html=True,
)
st.sidebar.header("Sample Questions")
sample_questions = [
"Da-mi top cinci sucursale cu vânzări în perioada 01.03.2024-01.04.2024.",
"Da-mi vânzările defalcate pe produse pentru top cinci sucursale cu vânzări în perioada 01.03.2024-01.04.2024.",
"Da-mi vânzările defalcate pe pachete pentru top cinci sucursale cu vânzări în perioada 01.03.2024-01.04.2024.",
]
selected_question = st.sidebar.selectbox("Select a sample question:", sample_questions)
user_query = st.text_area("Please write one question at a time.", value=selected_question, height=100)
def process_query():
try:
generated_code = generate_code(user_query, column_info, sample_str, csv_path)
result = execute_code(generated_code, csv_path)
if isinstance(result, pd.DataFrame):
chat_response = {
"metadata": {"query": user_query, "unit": "", "plot_types": []},
"data": result.to_dict(orient='records'),
"csv_data": result.to_dict(orient='records'),
}
elif isinstance(result, pd.Series):
result = result.reset_index()
chat_response = {
"metadata": {"query": user_query, "unit": "", "plot_types": []},
"data": result.to_dict(orient='records'),
"csv_data": result.to_dict(orient='records'),
}
elif isinstance(result, list):
if all(isinstance(item, (int, float)) for item in result):
chat_response = {
"metadata": {"query": user_query, "unit": "", "plot_types": []},
"data": [{"category": str(i), "value": v} for i, v in enumerate(result)],
"csv_data": [{"category": str(i), "value": v} for i, v in enumerate(result)],
}
elif all(isinstance(item, dict) for item in result):
chat_response = {
"metadata": {"query": user_query, "unit": "", "plot_types": []},
"data": result,
"csv_data": result,
}
else:
st.warning("Result is a list with mixed data types. Please inspect.")
return
else:
chat_response = {
"metadata": {"query": user_query, "unit": "", "plot_types": []},
"data": [{"category": "Result", "value": result}],
"csv_data": [{"category": "Result", "value": result}],
}
st.markdown(f"<h3 style='color: #2e86de;'>Question:</h3>", unsafe_allow_html=True)
st.markdown(f"<p style='color: #2e86de;'>{user_query}</p>", unsafe_allow_html=True)
st.write("-" * 200)
# Initially hide the code.
with st.expander("Show the code"):
st.code(generated_code, language="python")
st.write("-" * 200)
st.markdown("### Data:")
st.dataframe(pd.DataFrame(chat_response["data"]))
metadata = chat_response["metadata"]
data = chat_response["data"]
if data and isinstance(data, list) and isinstance(data[0], dict):
if len(data[0]) == 1:
categories = [item[list(item.keys())[0]] for item in data]
values = categories
else:
categories = list(data[0].keys())
if len(categories) == 1:
values = [item[categories[0]] for item in data]
categories = values
else:
prioritized_columns = ["DENUMIRE_SUCURSALA", "NUMAR_CERERE", "size", "HIST_DATE", "COD_SUCURSALA", "COD_AGENTIE",
"DENUMIRE_AGENTIE", "PRODUS", "DATA_SEM_OFERTA", "DATA_STARE_CERERE", "STATUS_CERERE",
"DESCRIERE_STARE_CERERE", "DATA_IN_OFERTA", "PBA_BAZA", "PBA_ASIG_SUM",
"PBA_TOTAL_SEMNARE_CERERE", "PBA_CTR_ASOC", "PBA_TOTAL_EMITERE_CERERE", "FRECVENTA_PLATA"]
for col in prioritized_columns:
if all(col in item for item in data):
categories = [str(item[col]) for item in data]
if col != "NUMAR_CERERE" and col != "size":
if all("NUMAR_CERERE" in item for item in data):
values = [item.get("NUMAR_CERERE", 0) for item in data]
elif all("size" in item for item in data):
values = [item.get("size", 0) for item in data]
else:
numeric_col = next((c for c in data[0] if isinstance(data[0][c], (int, float))), None)
if numeric_col:
values = [item.get(numeric_col, 0) for item in data]
else:
values = [str(list(item.values())[1]) for item in data]
break
else:
values = [str(list(item.values())[1]) for item in data]
elif isinstance(data, list) and all(isinstance(item, (int, float)) for item in data):
categories = list(range(len(data)))
values = data
elif isinstance(data, (int, float, str)):
categories = ["Result"]
values = [data]
else:
categories = []
values = []
st.warning("Unexpected data format. Check the query and data.")
plots = generate_plots(metadata, categories, values)
st.session_state["query"] = user_query
st.session_state["response_text"] = result
st.session_state["chat_response"] = chat_response
st.session_state["plots"] = plots
st.session_state["generated_code"] = generated_code # Store the generated code
except Exception as e:
st.error(f"An error occurred: {e}")
if st.button("Submit"):
with st.spinner("Processing query..."):
try:
process_query()
except Exception as e:
st.error(f"An error occurred: {e}")
if "chat_response" in st.session_state:
if st.button("Download PDF"):
with st.spinner("Generating PDF..."):
try:
pdf_file = generate_pdf(
st.session_state["query"],
st.session_state["response_text"],
st.session_state["chat_response"],
st.session_state["plots"]
)
with open(pdf_file, "rb") as f:
pdf_data = f.read()
sanitized_query = sanitize_filename(st.session_state["query"])
st.download_button(
label="Click Here to Download PDF",
data=pdf_data,
file_name=f"{sanitized_query}.pdf",
mime="application/pdf",
)
except Exception as e:
st.error(f"PDF generation failed: {e}")
|