|
|
|
|
|
|
|
|
import os |
|
|
import io |
|
|
import re |
|
|
import gradio as gr |
|
|
import pandas as pd |
|
|
import matplotlib.pyplot as plt |
|
|
from PIL import Image |
|
|
import traceback |
|
|
|
|
|
from groq import Groq |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
GROQ_API_KEY = (os.getenv("GROQ_API_KEY") or "").strip() |
|
|
GROQ_MODEL = (os.getenv("GROQ_MODEL") or "llama-3.3-70b-versatile").strip() |
|
|
|
|
|
if not GROQ_API_KEY: |
|
|
raise RuntimeError("Falta GROQ_API_KEY en Secrets del Space.") |
|
|
|
|
|
groq_client = Groq(api_key=GROQ_API_KEY) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def load_file(file): |
|
|
"""Load a CSV or Excel file into a pandas DataFrame (robusto para Gradio/HF).""" |
|
|
if file is None: |
|
|
return None |
|
|
|
|
|
try: |
|
|
|
|
|
file_name = None |
|
|
file_path = None |
|
|
file_bytes = None |
|
|
|
|
|
if isinstance(file, dict): |
|
|
file_name = (file.get("name") or "").lower() |
|
|
|
|
|
|
|
|
if file.get("path"): |
|
|
file_path = file["path"] |
|
|
elif file.get("data") is not None: |
|
|
d = file["data"] |
|
|
|
|
|
if isinstance(d, str): |
|
|
file_path = d |
|
|
elif isinstance(d, (bytes, bytearray)): |
|
|
file_bytes = bytes(d) |
|
|
else: |
|
|
file_path = str(d) |
|
|
else: |
|
|
|
|
|
file_name = (getattr(file, "name", "") or "").lower() |
|
|
file_path = getattr(file, "name", None) |
|
|
|
|
|
if not file_name: |
|
|
if file_path: |
|
|
file_name = str(file_path).lower() |
|
|
else: |
|
|
file_name = "" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if file_name.endswith(".csv") or (file_name and ".csv" in file_name): |
|
|
if file_bytes is not None: |
|
|
bio = io.BytesIO(file_bytes) |
|
|
try: |
|
|
return pd.read_csv(bio, sep=None, engine="python") |
|
|
except Exception: |
|
|
bio.seek(0) |
|
|
return pd.read_csv(bio, sep=None, engine="python", encoding="latin-1") |
|
|
|
|
|
if file_path: |
|
|
try: |
|
|
return pd.read_csv(file_path, sep=None, engine="python") |
|
|
except Exception: |
|
|
return pd.read_csv(file_path, sep=None, engine="python", encoding="latin-1") |
|
|
|
|
|
return None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if file_name.endswith(".xlsx"): |
|
|
if file_bytes is not None: |
|
|
return pd.read_excel(io.BytesIO(file_bytes), engine="openpyxl") |
|
|
if file_path: |
|
|
return pd.read_excel(file_path, engine="openpyxl") |
|
|
return None |
|
|
|
|
|
if file_name.endswith(".xls"): |
|
|
if file_bytes is not None: |
|
|
return pd.read_excel(io.BytesIO(file_bytes), engine="xlrd") |
|
|
if file_path: |
|
|
return pd.read_excel(file_path, engine="xlrd") |
|
|
return None |
|
|
|
|
|
return None |
|
|
|
|
|
except Exception as e: |
|
|
print("Error load_file:", repr(e)) |
|
|
return None |
|
|
|
|
|
|
|
|
def preview_file(file): |
|
|
"""Return the DataFrame for preview.""" |
|
|
df = load_file(file) |
|
|
if df is None: |
|
|
return pd.DataFrame({"Error": ["Error loading file or unsupported file type."]}) |
|
|
return df |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_basic_understanding_code(df_preview): |
|
|
prompt = f""" |
|
|
You are a data analysis expert. Write Python code that performs an exploratory analysis of the DataFrame. |
|
|
Assume a pandas DataFrame named 'df' is already loaded. |
|
|
Output only raw Python code without any markdown formatting or code fences. |
|
|
Assign the exploratory summary to a variable named 'basic_info' as a dictionary. |
|
|
For each column in df, include its data type. |
|
|
- For numeric columns (use pd.api.types.is_numeric_dtype), include summary statistics (mean, median, std, etc.). |
|
|
- For non-numeric columns, treat them as categorical variables and include counts, unique values, mode, and frequency distributions. |
|
|
When converting date strings to datetime, use pd.to_datetime() without a fixed format or with dayfirst=True. |
|
|
If your analysis includes charts, call plt.show() after each chart so they can be captured. |
|
|
Only reference columns that are present in df.columns. |
|
|
Note: The following safe built-ins are available: list, dict, set, tuple, abs, min, max, sum, len, range, print, pd, plt, __import__. |
|
|
DataFrame preview: |
|
|
Columns: {list(df_preview.columns)} |
|
|
Sample Data (first 3 rows): |
|
|
{df_preview.head(3).to_dict(orient='records')} |
|
|
""" |
|
|
response = groq_client.chat.completions.create( |
|
|
model=GROQ_MODEL, |
|
|
messages=[ |
|
|
{"role": "system", "content": "You are an expert data analysis assistant who outputs only raw Python code."}, |
|
|
{"role": "user", "content": prompt}, |
|
|
], |
|
|
temperature=0.3, |
|
|
max_tokens=3500, |
|
|
) |
|
|
return (response.choices[0].message.content or "").strip() |
|
|
|
|
|
|
|
|
def generate_problem_solving_code(nl_query, df_preview, basic_info): |
|
|
prompt = f""" |
|
|
You are a data analysis expert. Write Python code that performs the analysis as described below. |
|
|
Assume a pandas DataFrame named 'df' is already loaded and that you have already generated an exploratory summary stored in 'basic_info'. |
|
|
Output only raw Python code without any markdown formatting or code fences. |
|
|
Ensure that the final output is assigned to a variable named 'result' as a dictionary with the following keys: 'summary', 'detailed_stats', 'insights', and 'chart_descriptions'. The analysis should be verbose and include all relevant statistics, interpretations, and intermediate steps. |
|
|
When processing the DataFrame, first inspect each column’s data type: |
|
|
- For numeric columns (use pd.api.types.is_numeric_dtype), compute numeric statistics (mean, median, standard deviation, etc.). |
|
|
- For non-numeric columns, treat them as categorical variables and compute appropriate descriptive statistics (counts, unique values, mode, and frequency distributions). |
|
|
- Only generate charts and tables that are relevant to the problem at hand. Exclude fields that are not relevant to the problem from the charts and tables. |
|
|
Incorporate insights from 'basic_info' if relevant. |
|
|
When converting date strings to datetime, use pd.to_datetime() without a fixed format or with dayfirst=True. |
|
|
If your analysis includes charts, call plt.show() after each chart so they can be captured. |
|
|
Only reference columns that are present in df.columns. |
|
|
Note: The following safe built-ins are available: list, dict, set, tuple, abs, min, max, sum, len, range, print, pd, plt, __import__. |
|
|
DataFrame preview: |
|
|
Columns: {list(df_preview.columns)} |
|
|
Sample Data (first 3 rows): |
|
|
{df_preview.head(3).to_dict(orient='records')} |
|
|
User Query: "{nl_query}" |
|
|
""" |
|
|
response = groq_client.chat.completions.create( |
|
|
model=GROQ_MODEL, |
|
|
messages=[ |
|
|
{"role": "system", "content": "You are an expert data analysis assistant who outputs only raw Python code."}, |
|
|
{"role": "user", "content": prompt}, |
|
|
], |
|
|
temperature=0.3, |
|
|
max_tokens=3500, |
|
|
) |
|
|
return (response.choices[0].message.content or "").strip() |
|
|
|
|
|
|
|
|
def validate_generated_code(code, df): |
|
|
pattern = re.compile(r"df\[['\"]([^'\"]+)['\"]\]") |
|
|
referenced_cols = pattern.findall(code) |
|
|
missing_cols = [col for col in referenced_cols if col not in df.columns] |
|
|
if missing_cols: |
|
|
return False, missing_cols |
|
|
return True, [] |
|
|
|
|
|
|
|
|
def safe_exec_code(code, df, capture_charts=True, interactive=False, extra_globals=None): |
|
|
|
|
|
code_lines = code.splitlines() |
|
|
clean_lines = [line for line in code_lines if not line.strip().startswith("```")] |
|
|
clean_code = "\n".join(clean_lines).strip() |
|
|
|
|
|
valid, missing_cols = validate_generated_code(clean_code, df) |
|
|
if not valid: |
|
|
return (f"Generated code references missing columns: {missing_cols}\nPlease adjust your prompt or data.", []) |
|
|
|
|
|
safe_builtins = { |
|
|
"abs": abs, "min": min, "max": max, "sum": sum, "len": len, "range": range, "print": print, |
|
|
"list": list, "dict": dict, "set": set, "tuple": tuple, "sorted": sorted, "zip": zip, |
|
|
"enumerate": enumerate, "pd": pd, "plt": plt, "str": str, "float": float, "int": int, |
|
|
"bool": bool, "complex": complex, "round": round, "__import__": __import__, |
|
|
} |
|
|
safe_globals = {"__builtins__": safe_builtins, "df": df, "plt": plt, "charts": []} |
|
|
|
|
|
try: |
|
|
import seaborn as sns |
|
|
safe_globals["sns"] = sns |
|
|
except ImportError: |
|
|
pass |
|
|
|
|
|
if extra_globals is not None: |
|
|
safe_globals.update(extra_globals) |
|
|
safe_locals = {} |
|
|
|
|
|
if capture_charts: |
|
|
def custom_show(*args, **kwargs): |
|
|
buf = io.BytesIO() |
|
|
plt.savefig(buf, format="png") |
|
|
buf.seek(0) |
|
|
img = Image.open(buf).convert("RGB") |
|
|
safe_globals["charts"].append(img) |
|
|
plt.close() |
|
|
safe_globals["plt"].show = custom_show |
|
|
|
|
|
try: |
|
|
exec(clean_code, safe_globals, safe_locals) |
|
|
output = safe_locals.get("result", None) |
|
|
if output is None: |
|
|
output = safe_locals.get("basic_info", None) |
|
|
except Exception: |
|
|
error_details = traceback.format_exc() |
|
|
if "ValueError: time data" in error_details: |
|
|
error_details += "\nHint: Use pd.to_datetime() without fixed format or with dayfirst=True." |
|
|
if "KeyError" in error_details: |
|
|
error_details += "\nHint: The generated code might be referencing columns that do not exist." |
|
|
if "NameError" in error_details: |
|
|
error_details += "\nHint: Ensure all required built-ins are included." |
|
|
return f"An error occurred during code execution:\n{error_details}", safe_globals["charts"] |
|
|
|
|
|
if capture_charts and not safe_globals["charts"]: |
|
|
fig_nums = plt.get_fignums() |
|
|
for num in fig_nums: |
|
|
fig = plt.figure(num) |
|
|
buf = io.BytesIO() |
|
|
fig.savefig(buf, format="png") |
|
|
buf.seek(0) |
|
|
img = Image.open(buf).convert("RGB") |
|
|
safe_globals["charts"].append(img) |
|
|
plt.close("all") |
|
|
|
|
|
if interactive: |
|
|
for img in safe_globals["charts"]: |
|
|
img.show() |
|
|
|
|
|
if output is None: |
|
|
output = "No output variable ('result' or 'basic_info') was set by the code." |
|
|
return output, safe_globals["charts"] |
|
|
|
|
|
|
|
|
def generate_interpretation(analysis_result, nl_query): |
|
|
prompt = f""" |
|
|
You are a knowledgeable data analyst. Based on the following analysis result and the user's query, provide a detailed interpretation and descriptive analysis of the results. Explain what the results mean, any insights that can be drawn, and any potential limitations. |
|
|
Please format your output in markdown (including headers, bullet points, and other markdown formatting as appropriate). |
|
|
User Query: "{nl_query}" |
|
|
Analysis Result: |
|
|
{analysis_result} |
|
|
Provide a clear and detailed explanation in plain language. |
|
|
""" |
|
|
response = groq_client.chat.completions.create( |
|
|
model=GROQ_MODEL, |
|
|
messages=[ |
|
|
{"role": "system", "content": "You are an expert data analysis assistant who explains analysis results clearly."}, |
|
|
{"role": "user", "content": prompt}, |
|
|
], |
|
|
temperature=0.5, |
|
|
max_tokens=5000, |
|
|
) |
|
|
return (response.choices[0].message.content or "").strip() |
|
|
|
|
|
|
|
|
def generate_and_run(nl_query, file, interactive_mode=False): |
|
|
df = load_file(file) |
|
|
if df is None: |
|
|
|
|
|
return "Error loading file.", "", pd.DataFrame({"Error": ["No data available."]}), [], "" |
|
|
|
|
|
df_preview = df.copy() |
|
|
|
|
|
|
|
|
basic_code = generate_basic_understanding_code(df_preview) |
|
|
basic_info, basic_charts = safe_exec_code(basic_code, df, capture_charts=False, interactive=interactive_mode) |
|
|
|
|
|
|
|
|
problem_code = generate_problem_solving_code(nl_query, df_preview, basic_info) |
|
|
result, problem_charts = safe_exec_code( |
|
|
problem_code, df, capture_charts=True, interactive=interactive_mode, extra_globals={"basic_info": basic_info} |
|
|
) |
|
|
|
|
|
interpretation = generate_interpretation(result, nl_query) |
|
|
|
|
|
|
|
|
combined_code_hidden = "" |
|
|
|
|
|
combined_charts = basic_charts + problem_charts |
|
|
return result, combined_code_hidden, df_preview, combined_charts, interpretation |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks() as demo: |
|
|
gr.Markdown("## Dynamic Data Analysis with Two-Step Code Generation and Interpretation") |
|
|
|
|
|
with gr.Tab("Data Upload & Preview"): |
|
|
file_input = gr.File(label="Upload CSV or Excel file (.csv, .xls, .xlsx)") |
|
|
data_preview = gr.Dataframe(label="Data Preview") |
|
|
file_input.change(fn=preview_file, inputs=file_input, outputs=data_preview) |
|
|
|
|
|
with gr.Tab("Generate & Execute Analysis (Gradio Mode)"): |
|
|
nl_query = gr.Textbox( |
|
|
label="Enter your query", |
|
|
placeholder="e.g., Generate summary statistics and charts for Gender and Age distributions", |
|
|
) |
|
|
generate_btn = gr.Button("Generate & Execute Code") |
|
|
analysis_output = gr.Textbox(label="Analysis Result", lines=10) |
|
|
|
|
|
|
|
|
|
|
|
code_output = gr.Code(label="Generated Code", language="python", visible=False) |
|
|
|
|
|
preview_output = gr.Dataframe(label="Data Preview") |
|
|
charts_output = gr.Gallery(label="Charts", show_label=True) |
|
|
interpretation_output = gr.Markdown(label="Interpretation") |
|
|
|
|
|
generate_btn.click( |
|
|
fn=lambda query, file: generate_and_run(query, file, interactive_mode=True), |
|
|
inputs=[nl_query, file_input], |
|
|
outputs=[analysis_output, code_output, preview_output, charts_output, interpretation_output], |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch() |
|
|
|