DataAnalystDemo / app_generated.py
Vlad Bastina
Merge branch 'main' of https://huggingface.co/spaces/VladB46/DataAnalystDemo into main
1ebe9f7
import streamlit as st
import pandas as pd
import matplotlib.pyplot as plt
import plotly.express as px
from dotenv import load_dotenv
from langchain.agents.agent_types import AgentType
from langchain_experimental.agents.agent_toolkits import create_pandas_dataframe_agent
from langchain_openai import ChatOpenAI
import os
import seaborn as sns
import plotly.graph_objects as go
import json
import pdfkit
import io
import base64
from matplotlib.backends.backend_agg import FigureCanvasAgg
import html
import re
from openai import OpenAI
from io import StringIO
from pathlib import Path
load_dotenv()
# --- Configuration ---
OPENAI_API_KEY=os.getenv("OPENAI_API_KEY") or st.secrets.get("OPENAI_API_KEY")
client = OpenAI(api_key=OPENAI_API_KEY)
csv_path = "SalesData.csv"
if not os.path.exists(csv_path):
print(f"Error: CSV file '{csv_path}' not found.")
exit(1)
def get_csv_sample(csv_path, sample_size=5):
"""Reads a CSV file and returns column info, a sample, and the DataFrame."""
df = pd.read_csv(csv_path)
sample_df = df.sample(n=min(sample_size, len(df)), random_state=42)
return df.dtypes.to_string(), sample_df.to_string(index=False), df
column_info, sample_str, _ = get_csv_sample(csv_path)
# @observe()
def chat(response_text):
return json.loads(response_text) # Directly parse the JSON
def generate_code(question, column_info, sample_str, csv_path, model_name="gpt-4o"):
"""Asks OpenAI to generate Pandas code for a given question."""
prompt = f"""You are a highly skilled Python data analyst with expert-level proficiency in Pandas. Your task is to write **concise, correct, and efficient** Pandas code to answer a specific question about data contained within a CSV file. The code you generate must be self-contained, directly executable, and produce the correct numerical output or DataFrame structure.
**CSV File Information:**
* **Path:** '{csv_path}'
* **Column Information:** (This tells you the names and data types of the columns)
```
{column_info}
```
* **Sample Data:** (This gives you a glimpse of the data's structure. Note the European date format DD/MM/YYYY)
```
{sample_str}
```
**Strict Requirements (Follow these EXACTLY):**
0. **Multi-part Questions:**
* If the user asks a multi-part question, **reformat it** to process each part correctly while maintaining the original meaning. **Do not change the intent** of the question.
* **For multi-part questions**, the code should reflect how each part of the question is handled. You must ensure that each part is processed and combined correctly at the end.
* **Print a statement** explaining how you processed the multi-part question, e.g., `print("Question was split into parts for processing.")`.
1. **Load Data and Parse Dates:** Your code *MUST* begin with the following line to load the data, correctly parsing *ALL* potential date columns:
```python
import pandas as pd
df = pd.read_csv('{csv_path}', parse_dates=['Order Date'])
```
Do *NOT* modify this line. The `parse_dates` argument is *critical* for correct date handling.
2. **Imports:** Do *NOT* import any libraries other than pandas (which is already imported as `pd`). Do *NOT* use `numpy` or `datetime` directly, unless it is used within the context of parsing in read_csv. Pandas is sufficient for all tasks.
3. **Output:**
* Store your final answer in a variable named `result`.
* Print the `result` variable using `print(result)`.
* Do *NOT* use `display()`.
* The output must be a Pandas DataFrame, Series, or a single value, as appropriate for the question. If it's a DataFrame or Series, ensure the index is reset where appropriate (e.g., after a `groupby()` followed by `.size()`).
4. **Conciseness and Style:**
* Write the *most concise* and efficient Pandas code possible.
* Use method chaining (e.g., `df.groupby(...).sum().sort_values().head()`) whenever possible and appropriate.
* Avoid unnecessary intermediate variables unless they *significantly* improve readability.
* Use clear and understandable variable names for filtered dataframes, (for example: df_2019, df_filtered etc)
* If calculating a percentage or distribution, combine operations efficiently, ideally in a single chained expression.
5. **Correctness:** Your code *MUST* be syntactically correct Python and *MUST* produce the correct answer to the question. Double-check your logic, especially when grouping and aggregating. Pay close attention to the wording of the question.
6. **Date and Time Conditions (Implicit Filtering):**
* **Any question that refers to dates, time periods, months, years, or uses phrases like "issued in," "policies from," "between [dates]," etc., *MUST* filter the data using the `DATA_SEM_OFERTA` column.** This is the *implied* date column for policy issuance. Do *NOT* ask the user which column to use; assume `DATA_SEM_OFERTA`.
* When filtering dates, use combined boolean conditions for efficiency, e.g., `df[(df['Order Date'].dt.year == 2019) & (df['Order Date'].dt.month == 12)]` rather than separate filtering steps.
7. **Column Names:** Use the *exact* column names provided in the "CSV Column Information." Pay close attention to capitalization, spaces, and any special characters.
8. **No Explanations:** Output *ONLY* the Python code. Do *NOT* include any comments, explanations, surrounding text, or markdown formatting (like ```python). Just the code.
9. **Aggregation (VERY IMPORTANT):** When the question asks for:
* "top N" or "first N"
* "most frequent"
* "highest/lowest" (after grouping)
* "average/sum/count per [group]"
* **Calculate Percentage**: When percentage is asked, compute the correct percentage value
You *MUST* perform a `groupby()` operation *BEFORE* sorting or selecting the top N values. The correct order is:
1. Filter the DataFrame (if needed, using boolean indexing).
2. Group by the appropriate column(s) using `.groupby()`.
3. Apply an aggregation function (e.g., `.sum()`, `.mean()`, `.size()`, `.count()`, `.median()`).
4. *Then*, sort (if needed) using `.sort_values()` and/or select the top N (if needed) using `.nlargest()` or `.head()`.
10. **Error Handling:** Assume the CSV file exists and is correctly formatted. You do *not* need to write any explicit error handling code.
11. **Clarity:** Use clear and meaningful variable names if you create intermediate dataframes, but prioritize conciseness.
**Column Usage Guidance:**
13. primele means .nlargest and ultimele means .nsmallest
* Use *Product* when referring to specific items sold (e.g., "most popular product," "top-selling product").
* Use *City* when grouping or summarizing sales by location (e.g., "which city had the highest revenue?").
* Use *Order* Date for any time-based filtering (e.g., "sales in December," "transactions between January and March").
* Use *Sales* for financial aggregations (e.g., total revenue, average sale per transaction).
* Use *Quantity* Ordered when analyzing product demand (e.g., "most sold product in terms of units").
* Use *Hour* to analyze time-based trends (e.g., "which hour has the highest number of purchases?").
**Question:**
{question}
"""
response = client.chat.completions.create(model=model_name,
temperature=0, # Keep temperature at 0 for consistent, deterministic code
messages=[
{"role": "system", "content": "You are a helpful assistant that generates Python code."},
{"role": "user", "content": prompt}
])
code_to_execute = response.choices[0].message.content.strip()
code_to_execute = code_to_execute.replace("```python", "").replace("```", "").strip()
return code_to_execute
def execute_code(generated_code, csv_path):
"""Executes the generated Pandas code and captures the output."""
local_vars = {"pd": pd, "__file__": csv_path}
exec(generated_code, {}, local_vars)
return local_vars.get("result")
def generate_plot_code(question, dataframe, model_name="gpt-4o"):
"""Asks OpenAI to generate plotting code based on the question and dataframe."""
# Convert dataframe to string representation
df_str = dataframe.to_string(index=False)
df_json = dataframe.to_json(orient="records")
prompt = f"""You are a data visualization expert. Create Python code to visualize the data below based on the user's question. The visualizations must comprehensively represent *all* the information returned by the query to effectively answer the question.
**User Question:**
{question}
**Data (first few rows):**
```
{df_str}
```
**Data (JSON format):**
```json
{df_json}
```
**Requirements:**
1. Create 4-7 different, meaningful visualizations that collectively represent all aspects of the data returned by the query, ensuring no key information is omitted.
2. Ensure each visualization is simple, clear, and directly tied to a specific part of the data or question, while together they cover the full scope of the result.
3. Use ONLY Matplotlib and Seaborn (avoid Plotly to prevent compatibility issues).
4. Include proper titles, labels, and legends for clarity, reflecting the specific data being visualized.
5. Use appropriate color schemes that are visually appealing and accessible (e.g., colorblind-friendly palettes like Seaborn's 'colorblind').
6. Return a list of tuples containing the plot title and the base64-encoded image.
7. Make sure to close all plt figures with plt.close() after adding each to the plots list to prevent memory issues.
8. If the data includes categories (e.g., sucursale, produse, pachete), ensure these are fully represented across the plots (e.g., bar charts, pie charts, or grouped visuals).
9. If the data includes numerical values (e.g., sales, totals), use appropriate plot types (e.g., bar, line, or scatter) to show trends, comparisons, or distributions.
10. If the question involves time periods, ensure at least one visualization reflects the temporal aspect using the relevant date information.
11. Put a padding to the plots so there won't be situations like "an Francisco" is displayed instead of "San Francisco".
**Output Format:**
Your code should ONLY include a function called `create_plots(data)` that takes a pandas DataFrame as input and returns a list of tuples containing the plot titles and the base64-encoded images.
Return only the function definition without any explanations, imports, or additional code. Do NOT include any Streamlit-specific code.
"""
response = client.chat.completions.create(model=model_name,
temperature=0.2, # Slightly higher temperature for creative visualizations
messages=[
{"role": "system", "content": "You are a data visualization expert who creates Python code for plotting data."},
{"role": "user", "content": prompt}
])
plot_code = response.choices[0].message.content.strip()
plot_code = plot_code.replace("```python", "").replace("```", "").strip()
return plot_code
def execute_plot_code(plot_code, result_df):
"""Executes the generated plotting code and captures the outputs."""
try:
# Create a dictionary with all the necessary imports
globals_dict = {
"pd": pd,
"plt": plt,
"px": px,
"sns": sns,
"go": go,
"io": io,
"base64": base64,
"np": __import__('numpy'),
"plotly": __import__('plotly')
}
# Create a local variables dictionary with the data
local_vars = {
"data": result_df
}
# Define the helper functions first
helper_code = """
def fig_to_base64(fig):
buf = io.BytesIO()
fig.savefig(buf, format="png", bbox_inches="tight")
buf.seek(0)
img_str = base64.b64encode(buf.getvalue()).decode("utf-8")
buf.close()
return img_str
def plotly_to_base64(fig):
# For Plotly figures, convert to image bytes and then to base64
img_bytes = fig.to_image(format="png", scale=2)
img_str = base64.b64encode(img_bytes).decode("utf-8")
return img_str
"""
# Execute the helper functions first
exec(helper_code, globals_dict, local_vars)
# Then execute the plot code
exec(plot_code, globals_dict, local_vars)
# Get the plots from the create_plots function
if "create_plots" in local_vars:
plots = local_vars["create_plots"](result_df)
return plots
elif "plots" in local_vars:
return local_vars["plots"]
else:
return []
except Exception as e:
st.error(f"Error executing plot code: {str(e)}")
import traceback
st.error(traceback.format_exc())
return []
def sanitize_filename(filename):
return re.sub(r'[^a-zA-Z0-9]', '_', filename)
def generate_pdf(query, response_text, chat_response, plots):
query = html.unescape(query)
response_text = html.unescape(response_text)
escaped_query = html.escape(query)
escaped_response_text = html.escape(response_text)
html_content = f"""
<!DOCTYPE html>
<html lang="ro">
<head>
<title>Data Analysis Report</title>
<meta charset="UTF-8">
<style>
body {{ font-family: Arial, sans-serif; margin: 20px; background-color: #f9f9f9; color: #333; }}
h1 {{ color: #1f77b4; text-align: center; }}
h3 {{ color: #2c3e50; border-bottom: 2px solid #ddd; padding-bottom: 5px; }}
h4 {{ color: #2980b9; }}
p {{ line-height: 1.6; background-color: #fff; padding: 10px; border-radius: 5px; box-shadow: 0 1px 3px rgba(0,0,0,0.1); }}
pre {{ background-color: #ecf0f1; padding: 10px; border-radius: 5px; font-size: 12px; }}
table {{ border-collapse: collapse; width: 100%; margin: 10px 0; page-break-inside: avoid; }}
th, td {{ border: 1px solid #bdc3c7; padding: 10px; text-align: left; }}
th {{ background-color: #3498db; color: white; }}
td {{ background-color: #fff; }}
img {{ max-width: 100%; height: auto; margin: 10px 0; page-break-inside: avoid; }}
.section {{ margin-bottom: 20px; }}
.no-break {{ page-break-inside: avoid; }}
.powered-by {{ text-align: center; margin-top: 20px; font-size: 10px; color: #777; }}
.logo {{ height: 100px; }}
</style>
</head>
<body>
<h1>Data Analysis Agent Interface</h1>
<div class="section no-break"><h3>Query</h3><p>{escaped_query}</p></div>
<div class="section no-break"><h3>Response</h3><p>{escaped_response_text}</p></div>
<div class="section no-break">
<h3>Raw Structured Response</h3>
<h4>Metadata</h4><pre>{json.dumps(chat_response["metadata"], indent=2, ensure_ascii=False)}</pre>
<h4>Data</h4>{pd.DataFrame(chat_response["data"]).to_html(index=False, classes="no-break", escape=False)}
</div>
<div class="section"><h3>Plots</h3>{"".join([f'<div class="no-break"><h4>{name}</h4><img src="data:image/png;base64,{base64}"/></div>' for name, base64 in plots])}</div>
<div class="powered-by">Powered by <img src="data:image/png;base64,{get_zega_logo_base64()}" class="logo"></div>
</body></html>
"""
html_file = "temp.html"
sanitized_query = sanitize_filename(query)
os.makedirs("./exported_pdfs", exist_ok=True)
pdf_file = f"./exported_pdfs/{sanitized_query}.pdf"
try:
with open(html_file, "w", encoding="utf-8") as f:
f.write(html_content)
options = {'encoding': "UTF-8", 'custom-header': [('Content-Type', 'text/html; charset=UTF-8')], 'no-outline': None}
pdfkit.from_file(html_file, pdf_file, options=options)
os.remove(html_file)
except Exception as e:
raise
return pdf_file
def get_zega_logo_base64():
try:
with open("zega_logo.png", "rb") as image_file:
encoded_string = base64.b64encode(image_file.read()).decode("utf-8")
return encoded_string
except Exception as e:
raise
def load_css(file_name):
"""Loads a CSS file and injects it into the Streamlit app."""
try:
css_path = Path(__file__).parent / file_name
with open(css_path) as f:
st.markdown(f'<style>{f.read()}</style>', unsafe_allow_html=True)
# st.info(f"Loaded CSS: {file_name}") # Optional: uncomment for debugging
except FileNotFoundError:
st.error(f"CSS file not found: {file_name}. Make sure it's in the same directory as app.py.")
except Exception as e:
st.error(f"Error loading CSS file {file_name}: {e}")
st.markdown("""
<link rel="preconnect" href="https://fonts.googleapis.com">
<link rel="preconnect" href="https://fonts.gstatic.com" crossorigin>
<link href="https://fonts.googleapis.com/css2?family=Inter+Tight:ital,wght@0,100..900;1,100..900&family=Space+Grotesk:wght@300..700&display=swap" rel="stylesheet">
""", unsafe_allow_html=True)
load_css("style.css")
# Streamlit Interface
st.title("Data Analysis Agent Interface")
st.write(" In this demo, you'll explore a sample CSV file containing sales data—such as product names, regions, dates, and revenue figures. Simply type in natural language questions like “What were the best-selling products in March?” or “How did sales vary by region?” and the tool will instantly generate clear reports and visualizations. No technical skills needed—just ask and explore.")
st.sidebar.header("Sample Questions")
sample_questions = [
"Top 5 cities with the highest sales?",
"Bottom 3 products by total sales?",
"Top 10 products with reference to items sold?",
"Top 10 products with reference to total sums sold?"
]
selected_question = st.sidebar.selectbox("Select a sample question:", sample_questions)
with open(csv_path, "rb") as f:
st.sidebar.download_button(
label="Download CSV",
data=f,
file_name="data.csv",
mime="text/csv"
)
user_query = st.text_area("Please write one question at a time.", value=selected_question, height=100)
def process_query():
try:
if len(user_query.strip()) == 0:
st.error("Please enter a query.")
return
elif not re.match("^[a-zA-Z0-9!?. ]*$", user_query):
st.error("Special characters are not allowed. Please use only letters and numbers.")
return
# Step 1: Generate and execute code to get the data
generated_code = generate_code(user_query, column_info, sample_str, csv_path)
result = execute_code(generated_code, csv_path)
# Convert result to DataFrame if it's not already
if isinstance(result, pd.DataFrame):
result_df = result
elif isinstance(result, pd.Series):
result_df = result.reset_index()
elif isinstance(result, list):
if all(isinstance(item, dict) for item in result):
result_df = pd.DataFrame(result)
else:
result_df = pd.DataFrame({"value": result})
else:
result_df = pd.DataFrame({"value": [result]})
# Step 2: Generate and execute plotting code
plot_code = generate_plot_code(user_query, result_df)
plots = execute_plot_code(plot_code, result_df)
# Prepare the chat response
if isinstance(result, pd.DataFrame):
chat_response = {
"metadata": {"query": user_query, "unit": "", "plot_types": []},
"data": result.to_dict(orient='records'),
"csv_data": result.to_dict(orient='records'),
}
elif isinstance(result, pd.Series):
result = result.reset_index()
chat_response = {
"metadata": {"query": user_query, "unit": "", "plot_types": []},
"data": result.to_dict(orient='records'),
"csv_data": result.to_dict(orient='records'),
}
elif isinstance(result, list):
if all(isinstance(item, (int, float)) for item in result):
chat_response = {
"metadata": {"query": user_query, "unit": "", "plot_types": []},
"data": [{"category": str(i), "value": v} for i, v in enumerate(result)],
"csv_data": [{"category": str(i), "value": v} for i, v in enumerate(result)],
}
elif all(isinstance(item, dict) for item in result):
chat_response = {
"metadata": {"query": user_query, "unit": "", "plot_types": []},
"data": result,
"csv_data": result,
}
else:
st.warning("Result is a list with mixed data types. Please inspect.")
return
else:
chat_response = {
"metadata": {"query": user_query, "unit": "", "plot_types": []},
"data": [{"category": "Result", "value": result}],
"csv_data": [{"category": "Result", "value": result}],
}
# Display the query and data
st.markdown(f"<h3 style='color: #2e86de;'>Question:</h3>", unsafe_allow_html=True)
st.markdown(f"<p style='color: #2e86de;'>{user_query}</p>", unsafe_allow_html=True)
st.write("-" * 200)
# Initially hide the code
with st.expander("Show the generated data code"):
st.code(generated_code, language="python")
with st.expander("Show the generated plotting code"):
st.code(plot_code, language="python")
st.write("-" * 200)
# Display the data
st.markdown("### Data:")
st.dataframe(result_df)
st.write("-" * 200)
# Display the plots
st.markdown("### Visualizations:")
for name, base64_img in plots:
st.markdown(f"#### {name}")
st.markdown(f'<img src="data:image/png;base64,{base64_img}" style="max-width:100%">', unsafe_allow_html=True)
st.write("-" * 100)
# Store the data for PDF generation
st.session_state["query"] = user_query
st.session_state["response_text"] = str(result)
st.session_state["chat_response"] = chat_response
st.session_state["plots"] = plots
st.session_state["generated_code"] = generated_code
st.session_state["plot_code"] = plot_code
except Exception as e:
st.error(f"An error occurred: {e}")
import traceback
st.error(traceback.format_exc())
if st.button("Submit"):
with st.spinner("Processing query..."):
try:
process_query()
except Exception as e:
st.error(f"An error occurred: {e}")
import traceback
st.error(traceback.format_exc())
if "chat_response" in st.session_state:
if st.button("Download PDF"):
with st.spinner("Generating PDF..."):
try:
pdf_file = generate_pdf(
st.session_state["query"],
st.session_state["response_text"],
st.session_state["chat_response"],
st.session_state["plots"]
)
with open(pdf_file, "rb") as f:
pdf_data = f.read()
sanitized_query = sanitize_filename(st.session_state["query"])
st.download_button(
label="Click Here to Download PDF",
data=pdf_data,
file_name=f"{sanitized_query}.pdf",
mime="application/pdf",
)
except Exception as e:
st.error(f"PDF generation failed: {e}")
import streamlit.components.v1 as components
components.html(
"""
<script>
function sendHeightWhenReady() {
const el = window.parent.document.getElementsByClassName('stMain')[0];
if (el) {
const height = el.scrollHeight;
window.parent.parent.postMessage({ type: 'setHeight', height: height }, '*');
} else {
// Retry in 100ms until the element appears
setTimeout(sendHeightWhenReady, 1000);
}
}
window.onload = sendHeightWhenReady;
window.addEventListener('resize', sendHeightWhenReady);
setInterval(sendHeightWhenReady, 1000);
</script>
"""
)