triflix commited on
Commit
75c3b7d
·
verified ·
1 Parent(s): 00f7131

Create pipeline_with_agents.py

Browse files
Files changed (1) hide show
  1. pipeline_with_agents.py +649 -0
pipeline_with_agents.py ADDED
@@ -0,0 +1,649 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # (This file is verbatim copy of the pipeline you provided — unchanged.)
2
+ # Save exactly as provided by you.
3
+ # -------------------------
4
+ # Paste the full content you provided earlier here.
5
+ # -------------------------
6
+ """
7
+ Automated Data-Analysis Pipeline with Agent Prompts + Gemini (google-genai)
8
+
9
+ Fixed:
10
+ - Uses GEMINI_API_KEY from environment (no hardcoded key)
11
+ - Enforces planning agent contract: model MUST assign `result` variable
12
+ - Hardened execution: exec -> check for `result` or attempt safe eval for single-expression snippets
13
+ - Rejects unsafe snippets via regex blacklist
14
+ - Logs model code when execution fails and falls back to deterministic generator
15
+ """
16
+
17
+ import os
18
+ import sys
19
+ import json
20
+ import argparse
21
+ import random
22
+ import re
23
+ import logging
24
+ from typing import Any, Dict, List, Tuple, Optional
25
+
26
+ import pandas as pd
27
+ import numpy as np
28
+
29
+ # google-genai
30
+ from google import genai
31
+ from google.genai import types
32
+
33
+ logging.basicConfig(level=logging.INFO)
34
+ logger = logging.getLogger("pipeline")
35
+
36
+ # ---------------------------
37
+ # Agent system prompts
38
+ # ---------------------------
39
+
40
+ PROMPTS = {
41
+ "file_ingestion": """You are a file ingestion agent. Detect file type, extract sheets if Excel, load into structured dataframe. Return metadata: file_type, sheet_names, selected_sheet.""",
42
+ "preprocessing": """You are a preprocessing agent. Clean and normalize the dataset: handle nulls, infer types, encode categories, scale numerics. Output structured summary of cleaning steps and cleaned dataframe schema.""",
43
+ "sampling": """You are a sampling agent. From the dataframe, generate head(5), tail(5), and a random sample. Output each in JSON structure.""",
44
+ "classification": """You are a classification agent. Analyze given dataset samples. Identify domain (finance, sales, demographics, etc.) and suggest suitable visualization tasks (minimum 6). Return structured JSON listing tasks with fields: chart_type (one of pie, bar, line, scatter, histogram, boxplot), target_columns (array), aggregation (string or null), reasoning (short). Example output: { "domain": "sales", "tasks": [ { "chart_type": "pie", "target_columns": ["region"], "aggregation": "count", "reasoning": "distribution of region" }, ... ] }""",
45
+ "planning": """You are a planning agent. From classification output, create at least 6 chart tasks. For each: chart_type (pie, bar, line, scatter, histogram, boxplot), target_columns, aggregation (if needed), and a Python/pandas code snippet that generates the chart-ready aggregated JSON (but code must be limited to pandas/numpy operations). IMPORTANT: Your code MUST assign the final chart data to a variable named `result` that is a list of dictionaries ready for JSON serialization. Example:
46
+ result = df.groupby('anemia_label').size().reset_index(name='value').to_dict(orient='records')
47
+ Return JSON with array 'tasks'. Example task entry: { "chart_type": "bar", "target_columns": ["month", "sales"], "aggregation": "sum", "code": "result = df.groupby('month')['sales'].sum().reset_index().to_dict(orient=\\'records\\')" }""",
48
+ "execution": """You are an execution agent. Execute given Python code safely on provided dataset. Return structured JSON results in chart-ready format. Follow schema:
49
+ Pie → [{ "name": "...", "value": ... }]
50
+ Bar → [{ "label": "...", "metric1": ..., "metric2": ... }]
51
+ Line → [{ "x": "...", "y": ... }]
52
+ Scatter → [{ "x": ..., "y": ... }]
53
+ Histogram → [{ "bin": ..., "count": ... }]
54
+ Boxplot → [{ "category": "...", "q1": ..., "median": ..., "q3": ... }]""",
55
+ "output": """You are an output agent. Aggregate all chart JSON objects into final structured response. Ensure at least 6 charts included. Output JSON with keys: "pie", "bar", "line", "scatter", "histogram", "boxplot"."""
56
+ }
57
+
58
+ # ---------------------------
59
+ # Utility and safety helpers
60
+ # ---------------------------
61
+
62
+ DISALLOWED_PATTERNS = [
63
+ r"__\w+__", # dunder
64
+ r"\bimport\s+os\b",
65
+ r"\bimport\s+sys\b",
66
+ r"\bimport\s+subprocess\b",
67
+ r"\bimport\s+socket\b",
68
+ r"\bimport\s+requests\b",
69
+ r"\bopen\s*\(",
70
+ r"\beval\s*\(",
71
+ r"\bexec\s*\(",
72
+ r"\bcompile\s*\(",
73
+ r"\bsystem\s*\(",
74
+ r"\bPopen\b",
75
+ r"\bsh\b",
76
+ ]
77
+
78
+
79
+ def code_is_safe(code: str) -> Tuple[bool, Optional[str]]:
80
+ lowered = code
81
+ for pat in DISALLOWED_PATTERNS:
82
+ if re.search(pat, lowered):
83
+ return False, f"disallowed pattern: {pat}"
84
+ return True, None
85
+
86
+
87
+ def ensure_datetime_series(s: pd.Series) -> pd.Series:
88
+ if not np.issubdtype(s.dtype, np.datetime64):
89
+ try:
90
+ s = pd.to_datetime(s, errors="coerce")
91
+ except Exception:
92
+ s = pd.to_datetime(s.astype(str), errors="coerce")
93
+ return s
94
+
95
+
96
+ def simple_schema(df: pd.DataFrame) -> Dict[str, Any]:
97
+ return {
98
+ "columns": [
99
+ {"name": c, "dtype": str(df[c].dtype), "n_unique": int(df[c].nunique(dropna=True))}
100
+ for c in df.columns
101
+ ],
102
+ "n_rows": int(len(df)),
103
+ }
104
+
105
+
106
+ def to_json_serializable(obj):
107
+ if isinstance(obj, (np.integer, np.int64, np.int32)):
108
+ return int(obj)
109
+ if isinstance(obj, (np.floating, np.float32, np.float64)):
110
+ return float(obj)
111
+ if isinstance(obj, (np.ndarray,)):
112
+ return obj.tolist()
113
+ if pd.isna(obj):
114
+ return None
115
+ return obj
116
+
117
+ # ---------------------------
118
+ # Pipeline agents (local)
119
+ # ---------------------------
120
+
121
+
122
+ def ingest_file(path: str, sheet: Optional[str] = None) -> Tuple[pd.DataFrame, Dict[str, Any]]:
123
+ ext = os.path.splitext(path)[1].lower()
124
+ metadata = {"file_type": ext, "sheet_names": None, "selected_sheet": None}
125
+ if ext in [".csv", ".txt"]:
126
+ df = pd.read_csv(path)
127
+ metadata["selected_sheet"] = "csv"
128
+ elif ext in [".xls", ".xlsx"]:
129
+ xls = pd.ExcelFile(path)
130
+ sheets = xls.sheet_names
131
+ metadata["sheet_names"] = sheets
132
+ chosen = sheet if sheet and sheet in sheets else sheets[0]
133
+ metadata["selected_sheet"] = chosen
134
+ df = pd.read_excel(xls, sheet_name=chosen)
135
+ else:
136
+ raise ValueError("Unsupported file type: " + ext)
137
+ metadata["file_type"] = ext
138
+ return df, metadata
139
+
140
+
141
+ def preprocess_df(df: pd.DataFrame) -> Tuple[pd.DataFrame, Dict[str, Any]]:
142
+ actions = []
143
+ df = df.copy()
144
+
145
+ # Strip column names
146
+ df.columns = [str(c).strip() for c in df.columns]
147
+
148
+ # Trim whitespace in object columns
149
+ object_cols = df.select_dtypes(include="object").columns.tolist()
150
+ for c in object_cols:
151
+ try:
152
+ df[c] = df[c].where(df[c].isna(), df[c].astype(str).str.strip())
153
+ except Exception:
154
+ pass
155
+
156
+ # Numeric inference
157
+ for c in df.columns:
158
+ if df[c].dtype == object:
159
+ # try convert to numeric
160
+ coerced = pd.to_numeric(df[c], errors="coerce")
161
+ non_na = coerced.notna().sum()
162
+ if non_na >= max(1, 0.5 * len(df)): # if at least 50% convertable, cast
163
+ df[c] = coerced
164
+ actions.append(f"coerced {c} -> numeric")
165
+
166
+ # Fill numeric nulls with median
167
+ num_cols = df.select_dtypes(include=[np.number]).columns.tolist()
168
+ for c in num_cols:
169
+ median = df[c].median()
170
+ if pd.isna(median):
171
+ median = 0
172
+ df[c] = df[c].fillna(median)
173
+ actions.append(f"filled numeric {c} nulls with median {median}")
174
+
175
+ # Fill object nulls with mode
176
+ for c in object_cols:
177
+ try:
178
+ mode = df[c].mode().iloc[0] if not df[c].mode().empty else ""
179
+ except Exception:
180
+ mode = ""
181
+ df[c] = df[c].fillna(mode)
182
+ actions.append(f"filled object {c} nulls with mode '{mode}'")
183
+
184
+ schema = simple_schema(df)
185
+ return df, {"actions": actions, "schema": schema}
186
+
187
+
188
+ def sample_df(df: pd.DataFrame, n: int = 5) -> Dict[str, Any]:
189
+ head = df.head(n).to_dict(orient="records")
190
+ tail = df.tail(n).to_dict(orient="records")
191
+ if len(df) <= n:
192
+ rnd = df.sample(frac=1.0).to_dict(orient="records")
193
+ else:
194
+ rnd = df.sample(n=n, random_state=42).to_dict(orient="records")
195
+ return {"head": head, "tail": tail, "random": rnd}
196
+
197
+
198
+ # ---------------------------
199
+ # Gemini / genai interactions
200
+ # ---------------------------
201
+
202
+ def gemini_generate_json(model: str, system_instruction: str, user_content: str, require_json: bool = True) -> Any:
203
+ """
204
+ Calls genai generate_content_stream with given system prompt and user content.
205
+ Expects the model to return JSON text. Joins chunks and returns parsed JSON or raw text.
206
+ """
207
+ api_key = "AIzaSyDfy0E-9b2XjoYHrHX2C1nVLHWyrWUFkMs"
208
+ if not api_key:
209
+ raise EnvironmentError("GEMINI_API_KEY not set in environment.")
210
+ client = genai.Client(api_key=api_key)
211
+
212
+ contents = [
213
+ types.Content(
214
+ role="user",
215
+ parts=[types.Part.from_text(text=user_content)],
216
+ )
217
+ ]
218
+
219
+ config = types.GenerateContentConfig(
220
+ thinking_config=types.ThinkingConfig(thinking_budget=0),
221
+ response_mime_type="application/json",
222
+ system_instruction=[types.Part.from_text(text=system_instruction)],
223
+ )
224
+
225
+ full_text = ""
226
+ for chunk in client.models.generate_content_stream(model=model, contents=contents, config=config):
227
+ # chunk may have .text or nested candidate parts
228
+ if hasattr(chunk, "text") and chunk.text:
229
+ full_text += chunk.text
230
+ elif (
231
+ chunk.candidates
232
+ and chunk.candidates[0].content
233
+ and chunk.candidates[0].content.parts
234
+ and chunk.candidates[0].content.parts[0].text
235
+ ):
236
+ full_text += chunk.candidates[0].content.parts[0].text
237
+
238
+ full_text = full_text.strip()
239
+ if require_json:
240
+ try:
241
+ return json.loads(full_text)
242
+ except Exception:
243
+ # fallback: return raw text for debugging
244
+ return {"__raw_text__": full_text}
245
+ return full_text
246
+
247
+
248
+ # ---------------------------
249
+ # Task execution (local, deterministic)
250
+ # ---------------------------
251
+
252
+ def generate_chart_data_from_spec(df: pd.DataFrame, spec: Dict[str, Any]) -> Tuple[str, List[Dict[str, Any]]]:
253
+ """
254
+ Deterministic generator for known chart types.
255
+ spec expected keys: chart_type, target_columns (list), aggregation (str or null)
256
+ Returns (chart_type, results)
257
+ """
258
+ chart_type = spec.get("chart_type")
259
+ cols = spec.get("target_columns", [])
260
+ agg = spec.get("aggregation", None)
261
+
262
+ df_local = df.copy()
263
+
264
+ if chart_type == "pie":
265
+ # target_columns: [category_col] or [category_col, value_col]
266
+ if len(cols) == 0:
267
+ # pick first categorical
268
+ cat = _pick_categorical(df_local)
269
+ cols = [cat] if cat else []
270
+ if len(cols) == 1:
271
+ col = cols[0]
272
+ series = df_local[col].astype(str).value_counts(dropna=True)
273
+ out = [{"name": k, "value": int(v)} for k, v in series.items()]
274
+ return "pie", out
275
+ else:
276
+ cat, val = cols[0], cols[1]
277
+ grouped = df_local.groupby(cat)[val].sum().reset_index()
278
+ out = [{"name": r[cat], "value": to_json_serializable(r[val])} for r in grouped.to_dict(orient="records")]
279
+ return "pie", out
280
+
281
+ if chart_type == "bar":
282
+ # target_columns: [label_col, metric_col] or [label_col] with count
283
+ if len(cols) == 0:
284
+ label = _pick_categorical(df_local)
285
+ cols = [label] if label else []
286
+ if len(cols) == 1:
287
+ label = cols[0]
288
+ series = df_local[label].astype(str).value_counts().reset_index()
289
+ series.columns = [label, "count"]
290
+ out = [{"label": r[label], "count": int(r["count"])} for r in series.to_dict(orient="records")]
291
+ return "bar", out
292
+ else:
293
+ label, metric = cols[0], cols[1]
294
+ if agg in (None, "", "sum"):
295
+ grouped = df_local.groupby(label)[metric].sum().reset_index()
296
+ elif agg == "mean":
297
+ grouped = df_local.groupby(label)[metric].mean().reset_index()
298
+ else:
299
+ grouped = df_local.groupby(label)[metric].sum().reset_index()
300
+ out = [{"label": r[label], metric: to_json_serializable(r[metric])} for r in grouped.to_dict(orient="records")]
301
+ return "bar", out
302
+
303
+ if chart_type == "line":
304
+ # target_columns: [x_col, y_col]
305
+ if len(cols) < 2:
306
+ # pick first numeric as y, first date-like or index as x
307
+ y = _pick_numeric(df_local)
308
+ x = _pick_datetime_or_index(df_local)
309
+ cols = [x, y]
310
+ xcol, ycol = cols[0], cols[1]
311
+ s_x = ensure_datetime_series(df_local[xcol]) if xcol in df_local.columns else pd.Series(range(len(df_local)))
312
+ series = pd.DataFrame({ "x": s_x, "y": df_local[ycol] })
313
+ # sort by x if datetime
314
+ try:
315
+ series = series.sort_values("x").reset_index(drop=True)
316
+ except Exception:
317
+ pass
318
+ out = [{"x": to_json_serializable(r["x"].isoformat() if hasattr(r["x"], "isoformat") else r["x"]), "y": to_json_serializable(r["y"])} for r in series.to_dict(orient="records")]
319
+ return "line", out
320
+
321
+ if chart_type == "scatter":
322
+ # target_columns: [x_col, y_col]
323
+ if len(cols) < 2:
324
+ x = _pick_numeric(df_local)
325
+ y = _pick_numeric(df_local, exclude=[x]) if x else None
326
+ cols = [x, y]
327
+ xcol, ycol = cols[0], cols[1]
328
+ out = [{"x": to_json_serializable(r[xcol]), "y": to_json_serializable(r[ycol])} for r in df_local[[xcol, ycol]].to_dict(orient="records")]
329
+ return "scatter", out
330
+
331
+ if chart_type == "histogram":
332
+ # target_columns: [numeric_col]
333
+ col = cols[0] if cols else _pick_numeric(df_local)
334
+ series = df_local[col].dropna()
335
+ counts, bin_edges = np.histogram(series, bins=10)
336
+ out = []
337
+ for i in range(len(counts)):
338
+ out.append({"bin": f"{float(bin_edges[i]):.6g}-{float(bin_edges[i+1]):.6g}", "count": int(counts[i])})
339
+ return "histogram", out
340
+
341
+ if chart_type == "boxplot":
342
+ # target_columns: [category_col, numeric_col] or [numeric_col] (global box)
343
+ if len(cols) == 0:
344
+ num = _pick_numeric(df_local)
345
+ out = _box_aggregate(df_local, None, num)
346
+ return "boxplot", out
347
+ if len(cols) == 1:
348
+ num = cols[0]
349
+ out = _box_aggregate(df_local, None, num)
350
+ return "boxplot", out
351
+ cat, num = cols[0], cols[1]
352
+ out = _box_aggregate(df_local, cat, num)
353
+ return "boxplot", out
354
+
355
+ # fallback: return empty
356
+ return chart_type or "unknown", []
357
+
358
+
359
+ def _pick_categorical(df: pd.DataFrame) -> Optional[str]:
360
+ for c in df.columns:
361
+ if df[c].dtype == object or df[c].nunique() < max(50, 0.5 * len(df)):
362
+ return c
363
+ return None
364
+
365
+
366
+ def _pick_numeric(df: pd.DataFrame, exclude: List[str] = []) -> Optional[str]:
367
+ for c in df.select_dtypes(include=[np.number]).columns:
368
+ if c not in exclude:
369
+ return c
370
+ # try coercion
371
+ for c in df.columns:
372
+ try:
373
+ coerced = pd.to_numeric(df[c], errors="coerce")
374
+ if coerced.notna().sum() > 0:
375
+ return c
376
+ except Exception:
377
+ continue
378
+ return None
379
+
380
+
381
+ def _pick_datetime_or_index(df: pd.DataFrame) -> Optional[str]:
382
+ for c in df.columns:
383
+ if np.issubdtype(df[c].dtype, np.datetime64):
384
+ return c
385
+ # try to parse string columns
386
+ for c in df.columns:
387
+ try:
388
+ parsed = pd.to_datetime(df[c], errors="coerce")
389
+ if parsed.notna().sum() > 0:
390
+ return c
391
+ except Exception:
392
+ continue
393
+ return None
394
+
395
+
396
+ def _box_aggregate(df: pd.DataFrame, cat_col: Optional[str], num_col: str) -> List[Dict[str, Any]]:
397
+ out = []
398
+ if cat_col is None:
399
+ series = df[num_col].dropna()
400
+ q1 = float(series.quantile(0.25))
401
+ median = float(series.quantile(0.5))
402
+ q3 = float(series.quantile(0.75))
403
+ out.append({"category": None, "q1": q1, "median": median, "q3": q3})
404
+ return out
405
+ for name, group in df.groupby(cat_col):
406
+ ser = group[num_col].dropna()
407
+ if len(ser) == 0:
408
+ continue
409
+ q1 = float(ser.quantile(0.25))
410
+ median = float(ser.quantile(0.5))
411
+ q3 = float(ser.quantile(0.75))
412
+ out.append({"category": to_json_serializable(name), "q1": q1, "median": median, "q3": q3})
413
+ return out
414
+
415
+
416
+ # ---------------------------
417
+ # Main pipeline controller
418
+ # ---------------------------
419
+
420
+ def ensure_six_tasks(tasks: List[Dict[str, Any]], df: pd.DataFrame) -> List[Dict[str, Any]]:
421
+ """
422
+ Ensure at least 6 chart tasks. If <6, append deterministic tasks.
423
+ """
424
+ existing_types = [t.get("chart_type") for t in tasks]
425
+ candidates = ["pie", "bar", "line", "scatter", "histogram", "boxplot"]
426
+ out = tasks[:]
427
+ for ct in candidates:
428
+ if len(out) >= 6:
429
+ break
430
+ if ct not in existing_types:
431
+ # create a spec
432
+ if ct == "pie":
433
+ col = _pick_categorical(df) or df.columns[0]
434
+ out.append({"chart_type": "pie", "target_columns": [col], "aggregation": "count", "reasoning": "fallback pie"})
435
+ elif ct == "bar":
436
+ label = _pick_categorical(df) or df.columns[0]
437
+ num = _pick_numeric(df) or df.columns[0]
438
+ out.append({"chart_type": "bar", "target_columns": [label, num], "aggregation": "sum", "reasoning": "fallback bar"})
439
+ elif ct == "line":
440
+ y = _pick_numeric(df) or df.columns[0]
441
+ x = _pick_datetime_or_index(df) or df.index.name or "index"
442
+ if x == "index":
443
+ out.append({"chart_type": "line", "target_columns": [x, y], "aggregation": None, "reasoning": "fallback line on index"})
444
+ else:
445
+ out.append({"chart_type": "line", "target_columns": [x, y], "aggregation": None, "reasoning": "fallback line"})
446
+ elif ct == "scatter":
447
+ x = _pick_numeric(df)
448
+ y = _pick_numeric(df, exclude=[x]) or x
449
+ out.append({"chart_type": "scatter", "target_columns": [x, y], "aggregation": None, "reasoning": "fallback scatter"})
450
+ elif ct == "histogram":
451
+ num = _pick_numeric(df) or df.columns[0]
452
+ out.append({"chart_type": "histogram", "target_columns": [num], "aggregation": None, "reasoning": "fallback histogram"})
453
+ elif ct == "boxplot":
454
+ num = _pick_numeric(df) or df.columns[0]
455
+ cat = _pick_categorical(df)
456
+ if cat:
457
+ out.append({"chart_type": "boxplot", "target_columns": [cat, num], "aggregation": None, "reasoning": "fallback boxplot by category"})
458
+ else:
459
+ out.append({"chart_type": "boxplot", "target_columns": [num], "aggregation": None, "reasoning": "fallback boxplot global"})
460
+ return out
461
+
462
+
463
+ def process_file(path: str, sheet: Optional[str] = None, model: str = "gemini-2.5-flash-lite") -> Dict[str, Any]:
464
+ # ingest
465
+ df, meta = ingest_file(path, sheet)
466
+ pre_df, preprocess_meta = preprocess_df(df)
467
+ samples = sample_df(pre_df, n=5)
468
+
469
+ # prepare payload for classification agent
470
+ classification_input = json.dumps({"samples": samples, "schema": simple_schema(pre_df), "meta": meta})
471
+ classification_output = gemini_generate_json(
472
+ model=model,
473
+ system_instruction=PROMPTS["classification"],
474
+ user_content=classification_input,
475
+ require_json=True,
476
+ )
477
+
478
+ # If classification_output is raw or malformed, fallback to naive classification
479
+ if not isinstance(classification_output, dict) or "tasks" not in classification_output:
480
+ fallback = {
481
+ "domain": "unknown",
482
+ "tasks": [
483
+ {"chart_type": "pie", "target_columns": [_pick_categorical(pre_df) or pre_df.columns[0]], "aggregation": "count", "reasoning": "fallback"},
484
+ {"chart_type": "bar", "target_columns": [_pick_categorical(pre_df) or pre_df.columns[0], _pick_numeric(pre_df) or pre_df.columns[0]], "aggregation": "sum", "reasoning": "fallback"},
485
+ ],
486
+ }
487
+ classification_output = fallback
488
+
489
+ # planning agent: ask for code snippets & structured tasks
490
+ planning_input = json.dumps({"classification": classification_output, "schema": simple_schema(pre_df), "samples": samples})
491
+ planning_output = gemini_generate_json(
492
+ model=model,
493
+ system_instruction=PROMPTS["planning"],
494
+ user_content=planning_input,
495
+ require_json=True,
496
+ )
497
+
498
+ # planning_output expected form: {"tasks": [ {chart_type,..., code: "..."} ]}
499
+ tasks = []
500
+ if isinstance(planning_output, dict) and "tasks" in planning_output:
501
+ tasks = planning_output["tasks"]
502
+ else:
503
+ # If model didn't produce tasks array, use classification tasks
504
+ tasks = classification_output.get("tasks", [])
505
+
506
+ # Ensure at least 6 tasks
507
+ tasks = ensure_six_tasks(tasks, pre_df)
508
+
509
+ # Execute tasks
510
+ final = {"pie": [], "bar": [], "line": [], "scatter": [], "histogram": [], "boxplot": []}
511
+ execution_errors = []
512
+
513
+ for idx, task in enumerate(tasks):
514
+ chart_type = task.get("chart_type")
515
+ code_snippet = task.get("code") # optional
516
+ executed = False
517
+
518
+ if code_snippet:
519
+ safe, reason = code_is_safe(code_snippet)
520
+ if not safe:
521
+ logger.warning("Rejected unsafe code snippet: %s", reason)
522
+ else:
523
+ # Controlled globals for exec/eval
524
+ allowed_globals = {
525
+ "__builtins__": {
526
+ "None": None,
527
+ "True": True,
528
+ "False": False,
529
+ "len": len,
530
+ "min": min,
531
+ "max": max,
532
+ "sum": sum,
533
+ "sorted": sorted,
534
+ "round": round,
535
+ },
536
+ "pd": pd,
537
+ "np": np,
538
+ "df": pre_df.copy(),
539
+ }
540
+ local_vars = {}
541
+ try:
542
+ # 1) Try exec (model should assign `result`)
543
+ exec(code_snippet, allowed_globals, local_vars)
544
+ result = None
545
+ if "result" in local_vars:
546
+ result = local_vars["result"]
547
+ elif "output" in local_vars:
548
+ result = local_vars["output"]
549
+ else:
550
+ # 2) If no explicit result, attempt eval for single-expression snippets
551
+ single_expr = (
552
+ ("\n" not in code_snippet)
553
+ and ("=" not in code_snippet)
554
+ and ("return" not in code_snippet)
555
+ and (not code_snippet.strip().startswith("def "))
556
+ )
557
+ if single_expr:
558
+ try:
559
+ # eval in same controlled globals (no locals)
560
+ result = eval(code_snippet, allowed_globals, {})
561
+ except Exception:
562
+ result = None
563
+
564
+ # 3) Normalize result into list-of-dicts
565
+ result_json = None
566
+ if isinstance(result, pd.DataFrame):
567
+ result_json = [ {k: to_json_serializable(v) for k,v in r.items()} for r in result.to_dict(orient="records") ]
568
+ elif isinstance(result, list):
569
+ norm = []
570
+ valid = True
571
+ for r in result:
572
+ if isinstance(r, dict):
573
+ norm.append({k: to_json_serializable(v) for k,v in r.items()})
574
+ else:
575
+ # allow primitive lists but wrap as dict with value key
576
+ norm.append(to_json_serializable(r))
577
+ result_json = norm
578
+ elif isinstance(result, dict):
579
+ result_json = [{k: to_json_serializable(v) for k,v in result.items()}]
580
+ else:
581
+ # primitive or None -> invalid for chart payload
582
+ result_json = None
583
+
584
+ if result_json is not None:
585
+ # validate it's list of dicts or list
586
+ if isinstance(result_json, list):
587
+ # ensure each element is dict-like; if not, wrap
588
+ normalized = []
589
+ for item in result_json:
590
+ if isinstance(item, dict):
591
+ normalized.append(item)
592
+ else:
593
+ normalized.append({"value": to_json_serializable(item)})
594
+ if chart_type in final:
595
+ final[chart_type].extend(normalized)
596
+ else:
597
+ final.setdefault(chart_type, []).extend(normalized)
598
+ executed = True
599
+ if not executed:
600
+ execution_errors.append({"task_index": idx, "reason": "result not list-of-dicts or missing", "code": code_snippet})
601
+ except Exception as e:
602
+ logger.exception("Model code execution failed for task %s: %s", idx, str(e))
603
+ execution_errors.append({"task_index": idx, "reason": "exception during exec/eval", "exception": str(e), "code": code_snippet})
604
+
605
+ if not executed:
606
+ # deterministic fallback execution
607
+ ct, res = generate_chart_data_from_spec(pre_df, task)
608
+ if ct not in final:
609
+ final.setdefault(ct, [])
610
+ final[ct].extend(res)
611
+
612
+ # Ensure lists are trimmed reasonably
613
+ for k in final:
614
+ if isinstance(final[k], list):
615
+ final[k] = final[k][:200]
616
+
617
+ output_payload = {
618
+ "metadata": {
619
+ "source_file": os.path.basename(path),
620
+ "ingestion_meta": meta,
621
+ "preprocess_meta": preprocess_meta,
622
+ "classification": classification_output if isinstance(classification_output, dict) else {"raw": classification_output},
623
+ "planning_meta": planning_output if isinstance(planning_output, dict) else {"raw": planning_output},
624
+ "execution_errors": execution_errors,
625
+ },
626
+ "charts": final,
627
+ }
628
+
629
+ return output_payload
630
+
631
+
632
+ # ---------------------------
633
+ # CLI
634
+ # ---------------------------
635
+
636
+ def main():
637
+ parser = argparse.ArgumentParser(description="Automated analysis pipeline that outputs frontend-ready chart JSON.")
638
+ parser.add_argument("path", type=str, help="Path to CSV or Excel file")
639
+ parser.add_argument("--sheet", type=str, default=None, help="Sheet name if Excel")
640
+ parser.add_argument("--model", type=str, default="gemini-2.5-flash-lite", help="Gemini model id")
641
+ args = parser.parse_args()
642
+
643
+ result = process_file(args.path, sheet=args.sheet, model=args.model)
644
+ # Print final JSON to stdout
645
+ print(json.dumps(result, indent=2, default=to_json_serializable))
646
+
647
+
648
+ if __name__ == "__main__":
649
+ main()