SPARKNOVA / Sparknova.py
Tamannathakur's picture
Update Sparknova.py
e58620e verified
import os
import json
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import plotly.graph_objects as go
import plotly.express as px
import traceback
from io import BytesIO
from langchain_groq import ChatGroq
from prompts import ENHANCED_SYSTEM_PROMPT, SAMPLE_QUESTIONS, get_chart_prompt, validate_plot_spec, INSIGHTS_SYSTEM_PROMPT, get_insights_prompt
llm = ChatGroq(
api_key=GROQ_API_KEY,
model="llama-3.3-70b-versatile",
temperature=0.0
)
print("GROQ API initialized successfully")
def call_groq(messages):
try:
res = llm.invoke(messages)
return res.content if hasattr(res, "content") else str(res)
except Exception as e:
raise RuntimeError(f"GROQ API error: {e}")
def parse_plan(raw_text):
txt = raw_text.strip().replace("```json", "").replace("```", "").strip()
try:
start = txt.index("{")
end = txt.rindex("}") + 1
plan = json.loads(txt[start:end])
plan.setdefault("type", "analysis")
plan.setdefault("operations", [])
plan.setdefault("plot", None)
plan.setdefault("narrative", "")
plan.setdefault("insights_needed", False)
return plan
except Exception as e:
return {
"type": "error",
"operations": [],
"plot": None,
"narrative": f"Error parsing JSON: {str(e)}",
"insights_needed": False
}
def clean_numeric(df):
df = df.copy()
for col in df.columns:
if pd.api.types.is_string_dtype(df[col]) or df[col].dtype == object:
s = df[col].astype(str).str.strip()
if s.str.contains("%", na=False).any():
numeric_vals = pd.to_numeric(s.str.replace("%", "", regex=False), errors="coerce")
if numeric_vals.notna().sum() / len(df) > 0.5:
df[col] = numeric_vals / 100.0
continue
cleaned = s.str.replace(",", "", regex=False).str.replace("₹", "", regex=False).str.replace("$", "", regex=False)
numeric_vals = pd.to_numeric(cleaned, errors="coerce")
if numeric_vals.notna().sum() / len(df) > 0.5:
df[col] = numeric_vals
return df
def generate_insights(df, dfw, plan, plot_created):
context_parts = []
for op in plan.get("operations", []):
if op.get("op") == "describe":
cols = op.get("columns", [])
for col in cols:
if col in df.columns:
desc = df[col].describe()
context_parts.append(f"\n{col} Statistics:\n{desc.to_string()}")
elif op.get("op") == "groupby":
context_parts.append(f"\nGrouped Results:\n{dfw.head(10).to_string()}")
plot_spec = plan.get("plot")
if plot_created and plot_spec:
context_parts.append(f"\nChart Type: {plot_spec.get('type')}")
context_parts.append(f"Visualization: {plot_spec.get('title')}")
if len(dfw) > 0:
context_parts.append(f"\nResult Preview:\n{dfw.head(10).to_string()}")
insights_prompt = get_insights_prompt(context_parts, plan.get('narrative', ''))
try:
insights_response = call_groq([
{"role": "system", "content": INSIGHTS_SYSTEM_PROMPT},
{"role": "user", "content": insights_prompt}
])
return insights_response.strip()
except Exception as e:
return f"Analysis completed successfully\n{len(dfw)} records in result\nError generating detailed insights: {str(e)}"
def execute_plan(df, plan):
dfw = df.copy()
plot_bytes = None
plot_html = None
describe_stats = {}
try:
for op in plan.get("operations", []):
optype = op.get("op", "").lower()
if optype == "describe":
cols = op.get("columns", [])
for col in cols:
if col in dfw.columns:
stats = dfw[col].describe()
describe_stats[col] = stats
print(f"Described {col}")
print(f"\n{stats}\n")
continue
elif optype == "groupby":
cols = op.get("columns", [])
agg_col = op.get("agg_col")
agg_func = op.get("agg_func", "count")
if not cols:
raise ValueError("No columns specified for groupby")
if agg_func == "count" or not agg_col:
dfw = dfw.groupby(cols).size().reset_index(name="count")
print(f"Grouped by {cols} with count")
else:
if agg_col not in dfw.columns:
raise ValueError(f"Column '{agg_col}' not found for aggregation")
result_col = f"{agg_func}_{agg_col}"
dfw = dfw.groupby(cols)[agg_col].agg(agg_func).reset_index(name=result_col)
print(f"Grouped by {cols}, calculated {agg_func} of {agg_col}")
elif optype == "filter":
expr = op.get("expr", "")
if expr:
dfw = dfw.query(expr)
print(f"Filter applied: {expr}")
elif optype == "calculate":
expr = op.get("expr", "")
new_col = op.get("new_col", "Calculated")
dfw[new_col] = dfw.eval(expr)
print(f"Calculated {new_col} = {expr}")
plot_spec = plan.get("plot")
if plot_spec and plot_spec is not None:
ptype = plot_spec.get("type", "bar")
x = plot_spec.get("x")
y = plot_spec.get("y")
title = plot_spec.get("title", "Chart")
plot_df = df if describe_stats else dfw
if not x and len(plot_df.columns) > 0:
categorical_cols = plot_df.select_dtypes(include=['object', 'category']).columns
x = categorical_cols[0] if len(categorical_cols) > 0 else plot_df.columns[0]
if not y:
numeric_cols = plot_df.select_dtypes(include=[np.number]).columns
y = numeric_cols[0] if len(numeric_cols) > 0 else None
if not y:
print("No suitable Y column found for plotting.")
else:
if ptype == "pie":
if x and x in plot_df.columns:
value_counts = plot_df[x].value_counts()
fig = go.Figure(data=[go.Pie(
labels=value_counts.index,
values=value_counts.values,
hovertemplate='<b>%{label}</b><br>Count: %{value}<br>Percentage: %{percent}<extra></extra>',
textposition='auto',
hole=0.3
)])
else:
df_pie = plot_df[y].value_counts()
fig = go.Figure(data=[go.Pie(
labels=df_pie.index,
values=df_pie.values,
hole=0.3
)])
fig.update_layout(
title=title,
title_font_size=16,
showlegend=True,
width=950,
height=550
)
plot_html = fig.to_html(include_plotlyjs='cdn')
print("Enhanced pie chart generated")
elif ptype == "bar":
fig, ax = plt.subplots(figsize=(12, 7))
if x and x in plot_df.columns and y and y in plot_df.columns:
plot_df.plot.bar(x=x, y=y, ax=ax, legend=False, color='steelblue', edgecolor='black', alpha=0.8)
ax.set_xlabel(x, fontsize=12, fontweight='bold')
n_categories = len(plot_df[x].unique())
if n_categories > 10:
plt.xticks(rotation=90, ha='right', fontsize=9)
elif n_categories > 5:
plt.xticks(rotation=45, ha='right', fontsize=10)
else:
plt.xticks(rotation=0, fontsize=10)
else:
plot_df[y].plot.bar(ax=ax, color='steelblue', edgecolor='black', alpha=0.8)
ax.set_title(title, fontsize=16, fontweight='bold', pad=20)
ax.set_ylabel(y, fontsize=12, fontweight='bold')
ax.grid(axis='y', alpha=0.3, linestyle='--')
plt.tight_layout()
buf = BytesIO()
plt.savefig(buf, format="png", dpi=150, bbox_inches="tight")
buf.seek(0)
plot_bytes = buf.read()
plt.close()
print("Enhanced bar chart generated")
elif ptype == "line":
fig, ax = plt.subplots(figsize=(12, 7))
if x and x in plot_df.columns and y and y in plot_df.columns:
plot_df.plot.line(x=x, y=y, ax=ax, marker="o", linewidth=3,
markersize=8, color='darkblue', alpha=0.8)
ax.set_xlabel(x, fontsize=12, fontweight='bold')
if len(plot_df) > 15:
plt.xticks(rotation=45, ha='right', fontsize=9)
else:
plt.xticks(rotation=0, fontsize=10)
else:
plot_df[y].plot.line(ax=ax, marker="o", linewidth=3,
markersize=8, color='darkblue', alpha=0.8)
ax.set_title(title, fontsize=16, fontweight='bold', pad=20)
ax.set_ylabel(y, fontsize=12, fontweight='bold')
ax.grid(True, alpha=0.3, linestyle='--')
plt.tight_layout()
buf = BytesIO()
plt.savefig(buf, format="png", dpi=150, bbox_inches="tight")
buf.seek(0)
plot_bytes = buf.read()
plt.close()
print("Enhanced line chart generated")
elif ptype == "hist":
fig, ax = plt.subplots(figsize=(11, 7))
plot_df[y].dropna().plot.hist(ax=ax, bins=25, edgecolor='black',
alpha=0.7, color='teal')
ax.set_title(title, fontsize=16, fontweight='bold', pad=20)
ax.set_xlabel(y, fontsize=12, fontweight='bold')
ax.set_ylabel("Frequency", fontsize=12, fontweight='bold')
ax.grid(axis='y', alpha=0.3, linestyle='--')
plt.tight_layout()
buf = BytesIO()
plt.savefig(buf, format="png", dpi=150, bbox_inches="tight")
buf.seek(0)
plot_bytes = buf.read()
plt.close()
print("Enhanced histogram generated")
elif ptype == "scatter":
fig, ax = plt.subplots(figsize=(11, 7))
if x and x in plot_df.columns and y and y in plot_df.columns:
plot_df.plot.scatter(x=x, y=y, ax=ax, alpha=0.6, s=60, color='purple')
ax.set_xlabel(x, fontsize=12, fontweight='bold')
ax.set_ylabel(y, fontsize=12, fontweight='bold')
ax.set_title(title, fontsize=16, fontweight='bold', pad=20)
ax.grid(True, alpha=0.3, linestyle='--')
plt.tight_layout()
buf = BytesIO()
plt.savefig(buf, format="png", dpi=150, bbox_inches="tight")
buf.seek(0)
plot_bytes = buf.read()
plt.close()
print("Enhanced scatter plot generated")
return dfw, plot_bytes, plot_html, describe_stats
except Exception as e:
print(f"EXECUTION ERROR: {e}")
traceback.print_exc()
raise
def make_context(df):
sample_data = df.head(3).to_string(max_cols=10, max_colwidth=20)
return f"""Dataset: {len(df)} rows, {len(df.columns)} columns
Columns: {', '.join(df.columns)}
Data types: {df.dtypes.value_counts().to_dict()}
Sample data:
{sample_data}"""
def load_file(file_path):
if file_path.endswith('.csv'):
return pd.read_csv(file_path)
elif file_path.endswith(('.xlsx', '.xls')):
return pd.read_excel(file_path)
else:
raise ValueError("Unsupported file format. Please use CSV or Excel files.")
def start_agent():
print("=" * 80)
print("SparkNova v5.0 – Advanced Data Analysis & Visualization")
print("=" * 80)
df = None
while True:
if df is None:
file_path = input("\nEnter file path (CSV or Excel): ").strip()
if not file_path:
continue
try:
df = load_file(file_path)
df = clean_numeric(df)
print(f"Loaded {file_path} ({len(df)} rows × {len(df.columns)} cols)")
print("\nFirst 5 rows:")
print(df.head())
print(f"\nColumn types:\n{df.dtypes}")
print("\nSample Questions You Can Ask:")
for i, question in enumerate(SAMPLE_QUESTIONS[:8], 1):
print(f"{i}. {question}")
data_ctx = make_context(df)
except Exception as e:
print(f"Error loading file: {e}")
continue
q = input("\nYour question (or 'exit'/'reload'): ").strip()
if not q:
continue
if q.lower() in ("exit", "quit"):
print("Thank you for using SparkNova!")
break
if q.lower() == "reload":
df = None
continue
enhanced_prompt = get_chart_prompt(q, df.columns.tolist(), df.head(3).to_string())
try:
raw = call_groq([
{"role": "system", "content": ENHANCED_SYSTEM_PROMPT},
{"role": "user", "content": enhanced_prompt}
])
except Exception as e:
print(f"LLM call failed: {e}")
continue
plan = parse_plan(raw)
if plan.get("type") == "explain":
print("\nExplanation:")
print(plan.get("narrative", ""))
continue
if plan.get("type") == "error":
print("\nError:")
print(plan.get("narrative", ""))
continue
print("\nAnalysis Plan:")
print(json.dumps(plan, indent=2))
if plan.get("plot"):
plan["plot"] = validate_plot_spec(plan["plot"], df.columns.tolist())
try:
print("\nExecuting operations...")
res, plot_img, plot_html, desc_stats = execute_plan(df, plan)
if not desc_stats or len(res) != len(df):
print("\nResult:")
print(res.head(20))
if plot_html:
print("\nGenerated Interactive Chart (HTML saved as chart.html)")
with open("chart.html", "w") as f:
f.write(plot_html)
elif plot_img:
print("\nGenerated Chart (saved as chart.png)")
with open("chart.png", "wb") as f:
f.write(plot_img)
narrative = plan.get("narrative", "")
if narrative:
print(f"\nSummary: {narrative}")
if plan.get("insights_needed") and (plot_html or plot_img):
print("\nDetailed Insights:")
insights = generate_insights(df, res, plan, True)
print(insights)
except Exception as e:
print(f"Execution failed: {e}")
continue
if __name__ == "__main__":
start_agent()