Spaces:
Build error
Build error
| import os | |
| import json | |
| import pandas as pd | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import plotly.graph_objects as go | |
| import plotly.express as px | |
| import traceback | |
| from io import BytesIO | |
| from langchain_groq import ChatGroq | |
| from prompts import ENHANCED_SYSTEM_PROMPT, SAMPLE_QUESTIONS, get_chart_prompt, validate_plot_spec, INSIGHTS_SYSTEM_PROMPT, get_insights_prompt | |
| llm = ChatGroq( | |
| api_key=GROQ_API_KEY, | |
| model="llama-3.3-70b-versatile", | |
| temperature=0.0 | |
| ) | |
| print("GROQ API initialized successfully") | |
| def call_groq(messages): | |
| try: | |
| res = llm.invoke(messages) | |
| return res.content if hasattr(res, "content") else str(res) | |
| except Exception as e: | |
| raise RuntimeError(f"GROQ API error: {e}") | |
| def parse_plan(raw_text): | |
| txt = raw_text.strip().replace("```json", "").replace("```", "").strip() | |
| try: | |
| start = txt.index("{") | |
| end = txt.rindex("}") + 1 | |
| plan = json.loads(txt[start:end]) | |
| plan.setdefault("type", "analysis") | |
| plan.setdefault("operations", []) | |
| plan.setdefault("plot", None) | |
| plan.setdefault("narrative", "") | |
| plan.setdefault("insights_needed", False) | |
| return plan | |
| except Exception as e: | |
| return { | |
| "type": "error", | |
| "operations": [], | |
| "plot": None, | |
| "narrative": f"Error parsing JSON: {str(e)}", | |
| "insights_needed": False | |
| } | |
| def clean_numeric(df): | |
| df = df.copy() | |
| for col in df.columns: | |
| if pd.api.types.is_string_dtype(df[col]) or df[col].dtype == object: | |
| s = df[col].astype(str).str.strip() | |
| if s.str.contains("%", na=False).any(): | |
| numeric_vals = pd.to_numeric(s.str.replace("%", "", regex=False), errors="coerce") | |
| if numeric_vals.notna().sum() / len(df) > 0.5: | |
| df[col] = numeric_vals / 100.0 | |
| continue | |
| cleaned = s.str.replace(",", "", regex=False).str.replace("₹", "", regex=False).str.replace("$", "", regex=False) | |
| numeric_vals = pd.to_numeric(cleaned, errors="coerce") | |
| if numeric_vals.notna().sum() / len(df) > 0.5: | |
| df[col] = numeric_vals | |
| return df | |
| def generate_insights(df, dfw, plan, plot_created): | |
| context_parts = [] | |
| for op in plan.get("operations", []): | |
| if op.get("op") == "describe": | |
| cols = op.get("columns", []) | |
| for col in cols: | |
| if col in df.columns: | |
| desc = df[col].describe() | |
| context_parts.append(f"\n{col} Statistics:\n{desc.to_string()}") | |
| elif op.get("op") == "groupby": | |
| context_parts.append(f"\nGrouped Results:\n{dfw.head(10).to_string()}") | |
| plot_spec = plan.get("plot") | |
| if plot_created and plot_spec: | |
| context_parts.append(f"\nChart Type: {plot_spec.get('type')}") | |
| context_parts.append(f"Visualization: {plot_spec.get('title')}") | |
| if len(dfw) > 0: | |
| context_parts.append(f"\nResult Preview:\n{dfw.head(10).to_string()}") | |
| insights_prompt = get_insights_prompt(context_parts, plan.get('narrative', '')) | |
| try: | |
| insights_response = call_groq([ | |
| {"role": "system", "content": INSIGHTS_SYSTEM_PROMPT}, | |
| {"role": "user", "content": insights_prompt} | |
| ]) | |
| return insights_response.strip() | |
| except Exception as e: | |
| return f"Analysis completed successfully\n{len(dfw)} records in result\nError generating detailed insights: {str(e)}" | |
| def execute_plan(df, plan): | |
| dfw = df.copy() | |
| plot_bytes = None | |
| plot_html = None | |
| describe_stats = {} | |
| try: | |
| for op in plan.get("operations", []): | |
| optype = op.get("op", "").lower() | |
| if optype == "describe": | |
| cols = op.get("columns", []) | |
| for col in cols: | |
| if col in dfw.columns: | |
| stats = dfw[col].describe() | |
| describe_stats[col] = stats | |
| print(f"Described {col}") | |
| print(f"\n{stats}\n") | |
| continue | |
| elif optype == "groupby": | |
| cols = op.get("columns", []) | |
| agg_col = op.get("agg_col") | |
| agg_func = op.get("agg_func", "count") | |
| if not cols: | |
| raise ValueError("No columns specified for groupby") | |
| if agg_func == "count" or not agg_col: | |
| dfw = dfw.groupby(cols).size().reset_index(name="count") | |
| print(f"Grouped by {cols} with count") | |
| else: | |
| if agg_col not in dfw.columns: | |
| raise ValueError(f"Column '{agg_col}' not found for aggregation") | |
| result_col = f"{agg_func}_{agg_col}" | |
| dfw = dfw.groupby(cols)[agg_col].agg(agg_func).reset_index(name=result_col) | |
| print(f"Grouped by {cols}, calculated {agg_func} of {agg_col}") | |
| elif optype == "filter": | |
| expr = op.get("expr", "") | |
| if expr: | |
| dfw = dfw.query(expr) | |
| print(f"Filter applied: {expr}") | |
| elif optype == "calculate": | |
| expr = op.get("expr", "") | |
| new_col = op.get("new_col", "Calculated") | |
| dfw[new_col] = dfw.eval(expr) | |
| print(f"Calculated {new_col} = {expr}") | |
| plot_spec = plan.get("plot") | |
| if plot_spec and plot_spec is not None: | |
| ptype = plot_spec.get("type", "bar") | |
| x = plot_spec.get("x") | |
| y = plot_spec.get("y") | |
| title = plot_spec.get("title", "Chart") | |
| plot_df = df if describe_stats else dfw | |
| if not x and len(plot_df.columns) > 0: | |
| categorical_cols = plot_df.select_dtypes(include=['object', 'category']).columns | |
| x = categorical_cols[0] if len(categorical_cols) > 0 else plot_df.columns[0] | |
| if not y: | |
| numeric_cols = plot_df.select_dtypes(include=[np.number]).columns | |
| y = numeric_cols[0] if len(numeric_cols) > 0 else None | |
| if not y: | |
| print("No suitable Y column found for plotting.") | |
| else: | |
| if ptype == "pie": | |
| if x and x in plot_df.columns: | |
| value_counts = plot_df[x].value_counts() | |
| fig = go.Figure(data=[go.Pie( | |
| labels=value_counts.index, | |
| values=value_counts.values, | |
| hovertemplate='<b>%{label}</b><br>Count: %{value}<br>Percentage: %{percent}<extra></extra>', | |
| textposition='auto', | |
| hole=0.3 | |
| )]) | |
| else: | |
| df_pie = plot_df[y].value_counts() | |
| fig = go.Figure(data=[go.Pie( | |
| labels=df_pie.index, | |
| values=df_pie.values, | |
| hole=0.3 | |
| )]) | |
| fig.update_layout( | |
| title=title, | |
| title_font_size=16, | |
| showlegend=True, | |
| width=950, | |
| height=550 | |
| ) | |
| plot_html = fig.to_html(include_plotlyjs='cdn') | |
| print("Enhanced pie chart generated") | |
| elif ptype == "bar": | |
| fig, ax = plt.subplots(figsize=(12, 7)) | |
| if x and x in plot_df.columns and y and y in plot_df.columns: | |
| plot_df.plot.bar(x=x, y=y, ax=ax, legend=False, color='steelblue', edgecolor='black', alpha=0.8) | |
| ax.set_xlabel(x, fontsize=12, fontweight='bold') | |
| n_categories = len(plot_df[x].unique()) | |
| if n_categories > 10: | |
| plt.xticks(rotation=90, ha='right', fontsize=9) | |
| elif n_categories > 5: | |
| plt.xticks(rotation=45, ha='right', fontsize=10) | |
| else: | |
| plt.xticks(rotation=0, fontsize=10) | |
| else: | |
| plot_df[y].plot.bar(ax=ax, color='steelblue', edgecolor='black', alpha=0.8) | |
| ax.set_title(title, fontsize=16, fontweight='bold', pad=20) | |
| ax.set_ylabel(y, fontsize=12, fontweight='bold') | |
| ax.grid(axis='y', alpha=0.3, linestyle='--') | |
| plt.tight_layout() | |
| buf = BytesIO() | |
| plt.savefig(buf, format="png", dpi=150, bbox_inches="tight") | |
| buf.seek(0) | |
| plot_bytes = buf.read() | |
| plt.close() | |
| print("Enhanced bar chart generated") | |
| elif ptype == "line": | |
| fig, ax = plt.subplots(figsize=(12, 7)) | |
| if x and x in plot_df.columns and y and y in plot_df.columns: | |
| plot_df.plot.line(x=x, y=y, ax=ax, marker="o", linewidth=3, | |
| markersize=8, color='darkblue', alpha=0.8) | |
| ax.set_xlabel(x, fontsize=12, fontweight='bold') | |
| if len(plot_df) > 15: | |
| plt.xticks(rotation=45, ha='right', fontsize=9) | |
| else: | |
| plt.xticks(rotation=0, fontsize=10) | |
| else: | |
| plot_df[y].plot.line(ax=ax, marker="o", linewidth=3, | |
| markersize=8, color='darkblue', alpha=0.8) | |
| ax.set_title(title, fontsize=16, fontweight='bold', pad=20) | |
| ax.set_ylabel(y, fontsize=12, fontweight='bold') | |
| ax.grid(True, alpha=0.3, linestyle='--') | |
| plt.tight_layout() | |
| buf = BytesIO() | |
| plt.savefig(buf, format="png", dpi=150, bbox_inches="tight") | |
| buf.seek(0) | |
| plot_bytes = buf.read() | |
| plt.close() | |
| print("Enhanced line chart generated") | |
| elif ptype == "hist": | |
| fig, ax = plt.subplots(figsize=(11, 7)) | |
| plot_df[y].dropna().plot.hist(ax=ax, bins=25, edgecolor='black', | |
| alpha=0.7, color='teal') | |
| ax.set_title(title, fontsize=16, fontweight='bold', pad=20) | |
| ax.set_xlabel(y, fontsize=12, fontweight='bold') | |
| ax.set_ylabel("Frequency", fontsize=12, fontweight='bold') | |
| ax.grid(axis='y', alpha=0.3, linestyle='--') | |
| plt.tight_layout() | |
| buf = BytesIO() | |
| plt.savefig(buf, format="png", dpi=150, bbox_inches="tight") | |
| buf.seek(0) | |
| plot_bytes = buf.read() | |
| plt.close() | |
| print("Enhanced histogram generated") | |
| elif ptype == "scatter": | |
| fig, ax = plt.subplots(figsize=(11, 7)) | |
| if x and x in plot_df.columns and y and y in plot_df.columns: | |
| plot_df.plot.scatter(x=x, y=y, ax=ax, alpha=0.6, s=60, color='purple') | |
| ax.set_xlabel(x, fontsize=12, fontweight='bold') | |
| ax.set_ylabel(y, fontsize=12, fontweight='bold') | |
| ax.set_title(title, fontsize=16, fontweight='bold', pad=20) | |
| ax.grid(True, alpha=0.3, linestyle='--') | |
| plt.tight_layout() | |
| buf = BytesIO() | |
| plt.savefig(buf, format="png", dpi=150, bbox_inches="tight") | |
| buf.seek(0) | |
| plot_bytes = buf.read() | |
| plt.close() | |
| print("Enhanced scatter plot generated") | |
| return dfw, plot_bytes, plot_html, describe_stats | |
| except Exception as e: | |
| print(f"EXECUTION ERROR: {e}") | |
| traceback.print_exc() | |
| raise | |
| def make_context(df): | |
| sample_data = df.head(3).to_string(max_cols=10, max_colwidth=20) | |
| return f"""Dataset: {len(df)} rows, {len(df.columns)} columns | |
| Columns: {', '.join(df.columns)} | |
| Data types: {df.dtypes.value_counts().to_dict()} | |
| Sample data: | |
| {sample_data}""" | |
| def load_file(file_path): | |
| if file_path.endswith('.csv'): | |
| return pd.read_csv(file_path) | |
| elif file_path.endswith(('.xlsx', '.xls')): | |
| return pd.read_excel(file_path) | |
| else: | |
| raise ValueError("Unsupported file format. Please use CSV or Excel files.") | |
| def start_agent(): | |
| print("=" * 80) | |
| print("SparkNova v5.0 – Advanced Data Analysis & Visualization") | |
| print("=" * 80) | |
| df = None | |
| while True: | |
| if df is None: | |
| file_path = input("\nEnter file path (CSV or Excel): ").strip() | |
| if not file_path: | |
| continue | |
| try: | |
| df = load_file(file_path) | |
| df = clean_numeric(df) | |
| print(f"Loaded {file_path} ({len(df)} rows × {len(df.columns)} cols)") | |
| print("\nFirst 5 rows:") | |
| print(df.head()) | |
| print(f"\nColumn types:\n{df.dtypes}") | |
| print("\nSample Questions You Can Ask:") | |
| for i, question in enumerate(SAMPLE_QUESTIONS[:8], 1): | |
| print(f"{i}. {question}") | |
| data_ctx = make_context(df) | |
| except Exception as e: | |
| print(f"Error loading file: {e}") | |
| continue | |
| q = input("\nYour question (or 'exit'/'reload'): ").strip() | |
| if not q: | |
| continue | |
| if q.lower() in ("exit", "quit"): | |
| print("Thank you for using SparkNova!") | |
| break | |
| if q.lower() == "reload": | |
| df = None | |
| continue | |
| enhanced_prompt = get_chart_prompt(q, df.columns.tolist(), df.head(3).to_string()) | |
| try: | |
| raw = call_groq([ | |
| {"role": "system", "content": ENHANCED_SYSTEM_PROMPT}, | |
| {"role": "user", "content": enhanced_prompt} | |
| ]) | |
| except Exception as e: | |
| print(f"LLM call failed: {e}") | |
| continue | |
| plan = parse_plan(raw) | |
| if plan.get("type") == "explain": | |
| print("\nExplanation:") | |
| print(plan.get("narrative", "")) | |
| continue | |
| if plan.get("type") == "error": | |
| print("\nError:") | |
| print(plan.get("narrative", "")) | |
| continue | |
| print("\nAnalysis Plan:") | |
| print(json.dumps(plan, indent=2)) | |
| if plan.get("plot"): | |
| plan["plot"] = validate_plot_spec(plan["plot"], df.columns.tolist()) | |
| try: | |
| print("\nExecuting operations...") | |
| res, plot_img, plot_html, desc_stats = execute_plan(df, plan) | |
| if not desc_stats or len(res) != len(df): | |
| print("\nResult:") | |
| print(res.head(20)) | |
| if plot_html: | |
| print("\nGenerated Interactive Chart (HTML saved as chart.html)") | |
| with open("chart.html", "w") as f: | |
| f.write(plot_html) | |
| elif plot_img: | |
| print("\nGenerated Chart (saved as chart.png)") | |
| with open("chart.png", "wb") as f: | |
| f.write(plot_img) | |
| narrative = plan.get("narrative", "") | |
| if narrative: | |
| print(f"\nSummary: {narrative}") | |
| if plan.get("insights_needed") and (plot_html or plot_img): | |
| print("\nDetailed Insights:") | |
| insights = generate_insights(df, res, plan, True) | |
| print(insights) | |
| except Exception as e: | |
| print(f"Execution failed: {e}") | |
| continue | |
| if __name__ == "__main__": | |
| start_agent() |