import os import json import pandas as pd import numpy as np import matplotlib.pyplot as plt import plotly.graph_objects as go import plotly.express as px import traceback from io import BytesIO from langchain_groq import ChatGroq from prompts import ENHANCED_SYSTEM_PROMPT, SAMPLE_QUESTIONS, get_chart_prompt, validate_plot_spec, INSIGHTS_SYSTEM_PROMPT, get_insights_prompt llm = ChatGroq( api_key=GROQ_API_KEY, model="llama-3.3-70b-versatile", temperature=0.0 ) print("GROQ API initialized successfully") def call_groq(messages): try: res = llm.invoke(messages) return res.content if hasattr(res, "content") else str(res) except Exception as e: raise RuntimeError(f"GROQ API error: {e}") def parse_plan(raw_text): txt = raw_text.strip().replace("```json", "").replace("```", "").strip() try: start = txt.index("{") end = txt.rindex("}") + 1 plan = json.loads(txt[start:end]) plan.setdefault("type", "analysis") plan.setdefault("operations", []) plan.setdefault("plot", None) plan.setdefault("narrative", "") plan.setdefault("insights_needed", False) return plan except Exception as e: return { "type": "error", "operations": [], "plot": None, "narrative": f"Error parsing JSON: {str(e)}", "insights_needed": False } def clean_numeric(df): df = df.copy() for col in df.columns: if pd.api.types.is_string_dtype(df[col]) or df[col].dtype == object: s = df[col].astype(str).str.strip() if s.str.contains("%", na=False).any(): numeric_vals = pd.to_numeric(s.str.replace("%", "", regex=False), errors="coerce") if numeric_vals.notna().sum() / len(df) > 0.5: df[col] = numeric_vals / 100.0 continue cleaned = s.str.replace(",", "", regex=False).str.replace("₹", "", regex=False).str.replace("$", "", regex=False) numeric_vals = pd.to_numeric(cleaned, errors="coerce") if numeric_vals.notna().sum() / len(df) > 0.5: df[col] = numeric_vals return df def generate_insights(df, dfw, plan, plot_created): context_parts = [] for op in plan.get("operations", []): if op.get("op") == "describe": cols = op.get("columns", []) for col in cols: if col in df.columns: desc = df[col].describe() context_parts.append(f"\n{col} Statistics:\n{desc.to_string()}") elif op.get("op") == "groupby": context_parts.append(f"\nGrouped Results:\n{dfw.head(10).to_string()}") plot_spec = plan.get("plot") if plot_created and plot_spec: context_parts.append(f"\nChart Type: {plot_spec.get('type')}") context_parts.append(f"Visualization: {plot_spec.get('title')}") if len(dfw) > 0: context_parts.append(f"\nResult Preview:\n{dfw.head(10).to_string()}") insights_prompt = get_insights_prompt(context_parts, plan.get('narrative', '')) try: insights_response = call_groq([ {"role": "system", "content": INSIGHTS_SYSTEM_PROMPT}, {"role": "user", "content": insights_prompt} ]) return insights_response.strip() except Exception as e: return f"Analysis completed successfully\n{len(dfw)} records in result\nError generating detailed insights: {str(e)}" def execute_plan(df, plan): dfw = df.copy() plot_bytes = None plot_html = None describe_stats = {} try: for op in plan.get("operations", []): optype = op.get("op", "").lower() if optype == "describe": cols = op.get("columns", []) for col in cols: if col in dfw.columns: stats = dfw[col].describe() describe_stats[col] = stats print(f"Described {col}") print(f"\n{stats}\n") continue elif optype == "groupby": cols = op.get("columns", []) agg_col = op.get("agg_col") agg_func = op.get("agg_func", "count") if not cols: raise ValueError("No columns specified for groupby") if agg_func == "count" or not agg_col: dfw = dfw.groupby(cols).size().reset_index(name="count") print(f"Grouped by {cols} with count") else: if agg_col not in dfw.columns: raise ValueError(f"Column '{agg_col}' not found for aggregation") result_col = f"{agg_func}_{agg_col}" dfw = dfw.groupby(cols)[agg_col].agg(agg_func).reset_index(name=result_col) print(f"Grouped by {cols}, calculated {agg_func} of {agg_col}") elif optype == "filter": expr = op.get("expr", "") if expr: dfw = dfw.query(expr) print(f"Filter applied: {expr}") elif optype == "calculate": expr = op.get("expr", "") new_col = op.get("new_col", "Calculated") dfw[new_col] = dfw.eval(expr) print(f"Calculated {new_col} = {expr}") plot_spec = plan.get("plot") if plot_spec and plot_spec is not None: ptype = plot_spec.get("type", "bar") x = plot_spec.get("x") y = plot_spec.get("y") title = plot_spec.get("title", "Chart") plot_df = df if describe_stats else dfw if not x and len(plot_df.columns) > 0: categorical_cols = plot_df.select_dtypes(include=['object', 'category']).columns x = categorical_cols[0] if len(categorical_cols) > 0 else plot_df.columns[0] if not y: numeric_cols = plot_df.select_dtypes(include=[np.number]).columns y = numeric_cols[0] if len(numeric_cols) > 0 else None if not y: print("No suitable Y column found for plotting.") else: if ptype == "pie": if x and x in plot_df.columns: value_counts = plot_df[x].value_counts() fig = go.Figure(data=[go.Pie( labels=value_counts.index, values=value_counts.values, hovertemplate='%{label}
Count: %{value}
Percentage: %{percent}', textposition='auto', hole=0.3 )]) else: df_pie = plot_df[y].value_counts() fig = go.Figure(data=[go.Pie( labels=df_pie.index, values=df_pie.values, hole=0.3 )]) fig.update_layout( title=title, title_font_size=16, showlegend=True, width=950, height=550 ) plot_html = fig.to_html(include_plotlyjs='cdn') print("Enhanced pie chart generated") elif ptype == "bar": fig, ax = plt.subplots(figsize=(12, 7)) if x and x in plot_df.columns and y and y in plot_df.columns: plot_df.plot.bar(x=x, y=y, ax=ax, legend=False, color='steelblue', edgecolor='black', alpha=0.8) ax.set_xlabel(x, fontsize=12, fontweight='bold') n_categories = len(plot_df[x].unique()) if n_categories > 10: plt.xticks(rotation=90, ha='right', fontsize=9) elif n_categories > 5: plt.xticks(rotation=45, ha='right', fontsize=10) else: plt.xticks(rotation=0, fontsize=10) else: plot_df[y].plot.bar(ax=ax, color='steelblue', edgecolor='black', alpha=0.8) ax.set_title(title, fontsize=16, fontweight='bold', pad=20) ax.set_ylabel(y, fontsize=12, fontweight='bold') ax.grid(axis='y', alpha=0.3, linestyle='--') plt.tight_layout() buf = BytesIO() plt.savefig(buf, format="png", dpi=150, bbox_inches="tight") buf.seek(0) plot_bytes = buf.read() plt.close() print("Enhanced bar chart generated") elif ptype == "line": fig, ax = plt.subplots(figsize=(12, 7)) if x and x in plot_df.columns and y and y in plot_df.columns: plot_df.plot.line(x=x, y=y, ax=ax, marker="o", linewidth=3, markersize=8, color='darkblue', alpha=0.8) ax.set_xlabel(x, fontsize=12, fontweight='bold') if len(plot_df) > 15: plt.xticks(rotation=45, ha='right', fontsize=9) else: plt.xticks(rotation=0, fontsize=10) else: plot_df[y].plot.line(ax=ax, marker="o", linewidth=3, markersize=8, color='darkblue', alpha=0.8) ax.set_title(title, fontsize=16, fontweight='bold', pad=20) ax.set_ylabel(y, fontsize=12, fontweight='bold') ax.grid(True, alpha=0.3, linestyle='--') plt.tight_layout() buf = BytesIO() plt.savefig(buf, format="png", dpi=150, bbox_inches="tight") buf.seek(0) plot_bytes = buf.read() plt.close() print("Enhanced line chart generated") elif ptype == "hist": fig, ax = plt.subplots(figsize=(11, 7)) plot_df[y].dropna().plot.hist(ax=ax, bins=25, edgecolor='black', alpha=0.7, color='teal') ax.set_title(title, fontsize=16, fontweight='bold', pad=20) ax.set_xlabel(y, fontsize=12, fontweight='bold') ax.set_ylabel("Frequency", fontsize=12, fontweight='bold') ax.grid(axis='y', alpha=0.3, linestyle='--') plt.tight_layout() buf = BytesIO() plt.savefig(buf, format="png", dpi=150, bbox_inches="tight") buf.seek(0) plot_bytes = buf.read() plt.close() print("Enhanced histogram generated") elif ptype == "scatter": fig, ax = plt.subplots(figsize=(11, 7)) if x and x in plot_df.columns and y and y in plot_df.columns: plot_df.plot.scatter(x=x, y=y, ax=ax, alpha=0.6, s=60, color='purple') ax.set_xlabel(x, fontsize=12, fontweight='bold') ax.set_ylabel(y, fontsize=12, fontweight='bold') ax.set_title(title, fontsize=16, fontweight='bold', pad=20) ax.grid(True, alpha=0.3, linestyle='--') plt.tight_layout() buf = BytesIO() plt.savefig(buf, format="png", dpi=150, bbox_inches="tight") buf.seek(0) plot_bytes = buf.read() plt.close() print("Enhanced scatter plot generated") return dfw, plot_bytes, plot_html, describe_stats except Exception as e: print(f"EXECUTION ERROR: {e}") traceback.print_exc() raise def make_context(df): sample_data = df.head(3).to_string(max_cols=10, max_colwidth=20) return f"""Dataset: {len(df)} rows, {len(df.columns)} columns Columns: {', '.join(df.columns)} Data types: {df.dtypes.value_counts().to_dict()} Sample data: {sample_data}""" def load_file(file_path): if file_path.endswith('.csv'): return pd.read_csv(file_path) elif file_path.endswith(('.xlsx', '.xls')): return pd.read_excel(file_path) else: raise ValueError("Unsupported file format. Please use CSV or Excel files.") def start_agent(): print("=" * 80) print("SparkNova v5.0 – Advanced Data Analysis & Visualization") print("=" * 80) df = None while True: if df is None: file_path = input("\nEnter file path (CSV or Excel): ").strip() if not file_path: continue try: df = load_file(file_path) df = clean_numeric(df) print(f"Loaded {file_path} ({len(df)} rows × {len(df.columns)} cols)") print("\nFirst 5 rows:") print(df.head()) print(f"\nColumn types:\n{df.dtypes}") print("\nSample Questions You Can Ask:") for i, question in enumerate(SAMPLE_QUESTIONS[:8], 1): print(f"{i}. {question}") data_ctx = make_context(df) except Exception as e: print(f"Error loading file: {e}") continue q = input("\nYour question (or 'exit'/'reload'): ").strip() if not q: continue if q.lower() in ("exit", "quit"): print("Thank you for using SparkNova!") break if q.lower() == "reload": df = None continue enhanced_prompt = get_chart_prompt(q, df.columns.tolist(), df.head(3).to_string()) try: raw = call_groq([ {"role": "system", "content": ENHANCED_SYSTEM_PROMPT}, {"role": "user", "content": enhanced_prompt} ]) except Exception as e: print(f"LLM call failed: {e}") continue plan = parse_plan(raw) if plan.get("type") == "explain": print("\nExplanation:") print(plan.get("narrative", "")) continue if plan.get("type") == "error": print("\nError:") print(plan.get("narrative", "")) continue print("\nAnalysis Plan:") print(json.dumps(plan, indent=2)) if plan.get("plot"): plan["plot"] = validate_plot_spec(plan["plot"], df.columns.tolist()) try: print("\nExecuting operations...") res, plot_img, plot_html, desc_stats = execute_plan(df, plan) if not desc_stats or len(res) != len(df): print("\nResult:") print(res.head(20)) if plot_html: print("\nGenerated Interactive Chart (HTML saved as chart.html)") with open("chart.html", "w") as f: f.write(plot_html) elif plot_img: print("\nGenerated Chart (saved as chart.png)") with open("chart.png", "wb") as f: f.write(plot_img) narrative = plan.get("narrative", "") if narrative: print(f"\nSummary: {narrative}") if plan.get("insights_needed") and (plot_html or plot_img): print("\nDetailed Insights:") insights = generate_insights(df, res, plan, True) print(insights) except Exception as e: print(f"Execution failed: {e}") continue if __name__ == "__main__": start_agent()