Spaces:
Sleeping
Sleeping
Upload 5 files
Browse files- ai_agent.py +346 -0
- app.py +312 -0
- data_engine.py +533 -0
- prompts.py +168 -0
- sparknova.py +412 -0
ai_agent.py
ADDED
|
@@ -0,0 +1,346 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import pandas as pd
|
| 4 |
+
import numpy as np
|
| 5 |
+
import plotly.express as px
|
| 6 |
+
import plotly.graph_objects as go
|
| 7 |
+
from langchain_groq import ChatGroq
|
| 8 |
+
from prompts import ENHANCED_SYSTEM_PROMPT, get_chart_prompt, validate_plot_spec, INSIGHTS_SYSTEM_PROMPT, get_insights_prompt
|
| 9 |
+
|
| 10 |
+
GROQ_API_KEY = "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"
|
| 11 |
+
|
| 12 |
+
def initialize_llm():
|
| 13 |
+
try:
|
| 14 |
+
os.environ["GROQ_API_KEY"] = GROQ_API_KEY
|
| 15 |
+
llm = ChatGroq(model="llama-3.3-70b-versatile", api_key=GROQ_API_KEY, temperature=0.0)
|
| 16 |
+
return llm
|
| 17 |
+
except:
|
| 18 |
+
return None
|
| 19 |
+
|
| 20 |
+
def parse_plan(raw_text):
|
| 21 |
+
txt = raw_text.strip().replace("```json", "").replace("```", "").strip()
|
| 22 |
+
try:
|
| 23 |
+
start = txt.index("{")
|
| 24 |
+
end = txt.rindex("}") + 1
|
| 25 |
+
plan = json.loads(txt[start:end])
|
| 26 |
+
plan.setdefault("type", "analysis")
|
| 27 |
+
plan.setdefault("operations", [])
|
| 28 |
+
plan.setdefault("plot", None)
|
| 29 |
+
plan.setdefault("narrative", "")
|
| 30 |
+
plan.setdefault("insights_needed", False)
|
| 31 |
+
return plan
|
| 32 |
+
except Exception as e:
|
| 33 |
+
return {
|
| 34 |
+
"type": "error",
|
| 35 |
+
"operations": [],
|
| 36 |
+
"plot": None,
|
| 37 |
+
"narrative": f"Error parsing response: {str(e)}",
|
| 38 |
+
"insights_needed": False
|
| 39 |
+
}
|
| 40 |
+
|
| 41 |
+
def execute_plan(df, plan):
|
| 42 |
+
dfw = df.copy()
|
| 43 |
+
describe_stats = {}
|
| 44 |
+
|
| 45 |
+
try:
|
| 46 |
+
for op in plan.get("operations", []):
|
| 47 |
+
optype = op.get("op", "").lower()
|
| 48 |
+
|
| 49 |
+
if optype == "describe":
|
| 50 |
+
cols = op.get("columns", [])
|
| 51 |
+
for col in cols:
|
| 52 |
+
if col in dfw.columns:
|
| 53 |
+
stats = dfw[col].describe()
|
| 54 |
+
describe_stats[col] = stats
|
| 55 |
+
|
| 56 |
+
elif optype == "groupby":
|
| 57 |
+
cols = op.get("columns", [])
|
| 58 |
+
agg_col = op.get("agg_col")
|
| 59 |
+
agg_func = op.get("agg_func", "count")
|
| 60 |
+
|
| 61 |
+
if cols and all(c in dfw.columns for c in cols):
|
| 62 |
+
if agg_func == "count" or not agg_col:
|
| 63 |
+
dfw = dfw.groupby(cols).size().reset_index(name="count")
|
| 64 |
+
else:
|
| 65 |
+
if agg_col in dfw.columns:
|
| 66 |
+
result_col = f"{agg_func}_{agg_col}"
|
| 67 |
+
dfw = dfw.groupby(cols)[agg_col].agg(agg_func).reset_index(name=result_col)
|
| 68 |
+
|
| 69 |
+
elif optype == "filter":
|
| 70 |
+
expr = op.get("expr", "")
|
| 71 |
+
column = op.get("column")
|
| 72 |
+
value = op.get("value")
|
| 73 |
+
|
| 74 |
+
if expr:
|
| 75 |
+
try:
|
| 76 |
+
dfw = dfw.query(expr)
|
| 77 |
+
except Exception:
|
| 78 |
+
# If query fails, try alternative filtering methods
|
| 79 |
+
if column and column in dfw.columns and value:
|
| 80 |
+
if dfw[column].dtype == 'object':
|
| 81 |
+
dfw = dfw[dfw[column].str.contains(str(value), case=False, na=False)]
|
| 82 |
+
else:
|
| 83 |
+
dfw = dfw[dfw[column] == value]
|
| 84 |
+
elif column and column in dfw.columns and value:
|
| 85 |
+
if dfw[column].dtype == 'object':
|
| 86 |
+
dfw = dfw[dfw[column].str.contains(str(value), case=False, na=False)]
|
| 87 |
+
else:
|
| 88 |
+
dfw = dfw[dfw[column] == value]
|
| 89 |
+
|
| 90 |
+
elif optype == "calculate":
|
| 91 |
+
expr = op.get("expr", "")
|
| 92 |
+
new_col = op.get("new_col", "Calculated")
|
| 93 |
+
if expr:
|
| 94 |
+
try:
|
| 95 |
+
dfw[new_col] = dfw.eval(expr)
|
| 96 |
+
except:
|
| 97 |
+
if "std" in expr:
|
| 98 |
+
for col in dfw.select_dtypes(include=[np.number]).columns:
|
| 99 |
+
if col in expr:
|
| 100 |
+
dfw[new_col] = dfw[col].std()
|
| 101 |
+
break
|
| 102 |
+
elif "mean" in expr:
|
| 103 |
+
for col in dfw.select_dtypes(include=[np.number]).columns:
|
| 104 |
+
if col in expr:
|
| 105 |
+
dfw[new_col] = dfw[col].mean()
|
| 106 |
+
break
|
| 107 |
+
else:
|
| 108 |
+
pass
|
| 109 |
+
|
| 110 |
+
elif optype == "count":
|
| 111 |
+
column = op.get("column")
|
| 112 |
+
value = op.get("value")
|
| 113 |
+
if column and column in dfw.columns:
|
| 114 |
+
if value:
|
| 115 |
+
# Handle both string and numeric columns dynamically
|
| 116 |
+
if dfw[column].dtype == 'object':
|
| 117 |
+
count_result = dfw[column].str.contains(str(value), case=False, na=False).sum()
|
| 118 |
+
else:
|
| 119 |
+
count_result = (dfw[column] == value).sum()
|
| 120 |
+
describe_stats[f"count_{column}_{value}"] = count_result
|
| 121 |
+
else:
|
| 122 |
+
# Show all unique values with their counts
|
| 123 |
+
count_result = dfw[column].value_counts()
|
| 124 |
+
describe_stats[f"values_{column}"] = count_result
|
| 125 |
+
|
| 126 |
+
return dfw, describe_stats
|
| 127 |
+
except Exception as e:
|
| 128 |
+
raise Exception(f"Execution error: {str(e)}")
|
| 129 |
+
|
| 130 |
+
def create_chart(df, selected_columns=None, chart_type="bar", title=None):
|
| 131 |
+
try:
|
| 132 |
+
if selected_columns and len(selected_columns) >= 2:
|
| 133 |
+
x_col, y_col = selected_columns[0], selected_columns[1]
|
| 134 |
+
if x_col in df.columns and y_col in df.columns:
|
| 135 |
+
if chart_type == "scatter":
|
| 136 |
+
fig = px.scatter(df.head(100), x=x_col, y=y_col, title=title or f"{y_col} vs {x_col}")
|
| 137 |
+
elif chart_type == "line":
|
| 138 |
+
fig = px.line(df.head(50), x=x_col, y=y_col, title=title or f"{y_col} over {x_col}", markers=True)
|
| 139 |
+
else:
|
| 140 |
+
fig = px.bar(df.head(50), x=x_col, y=y_col, title=title or f"{y_col} by {x_col}")
|
| 141 |
+
fig.update_layout(width=900, height=500)
|
| 142 |
+
return fig
|
| 143 |
+
if selected_columns and len(selected_columns) == 1:
|
| 144 |
+
col = selected_columns[0]
|
| 145 |
+
if col in df.columns:
|
| 146 |
+
if pd.api.types.is_numeric_dtype(df[col]):
|
| 147 |
+
fig = px.histogram(df, x=col, title=f"Distribution of {col}")
|
| 148 |
+
else:
|
| 149 |
+
value_counts = df[col].value_counts().head(10)
|
| 150 |
+
fig = px.bar(x=value_counts.index, y=value_counts.values, title=f"Top Values in {col}")
|
| 151 |
+
fig.update_layout(width=900, height=500)
|
| 152 |
+
return fig
|
| 153 |
+
categorical_cols = df.select_dtypes(include=['object', 'category']).columns
|
| 154 |
+
numeric_cols = df.select_dtypes(include=[np.number]).columns
|
| 155 |
+
if len(numeric_cols) > 1:
|
| 156 |
+
fig = px.scatter(df.head(100), x=numeric_cols[0], y=numeric_cols[1], title=f"{numeric_cols[1]} vs {numeric_cols[0]}")
|
| 157 |
+
elif len(categorical_cols) > 0 and len(numeric_cols) > 0:
|
| 158 |
+
fig = px.bar(df.head(50), x=categorical_cols[0], y=numeric_cols[0], title=f"{numeric_cols[0]} by {categorical_cols[0]}")
|
| 159 |
+
elif len(categorical_cols) > 0:
|
| 160 |
+
value_counts = df[categorical_cols[0]].value_counts().head(10)
|
| 161 |
+
fig = px.pie(values=value_counts.values, names=value_counts.index, title=f"Distribution of {categorical_cols[0]}")
|
| 162 |
+
elif len(numeric_cols) > 0:
|
| 163 |
+
fig = px.histogram(df, x=numeric_cols[0], title=f"Distribution of {numeric_cols[0]}")
|
| 164 |
+
else:
|
| 165 |
+
return None
|
| 166 |
+
fig.update_layout(width=900, height=500)
|
| 167 |
+
return fig
|
| 168 |
+
except:
|
| 169 |
+
return None
|
| 170 |
+
|
| 171 |
+
def create_plot(df, dfw, plan, describe_stats, selected_columns=None):
|
| 172 |
+
plot_spec = plan.get("plot")
|
| 173 |
+
if not plot_spec:
|
| 174 |
+
return None
|
| 175 |
+
ptype = plot_spec.get("type", "bar")
|
| 176 |
+
title = plot_spec.get("title", "Chart")
|
| 177 |
+
plot_df = df if describe_stats else dfw
|
| 178 |
+
x = plot_spec.get("x")
|
| 179 |
+
y = plot_spec.get("y")
|
| 180 |
+
|
| 181 |
+
if not x and len(plot_df.columns) > 0:
|
| 182 |
+
categorical_cols = plot_df.select_dtypes(include=['object', 'category']).columns
|
| 183 |
+
x = categorical_cols[0] if len(categorical_cols) > 0 else plot_df.columns[0]
|
| 184 |
+
if not y:
|
| 185 |
+
numeric_cols = plot_df.select_dtypes(include=[np.number]).columns
|
| 186 |
+
y = numeric_cols[0] if len(numeric_cols) > 0 else None
|
| 187 |
+
|
| 188 |
+
try:
|
| 189 |
+
if ptype == "pie" and x and x in plot_df.columns:
|
| 190 |
+
value_counts = plot_df[x].value_counts()
|
| 191 |
+
fig = go.Figure(data=[go.Pie(labels=value_counts.index, values=value_counts.values, hole=0.3)])
|
| 192 |
+
fig.update_layout(title=title, width=900, height=500)
|
| 193 |
+
return fig
|
| 194 |
+
elif ptype == "bar" and x and x in plot_df.columns and y and y in plot_df.columns:
|
| 195 |
+
fig = px.bar(plot_df, x=x, y=y, title=title)
|
| 196 |
+
fig.update_layout(width=900, height=500)
|
| 197 |
+
return fig
|
| 198 |
+
elif ptype == "line" and x and x in plot_df.columns and y and y in plot_df.columns:
|
| 199 |
+
fig = px.line(plot_df, x=x, y=y, title=title, markers=True)
|
| 200 |
+
fig.update_layout(width=900, height=500)
|
| 201 |
+
return fig
|
| 202 |
+
elif ptype == "hist" and y and y in plot_df.columns:
|
| 203 |
+
fig = px.histogram(plot_df, x=y, title=title, nbins=30)
|
| 204 |
+
fig.update_layout(width=900, height=500)
|
| 205 |
+
return fig
|
| 206 |
+
elif ptype == "scatter" and x and x in plot_df.columns and y and y in plot_df.columns:
|
| 207 |
+
fig = px.scatter(plot_df, x=x, y=y, title=title)
|
| 208 |
+
fig.update_layout(width=900, height=500)
|
| 209 |
+
return fig
|
| 210 |
+
except:
|
| 211 |
+
pass
|
| 212 |
+
return None
|
| 213 |
+
|
| 214 |
+
def generate_insights(df, dfw, plan, llm):
|
| 215 |
+
try:
|
| 216 |
+
context_parts = []
|
| 217 |
+
for op in plan.get("operations", []):
|
| 218 |
+
if op.get("op") == "describe":
|
| 219 |
+
cols = op.get("columns", [])
|
| 220 |
+
for col in cols:
|
| 221 |
+
if col in df.columns:
|
| 222 |
+
desc = df[col].describe()
|
| 223 |
+
context_parts.append(f"\n{col} Statistics:\n{desc.to_string()}")
|
| 224 |
+
elif op.get("op") == "groupby":
|
| 225 |
+
context_parts.append(f"\nGrouped Results:\n{dfw.head(10).to_string()}")
|
| 226 |
+
|
| 227 |
+
insights_prompt = get_insights_prompt(context_parts, plan.get('narrative', ''))
|
| 228 |
+
|
| 229 |
+
response = llm.invoke([
|
| 230 |
+
{"role": "system", "content": INSIGHTS_SYSTEM_PROMPT},
|
| 231 |
+
{"role": "user", "content": insights_prompt}
|
| 232 |
+
])
|
| 233 |
+
|
| 234 |
+
return response.content if hasattr(response, 'content') else str(response)
|
| 235 |
+
except Exception as e:
|
| 236 |
+
return f"Error generating insights: {str(e)}"
|
| 237 |
+
|
| 238 |
+
def analyze_question(question, selected_columns, uploaded_df, llm):
|
| 239 |
+
if llm is None:
|
| 240 |
+
return "API not initialized. Please restart.", None, None
|
| 241 |
+
|
| 242 |
+
if uploaded_df is None:
|
| 243 |
+
return "Please upload a dataset first.", None, None
|
| 244 |
+
|
| 245 |
+
if not question.strip():
|
| 246 |
+
return "Please enter a question.", None, None
|
| 247 |
+
|
| 248 |
+
try:
|
| 249 |
+
df_to_analyze = uploaded_df[selected_columns] if selected_columns else uploaded_df
|
| 250 |
+
|
| 251 |
+
sample_data = df_to_analyze.head(3).to_string(max_cols=10, max_colwidth=20)
|
| 252 |
+
|
| 253 |
+
if selected_columns:
|
| 254 |
+
column_context = f"Selected columns for analysis: {', '.join(selected_columns)}\n"
|
| 255 |
+
else:
|
| 256 |
+
column_context = ""
|
| 257 |
+
|
| 258 |
+
# data_ctx = f"""{column_context}Dataset: {len(df_to_analyze)} rows, {len(df_to_analyze.columns)} columns
|
| 259 |
+
# Columns: {', '.join(df_to_analyze.columns)}
|
| 260 |
+
# Sample data:
|
| 261 |
+
# {sample_data}"""
|
| 262 |
+
|
| 263 |
+
enhanced_prompt = get_chart_prompt(question, df_to_analyze.columns.tolist(), sample_data)
|
| 264 |
+
|
| 265 |
+
messages = [
|
| 266 |
+
{"role": "system", "content": ENHANCED_SYSTEM_PROMPT},
|
| 267 |
+
{"role": "user", "content": enhanced_prompt}
|
| 268 |
+
]
|
| 269 |
+
|
| 270 |
+
response = llm.invoke(messages)
|
| 271 |
+
raw_text = response.content if hasattr(response, 'content') else str(response)
|
| 272 |
+
|
| 273 |
+
plan = parse_plan(raw_text)
|
| 274 |
+
|
| 275 |
+
if plan.get("type") == "explain":
|
| 276 |
+
return plan.get("narrative", "No explanation provided"), None, None
|
| 277 |
+
|
| 278 |
+
if plan.get("type") == "error":
|
| 279 |
+
return plan.get("narrative", "Error occurred"), None, None
|
| 280 |
+
|
| 281 |
+
if plan.get("plot"):
|
| 282 |
+
plan["plot"] = validate_plot_spec(plan["plot"], df_to_analyze.columns.tolist())
|
| 283 |
+
|
| 284 |
+
dfw, describe_stats = execute_plan(df_to_analyze, plan)
|
| 285 |
+
|
| 286 |
+
# Only use narrative for explain type, avoid hallucinations for data operations
|
| 287 |
+
has_data_operations = any(col.startswith(("count_", "values_")) for col in describe_stats.keys()) if describe_stats else False
|
| 288 |
+
has_filtered_data = len(dfw) != len(df_to_analyze)
|
| 289 |
+
|
| 290 |
+
if has_data_operations:
|
| 291 |
+
response_text = "Analysis completed."
|
| 292 |
+
elif has_filtered_data:
|
| 293 |
+
response_text = f"Filter applied. Found {len(dfw)} matching rows out of {len(df_to_analyze)} total rows."
|
| 294 |
+
else:
|
| 295 |
+
response_text = plan.get("narrative", "Analysis complete")
|
| 296 |
+
|
| 297 |
+
if describe_stats:
|
| 298 |
+
response_text += "\n\nResults:\n"
|
| 299 |
+
for col, stats in describe_stats.items():
|
| 300 |
+
if col.startswith("count_"):
|
| 301 |
+
# Extract column and value from key for dynamic display
|
| 302 |
+
parts = col.replace('count_', '').split('_', 1)
|
| 303 |
+
if len(parts) == 2:
|
| 304 |
+
column_name, value_name = parts
|
| 305 |
+
response_text += f"\nCount of '{value_name}' in {column_name}: {int(stats)}\n"
|
| 306 |
+
else:
|
| 307 |
+
response_text += f"\n{col}: {int(stats) if isinstance(stats, (int, float, np.integer)) else stats}\n"
|
| 308 |
+
elif col.startswith("values_"):
|
| 309 |
+
# Show all values in the column
|
| 310 |
+
column_name = col.replace('values_', '')
|
| 311 |
+
if hasattr(stats, 'to_string'):
|
| 312 |
+
response_text += f"\nAll values in {column_name}:\n{stats.to_string()}\n"
|
| 313 |
+
else:
|
| 314 |
+
response_text += f"\nAll values in {column_name}: {stats}\n"
|
| 315 |
+
else:
|
| 316 |
+
if hasattr(stats, 'to_string'):
|
| 317 |
+
response_text += f"\n{col}:\n{stats.to_string()}\n"
|
| 318 |
+
else:
|
| 319 |
+
response_text += f"\n{col}: {stats}\n"
|
| 320 |
+
|
| 321 |
+
fig = None
|
| 322 |
+
if plan.get("plot"):
|
| 323 |
+
fig = create_plot(df_to_analyze, dfw, plan, describe_stats, selected_columns)
|
| 324 |
+
|
| 325 |
+
if fig is None:
|
| 326 |
+
fig = create_chart(df_to_analyze, selected_columns)
|
| 327 |
+
|
| 328 |
+
if fig:
|
| 329 |
+
response_text += "\n\nChart generated successfully!"
|
| 330 |
+
if selected_columns and len(selected_columns) >= 1:
|
| 331 |
+
response_text += f"\nUsing selected columns: {', '.join(selected_columns)}"
|
| 332 |
+
|
| 333 |
+
if plan.get("insights_needed") and fig:
|
| 334 |
+
insights = generate_insights(df_to_analyze, dfw, plan, llm)
|
| 335 |
+
response_text += f"\n\nKey Insights:\n{insights}"
|
| 336 |
+
|
| 337 |
+
result_table = None
|
| 338 |
+
if len(dfw) != len(df_to_analyze):
|
| 339 |
+
result_table = dfw.head(50)
|
| 340 |
+
elif not describe_stats and len(dfw) > 0:
|
| 341 |
+
result_table = dfw.head(50)
|
| 342 |
+
|
| 343 |
+
return response_text, fig, result_table
|
| 344 |
+
|
| 345 |
+
except Exception as e:
|
| 346 |
+
return f"Error during analysis: {str(e)}", None, None
|
app.py
ADDED
|
@@ -0,0 +1,312 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import pandas as pd
|
| 3 |
+
import os
|
| 4 |
+
from data_engine import (
|
| 5 |
+
clean_numeric, run_analysis, create_visualization, handle_missing_data,
|
| 6 |
+
undo_last_change, undo_all_changes, download_dataset,
|
| 7 |
+
display_data_format, display_text_format
|
| 8 |
+
)
|
| 9 |
+
from ai_agent import initialize_llm, analyze_question
|
| 10 |
+
from prompts import SAMPLE_QUESTIONS
|
| 11 |
+
|
| 12 |
+
llm = None
|
| 13 |
+
uploaded_df = None
|
| 14 |
+
original_df = None
|
| 15 |
+
dataset_name = None
|
| 16 |
+
change_history = []
|
| 17 |
+
|
| 18 |
+
def upload_dataset(file):
|
| 19 |
+
global uploaded_df, original_df, dataset_name
|
| 20 |
+
if file is None:
|
| 21 |
+
return "No file uploaded", gr.update(visible=False), gr.update(choices=[]), gr.update(visible=False)
|
| 22 |
+
try:
|
| 23 |
+
dataset_name = os.path.basename(file.name)
|
| 24 |
+
if file.name.endswith('.csv'):
|
| 25 |
+
uploaded_df = pd.read_csv(file.name)
|
| 26 |
+
elif file.name.endswith(('.xlsx', '.xls')):
|
| 27 |
+
uploaded_df = pd.read_excel(file.name)
|
| 28 |
+
else:
|
| 29 |
+
return "Unsupported file format. Please upload CSV or Excel files.", gr.update(visible=False), gr.update(choices=[]), gr.update(visible=False)
|
| 30 |
+
uploaded_df = clean_numeric(uploaded_df)
|
| 31 |
+
original_df = uploaded_df.copy()
|
| 32 |
+
info_text = f" Dataset Loaded: {dataset_name} ({uploaded_df.shape[0]} rows × {uploaded_df.shape[1]} columns)"
|
| 33 |
+
return info_text, gr.update(visible=False), gr.update(choices=list(uploaded_df.columns), value=[]), gr.update(visible=True)
|
| 34 |
+
except Exception as e:
|
| 35 |
+
return f"Error loading file: {str(e)}", gr.update(visible=False), gr.update(choices=[]), gr.update(visible=False)
|
| 36 |
+
|
| 37 |
+
def clear_dataset():
|
| 38 |
+
global uploaded_df, original_df, dataset_name, change_history
|
| 39 |
+
uploaded_df = None
|
| 40 |
+
original_df = None
|
| 41 |
+
dataset_name = None
|
| 42 |
+
change_history = []
|
| 43 |
+
return "Dataset cleared. Please upload a new file.", gr.update(visible=False), gr.update(choices=[], value=[]), gr.update(visible=False)
|
| 44 |
+
|
| 45 |
+
def update_preview(format_type, selected_columns):
|
| 46 |
+
if format_type == "None":
|
| 47 |
+
return None, "", gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)
|
| 48 |
+
elif format_type == "DataFrame":
|
| 49 |
+
return display_data_format(format_type, selected_columns, uploaded_df), "", gr.update(visible=True), gr.update(visible=False), gr.update(visible=True)
|
| 50 |
+
else:
|
| 51 |
+
return None, display_text_format(format_type, selected_columns, uploaded_df), gr.update(visible=False), gr.update(visible=True), gr.update(visible=True)
|
| 52 |
+
|
| 53 |
+
def handle_analysis_change(analysis_type, selected_columns):
|
| 54 |
+
result_text, data_table = run_analysis(analysis_type, selected_columns, uploaded_df)
|
| 55 |
+
if result_text and result_text.strip() and analysis_type != "None":
|
| 56 |
+
if data_table is not None:
|
| 57 |
+
return gr.update(value=result_text, visible=True), gr.update(visible=True), gr.update(value=data_table, visible=True)
|
| 58 |
+
else:
|
| 59 |
+
return gr.update(value=result_text, visible=True), gr.update(visible=True), gr.update(visible=False)
|
| 60 |
+
else:
|
| 61 |
+
return gr.update(value="", visible=False), gr.update(visible=False), gr.update(visible=False)
|
| 62 |
+
|
| 63 |
+
def handle_viz_change(viz_type, selected_columns):
|
| 64 |
+
result = create_visualization(viz_type, selected_columns, uploaded_df)
|
| 65 |
+
if result and len(result) == 3:
|
| 66 |
+
fig, explanation, chart_obj = result
|
| 67 |
+
if explanation and fig is not None:
|
| 68 |
+
return fig, gr.update(visible=True), explanation, gr.update(visible=True)
|
| 69 |
+
else:
|
| 70 |
+
return None, gr.update(visible=False), explanation or "Error in visualization", gr.update(visible=False)
|
| 71 |
+
else:
|
| 72 |
+
return None, gr.update(visible=False), "Error in visualization", gr.update(visible=False)
|
| 73 |
+
|
| 74 |
+
def show_constant_input(method):
|
| 75 |
+
return gr.update(visible=(method == "Constant Fill"))
|
| 76 |
+
|
| 77 |
+
def handle_data_and_refresh(method, selected_columns, constant_value, analysis_type):
|
| 78 |
+
global uploaded_df, change_history
|
| 79 |
+
|
| 80 |
+
result, uploaded_df, change_history = handle_missing_data(method, selected_columns, constant_value, uploaded_df, change_history)
|
| 81 |
+
|
| 82 |
+
if analysis_type == "Missing Values" and uploaded_df is not None:
|
| 83 |
+
analysis_result = "Missing Values Analysis:\n" + "=" * 30 + "\n\n"
|
| 84 |
+
|
| 85 |
+
patterns = ['UNKNOWN', 'unknown', 'ERROR', 'error', 'NULL', 'null', 'NA', 'na', 'N/A',
|
| 86 |
+
'Not Given', 'not given', 'NOT GIVEN', '', ' ', '-', '?', 'NaN', 'nan',
|
| 87 |
+
'None', 'none', 'NONE', '#N/A', 'n/a', 'N.A.', 'n.a.']
|
| 88 |
+
|
| 89 |
+
for col in uploaded_df.columns:
|
| 90 |
+
nan_count = uploaded_df[col].isnull().sum()
|
| 91 |
+
pseudo_missing_count = 0
|
| 92 |
+
|
| 93 |
+
non_null_data = uploaded_df[col].dropna()
|
| 94 |
+
if len(non_null_data) > 0:
|
| 95 |
+
col_str = non_null_data.astype(str).str.strip()
|
| 96 |
+
empty_count = (col_str == '').sum()
|
| 97 |
+
pattern_count = 0
|
| 98 |
+
for pattern in patterns:
|
| 99 |
+
if pattern != '':
|
| 100 |
+
pattern_count += (col_str.str.lower() == pattern.lower()).sum()
|
| 101 |
+
pseudo_missing_count = empty_count + pattern_count
|
| 102 |
+
|
| 103 |
+
total_missing = nan_count + pseudo_missing_count
|
| 104 |
+
missing_percent = (total_missing / len(uploaded_df)) * 100
|
| 105 |
+
|
| 106 |
+
if total_missing > 0:
|
| 107 |
+
details = []
|
| 108 |
+
if nan_count > 0:
|
| 109 |
+
details.append(f"{nan_count} NaN")
|
| 110 |
+
if pseudo_missing_count > 0:
|
| 111 |
+
details.append(f"{pseudo_missing_count} text-missing")
|
| 112 |
+
detail_str = f" ({', '.join(details)})"
|
| 113 |
+
else:
|
| 114 |
+
detail_str = ""
|
| 115 |
+
|
| 116 |
+
analysis_result += f"{col}: {total_missing} missing ({missing_percent:.2f}%){detail_str}\n"
|
| 117 |
+
|
| 118 |
+
return result, gr.update(visible=True), gr.update(choices=list(uploaded_df.columns), value=[]), gr.update(value=analysis_result, visible=True), gr.update(visible=True)
|
| 119 |
+
|
| 120 |
+
return result, gr.update(visible=True), gr.update(choices=list(uploaded_df.columns), value=[]), gr.update(), gr.update()
|
| 121 |
+
|
| 122 |
+
def handle_undo_and_refresh(analysis_type, is_undo_all=False):
|
| 123 |
+
global uploaded_df, change_history
|
| 124 |
+
|
| 125 |
+
if is_undo_all:
|
| 126 |
+
result, uploaded_df, change_history = undo_all_changes(original_df, change_history)
|
| 127 |
+
else:
|
| 128 |
+
result, uploaded_df, change_history = undo_last_change(uploaded_df, change_history)
|
| 129 |
+
|
| 130 |
+
if analysis_type == "Missing Values" and uploaded_df is not None:
|
| 131 |
+
result_text, data_table = run_analysis(analysis_type, [], uploaded_df)
|
| 132 |
+
return result, gr.update(visible=True), gr.update(choices=list(uploaded_df.columns), value=[]), gr.update(value=result_text, visible=True), gr.update(visible=True)
|
| 133 |
+
|
| 134 |
+
return result, gr.update(visible=True), gr.update(choices=list(uploaded_df.columns), value=[]), gr.update(), gr.update()
|
| 135 |
+
|
| 136 |
+
def handle_question_analysis(question, selected_columns):
|
| 137 |
+
return analyze_question(question, selected_columns, uploaded_df, llm)
|
| 138 |
+
|
| 139 |
+
custom_css = """
|
| 140 |
+
.gradio-container {
|
| 141 |
+
max-width: 1400px !important;
|
| 142 |
+
margin: 0 auto !important;
|
| 143 |
+
}
|
| 144 |
+
.header-box {
|
| 145 |
+
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
| 146 |
+
border-radius: 15px;
|
| 147 |
+
padding: 25px;
|
| 148 |
+
margin: 20px auto;
|
| 149 |
+
text-align: center;
|
| 150 |
+
box-shadow: 0 8px 32px rgba(0, 0, 0, 0.1);
|
| 151 |
+
}
|
| 152 |
+
.header-title {
|
| 153 |
+
font-size: 36px;
|
| 154 |
+
font-weight: bold;
|
| 155 |
+
color: white;
|
| 156 |
+
margin: 0;
|
| 157 |
+
text-shadow: 2px 2px 4px rgba(0,0,0,0.3);
|
| 158 |
+
}
|
| 159 |
+
.section-box {
|
| 160 |
+
background-color: #f8f9fa;
|
| 161 |
+
padding: 20px;
|
| 162 |
+
border-radius: 12px;
|
| 163 |
+
margin: 15px 0;
|
| 164 |
+
border: 1px solid #e9ecef;
|
| 165 |
+
}
|
| 166 |
+
"""
|
| 167 |
+
|
| 168 |
+
with gr.Blocks(theme=gr.themes.Soft(), css=custom_css) as demo:
|
| 169 |
+
|
| 170 |
+
gr.HTML("""
|
| 171 |
+
<div class="header-box">
|
| 172 |
+
<h1 class="header-title">SparkNova</h1>
|
| 173 |
+
<p style="color: white; font-size: 18px; margin: 10px 0 0 0;">Advanced Data Analysis Platform</p>
|
| 174 |
+
</div>
|
| 175 |
+
""")
|
| 176 |
+
|
| 177 |
+
with gr.Row():
|
| 178 |
+
with gr.Column(scale=1):
|
| 179 |
+
gr.Markdown("### Upload Dataset")
|
| 180 |
+
file_input = gr.File(label="Choose CSV or Excel File", file_types=[".csv", ".xlsx", ".xls"])
|
| 181 |
+
dataset_info = gr.Markdown()
|
| 182 |
+
|
| 183 |
+
with gr.Row():
|
| 184 |
+
clear_btn = gr.Button("Clear Dataset", variant="secondary", size="sm")
|
| 185 |
+
|
| 186 |
+
column_selector = gr.CheckboxGroup(
|
| 187 |
+
label="Select Columns (optional - for multi-column charts)",
|
| 188 |
+
choices=[],
|
| 189 |
+
visible=False
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
format_selector = gr.Dropdown(
|
| 193 |
+
choices=["None", "DataFrame", "JSON", "Dictionary"],
|
| 194 |
+
value="None",
|
| 195 |
+
label="Display Format"
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
gr.Markdown("### Choose an Analysis Type")
|
| 199 |
+
analysis_selector = gr.Dropdown(
|
| 200 |
+
choices=["None", "Summary", "Describe", "Top 5 Rows", "Bottom 5 Rows", "Missing Values", "Group & Aggregate", "Calculate Expressions", "Highest Correlation"],
|
| 201 |
+
value="None",
|
| 202 |
+
label="Analysis Type"
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
gr.Markdown("### Visualization Types")
|
| 206 |
+
viz_selector = gr.Dropdown(
|
| 207 |
+
choices=["None", "Bar Chart", "Line Chart", "Scatter Plot", "Pie Chart", "Histogram", "Box Plot", "Heat Map"],
|
| 208 |
+
value="None",
|
| 209 |
+
label="Chart Type"
|
| 210 |
+
)
|
| 211 |
+
|
| 212 |
+
gr.Markdown("### Handling Data")
|
| 213 |
+
data_handler = gr.Dropdown(
|
| 214 |
+
choices=["None", "Forward Fill", "Backward Fill", "Constant Fill", "Mean Fill", "Median Fill", "Mode Fill", "Drop Columns"],
|
| 215 |
+
value="None",
|
| 216 |
+
label="Data Handling Method"
|
| 217 |
+
)
|
| 218 |
+
|
| 219 |
+
constant_input = gr.Textbox(
|
| 220 |
+
label="Constant Value (for Constant Fill)",
|
| 221 |
+
placeholder="Enter value to fill missing data",
|
| 222 |
+
visible=False
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
with gr.Row():
|
| 226 |
+
apply_btn = gr.Button("Apply Change", variant="primary", size="sm")
|
| 227 |
+
undo_last_btn = gr.Button("Undo Last", variant="secondary", size="sm")
|
| 228 |
+
|
| 229 |
+
with gr.Row():
|
| 230 |
+
undo_all_btn = gr.Button("Undo All", variant="secondary", size="sm")
|
| 231 |
+
download_btn = gr.Button("Download", variant="secondary", size="sm")
|
| 232 |
+
|
| 233 |
+
data_handling_output = gr.Textbox(label="Data Handling Results", lines=3, visible=False, interactive=False)
|
| 234 |
+
download_file = gr.File(label="Download Modified Dataset", visible=False)
|
| 235 |
+
|
| 236 |
+
with gr.Column(scale=2):
|
| 237 |
+
preview_heading = gr.Markdown("### Dataset Preview", visible=False)
|
| 238 |
+
dataset_preview = gr.Dataframe(wrap=True, visible=False)
|
| 239 |
+
text_preview = gr.Textbox(label="Text Preview", lines=15, visible=False)
|
| 240 |
+
|
| 241 |
+
analysis_heading = gr.Markdown("### Analysis Results", visible=False)
|
| 242 |
+
analysis_output = gr.Textbox(label="Analysis Output", lines=10, visible=False, interactive=False)
|
| 243 |
+
analysis_data_table = gr.Dataframe(label="Data Table", visible=False, wrap=True)
|
| 244 |
+
chart_output_new = gr.Plot(label="Chart", visible=False)
|
| 245 |
+
chart_explanation = gr.Textbox(label="Chart Analysis", lines=5, visible=False, interactive=False)
|
| 246 |
+
|
| 247 |
+
gr.Markdown("### Sample Questions")
|
| 248 |
+
with gr.Row():
|
| 249 |
+
for i in range(0, len(SAMPLE_QUESTIONS), 3):
|
| 250 |
+
with gr.Column():
|
| 251 |
+
for j in range(3):
|
| 252 |
+
if i + j < len(SAMPLE_QUESTIONS):
|
| 253 |
+
gr.Markdown(f"• {SAMPLE_QUESTIONS[i + j]}")
|
| 254 |
+
|
| 255 |
+
gr.Markdown("### Ask Your Question")
|
| 256 |
+
user_question = gr.Textbox(
|
| 257 |
+
label="Enter your question",
|
| 258 |
+
placeholder="Ask anything about your data...",
|
| 259 |
+
lines=3
|
| 260 |
+
)
|
| 261 |
+
submit_btn = gr.Button("Analyze", variant="primary", size="lg")
|
| 262 |
+
|
| 263 |
+
gr.Markdown("### Analysis Results")
|
| 264 |
+
with gr.Tabs():
|
| 265 |
+
with gr.Tab("Response"):
|
| 266 |
+
output_text = gr.Textbox(
|
| 267 |
+
label="Analysis Response",
|
| 268 |
+
interactive=False,
|
| 269 |
+
lines=15,
|
| 270 |
+
show_copy_button=True
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
with gr.Tab("Visualization"):
|
| 274 |
+
chart_output = gr.Plot(label="Generated Chart")
|
| 275 |
+
|
| 276 |
+
with gr.Tab("Data"):
|
| 277 |
+
result_table = gr.Dataframe(label="Result Data", wrap=True)
|
| 278 |
+
|
| 279 |
+
file_input.change(upload_dataset, inputs=file_input, outputs=[dataset_info, dataset_preview, column_selector, column_selector])
|
| 280 |
+
clear_btn.click(clear_dataset, outputs=[dataset_info, dataset_preview, column_selector, column_selector])
|
| 281 |
+
|
| 282 |
+
format_selector.change(update_preview, inputs=[format_selector, column_selector], outputs=[dataset_preview, text_preview, dataset_preview, text_preview, preview_heading])
|
| 283 |
+
column_selector.change(update_preview, inputs=[format_selector, column_selector], outputs=[dataset_preview, text_preview, dataset_preview, text_preview, preview_heading])
|
| 284 |
+
|
| 285 |
+
analysis_selector.change(handle_analysis_change, inputs=[analysis_selector, column_selector], outputs=[analysis_output, analysis_heading, analysis_data_table])
|
| 286 |
+
column_selector.change(handle_analysis_change, inputs=[analysis_selector, column_selector], outputs=[analysis_output, analysis_heading, analysis_data_table])
|
| 287 |
+
|
| 288 |
+
viz_selector.change(handle_viz_change, inputs=[viz_selector, column_selector], outputs=[chart_output_new, chart_output_new, chart_explanation, chart_explanation])
|
| 289 |
+
column_selector.change(handle_viz_change, inputs=[viz_selector, column_selector], outputs=[chart_output_new, chart_output_new, chart_explanation, chart_explanation])
|
| 290 |
+
|
| 291 |
+
submit_btn.click(handle_question_analysis, inputs=[user_question, column_selector], outputs=[output_text, chart_output, result_table])
|
| 292 |
+
|
| 293 |
+
data_handler.change(show_constant_input, inputs=data_handler, outputs=constant_input)
|
| 294 |
+
apply_btn.click(handle_data_and_refresh, inputs=[data_handler, column_selector, constant_input, analysis_selector], outputs=[data_handling_output, data_handling_output, column_selector, analysis_output, analysis_heading])
|
| 295 |
+
|
| 296 |
+
undo_last_btn.click(lambda analysis_type: handle_undo_and_refresh(analysis_type, False), inputs=[analysis_selector], outputs=[data_handling_output, data_handling_output, column_selector, analysis_output, analysis_heading])
|
| 297 |
+
undo_all_btn.click(lambda analysis_type: handle_undo_and_refresh(analysis_type, True), inputs=[analysis_selector], outputs=[data_handling_output, data_handling_output, column_selector, analysis_output, analysis_heading])
|
| 298 |
+
|
| 299 |
+
def handle_download():
|
| 300 |
+
filepath = download_dataset(uploaded_df, dataset_name)
|
| 301 |
+
return gr.File(value=filepath, visible=bool(filepath))
|
| 302 |
+
|
| 303 |
+
download_btn.click(handle_download, outputs=download_file)
|
| 304 |
+
|
| 305 |
+
gr.HTML("<div style='text-align: center; margin-top: 20px; color: #666;'>Powered by GROQ LLM & Gradio</div>")
|
| 306 |
+
|
| 307 |
+
if __name__ == "__main__":
|
| 308 |
+
llm = initialize_llm()
|
| 309 |
+
if not llm:
|
| 310 |
+
print("Warning: Failed to initialize GROQ API")
|
| 311 |
+
|
| 312 |
+
demo.launch(show_error=True, share=False)
|
data_engine.py
ADDED
|
@@ -0,0 +1,533 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import numpy as np
|
| 3 |
+
import plotly.express as px
|
| 4 |
+
import plotly.graph_objects as go
|
| 5 |
+
import os
|
| 6 |
+
import tempfile
|
| 7 |
+
|
| 8 |
+
def clean_numeric(df):
|
| 9 |
+
df = df.copy()
|
| 10 |
+
for col in df.columns:
|
| 11 |
+
if pd.api.types.is_string_dtype(df[col]) or df[col].dtype == object:
|
| 12 |
+
s = df[col].astype(str).str.strip()
|
| 13 |
+
if s.str.contains("%", na=False).any():
|
| 14 |
+
numeric_vals = pd.to_numeric(s.str.replace("%", "", regex=False), errors="coerce")
|
| 15 |
+
if numeric_vals.notna().sum() / len(df) > 0.5:
|
| 16 |
+
df[col] = numeric_vals / 100.0
|
| 17 |
+
continue
|
| 18 |
+
cleaned = s.str.replace(",", "", regex=False).str.replace("₹", "", regex=False).str.replace("$", "", regex=False)
|
| 19 |
+
numeric_vals = pd.to_numeric(cleaned, errors="coerce")
|
| 20 |
+
if numeric_vals.notna().sum() / len(df) > 0.5:
|
| 21 |
+
df[col] = numeric_vals
|
| 22 |
+
return df
|
| 23 |
+
|
| 24 |
+
def run_analysis(analysis_type, selected_columns, uploaded_df):
|
| 25 |
+
if uploaded_df is None:
|
| 26 |
+
return "Please upload a dataset first.", None
|
| 27 |
+
if analysis_type == "None" or analysis_type is None:
|
| 28 |
+
return "", None
|
| 29 |
+
|
| 30 |
+
if 'title' in uploaded_df.columns:
|
| 31 |
+
title_nulls = uploaded_df['title'].isnull().sum()
|
| 32 |
+
print(f"DEBUG: Title column has {title_nulls} null values at analysis time")
|
| 33 |
+
|
| 34 |
+
whole_dataset_analyses = ["Summary", "Top 5 Rows", "Bottom 5 Rows", "Missing Values"]
|
| 35 |
+
if analysis_type in whole_dataset_analyses:
|
| 36 |
+
df_to_analyze = uploaded_df
|
| 37 |
+
else:
|
| 38 |
+
if not selected_columns:
|
| 39 |
+
return f"Please select columns for {analysis_type} analysis.", None
|
| 40 |
+
df_to_analyze = uploaded_df[selected_columns]
|
| 41 |
+
|
| 42 |
+
try:
|
| 43 |
+
if analysis_type == "Summary":
|
| 44 |
+
numeric_cols = uploaded_df.select_dtypes(include=[np.number]).columns
|
| 45 |
+
categorical_cols = uploaded_df.select_dtypes(include=['object', 'category']).columns
|
| 46 |
+
result = f"Dataset Summary:\nRows: {len(uploaded_df):,}\nColumns: {len(uploaded_df.columns)}\nNumeric Columns: {len(numeric_cols)}\nText Columns: {len(categorical_cols)}\n\n"
|
| 47 |
+
if len(numeric_cols) > 0:
|
| 48 |
+
result += "Numeric Columns: " + ", ".join(numeric_cols.tolist()) + "\n"
|
| 49 |
+
if len(categorical_cols) > 0:
|
| 50 |
+
result += "Text Columns: " + ", ".join(categorical_cols.tolist())
|
| 51 |
+
return result, None
|
| 52 |
+
|
| 53 |
+
elif analysis_type == "Describe":
|
| 54 |
+
result = "Column Description:\n" + "=" * 30 + "\n\n"
|
| 55 |
+
for col in selected_columns:
|
| 56 |
+
if col in df_to_analyze.columns:
|
| 57 |
+
result += f"Column: {col}\n"
|
| 58 |
+
if pd.api.types.is_numeric_dtype(df_to_analyze[col]):
|
| 59 |
+
stats = df_to_analyze[col].describe()
|
| 60 |
+
result += f" Type: Numeric\n Count: {stats['count']:.0f}\n Mean: {stats['mean']:.3f}\n Std: {stats['std']:.3f}\n Min: {stats['min']:.3f}\n 25%: {stats['25%']:.3f}\n 50%: {stats['50%']:.3f}\n 75%: {stats['75%']:.3f}\n Max: {stats['max']:.3f}\n\n"
|
| 61 |
+
else:
|
| 62 |
+
unique_count = df_to_analyze[col].nunique()
|
| 63 |
+
null_count = df_to_analyze[col].isnull().sum()
|
| 64 |
+
most_common = df_to_analyze[col].mode().iloc[0] if len(df_to_analyze[col].mode()) > 0 else "N/A"
|
| 65 |
+
result += f" Type: Categorical/Text\n Unique Values: {unique_count}\n Missing Values: {null_count}\n Most Common: {most_common}\n"
|
| 66 |
+
top_values = df_to_analyze[col].value_counts().head(5)
|
| 67 |
+
result += " Top Values:\n"
|
| 68 |
+
for val, count in top_values.items():
|
| 69 |
+
result += f" {val}: {count} times\n"
|
| 70 |
+
result += "\n"
|
| 71 |
+
return result, None
|
| 72 |
+
|
| 73 |
+
elif analysis_type == "Top 5 Rows":
|
| 74 |
+
return "Top 5 Rows - See data table below", df_to_analyze.head(5)
|
| 75 |
+
|
| 76 |
+
elif analysis_type == "Bottom 5 Rows":
|
| 77 |
+
return "Bottom 5 Rows - See data table below", df_to_analyze.tail(5)
|
| 78 |
+
|
| 79 |
+
elif analysis_type == "Missing Values":
|
| 80 |
+
result = "Missing Values Analysis:\n" + "=" * 30 + "\n\n"
|
| 81 |
+
patterns = ['UNKNOWN', 'unknown', 'ERROR', 'error', 'NULL', 'null', 'NA', 'na', 'N/A',
|
| 82 |
+
'Not Given', 'not given', 'NOT GIVEN', '', ' ', '-', '?', 'NaN', 'nan',
|
| 83 |
+
'None', 'none', 'NONE', '#N/A', 'n/a', 'N.A.', 'n.a.']
|
| 84 |
+
|
| 85 |
+
for col in uploaded_df.columns:
|
| 86 |
+
nan_count = uploaded_df[col].isnull().sum()
|
| 87 |
+
pseudo_missing_count = 0
|
| 88 |
+
|
| 89 |
+
non_null_data = uploaded_df[col].dropna()
|
| 90 |
+
if len(non_null_data) > 0:
|
| 91 |
+
col_str = non_null_data.astype(str).str.strip()
|
| 92 |
+
empty_count = (col_str == '').sum()
|
| 93 |
+
pattern_count = 0
|
| 94 |
+
for pattern in patterns:
|
| 95 |
+
if pattern != '':
|
| 96 |
+
pattern_count += (col_str.str.lower() == pattern.lower()).sum()
|
| 97 |
+
pseudo_missing_count = empty_count + pattern_count
|
| 98 |
+
|
| 99 |
+
total_missing = nan_count + pseudo_missing_count
|
| 100 |
+
missing_percent = (total_missing / len(uploaded_df)) * 100
|
| 101 |
+
|
| 102 |
+
if col == 'title':
|
| 103 |
+
print(f"DEBUG: Title analysis - NaN: {nan_count}, Pseudo: {pseudo_missing_count}, Total: {total_missing}")
|
| 104 |
+
|
| 105 |
+
if total_missing > 0:
|
| 106 |
+
details = []
|
| 107 |
+
if nan_count > 0:
|
| 108 |
+
details.append(f"{nan_count} NaN")
|
| 109 |
+
if pseudo_missing_count > 0:
|
| 110 |
+
details.append(f"{pseudo_missing_count} text-missing")
|
| 111 |
+
detail_str = f" ({', '.join(details)})"
|
| 112 |
+
else:
|
| 113 |
+
detail_str = ""
|
| 114 |
+
|
| 115 |
+
result += f"{col}: {total_missing} missing ({missing_percent:.2f}%){detail_str}\n"
|
| 116 |
+
|
| 117 |
+
return result, None
|
| 118 |
+
|
| 119 |
+
elif analysis_type == "Highest Correlation":
|
| 120 |
+
numeric_cols = df_to_analyze.select_dtypes(include=[np.number]).columns
|
| 121 |
+
if len(numeric_cols) < 2:
|
| 122 |
+
return "Need at least 2 numeric columns for correlation analysis.", None
|
| 123 |
+
corr_matrix = df_to_analyze[numeric_cols].corr()
|
| 124 |
+
result = "Highest Correlations:\n" + "=" * 25 + "\n\n"
|
| 125 |
+
correlations = []
|
| 126 |
+
for i in range(len(corr_matrix.columns)):
|
| 127 |
+
for j in range(i+1, len(corr_matrix.columns)):
|
| 128 |
+
col1, col2 = corr_matrix.columns[i], corr_matrix.columns[j]
|
| 129 |
+
corr_val = corr_matrix.iloc[i, j]
|
| 130 |
+
correlations.append((abs(corr_val), col1, col2, corr_val))
|
| 131 |
+
correlations.sort(reverse=True)
|
| 132 |
+
for _, col1, col2, corr_val in correlations[:10]:
|
| 133 |
+
result += f"{col1} ↔ {col2}: {corr_val:.3f}\n"
|
| 134 |
+
return result, None
|
| 135 |
+
|
| 136 |
+
elif analysis_type == "Group & Aggregate":
|
| 137 |
+
if not selected_columns:
|
| 138 |
+
result = "Please select columns for grouping and aggregation."
|
| 139 |
+
else:
|
| 140 |
+
categorical_cols = [col for col in selected_columns if not pd.api.types.is_numeric_dtype(df_to_analyze[col])]
|
| 141 |
+
numeric_cols = [col for col in selected_columns if pd.api.types.is_numeric_dtype(df_to_analyze[col])]
|
| 142 |
+
|
| 143 |
+
if categorical_cols and numeric_cols:
|
| 144 |
+
group_col = categorical_cols[0]
|
| 145 |
+
agg_col = numeric_cols[0]
|
| 146 |
+
grouped = df_to_analyze.groupby(group_col)[agg_col].agg(['count', 'mean', 'sum']).round(2)
|
| 147 |
+
result = f"Group & Aggregate Analysis:\n" + "=" * 35 + "\n\n"
|
| 148 |
+
result += f"Grouped by: {group_col}\nAggregated: {agg_col}\n\n"
|
| 149 |
+
result += grouped.to_string()
|
| 150 |
+
elif categorical_cols:
|
| 151 |
+
group_col = categorical_cols[0]
|
| 152 |
+
grouped = df_to_analyze[group_col].value_counts()
|
| 153 |
+
result = f"Group Count Analysis:\n" + "=" * 25 + "\n\n"
|
| 154 |
+
result += grouped.to_string()
|
| 155 |
+
else:
|
| 156 |
+
result = "Please select at least one categorical column for grouping."
|
| 157 |
+
return result, None
|
| 158 |
+
|
| 159 |
+
elif analysis_type == "Calculate Expressions":
|
| 160 |
+
numeric_cols = df_to_analyze.select_dtypes(include=[np.number]).columns
|
| 161 |
+
|
| 162 |
+
if len(numeric_cols) >= 2:
|
| 163 |
+
col1, col2 = numeric_cols[0], numeric_cols[1]
|
| 164 |
+
df_calc = df_to_analyze.copy()
|
| 165 |
+
df_calc['Sum'] = df_calc[col1] + df_calc[col2]
|
| 166 |
+
df_calc['Difference'] = df_calc[col1] - df_calc[col2]
|
| 167 |
+
|
| 168 |
+
result = f"Calculated Expressions:\n" + "=" * 30 + "\n\n"
|
| 169 |
+
result += f"Using columns: {col1} and {col2}\n\n"
|
| 170 |
+
result += f"New calculated columns:\nSum = {col1} + {col2}\nDifference = {col1} - {col2}\n\n"
|
| 171 |
+
result += "Sample results:\n"
|
| 172 |
+
result += df_calc[['Sum', 'Difference']].head().to_string()
|
| 173 |
+
else:
|
| 174 |
+
result = "Need at least 2 numeric columns for calculations."
|
| 175 |
+
return result, None
|
| 176 |
+
|
| 177 |
+
else:
|
| 178 |
+
return f"Analysis type '{analysis_type}' is under development.", None
|
| 179 |
+
|
| 180 |
+
except Exception as e:
|
| 181 |
+
return f"Error in analysis: {str(e)}", None
|
| 182 |
+
|
| 183 |
+
def create_chart_explanation(viz_type, df_to_plot, selected_columns, fig_data=None):
|
| 184 |
+
try:
|
| 185 |
+
if viz_type == "Bar Chart" and len(selected_columns) >= 2:
|
| 186 |
+
x_col, y_col = selected_columns[0], selected_columns[1]
|
| 187 |
+
if pd.api.types.is_numeric_dtype(df_to_plot[y_col]):
|
| 188 |
+
max_val_idx = df_to_plot[y_col].idxmax()
|
| 189 |
+
max_category = df_to_plot.loc[max_val_idx, x_col]
|
| 190 |
+
max_value = df_to_plot[y_col].max()
|
| 191 |
+
y_mean = df_to_plot[y_col].mean()
|
| 192 |
+
else:
|
| 193 |
+
grouped = df_to_plot.groupby(x_col)[y_col].count()
|
| 194 |
+
max_category = grouped.idxmax()
|
| 195 |
+
max_value = grouped.max()
|
| 196 |
+
y_mean = grouped.mean()
|
| 197 |
+
return f"BAR CHART: {y_col} by {x_col}\nHighest: {max_category} ({max_value:.2f})\nAverage: {y_mean:.2f}\nCategories: {df_to_plot[x_col].nunique()}"
|
| 198 |
+
elif viz_type == "Line Chart" and fig_data is not None:
|
| 199 |
+
max_combo = fig_data.loc[fig_data['Count'].idxmax()]
|
| 200 |
+
min_combo = fig_data.loc[fig_data['Count'].idxmin()]
|
| 201 |
+
return f"LINE CHART: Distribution\nHighest: {max_combo[selected_columns[1]]} in {max_combo[selected_columns[0]]} ({max_combo['Count']})\nLowest: {min_combo[selected_columns[1]]} in {min_combo[selected_columns[0]]} ({min_combo['Count']})\nTotal: {len(df_to_plot)}"
|
| 202 |
+
except:
|
| 203 |
+
pass
|
| 204 |
+
return f"{viz_type} visualization\nShows data patterns and relationships"
|
| 205 |
+
|
| 206 |
+
def create_visualization(viz_type, selected_columns, uploaded_df):
|
| 207 |
+
if uploaded_df is None or viz_type == "None":
|
| 208 |
+
return None, "", None
|
| 209 |
+
if not selected_columns:
|
| 210 |
+
return None, "Please select columns for visualization.", None
|
| 211 |
+
df_to_plot = uploaded_df[selected_columns]
|
| 212 |
+
|
| 213 |
+
try:
|
| 214 |
+
if viz_type == "Bar Chart":
|
| 215 |
+
if len(selected_columns) >= 2:
|
| 216 |
+
x_col, y_col = selected_columns[0], selected_columns[1]
|
| 217 |
+
color_col = selected_columns[2] if len(selected_columns) > 2 else None
|
| 218 |
+
|
| 219 |
+
# Handle different data type combinations
|
| 220 |
+
if pd.api.types.is_numeric_dtype(df_to_plot[y_col]):
|
| 221 |
+
# Numeric Y-axis: use as-is
|
| 222 |
+
plot_data = df_to_plot.head(100)
|
| 223 |
+
fig = px.bar(plot_data, x=x_col, y=y_col, color=color_col, title=f"{y_col} by {x_col}")
|
| 224 |
+
else:
|
| 225 |
+
# Non-numeric Y-axis: count occurrences
|
| 226 |
+
if pd.api.types.is_numeric_dtype(df_to_plot[x_col]):
|
| 227 |
+
# If X is numeric, group and count Y values
|
| 228 |
+
grouped = df_to_plot.groupby(x_col)[y_col].count().reset_index()
|
| 229 |
+
grouped.columns = [x_col, f'Count of {y_col}']
|
| 230 |
+
fig = px.bar(grouped, x=x_col, y=f'Count of {y_col}', title=f"Count of {y_col} by {x_col}")
|
| 231 |
+
else:
|
| 232 |
+
# Both categorical: cross-tabulation
|
| 233 |
+
crosstab = pd.crosstab(df_to_plot[x_col], df_to_plot[y_col])
|
| 234 |
+
crosstab_reset = crosstab.reset_index().melt(id_vars=[x_col], var_name=y_col, value_name='Count')
|
| 235 |
+
fig = px.bar(crosstab_reset, x=x_col, y='Count', color=y_col, title=f"{y_col} distribution by {x_col}")
|
| 236 |
+
|
| 237 |
+
explanation = create_chart_explanation(viz_type, df_to_plot, selected_columns)
|
| 238 |
+
else:
|
| 239 |
+
col = selected_columns[0]
|
| 240 |
+
if pd.api.types.is_numeric_dtype(df_to_plot[col]):
|
| 241 |
+
fig = px.histogram(df_to_plot, x=col, title=f"Distribution of {col}")
|
| 242 |
+
else:
|
| 243 |
+
value_counts = df_to_plot[col].value_counts().head(15)
|
| 244 |
+
fig = px.bar(x=value_counts.index, y=value_counts.values, title=f"Top Values in {col}")
|
| 245 |
+
explanation = f"Chart showing distribution of {col}"
|
| 246 |
+
fig.update_layout(width=800, height=500)
|
| 247 |
+
return fig, explanation, fig
|
| 248 |
+
|
| 249 |
+
elif viz_type == "Pie Chart":
|
| 250 |
+
col = selected_columns[0]
|
| 251 |
+
if len(selected_columns) >= 2 and pd.api.types.is_numeric_dtype(df_to_plot[selected_columns[1]]):
|
| 252 |
+
grouped_data = df_to_plot.groupby(col)[selected_columns[1]].sum().reset_index()
|
| 253 |
+
fig = px.pie(grouped_data, values=selected_columns[1], names=col, title=f"Total {selected_columns[1]} by {col}")
|
| 254 |
+
legend_title = f"{col} Categories"
|
| 255 |
+
else:
|
| 256 |
+
value_counts = df_to_plot[col].value_counts().head(10)
|
| 257 |
+
fig = px.pie(values=value_counts.values, names=value_counts.index, title=f"Distribution of {col}")
|
| 258 |
+
legend_title = f"{col} Values"
|
| 259 |
+
|
| 260 |
+
fig.update_layout(
|
| 261 |
+
width=800,
|
| 262 |
+
height=500,
|
| 263 |
+
showlegend=True,
|
| 264 |
+
legend=dict(
|
| 265 |
+
title=dict(text=legend_title, font=dict(size=14, color="black")),
|
| 266 |
+
orientation="v",
|
| 267 |
+
yanchor="middle",
|
| 268 |
+
y=0.5,
|
| 269 |
+
xanchor="left",
|
| 270 |
+
x=1.05,
|
| 271 |
+
font=dict(size=12)
|
| 272 |
+
)
|
| 273 |
+
)
|
| 274 |
+
explanation = f"PIE CHART: {col} Distribution\nShows proportion of each category\nUse to understand category distribution patterns"
|
| 275 |
+
return fig, explanation, fig
|
| 276 |
+
|
| 277 |
+
elif viz_type == "Scatter Plot":
|
| 278 |
+
if len(selected_columns) >= 2:
|
| 279 |
+
x_col, y_col = selected_columns[0], selected_columns[1]
|
| 280 |
+
color_col = selected_columns[2] if len(selected_columns) > 2 else None
|
| 281 |
+
|
| 282 |
+
# Check if both columns are suitable for scatter plot
|
| 283 |
+
if not (pd.api.types.is_numeric_dtype(df_to_plot[x_col]) and pd.api.types.is_numeric_dtype(df_to_plot[y_col])):
|
| 284 |
+
return None, f"Scatter plot requires numeric data. {x_col} and {y_col} must be numeric.", None
|
| 285 |
+
|
| 286 |
+
fig = px.scatter(df_to_plot, x=x_col, y=y_col, color=color_col, title=f"{y_col} vs {x_col}")
|
| 287 |
+
explanation = f"Scatter plot showing relationship between {x_col} and {y_col}"
|
| 288 |
+
else:
|
| 289 |
+
return None, "Scatter plot requires at least 2 columns.", None
|
| 290 |
+
fig.update_layout(width=800, height=500)
|
| 291 |
+
return fig, explanation, fig
|
| 292 |
+
|
| 293 |
+
elif viz_type == "Line Chart":
|
| 294 |
+
if len(selected_columns) >= 2:
|
| 295 |
+
x_col, y_col = selected_columns[0], selected_columns[1]
|
| 296 |
+
|
| 297 |
+
if pd.api.types.is_numeric_dtype(df_to_plot[y_col]):
|
| 298 |
+
# Numeric Y: sort by X and plot trend
|
| 299 |
+
sorted_data = df_to_plot.sort_values(x_col)
|
| 300 |
+
fig = px.line(sorted_data, x=x_col, y=y_col, title=f"Trend of {y_col} over {x_col}", markers=True)
|
| 301 |
+
explanation = f"Line chart showing trend of {y_col} over {x_col}"
|
| 302 |
+
else:
|
| 303 |
+
# Non-numeric Y: create cross-tabulation
|
| 304 |
+
crosstab = pd.crosstab(df_to_plot[x_col], df_to_plot[y_col])
|
| 305 |
+
melted = pd.melt(crosstab.reset_index(), id_vars=[x_col], var_name=y_col, value_name='Count')
|
| 306 |
+
fig = px.line(melted, x=x_col, y='Count', color=y_col, title=f"Distribution of {y_col} across {x_col}", markers=True)
|
| 307 |
+
explanation = create_chart_explanation(viz_type, df_to_plot, selected_columns, melted)
|
| 308 |
+
else:
|
| 309 |
+
return None, "Line chart requires at least 2 columns.", None
|
| 310 |
+
fig.update_layout(width=800, height=500)
|
| 311 |
+
return fig, explanation, fig
|
| 312 |
+
|
| 313 |
+
elif viz_type == "Histogram":
|
| 314 |
+
col = selected_columns[0]
|
| 315 |
+
if pd.api.types.is_numeric_dtype(df_to_plot[col]):
|
| 316 |
+
fig = px.histogram(df_to_plot, x=col, title=f"Distribution of {col}", nbins=30)
|
| 317 |
+
explanation = f"Histogram showing distribution of {col}"
|
| 318 |
+
else:
|
| 319 |
+
return None, f"Histogram requires numeric data. Try Bar Chart instead.", None
|
| 320 |
+
fig.update_layout(width=800, height=500)
|
| 321 |
+
return fig, explanation, fig
|
| 322 |
+
|
| 323 |
+
elif viz_type == "Heat Map":
|
| 324 |
+
if len(selected_columns) >= 2:
|
| 325 |
+
numeric_cols = [col for col in selected_columns if pd.api.types.is_numeric_dtype(df_to_plot[col])]
|
| 326 |
+
if len(numeric_cols) >= 2:
|
| 327 |
+
corr_matrix = df_to_plot[numeric_cols].corr()
|
| 328 |
+
fig = px.imshow(corr_matrix, text_auto=True, aspect="auto", title="Correlation Heatmap", color_continuous_scale='RdBu')
|
| 329 |
+
explanation = f"Heatmap showing correlations between numeric columns"
|
| 330 |
+
else:
|
| 331 |
+
x_col, y_col = selected_columns[0], selected_columns[1]
|
| 332 |
+
crosstab = pd.crosstab(df_to_plot[x_col], df_to_plot[y_col])
|
| 333 |
+
fig = px.imshow(crosstab.values, x=crosstab.columns, y=crosstab.index, text_auto=True, aspect="auto", title=f"Cross-tabulation: {y_col} vs {x_col}")
|
| 334 |
+
explanation = f"Heatmap showing cross-tabulation between {x_col} and {y_col}"
|
| 335 |
+
else:
|
| 336 |
+
return None, "Heat map requires at least 2 columns.", None
|
| 337 |
+
fig.update_layout(width=800, height=500)
|
| 338 |
+
return fig, explanation, fig
|
| 339 |
+
|
| 340 |
+
elif viz_type == "Box Plot":
|
| 341 |
+
if len(selected_columns) >= 1:
|
| 342 |
+
y_col = selected_columns[0]
|
| 343 |
+
if not pd.api.types.is_numeric_dtype(df_to_plot[y_col]):
|
| 344 |
+
return None, f"Box plot requires numeric Y-axis. {y_col} is not numeric.", None
|
| 345 |
+
|
| 346 |
+
x_col = selected_columns[1] if len(selected_columns) > 1 else None
|
| 347 |
+
fig = px.box(df_to_plot, x=x_col, y=y_col, title=f"Box Plot of {y_col}" + (f" by {x_col}" if x_col else ""))
|
| 348 |
+
explanation = f"Box plot showing distribution of {y_col}" + (f" grouped by {x_col}" if x_col else "")
|
| 349 |
+
else:
|
| 350 |
+
return None, "Box plot requires at least 1 column.", None
|
| 351 |
+
fig.update_layout(width=800, height=500)
|
| 352 |
+
return fig, explanation, fig
|
| 353 |
+
|
| 354 |
+
else:
|
| 355 |
+
return None, f"Visualization type '{viz_type}' is under development.", None
|
| 356 |
+
|
| 357 |
+
except Exception as e:
|
| 358 |
+
return None, f"Error creating visualization: {str(e)}", None
|
| 359 |
+
|
| 360 |
+
def handle_missing_data(method, selected_columns, constant_value, uploaded_df, change_history):
|
| 361 |
+
print(f"DEBUG: Starting {method} on columns {selected_columns}")
|
| 362 |
+
|
| 363 |
+
if uploaded_df is None:
|
| 364 |
+
return "Please upload a dataset first.", uploaded_df, change_history
|
| 365 |
+
if method == "None":
|
| 366 |
+
return "", uploaded_df, change_history
|
| 367 |
+
if not selected_columns:
|
| 368 |
+
return "Please select columns to apply data handling.", uploaded_df, change_history
|
| 369 |
+
|
| 370 |
+
try:
|
| 371 |
+
change_history.append(uploaded_df.copy())
|
| 372 |
+
df_copy = uploaded_df.copy()
|
| 373 |
+
|
| 374 |
+
if method == "Clean All Missing":
|
| 375 |
+
return "Clean All Missing is not available", uploaded_df, change_history
|
| 376 |
+
|
| 377 |
+
processed_columns = []
|
| 378 |
+
dropped_columns = []
|
| 379 |
+
|
| 380 |
+
for col in selected_columns:
|
| 381 |
+
if col not in df_copy.columns:
|
| 382 |
+
continue
|
| 383 |
+
|
| 384 |
+
if method == "Forward Fill":
|
| 385 |
+
if col == 'title':
|
| 386 |
+
print(f"DEBUG: Skipping title column due to data inconsistencies")
|
| 387 |
+
continue
|
| 388 |
+
|
| 389 |
+
if df_copy[col].dtype == 'object':
|
| 390 |
+
patterns = ['UNKNOWN', 'unknown', 'ERROR', 'error', 'NULL', 'null', 'NA', 'na', 'N/A',
|
| 391 |
+
'Not Given', 'not given', 'NOT GIVEN', '', ' ', '-', '?', 'NaN', 'nan',
|
| 392 |
+
'None', 'none', 'NONE', '#N/A', 'n/a', 'N.A.', 'n.a.']
|
| 393 |
+
for pattern in patterns:
|
| 394 |
+
df_copy[col] = df_copy[col].replace(pattern, np.nan)
|
| 395 |
+
df_copy[col] = df_copy[col].replace('', np.nan)
|
| 396 |
+
|
| 397 |
+
df_copy[col] = df_copy[col].ffill()
|
| 398 |
+
processed_columns.append(col)
|
| 399 |
+
elif method == "Backward Fill":
|
| 400 |
+
if df_copy[col].dtype == 'object':
|
| 401 |
+
patterns = ['UNKNOWN', 'unknown', 'ERROR', 'error', 'NULL', 'null', 'NA', 'na', 'N/A',
|
| 402 |
+
'Not Given', 'not given', 'NOT GIVEN', '', ' ', '-', '?', 'NaN', 'nan',
|
| 403 |
+
'None', 'none', 'NONE', '#N/A', 'n/a', 'N.A.', 'n.a.']
|
| 404 |
+
for pattern in patterns:
|
| 405 |
+
df_copy[col] = df_copy[col].replace(pattern, np.nan)
|
| 406 |
+
df_copy[col] = df_copy[col].replace('', np.nan)
|
| 407 |
+
|
| 408 |
+
df_copy[col] = df_copy[col].bfill()
|
| 409 |
+
processed_columns.append(col)
|
| 410 |
+
elif method == "Constant Fill":
|
| 411 |
+
if df_copy[col].dtype == 'object':
|
| 412 |
+
patterns = ['UNKNOWN', 'unknown', 'ERROR', 'error', 'NULL', 'null', 'NA', 'na', 'N/A',
|
| 413 |
+
'Not Given', 'not given', 'NOT GIVEN', '', ' ', '-', '?', 'NaN', 'nan',
|
| 414 |
+
'None', 'none', 'NONE', '#N/A', 'n/a', 'N.A.', 'n.a.']
|
| 415 |
+
for pattern in patterns:
|
| 416 |
+
df_copy[col] = df_copy[col].replace(pattern, np.nan)
|
| 417 |
+
df_copy[col] = df_copy[col].replace('', np.nan)
|
| 418 |
+
|
| 419 |
+
fill_val = constant_value.strip() if constant_value else "Unknown"
|
| 420 |
+
df_copy[col] = df_copy[col].fillna(fill_val)
|
| 421 |
+
processed_columns.append(col)
|
| 422 |
+
elif method == "Mean Fill":
|
| 423 |
+
if pd.api.types.is_numeric_dtype(df_copy[col]):
|
| 424 |
+
if not df_copy[col].isna().all():
|
| 425 |
+
mean_val = df_copy[col].mean()
|
| 426 |
+
df_copy[col] = df_copy[col].fillna(mean_val)
|
| 427 |
+
processed_columns.append(col)
|
| 428 |
+
else:
|
| 429 |
+
numeric_col = pd.to_numeric(df_copy[col], errors='coerce')
|
| 430 |
+
if not numeric_col.isna().all():
|
| 431 |
+
mean_val = numeric_col.mean()
|
| 432 |
+
df_copy[col] = numeric_col.fillna(mean_val)
|
| 433 |
+
processed_columns.append(col)
|
| 434 |
+
elif method == "Median Fill":
|
| 435 |
+
if pd.api.types.is_numeric_dtype(df_copy[col]):
|
| 436 |
+
if not df_copy[col].isna().all():
|
| 437 |
+
median_val = df_copy[col].median()
|
| 438 |
+
df_copy[col] = df_copy[col].fillna(median_val)
|
| 439 |
+
processed_columns.append(col)
|
| 440 |
+
else:
|
| 441 |
+
numeric_col = pd.to_numeric(df_copy[col], errors='coerce')
|
| 442 |
+
if not numeric_col.isna().all():
|
| 443 |
+
median_val = numeric_col.median()
|
| 444 |
+
df_copy[col] = numeric_col.fillna(median_val)
|
| 445 |
+
processed_columns.append(col)
|
| 446 |
+
elif method == "Mode Fill":
|
| 447 |
+
patterns = ['UNKNOWN', 'unknown', 'ERROR', 'error', 'NULL', 'null', 'NA', 'na', 'N/A',
|
| 448 |
+
'Not Given', 'not given', 'NOT GIVEN', '', ' ', '-', '?', 'NaN', 'nan',
|
| 449 |
+
'None', 'none', 'NONE', '#N/A', 'n/a', 'N.A.', 'n.a.']
|
| 450 |
+
|
| 451 |
+
valid_values = df_copy[col][~df_copy[col].isin(patterns) & df_copy[col].notna()]
|
| 452 |
+
|
| 453 |
+
if len(valid_values) > 0:
|
| 454 |
+
mode_value = valid_values.mode()
|
| 455 |
+
if len(mode_value) > 0:
|
| 456 |
+
most_common = mode_value.iloc[0]
|
| 457 |
+
print(f"DEBUG: Mode Fill - Most common value for {col}: {most_common}")
|
| 458 |
+
|
| 459 |
+
for pattern in patterns:
|
| 460 |
+
df_copy[col] = df_copy[col].replace(pattern, most_common)
|
| 461 |
+
|
| 462 |
+
df_copy[col] = df_copy[col].fillna(most_common)
|
| 463 |
+
|
| 464 |
+
processed_columns.append(col)
|
| 465 |
+
elif method == "Drop Columns":
|
| 466 |
+
df_copy = df_copy.drop(columns=[col])
|
| 467 |
+
dropped_columns.append(col)
|
| 468 |
+
|
| 469 |
+
uploaded_df = df_copy
|
| 470 |
+
remaining_cols = [col for col in selected_columns if col not in dropped_columns]
|
| 471 |
+
|
| 472 |
+
if 'title' in uploaded_df.columns:
|
| 473 |
+
title_check = uploaded_df['title'].astype(str).str.contains('UNKNOWN', case=False, na=False).sum()
|
| 474 |
+
print(f"DEBUG: After update, title has {title_check} UNKNOWN values")
|
| 475 |
+
|
| 476 |
+
if processed_columns:
|
| 477 |
+
result = f"Applied {method} to: {', '.join(processed_columns)}"
|
| 478 |
+
for col in processed_columns:
|
| 479 |
+
if col in uploaded_df.columns:
|
| 480 |
+
after_missing = uploaded_df[col].isnull().sum()
|
| 481 |
+
result += f"\n- {col}: {after_missing} missing values remaining"
|
| 482 |
+
elif dropped_columns:
|
| 483 |
+
result = f"Dropped columns: {', '.join(dropped_columns)}"
|
| 484 |
+
else:
|
| 485 |
+
result = "No columns processed - check column selection or data types"
|
| 486 |
+
|
| 487 |
+
return result, uploaded_df, change_history
|
| 488 |
+
|
| 489 |
+
except Exception as e:
|
| 490 |
+
return f"Error: {str(e)}", uploaded_df, change_history
|
| 491 |
+
|
| 492 |
+
def undo_last_change(uploaded_df, change_history):
|
| 493 |
+
if not change_history:
|
| 494 |
+
return "No changes to undo.", uploaded_df, change_history
|
| 495 |
+
uploaded_df = change_history.pop()
|
| 496 |
+
return f"Undid last change. Dataset now has {uploaded_df.shape[0]} rows × {uploaded_df.shape[1]} columns", uploaded_df, change_history
|
| 497 |
+
|
| 498 |
+
def undo_all_changes(original_df, change_history):
|
| 499 |
+
if original_df is None:
|
| 500 |
+
return "No original dataset to restore.", None, change_history
|
| 501 |
+
uploaded_df = original_df.copy()
|
| 502 |
+
change_history = []
|
| 503 |
+
return f"Dataset restored to original state ({uploaded_df.shape[0]} rows × {uploaded_df.shape[1]} columns)", uploaded_df, change_history
|
| 504 |
+
|
| 505 |
+
def download_dataset(uploaded_df, dataset_name):
|
| 506 |
+
if uploaded_df is None:
|
| 507 |
+
return None
|
| 508 |
+
|
| 509 |
+
if dataset_name:
|
| 510 |
+
base_name = dataset_name.replace('.csv', '').replace('.xlsx', '').replace('.xls', '')
|
| 511 |
+
filename = f"{base_name}_modified.csv"
|
| 512 |
+
else:
|
| 513 |
+
filename = "modified_dataset.csv"
|
| 514 |
+
|
| 515 |
+
temp_dir = tempfile.gettempdir()
|
| 516 |
+
filepath = os.path.join(temp_dir, filename)
|
| 517 |
+
uploaded_df.to_csv(filepath, index=False)
|
| 518 |
+
return filepath
|
| 519 |
+
|
| 520 |
+
def display_data_format(format_type, selected_columns, uploaded_df):
|
| 521 |
+
if uploaded_df is None or format_type == "None":
|
| 522 |
+
return None
|
| 523 |
+
df_to_show = uploaded_df[selected_columns] if selected_columns else uploaded_df
|
| 524 |
+
return df_to_show.head(100) if format_type == "DataFrame" else None
|
| 525 |
+
|
| 526 |
+
def display_text_format(format_type, selected_columns, uploaded_df):
|
| 527 |
+
if uploaded_df is None or format_type == "None":
|
| 528 |
+
return ""
|
| 529 |
+
df_to_show = uploaded_df[selected_columns] if selected_columns else uploaded_df
|
| 530 |
+
if format_type == "JSON":
|
| 531 |
+
return df_to_show.head(20).to_json(orient='records', indent=2)
|
| 532 |
+
elif format_type == "Dictionary":
|
| 533 |
+
return str(df_to_show.head(20).to_dict(orient='records'))
|
prompts.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
ENHANCED_SYSTEM_PROMPT = """You are a data analysis assistant. Respond ONLY in valid JSON format.
|
| 2 |
+
RULES:
|
| 3 |
+
1. DATASET ANALYSIS QUESTIONS (always process these): patterns, trends, insights, statistics, correlations, distributions, summaries, comparisons, relationships, data quality, outliers, analysis, exploration, findings, recommendations
|
| 4 |
+
2. NON-DATASET QUESTIONS (reject these): general knowledge, current events, personal questions, definitions unrelated to the data
|
| 5 |
+
3. Parse user queries to identify: column names, values, conditions, operations
|
| 6 |
+
4. For complex queries with multiple conditions, use multiple operations in sequence
|
| 7 |
+
5. Always use exact column names from the available columns list
|
| 8 |
+
OPERATIONS:
|
| 9 |
+
- filter: Use "expr" for conditions like "column_name > 100" or "column" + "value" for exact matches
|
| 10 |
+
- count: Count specific values in columns
|
| 11 |
+
- describe: Statistical summary
|
| 12 |
+
- groupby: Group and aggregate data
|
| 13 |
+
- calculate: Mathematical operations
|
| 14 |
+
FOR MULTI-CONDITION QUERIES:
|
| 15 |
+
- Step 1: Filter data based on conditions
|
| 16 |
+
- Step 2: Perform count/analysis on filtered data
|
| 17 |
+
CHART CREATION RULES:
|
| 18 |
+
- For visualization requests: Always include "plot" object
|
| 19 |
+
- For informational queries: Set "plot": null
|
| 20 |
+
RESPONSE FORMATS:
|
| 21 |
+
1. INFORMATIONAL (no visualization):
|
| 22 |
+
{
|
| 23 |
+
"type": "explain",
|
| 24 |
+
"operations": [],
|
| 25 |
+
"plot": null,
|
| 26 |
+
"narrative": "detailed answer",
|
| 27 |
+
"insights_needed": false
|
| 28 |
+
}
|
| 29 |
+
2. STATISTICAL DESCRIPTION:
|
| 30 |
+
{
|
| 31 |
+
"type": "describe",
|
| 32 |
+
"operations": [{"op": "describe", "columns": ["col1", "col2"]}],
|
| 33 |
+
"plot": null,
|
| 34 |
+
"narrative": "statistical summary",
|
| 35 |
+
"insights_needed": false
|
| 36 |
+
}
|
| 37 |
+
3. VISUALIZATION REQUEST:
|
| 38 |
+
{
|
| 39 |
+
"type": "analysis",
|
| 40 |
+
"operations": [
|
| 41 |
+
{"op": "groupby", "columns": ["category"], "agg_col": "value", "agg_func": "sum"}
|
| 42 |
+
],
|
| 43 |
+
"plot": {
|
| 44 |
+
"type": "bar|line|pie|hist|scatter",
|
| 45 |
+
"x": "category",
|
| 46 |
+
"y": "sum_value",
|
| 47 |
+
"title": "Chart Title"
|
| 48 |
+
},
|
| 49 |
+
"narrative": "brief explanation",
|
| 50 |
+
"insights_needed": true
|
| 51 |
+
}
|
| 52 |
+
4. FILTERING:
|
| 53 |
+
{
|
| 54 |
+
"type": "analysis",
|
| 55 |
+
"operations": [{"op": "filter", "column": "column_name", "value": "specific_value"}],
|
| 56 |
+
"plot": null,
|
| 57 |
+
"narrative": "filtered data explanation",
|
| 58 |
+
"insights_needed": false
|
| 59 |
+
}
|
| 60 |
+
5. CALCULATIONS:
|
| 61 |
+
{
|
| 62 |
+
"type": "analysis",
|
| 63 |
+
"operations": [{"op": "calculate", "expr": "Col1 * Col2", "new_col": "Product"}],
|
| 64 |
+
"plot": null,
|
| 65 |
+
"narrative": "calculation explanation",
|
| 66 |
+
"insights_needed": false
|
| 67 |
+
}
|
| 68 |
+
6. COUNT VALUES:
|
| 69 |
+
{
|
| 70 |
+
"type": "analysis",
|
| 71 |
+
"operations": [{"op": "count", "column": "column_name", "value": "specific_value"}],
|
| 72 |
+
"plot": null,
|
| 73 |
+
"narrative": "count result explanation",
|
| 74 |
+
"insights_needed": false
|
| 75 |
+
}
|
| 76 |
+
7. SHOW ALL VALUES:
|
| 77 |
+
{
|
| 78 |
+
"type": "analysis",
|
| 79 |
+
"operations": [{"op": "count", "column": "column_name"}],
|
| 80 |
+
"plot": null,
|
| 81 |
+
"narrative": "showing all unique values",
|
| 82 |
+
"insights_needed": false
|
| 83 |
+
}
|
| 84 |
+
8. MULTI-CONDITION QUERIES:
|
| 85 |
+
{
|
| 86 |
+
"type": "analysis",
|
| 87 |
+
"operations": [
|
| 88 |
+
{"op": "filter", "expr": "column_name > value"},
|
| 89 |
+
{"op": "count", "column": "another_column", "value": "target_value"}
|
| 90 |
+
],
|
| 91 |
+
"plot": null,
|
| 92 |
+
"narrative": "",
|
| 93 |
+
"insights_needed": false
|
| 94 |
+
}
|
| 95 |
+
CHART TYPES:
|
| 96 |
+
- "bar": For categorical comparisons
|
| 97 |
+
- "line": For trends over time/sequence
|
| 98 |
+
- "pie": For proportions/percentages
|
| 99 |
+
- "hist": For distributions
|
| 100 |
+
- "scatter": For correlations
|
| 101 |
+
Always ensure column names exist in the dataset before referencing them.
|
| 102 |
+
"""
|
| 103 |
+
|
| 104 |
+
INSIGHTS_SYSTEM_PROMPT = "You are a data insights expert. Analyze the provided data context and generate meaningful insights about patterns, trends, relationships, and key findings. Focus on actionable insights that help understand the data better. Provide clear, specific observations based on the actual data values and statistics shown."
|
| 105 |
+
|
| 106 |
+
SAMPLE_QUESTIONS = [
|
| 107 |
+
"What are the key patterns in this dataset?",
|
| 108 |
+
"Show me insights about this data",
|
| 109 |
+
"What trends can you identify?",
|
| 110 |
+
"Analyze the relationships between columns",
|
| 111 |
+
"What are the main findings from this data?",
|
| 112 |
+
"Describe the data distribution and patterns",
|
| 113 |
+
"What recommendations can you make?",
|
| 114 |
+
"Find correlations in the dataset",
|
| 115 |
+
"Summarize the key statistics"
|
| 116 |
+
]
|
| 117 |
+
|
| 118 |
+
def get_chart_prompt(question, columns, data_sample):
|
| 119 |
+
return f"""
|
| 120 |
+
User Query: {question}
|
| 121 |
+
Dataset Information:
|
| 122 |
+
Available Columns: {', '.join(columns)}
|
| 123 |
+
Sample Data:
|
| 124 |
+
{data_sample}
|
| 125 |
+
CRITICAL INSTRUCTIONS:
|
| 126 |
+
1. If the question contains ANY of these keywords, it's a DATASET ANALYSIS question - ALWAYS process it:
|
| 127 |
+
- patterns, trends, insights, analysis, statistics, correlations, relationships
|
| 128 |
+
- distribution, summary, compare, explore, findings, recommendations
|
| 129 |
+
- data, dataset, columns, values, records, rows
|
| 130 |
+
- show, find, count, filter, group, calculate, describe
|
| 131 |
+
2. ONLY reject questions about: presidents, weather, news, definitions, general knowledge
|
| 132 |
+
3. For dataset analysis questions:
|
| 133 |
+
- Use describe operations for exploratory questions
|
| 134 |
+
- Set "insights_needed": true for pattern/trend questions
|
| 135 |
+
- Create appropriate operations based on available columns
|
| 136 |
+
4. ALWAYS use exact column names from: {', '.join(columns)}
|
| 137 |
+
5. For vague questions like "analyze this data", use describe on all key columns
|
| 138 |
+
"""
|
| 139 |
+
|
| 140 |
+
def validate_plot_spec(plot_spec, available_columns):
|
| 141 |
+
if not plot_spec:
|
| 142 |
+
return plot_spec
|
| 143 |
+
|
| 144 |
+
x_col = plot_spec.get('x')
|
| 145 |
+
y_col = plot_spec.get('y')
|
| 146 |
+
|
| 147 |
+
if x_col and x_col not in available_columns:
|
| 148 |
+
for col in available_columns:
|
| 149 |
+
if any(keyword in col.lower() for keyword in ['name', 'category', 'type', 'group']):
|
| 150 |
+
plot_spec['x'] = col
|
| 151 |
+
break
|
| 152 |
+
|
| 153 |
+
if y_col and y_col not in available_columns:
|
| 154 |
+
for col in available_columns:
|
| 155 |
+
if any(keyword in col.lower() for keyword in ['value', 'amount', 'count', 'price', 'sales']):
|
| 156 |
+
plot_spec['y'] = col
|
| 157 |
+
break
|
| 158 |
+
|
| 159 |
+
return plot_spec
|
| 160 |
+
|
| 161 |
+
def get_insights_prompt(context_parts, narrative):
|
| 162 |
+
insights_context = "\n".join(context_parts)
|
| 163 |
+
return f"""Based on this analysis, provide 4-6 detailed bullet points explaining key insights, patterns, and findings.
|
| 164 |
+
Analysis Context:
|
| 165 |
+
{insights_context}
|
| 166 |
+
Original Question Context:
|
| 167 |
+
{narrative}
|
| 168 |
+
Provide insights as bullet points."""
|
sparknova.py
ADDED
|
@@ -0,0 +1,412 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import pandas as pd
|
| 4 |
+
import numpy as np
|
| 5 |
+
import matplotlib.pyplot as plt
|
| 6 |
+
import plotly.graph_objects as go
|
| 7 |
+
import plotly.express as px
|
| 8 |
+
import traceback
|
| 9 |
+
from io import BytesIO
|
| 10 |
+
from langchain_groq import ChatGroq
|
| 11 |
+
from prompts import ENHANCED_SYSTEM_PROMPT, SAMPLE_QUESTIONS, get_chart_prompt, validate_plot_spec, INSIGHTS_SYSTEM_PROMPT, get_insights_prompt
|
| 12 |
+
|
| 13 |
+
GROQ_API_KEY = "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"
|
| 14 |
+
|
| 15 |
+
llm = ChatGroq(
|
| 16 |
+
api_key=GROQ_API_KEY,
|
| 17 |
+
model="llama-3.3-70b-versatile",
|
| 18 |
+
temperature=0.0
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
print("GROQ API initialized successfully")
|
| 22 |
+
|
| 23 |
+
def call_groq(messages):
|
| 24 |
+
try:
|
| 25 |
+
res = llm.invoke(messages)
|
| 26 |
+
return res.content if hasattr(res, "content") else str(res)
|
| 27 |
+
except Exception as e:
|
| 28 |
+
raise RuntimeError(f"GROQ API error: {e}")
|
| 29 |
+
|
| 30 |
+
def parse_plan(raw_text):
|
| 31 |
+
txt = raw_text.strip().replace("```json", "").replace("```", "").strip()
|
| 32 |
+
try:
|
| 33 |
+
start = txt.index("{")
|
| 34 |
+
end = txt.rindex("}") + 1
|
| 35 |
+
plan = json.loads(txt[start:end])
|
| 36 |
+
plan.setdefault("type", "analysis")
|
| 37 |
+
plan.setdefault("operations", [])
|
| 38 |
+
plan.setdefault("plot", None)
|
| 39 |
+
plan.setdefault("narrative", "")
|
| 40 |
+
plan.setdefault("insights_needed", False)
|
| 41 |
+
return plan
|
| 42 |
+
except Exception as e:
|
| 43 |
+
return {
|
| 44 |
+
"type": "error",
|
| 45 |
+
"operations": [],
|
| 46 |
+
"plot": None,
|
| 47 |
+
"narrative": f"Error parsing JSON: {str(e)}",
|
| 48 |
+
"insights_needed": False
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
def clean_numeric(df):
|
| 52 |
+
df = df.copy()
|
| 53 |
+
for col in df.columns:
|
| 54 |
+
if pd.api.types.is_string_dtype(df[col]) or df[col].dtype == object:
|
| 55 |
+
s = df[col].astype(str).str.strip()
|
| 56 |
+
if s.str.contains("%", na=False).any():
|
| 57 |
+
numeric_vals = pd.to_numeric(s.str.replace("%", "", regex=False), errors="coerce")
|
| 58 |
+
if numeric_vals.notna().sum() / len(df) > 0.5:
|
| 59 |
+
df[col] = numeric_vals / 100.0
|
| 60 |
+
continue
|
| 61 |
+
cleaned = s.str.replace(",", "", regex=False).str.replace("₹", "", regex=False).str.replace("$", "", regex=False)
|
| 62 |
+
numeric_vals = pd.to_numeric(cleaned, errors="coerce")
|
| 63 |
+
if numeric_vals.notna().sum() / len(df) > 0.5:
|
| 64 |
+
df[col] = numeric_vals
|
| 65 |
+
return df
|
| 66 |
+
|
| 67 |
+
def generate_insights(df, dfw, plan, plot_created):
|
| 68 |
+
context_parts = []
|
| 69 |
+
for op in plan.get("operations", []):
|
| 70 |
+
if op.get("op") == "describe":
|
| 71 |
+
cols = op.get("columns", [])
|
| 72 |
+
for col in cols:
|
| 73 |
+
if col in df.columns:
|
| 74 |
+
desc = df[col].describe()
|
| 75 |
+
context_parts.append(f"\n{col} Statistics:\n{desc.to_string()}")
|
| 76 |
+
elif op.get("op") == "groupby":
|
| 77 |
+
context_parts.append(f"\nGrouped Results:\n{dfw.head(10).to_string()}")
|
| 78 |
+
|
| 79 |
+
plot_spec = plan.get("plot")
|
| 80 |
+
if plot_created and plot_spec:
|
| 81 |
+
context_parts.append(f"\nChart Type: {plot_spec.get('type')}")
|
| 82 |
+
context_parts.append(f"Visualization: {plot_spec.get('title')}")
|
| 83 |
+
|
| 84 |
+
if len(dfw) > 0:
|
| 85 |
+
context_parts.append(f"\nResult Preview:\n{dfw.head(10).to_string()}")
|
| 86 |
+
|
| 87 |
+
insights_prompt = get_insights_prompt(context_parts, plan.get('narrative', ''))
|
| 88 |
+
|
| 89 |
+
try:
|
| 90 |
+
insights_response = call_groq([
|
| 91 |
+
{"role": "system", "content": INSIGHTS_SYSTEM_PROMPT},
|
| 92 |
+
{"role": "user", "content": insights_prompt}
|
| 93 |
+
])
|
| 94 |
+
return insights_response.strip()
|
| 95 |
+
except Exception as e:
|
| 96 |
+
return f"Analysis completed successfully\n{len(dfw)} records in result\nError generating detailed insights: {str(e)}"
|
| 97 |
+
|
| 98 |
+
def execute_plan(df, plan):
|
| 99 |
+
dfw = df.copy()
|
| 100 |
+
plot_bytes = None
|
| 101 |
+
plot_html = None
|
| 102 |
+
describe_stats = {}
|
| 103 |
+
|
| 104 |
+
try:
|
| 105 |
+
for op in plan.get("operations", []):
|
| 106 |
+
optype = op.get("op", "").lower()
|
| 107 |
+
if optype == "describe":
|
| 108 |
+
cols = op.get("columns", [])
|
| 109 |
+
for col in cols:
|
| 110 |
+
if col in dfw.columns:
|
| 111 |
+
stats = dfw[col].describe()
|
| 112 |
+
describe_stats[col] = stats
|
| 113 |
+
print(f"Described {col}")
|
| 114 |
+
print(f"\n{stats}\n")
|
| 115 |
+
continue
|
| 116 |
+
|
| 117 |
+
elif optype == "groupby":
|
| 118 |
+
cols = op.get("columns", [])
|
| 119 |
+
agg_col = op.get("agg_col")
|
| 120 |
+
agg_func = op.get("agg_func", "count")
|
| 121 |
+
|
| 122 |
+
if not cols:
|
| 123 |
+
raise ValueError("No columns specified for groupby")
|
| 124 |
+
|
| 125 |
+
if agg_func == "count" or not agg_col:
|
| 126 |
+
dfw = dfw.groupby(cols).size().reset_index(name="count")
|
| 127 |
+
print(f"Grouped by {cols} with count")
|
| 128 |
+
else:
|
| 129 |
+
if agg_col not in dfw.columns:
|
| 130 |
+
raise ValueError(f"Column '{agg_col}' not found for aggregation")
|
| 131 |
+
result_col = f"{agg_func}_{agg_col}"
|
| 132 |
+
dfw = dfw.groupby(cols)[agg_col].agg(agg_func).reset_index(name=result_col)
|
| 133 |
+
print(f"Grouped by {cols}, calculated {agg_func} of {agg_col}")
|
| 134 |
+
|
| 135 |
+
elif optype == "filter":
|
| 136 |
+
expr = op.get("expr", "")
|
| 137 |
+
if expr:
|
| 138 |
+
dfw = dfw.query(expr)
|
| 139 |
+
print(f"Filter applied: {expr}")
|
| 140 |
+
|
| 141 |
+
elif optype == "calculate":
|
| 142 |
+
expr = op.get("expr", "")
|
| 143 |
+
new_col = op.get("new_col", "Calculated")
|
| 144 |
+
dfw[new_col] = dfw.eval(expr)
|
| 145 |
+
print(f"Calculated {new_col} = {expr}")
|
| 146 |
+
|
| 147 |
+
plot_spec = plan.get("plot")
|
| 148 |
+
if plot_spec and plot_spec is not None:
|
| 149 |
+
ptype = plot_spec.get("type", "bar")
|
| 150 |
+
x = plot_spec.get("x")
|
| 151 |
+
y = plot_spec.get("y")
|
| 152 |
+
title = plot_spec.get("title", "Chart")
|
| 153 |
+
|
| 154 |
+
plot_df = df if describe_stats else dfw
|
| 155 |
+
|
| 156 |
+
if not x and len(plot_df.columns) > 0:
|
| 157 |
+
categorical_cols = plot_df.select_dtypes(include=['object', 'category']).columns
|
| 158 |
+
x = categorical_cols[0] if len(categorical_cols) > 0 else plot_df.columns[0]
|
| 159 |
+
|
| 160 |
+
if not y:
|
| 161 |
+
numeric_cols = plot_df.select_dtypes(include=[np.number]).columns
|
| 162 |
+
y = numeric_cols[0] if len(numeric_cols) > 0 else None
|
| 163 |
+
|
| 164 |
+
if not y:
|
| 165 |
+
print("No suitable Y column found for plotting.")
|
| 166 |
+
else:
|
| 167 |
+
if ptype == "pie":
|
| 168 |
+
if x and x in plot_df.columns:
|
| 169 |
+
value_counts = plot_df[x].value_counts()
|
| 170 |
+
fig = go.Figure(data=[go.Pie(
|
| 171 |
+
labels=value_counts.index,
|
| 172 |
+
values=value_counts.values,
|
| 173 |
+
hovertemplate='<b>%{label}</b><br>Count: %{value}<br>Percentage: %{percent}<extra></extra>',
|
| 174 |
+
textposition='auto',
|
| 175 |
+
hole=0.3
|
| 176 |
+
)])
|
| 177 |
+
else:
|
| 178 |
+
df_pie = plot_df[y].value_counts()
|
| 179 |
+
fig = go.Figure(data=[go.Pie(
|
| 180 |
+
labels=df_pie.index,
|
| 181 |
+
values=df_pie.values,
|
| 182 |
+
hole=0.3
|
| 183 |
+
)])
|
| 184 |
+
|
| 185 |
+
fig.update_layout(
|
| 186 |
+
title=title,
|
| 187 |
+
title_font_size=16,
|
| 188 |
+
showlegend=True,
|
| 189 |
+
width=950,
|
| 190 |
+
height=550
|
| 191 |
+
)
|
| 192 |
+
plot_html = fig.to_html(include_plotlyjs='cdn')
|
| 193 |
+
print("Enhanced pie chart generated")
|
| 194 |
+
|
| 195 |
+
elif ptype == "bar":
|
| 196 |
+
fig, ax = plt.subplots(figsize=(12, 7))
|
| 197 |
+
|
| 198 |
+
if x and x in plot_df.columns and y and y in plot_df.columns:
|
| 199 |
+
plot_df.plot.bar(x=x, y=y, ax=ax, legend=False, color='steelblue', edgecolor='black', alpha=0.8)
|
| 200 |
+
ax.set_xlabel(x, fontsize=12, fontweight='bold')
|
| 201 |
+
|
| 202 |
+
n_categories = len(plot_df[x].unique())
|
| 203 |
+
if n_categories > 10:
|
| 204 |
+
plt.xticks(rotation=90, ha='right', fontsize=9)
|
| 205 |
+
elif n_categories > 5:
|
| 206 |
+
plt.xticks(rotation=45, ha='right', fontsize=10)
|
| 207 |
+
else:
|
| 208 |
+
plt.xticks(rotation=0, fontsize=10)
|
| 209 |
+
else:
|
| 210 |
+
plot_df[y].plot.bar(ax=ax, color='steelblue', edgecolor='black', alpha=0.8)
|
| 211 |
+
|
| 212 |
+
ax.set_title(title, fontsize=16, fontweight='bold', pad=20)
|
| 213 |
+
ax.set_ylabel(y, fontsize=12, fontweight='bold')
|
| 214 |
+
ax.grid(axis='y', alpha=0.3, linestyle='--')
|
| 215 |
+
plt.tight_layout()
|
| 216 |
+
|
| 217 |
+
buf = BytesIO()
|
| 218 |
+
plt.savefig(buf, format="png", dpi=150, bbox_inches="tight")
|
| 219 |
+
buf.seek(0)
|
| 220 |
+
plot_bytes = buf.read()
|
| 221 |
+
plt.close()
|
| 222 |
+
print("Enhanced bar chart generated")
|
| 223 |
+
|
| 224 |
+
elif ptype == "line":
|
| 225 |
+
fig, ax = plt.subplots(figsize=(12, 7))
|
| 226 |
+
|
| 227 |
+
if x and x in plot_df.columns and y and y in plot_df.columns:
|
| 228 |
+
plot_df.plot.line(x=x, y=y, ax=ax, marker="o", linewidth=3,
|
| 229 |
+
markersize=8, color='darkblue', alpha=0.8)
|
| 230 |
+
ax.set_xlabel(x, fontsize=12, fontweight='bold')
|
| 231 |
+
|
| 232 |
+
if len(plot_df) > 15:
|
| 233 |
+
plt.xticks(rotation=45, ha='right', fontsize=9)
|
| 234 |
+
else:
|
| 235 |
+
plt.xticks(rotation=0, fontsize=10)
|
| 236 |
+
else:
|
| 237 |
+
plot_df[y].plot.line(ax=ax, marker="o", linewidth=3,
|
| 238 |
+
markersize=8, color='darkblue', alpha=0.8)
|
| 239 |
+
|
| 240 |
+
ax.set_title(title, fontsize=16, fontweight='bold', pad=20)
|
| 241 |
+
ax.set_ylabel(y, fontsize=12, fontweight='bold')
|
| 242 |
+
ax.grid(True, alpha=0.3, linestyle='--')
|
| 243 |
+
plt.tight_layout()
|
| 244 |
+
|
| 245 |
+
buf = BytesIO()
|
| 246 |
+
plt.savefig(buf, format="png", dpi=150, bbox_inches="tight")
|
| 247 |
+
buf.seek(0)
|
| 248 |
+
plot_bytes = buf.read()
|
| 249 |
+
plt.close()
|
| 250 |
+
print("Enhanced line chart generated")
|
| 251 |
+
|
| 252 |
+
elif ptype == "hist":
|
| 253 |
+
fig, ax = plt.subplots(figsize=(11, 7))
|
| 254 |
+
|
| 255 |
+
plot_df[y].dropna().plot.hist(ax=ax, bins=25, edgecolor='black',
|
| 256 |
+
alpha=0.7, color='teal')
|
| 257 |
+
ax.set_title(title, fontsize=16, fontweight='bold', pad=20)
|
| 258 |
+
ax.set_xlabel(y, fontsize=12, fontweight='bold')
|
| 259 |
+
ax.set_ylabel("Frequency", fontsize=12, fontweight='bold')
|
| 260 |
+
ax.grid(axis='y', alpha=0.3, linestyle='--')
|
| 261 |
+
plt.tight_layout()
|
| 262 |
+
|
| 263 |
+
buf = BytesIO()
|
| 264 |
+
plt.savefig(buf, format="png", dpi=150, bbox_inches="tight")
|
| 265 |
+
buf.seek(0)
|
| 266 |
+
plot_bytes = buf.read()
|
| 267 |
+
plt.close()
|
| 268 |
+
print("Enhanced histogram generated")
|
| 269 |
+
|
| 270 |
+
elif ptype == "scatter":
|
| 271 |
+
fig, ax = plt.subplots(figsize=(11, 7))
|
| 272 |
+
|
| 273 |
+
if x and x in plot_df.columns and y and y in plot_df.columns:
|
| 274 |
+
plot_df.plot.scatter(x=x, y=y, ax=ax, alpha=0.6, s=60, color='purple')
|
| 275 |
+
ax.set_xlabel(x, fontsize=12, fontweight='bold')
|
| 276 |
+
ax.set_ylabel(y, fontsize=12, fontweight='bold')
|
| 277 |
+
|
| 278 |
+
ax.set_title(title, fontsize=16, fontweight='bold', pad=20)
|
| 279 |
+
ax.grid(True, alpha=0.3, linestyle='--')
|
| 280 |
+
plt.tight_layout()
|
| 281 |
+
|
| 282 |
+
buf = BytesIO()
|
| 283 |
+
plt.savefig(buf, format="png", dpi=150, bbox_inches="tight")
|
| 284 |
+
buf.seek(0)
|
| 285 |
+
plot_bytes = buf.read()
|
| 286 |
+
plt.close()
|
| 287 |
+
print("Enhanced scatter plot generated")
|
| 288 |
+
|
| 289 |
+
return dfw, plot_bytes, plot_html, describe_stats
|
| 290 |
+
|
| 291 |
+
except Exception as e:
|
| 292 |
+
print(f"EXECUTION ERROR: {e}")
|
| 293 |
+
traceback.print_exc()
|
| 294 |
+
raise
|
| 295 |
+
|
| 296 |
+
def make_context(df):
|
| 297 |
+
sample_data = df.head(3).to_string(max_cols=10, max_colwidth=20)
|
| 298 |
+
return f"""Dataset: {len(df)} rows, {len(df.columns)} columns
|
| 299 |
+
Columns: {', '.join(df.columns)}
|
| 300 |
+
Data types: {df.dtypes.value_counts().to_dict()}
|
| 301 |
+
Sample data:
|
| 302 |
+
{sample_data}"""
|
| 303 |
+
|
| 304 |
+
def load_file(file_path):
|
| 305 |
+
if file_path.endswith('.csv'):
|
| 306 |
+
return pd.read_csv(file_path)
|
| 307 |
+
elif file_path.endswith(('.xlsx', '.xls')):
|
| 308 |
+
return pd.read_excel(file_path)
|
| 309 |
+
else:
|
| 310 |
+
raise ValueError("Unsupported file format. Please use CSV or Excel files.")
|
| 311 |
+
|
| 312 |
+
def start_agent():
|
| 313 |
+
print("=" * 80)
|
| 314 |
+
print("SparkNova v5.0 – Advanced Data Analysis & Visualization")
|
| 315 |
+
print("=" * 80)
|
| 316 |
+
|
| 317 |
+
df = None
|
| 318 |
+
|
| 319 |
+
while True:
|
| 320 |
+
if df is None:
|
| 321 |
+
file_path = input("\nEnter file path (CSV or Excel): ").strip()
|
| 322 |
+
if not file_path:
|
| 323 |
+
continue
|
| 324 |
+
|
| 325 |
+
try:
|
| 326 |
+
df = load_file(file_path)
|
| 327 |
+
df = clean_numeric(df)
|
| 328 |
+
print(f"Loaded {file_path} ({len(df)} rows × {len(df.columns)} cols)")
|
| 329 |
+
print("\nFirst 5 rows:")
|
| 330 |
+
print(df.head())
|
| 331 |
+
print(f"\nColumn types:\n{df.dtypes}")
|
| 332 |
+
|
| 333 |
+
print("\nSample Questions You Can Ask:")
|
| 334 |
+
for i, question in enumerate(SAMPLE_QUESTIONS[:8], 1):
|
| 335 |
+
print(f"{i}. {question}")
|
| 336 |
+
|
| 337 |
+
data_ctx = make_context(df)
|
| 338 |
+
except Exception as e:
|
| 339 |
+
print(f"Error loading file: {e}")
|
| 340 |
+
continue
|
| 341 |
+
|
| 342 |
+
q = input("\nYour question (or 'exit'/'reload'): ").strip()
|
| 343 |
+
if not q:
|
| 344 |
+
continue
|
| 345 |
+
if q.lower() in ("exit", "quit"):
|
| 346 |
+
print("Thank you for using SparkNova!")
|
| 347 |
+
break
|
| 348 |
+
if q.lower() == "reload":
|
| 349 |
+
df = None
|
| 350 |
+
continue
|
| 351 |
+
|
| 352 |
+
enhanced_prompt = get_chart_prompt(q, df.columns.tolist(), df.head(3).to_string())
|
| 353 |
+
|
| 354 |
+
try:
|
| 355 |
+
raw = call_groq([
|
| 356 |
+
{"role": "system", "content": ENHANCED_SYSTEM_PROMPT},
|
| 357 |
+
{"role": "user", "content": enhanced_prompt}
|
| 358 |
+
])
|
| 359 |
+
except Exception as e:
|
| 360 |
+
print(f"LLM call failed: {e}")
|
| 361 |
+
continue
|
| 362 |
+
|
| 363 |
+
plan = parse_plan(raw)
|
| 364 |
+
|
| 365 |
+
if plan.get("type") == "explain":
|
| 366 |
+
print("\nExplanation:")
|
| 367 |
+
print(plan.get("narrative", ""))
|
| 368 |
+
continue
|
| 369 |
+
|
| 370 |
+
if plan.get("type") == "error":
|
| 371 |
+
print("\nError:")
|
| 372 |
+
print(plan.get("narrative", ""))
|
| 373 |
+
continue
|
| 374 |
+
|
| 375 |
+
print("\nAnalysis Plan:")
|
| 376 |
+
print(json.dumps(plan, indent=2))
|
| 377 |
+
|
| 378 |
+
if plan.get("plot"):
|
| 379 |
+
plan["plot"] = validate_plot_spec(plan["plot"], df.columns.tolist())
|
| 380 |
+
|
| 381 |
+
try:
|
| 382 |
+
print("\nExecuting operations...")
|
| 383 |
+
res, plot_img, plot_html, desc_stats = execute_plan(df, plan)
|
| 384 |
+
|
| 385 |
+
if not desc_stats or len(res) != len(df):
|
| 386 |
+
print("\nResult:")
|
| 387 |
+
print(res.head(20))
|
| 388 |
+
|
| 389 |
+
if plot_html:
|
| 390 |
+
print("\nGenerated Interactive Chart (HTML saved as chart.html)")
|
| 391 |
+
with open("chart.html", "w") as f:
|
| 392 |
+
f.write(plot_html)
|
| 393 |
+
elif plot_img:
|
| 394 |
+
print("\nGenerated Chart (saved as chart.png)")
|
| 395 |
+
with open("chart.png", "wb") as f:
|
| 396 |
+
f.write(plot_img)
|
| 397 |
+
|
| 398 |
+
narrative = plan.get("narrative", "")
|
| 399 |
+
if narrative:
|
| 400 |
+
print(f"\nSummary: {narrative}")
|
| 401 |
+
|
| 402 |
+
if plan.get("insights_needed") and (plot_html or plot_img):
|
| 403 |
+
print("\nDetailed Insights:")
|
| 404 |
+
insights = generate_insights(df, res, plan, True)
|
| 405 |
+
print(insights)
|
| 406 |
+
|
| 407 |
+
except Exception as e:
|
| 408 |
+
print(f"Execution failed: {e}")
|
| 409 |
+
continue
|
| 410 |
+
|
| 411 |
+
if __name__ == "__main__":
|
| 412 |
+
start_agent()
|