triflix commited on
Commit
ea760a5
·
verified ·
1 Parent(s): bda6ce3

Create pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +486 -0
pipeline.py ADDED
@@ -0,0 +1,486 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app/pipeline.py
2
+ import os
3
+ import json
4
+ import logging
5
+ import re
6
+ from typing import Any, Dict, List, Optional, Tuple
7
+
8
+ import pandas as pd
9
+ import numpy as np
10
+
11
+ logger = logging.getLogger("pipeline")
12
+ logger.setLevel(logging.INFO)
13
+
14
+ # ---------------------------
15
+ # Safety / helpers
16
+ # ---------------------------
17
+
18
+ DISALLOWED_PATTERNS = [
19
+ r"__\w+__", r"\bimport\s+os\b", r"\bimport\s+sys\b", r"\bimport\s+subprocess\b",
20
+ r"\bimport\s+socket\b", r"\bimport\s+requests\b", r"\bopen\s*\(", r"\beval\s*\(",
21
+ r"\bexec\s*\(", r"\bcompile\s*\(", r"\bsystem\s*\(", r"\bPopen\b", r"\bsh\b",
22
+ ]
23
+
24
+ def code_is_safe(code: str) -> Tuple[bool, Optional[str]]:
25
+ for pat in DISALLOWED_PATTERNS:
26
+ if re.search(pat, code):
27
+ return False, f"disallowed pattern: {pat}"
28
+ return True, None
29
+
30
+ def to_json_serializable(obj):
31
+ if isinstance(obj, (np.integer,)):
32
+ return int(obj)
33
+ if isinstance(obj, (np.floating,)):
34
+ return float(obj)
35
+ if isinstance(obj, (np.ndarray,)):
36
+ return obj.tolist()
37
+ if pd.isna(obj):
38
+ return None
39
+ return obj
40
+
41
+ def simple_schema(df: pd.DataFrame) -> Dict[str, Any]:
42
+ return {
43
+ "columns": [
44
+ {"name": c, "dtype": str(df[c].dtype), "n_unique": int(df[c].nunique(dropna=True))}
45
+ for c in df.columns
46
+ ],
47
+ "n_rows": int(len(df)),
48
+ }
49
+
50
+ # ---------------------------
51
+ # Ingest / preprocess / sampling
52
+ # ---------------------------
53
+
54
+ def ingest_file(path: str, sheet: Optional[str] = None) -> Tuple[pd.DataFrame, Dict[str, Any]]:
55
+ ext = os.path.splitext(path)[1].lower()
56
+ metadata = {"file_type": ext, "sheet_names": None, "selected_sheet": None}
57
+ if ext in [".csv", ".txt"]:
58
+ df = pd.read_csv(path)
59
+ metadata["selected_sheet"] = "csv"
60
+ elif ext in [".xls", ".xlsx"]:
61
+ xls = pd.ExcelFile(path)
62
+ sheets = xls.sheet_names
63
+ metadata["sheet_names"] = sheets
64
+ chosen = sheet if sheet and sheet in sheets else sheets[0]
65
+ metadata["selected_sheet"] = chosen
66
+ df = pd.read_excel(xls, sheet_name=chosen)
67
+ else:
68
+ raise ValueError("Unsupported file type: " + ext)
69
+ metadata["file_type"] = ext
70
+ return df, metadata
71
+
72
+ def preprocess_df(df: pd.DataFrame) -> Tuple[pd.DataFrame, Dict[str, Any]]:
73
+ actions = []
74
+ df = df.copy()
75
+ df.columns = [str(c).strip() for c in df.columns]
76
+
77
+ object_cols = df.select_dtypes(include="object").columns.tolist()
78
+ for c in object_cols:
79
+ try:
80
+ df[c] = df[c].where(df[c].isna(), df[c].astype(str).str.strip())
81
+ except Exception:
82
+ pass
83
+
84
+ for c in df.columns:
85
+ if df[c].dtype == object:
86
+ coerced = pd.to_numeric(df[c], errors="coerce")
87
+ non_na = coerced.notna().sum()
88
+ if non_na >= max(1, 0.5 * len(df)):
89
+ df[c] = coerced
90
+ actions.append(f"coerced {c} -> numeric")
91
+
92
+ num_cols = df.select_dtypes(include=[np.number]).columns.tolist()
93
+ for c in num_cols:
94
+ median = df[c].median()
95
+ if pd.isna(median):
96
+ median = 0
97
+ df[c] = df[c].fillna(median)
98
+ actions.append(f"filled numeric {c} nulls with median {median}")
99
+
100
+ for c in object_cols:
101
+ try:
102
+ mode = df[c].mode().iloc[0] if not df[c].mode().empty else ""
103
+ except Exception:
104
+ mode = ""
105
+ df[c] = df[c].fillna(mode)
106
+ actions.append(f"filled object {c} nulls with mode '{mode}'")
107
+
108
+ schema = simple_schema(df)
109
+ return df, {"actions": actions, "schema": schema}
110
+
111
+ def sample_df(df: pd.DataFrame, n: int = 5) -> Dict[str, Any]:
112
+ head = df.head(n).to_dict(orient="records")
113
+ tail = df.tail(n).to_dict(orient="records")
114
+ rnd = df.sample(n=n, random_state=42).to_dict(orient="records") if len(df) > n else df.sample(frac=1.0).to_dict(orient="records")
115
+ return {"head": head, "tail": tail, "random": rnd}
116
+
117
+ # ---------------------------
118
+ # Deterministic chart generator (fallback)
119
+ # ---------------------------
120
+
121
+ def _pick_categorical(df: pd.DataFrame) -> Optional[str]:
122
+ for c in df.columns:
123
+ if df[c].dtype == object or df[c].nunique() < max(50, 0.5 * len(df)):
124
+ return c
125
+ return None
126
+
127
+ def _pick_numeric(df: pd.DataFrame, exclude: List[str] = []) -> Optional[str]:
128
+ for c in df.select_dtypes(include=[np.number]).columns:
129
+ if c not in exclude:
130
+ return c
131
+ for c in df.columns:
132
+ try:
133
+ coerced = pd.to_numeric(df[c], errors="coerce")
134
+ if coerced.notna().sum() > 0:
135
+ return c
136
+ except Exception:
137
+ continue
138
+ return None
139
+
140
+ def _pick_datetime_or_index(df: pd.DataFrame) -> Optional[str]:
141
+ for c in df.columns:
142
+ if np.issubdtype(df[c].dtype, np.datetime64):
143
+ return c
144
+ for c in df.columns:
145
+ try:
146
+ parsed = pd.to_datetime(df[c], errors="coerce")
147
+ if parsed.notna().sum() > 0:
148
+ return c
149
+ except Exception:
150
+ continue
151
+ return None
152
+
153
+ def _box_aggregate(df: pd.DataFrame, cat_col: Optional[str], num_col: str) -> List[Dict[str, Any]]:
154
+ out = []
155
+ if cat_col is None:
156
+ series = df[num_col].dropna()
157
+ if len(series) == 0:
158
+ return out
159
+ q1 = float(series.quantile(0.25))
160
+ median = float(series.quantile(0.5))
161
+ q3 = float(series.quantile(0.75))
162
+ out.append({"category": None, "q1": q1, "median": median, "q3": q3})
163
+ return out
164
+ for name, group in df.groupby(cat_col):
165
+ ser = group[num_col].dropna()
166
+ if len(ser) == 0:
167
+ continue
168
+ q1 = float(ser.quantile(0.25))
169
+ median = float(ser.quantile(0.5))
170
+ q3 = float(ser.quantile(0.75))
171
+ out.append({"category": to_json_serializable(name), "q1": q1, "median": median, "q3": q3})
172
+ return out
173
+
174
+ def generate_chart_data_from_spec(df: pd.DataFrame, spec: Dict[str, Any]) -> Tuple[str, List[Dict[str, Any]]]:
175
+ chart_type = spec.get("chart_type")
176
+ cols = spec.get("target_columns", [])
177
+ agg = spec.get("aggregation", None)
178
+ df_local = df.copy()
179
+
180
+ if chart_type == "pie":
181
+ if len(cols) == 0:
182
+ col = _pick_categorical(df_local)
183
+ cols = [col] if col else []
184
+ if len(cols) == 1:
185
+ col = cols[0]
186
+ series = df_local[col].astype(str).value_counts(dropna=True)
187
+ return "pie", [{"name": k, "value": int(v)} for k, v in series.items()]
188
+ else:
189
+ cat, val = cols[0], cols[1]
190
+ grouped = df_local.groupby(cat)[val].sum().reset_index()
191
+ return "pie", [{"name": r[cat], "value": to_json_serializable(r[val])} for r in grouped.to_dict(orient="records")]
192
+
193
+ if chart_type == "bar":
194
+ if len(cols) == 0:
195
+ label = _pick_categorical(df_local)
196
+ cols = [label] if label else []
197
+ if len(cols) == 1:
198
+ label = cols[0]
199
+ series = df_local[label].astype(str).value_counts().reset_index()
200
+ series.columns = [label, "count"]
201
+ return "bar", [{"label": r[label], "count": int(r["count"])} for r in series.to_dict(orient="records")]
202
+ else:
203
+ label, metric = cols[0], cols[1]
204
+ if agg in (None, "", "sum"):
205
+ grouped = df_local.groupby(label)[metric].sum().reset_index()
206
+ elif agg == "mean":
207
+ grouped = df_local.groupby(label)[metric].mean().reset_index()
208
+ else:
209
+ grouped = df_local.groupby(label)[metric].sum().reset_index()
210
+ return "bar", [{"label": r[label], metric: to_json_serializable(r[metric])} for r in grouped.to_dict(orient="records")]
211
+
212
+ if chart_type == "line":
213
+ if len(cols) < 2:
214
+ y = _pick_numeric(df_local)
215
+ x = _pick_datetime_or_index(df_local)
216
+ cols = [x, y]
217
+ xcol, ycol = cols[0], cols[1]
218
+ s_x = pd.to_datetime(df_local[xcol], errors="coerce") if xcol in df_local.columns else pd.Series(range(len(df_local)))
219
+ series = pd.DataFrame({"x": s_x, "y": df_local[ycol]})
220
+ try:
221
+ series = series.sort_values("x").reset_index(drop=True)
222
+ except Exception:
223
+ pass
224
+ out = []
225
+ for r in series.to_dict(orient="records"):
226
+ x_val = r["x"].isoformat() if hasattr(r["x"], "isoformat") else r["x"]
227
+ out.append({"x": to_json_serializable(x_val), "y": to_json_serializable(r["y"])})
228
+ return "line", out
229
+
230
+ if chart_type == "scatter":
231
+ if len(cols) < 2:
232
+ x = _pick_numeric(df_local)
233
+ y = _pick_numeric(df_local, exclude=[x]) if x else None
234
+ cols = [x, y]
235
+ xcol, ycol = cols[0], cols[1]
236
+ return "scatter", [{"x": to_json_serializable(r[xcol]), "y": to_json_serializable(r[ycol])} for r in df_local[[xcol, ycol]].to_dict(orient="records")]
237
+
238
+ if chart_type == "histogram":
239
+ col = cols[0] if cols else _pick_numeric(df_local)
240
+ series = df_local[col].dropna()
241
+ counts, bin_edges = np.histogram(series, bins=10)
242
+ out = []
243
+ for i in range(len(counts)):
244
+ out.append({"bin": f"{float(bin_edges[i]):.6g}-{float(bin_edges[i+1]):.6g}", "count": int(counts[i])})
245
+ return "histogram", out
246
+
247
+ if chart_type == "boxplot":
248
+ if len(cols) == 0:
249
+ num = _pick_numeric(df_local)
250
+ return "boxplot", _box_aggregate(df_local, None, num)
251
+ if len(cols) == 1:
252
+ num = cols[0]
253
+ return "boxplot", _box_aggregate(df_local, None, num)
254
+ cat, num = cols[0], cols[1]
255
+ return "boxplot", _box_aggregate(df_local, cat, num)
256
+
257
+ return chart_type or "unknown", []
258
+
259
+ # ---------------------------
260
+ # Gemini wrapper (optional)
261
+ # ---------------------------
262
+
263
+ def gemini_generate_json(model: str, system_instruction: str, user_content: str, require_json: bool = True) -> Any:
264
+ """
265
+ Attempts to call google-genai if GEMINI_API_KEY is provided.
266
+ Falls back to returning a safe marker that indicates a deterministic fallback should be used.
267
+ """
268
+ api_key = "AIzaSyDfy0E-9b2XjoYHrHX2C1nVLHWyrWUFkMs"
269
+ if not api_key:
270
+ logger.info("GEMINI_API_KEY not set; using deterministic fallbacks.")
271
+ return {"__use_fallback__": True}
272
+
273
+ try:
274
+ # import lazily
275
+ from google import genai
276
+ from google.genai import types
277
+ except Exception as e:
278
+ logger.exception("google-genai not available: %s", e)
279
+ return {"__use_fallback__": True}
280
+
281
+ client = genai.Client(api_key=api_key)
282
+ contents = [
283
+ types.Content(
284
+ role="user",
285
+ parts=[types.Part.from_text(text=user_content)],
286
+ )
287
+ ]
288
+ config = types.GenerateContentConfig(
289
+ thinking_config=types.ThinkingConfig(thinking_budget=0),
290
+ response_mime_type="application/json",
291
+ system_instruction=[types.Part.from_text(text=system_instruction)],
292
+ )
293
+ full_text = ""
294
+ try:
295
+ for chunk in client.models.generate_content_stream(model=model, contents=contents, config=config):
296
+ if hasattr(chunk, "text") and chunk.text:
297
+ full_text += chunk.text
298
+ elif (
299
+ chunk.candidates
300
+ and chunk.candidates[0].content
301
+ and chunk.candidates[0].content.parts
302
+ and chunk.candidates[0].content.parts[0].text
303
+ ):
304
+ full_text += chunk.candidates[0].content.parts[0].text
305
+ full_text = full_text.strip()
306
+ if require_json:
307
+ try:
308
+ return json.loads(full_text)
309
+ except Exception:
310
+ return {"__raw_text__": full_text}
311
+ return full_text
312
+ except Exception as e:
313
+ logger.exception("genai call failed: %s", e)
314
+ return {"__use_fallback__": True}
315
+
316
+ # ---------------------------
317
+ # Controller
318
+ # ---------------------------
319
+
320
+ DEFAULT_MODEL = "gemini-2.5-flash-lite"
321
+
322
+ def ensure_six_tasks(tasks: List[Dict[str, Any]], df: pd.DataFrame) -> List[Dict[str, Any]]:
323
+ existing_types = [t.get("chart_type") for t in tasks]
324
+ candidates = ["pie", "bar", "line", "scatter", "histogram", "boxplot"]
325
+ out = tasks[:]
326
+ for ct in candidates:
327
+ if len(out) >= 6:
328
+ break
329
+ if ct not in existing_types:
330
+ if ct == "pie":
331
+ col = _pick_categorical(df) or df.columns[0]
332
+ out.append({"chart_type": "pie", "target_columns": [col], "aggregation": "count", "reasoning": "fallback pie"})
333
+ elif ct == "bar":
334
+ label = _pick_categorical(df) or df.columns[0]
335
+ num = _pick_numeric(df) or df.columns[0]
336
+ out.append({"chart_type": "bar", "target_columns": [label, num], "aggregation": "sum", "reasoning": "fallback bar"})
337
+ elif ct == "line":
338
+ y = _pick_numeric(df) or df.columns[0]
339
+ x = _pick_datetime_or_index(df) or df.index.name or "index"
340
+ if x == "index":
341
+ out.append({"chart_type": "line", "target_columns": [x, y], "aggregation": None, "reasoning": "fallback line on index"})
342
+ else:
343
+ out.append({"chart_type": "line", "target_columns": [x, y], "aggregation": None, "reasoning": "fallback line"})
344
+ elif ct == "scatter":
345
+ x = _pick_numeric(df)
346
+ y = _pick_numeric(df, exclude=[x]) or x
347
+ out.append({"chart_type": "scatter", "target_columns": [x, y], "aggregation": None, "reasoning": "fallback scatter"})
348
+ elif ct == "histogram":
349
+ num = _pick_numeric(df) or df.columns[0]
350
+ out.append({"chart_type": "histogram", "target_columns": [num], "aggregation": None, "reasoning": "fallback histogram"})
351
+ elif ct == "boxplot":
352
+ num = _pick_numeric(df) or df.columns[0]
353
+ cat = _pick_categorical(df)
354
+ if cat:
355
+ out.append({"chart_type": "boxplot", "target_columns": [cat, num], "aggregation": None, "reasoning": "fallback boxplot by category"})
356
+ else:
357
+ out.append({"chart_type": "boxplot", "target_columns": [num], "aggregation": None, "reasoning": "fallback boxplot global"})
358
+ return out
359
+
360
+ def process_file(path: str, sheet: Optional[str] = None, model: str = DEFAULT_MODEL) -> Dict[str, Any]:
361
+ df, meta = ingest_file(path, sheet)
362
+ pre_df, preprocess_meta = preprocess_df(df)
363
+ samples = sample_df(pre_df, n=5)
364
+
365
+ classification_input = json.dumps({"samples": samples, "schema": simple_schema(pre_df), "meta": meta})
366
+ classification_output = gemini_generate_json(
367
+ model=model,
368
+ system_instruction="You are a classification agent. Identify domain and chart tasks.",
369
+ user_content=classification_input,
370
+ require_json=True,
371
+ )
372
+
373
+ if not isinstance(classification_output, dict) or "tasks" not in classification_output:
374
+ fallback = {
375
+ "domain": "unknown",
376
+ "tasks": [
377
+ {"chart_type": "pie", "target_columns": [_pick_categorical(pre_df) or pre_df.columns[0]], "aggregation": "count", "reasoning": "fallback"},
378
+ {"chart_type": "bar", "target_columns": [_pick_categorical(pre_df) or pre_df.columns[0], _pick_numeric(pre_df) or pre_df.columns[0]], "aggregation": "sum", "reasoning": "fallback"},
379
+ ],
380
+ }
381
+ classification_output = fallback
382
+
383
+ planning_input = json.dumps({"classification": classification_output, "schema": simple_schema(pre_df), "samples": samples})
384
+ planning_output = gemini_generate_json(
385
+ model=model,
386
+ system_instruction="You are a planning agent. Produce tasks with code assigning `result` variable.",
387
+ user_content=planning_input,
388
+ require_json=True,
389
+ )
390
+
391
+ tasks = []
392
+ if isinstance(planning_output, dict) and "tasks" in planning_output:
393
+ tasks = planning_output["tasks"]
394
+ else:
395
+ tasks = classification_output.get("tasks", [])
396
+
397
+ tasks = ensure_six_tasks(tasks, pre_df)
398
+
399
+ final = {"pie": [], "bar": [], "line": [], "scatter": [], "histogram": [], "boxplot": []}
400
+ execution_errors = []
401
+
402
+ for idx, task in enumerate(tasks):
403
+ chart_type = task.get("chart_type")
404
+ code_snippet = task.get("code")
405
+ executed = False
406
+
407
+ if code_snippet:
408
+ safe, reason = code_is_safe(code_snippet)
409
+ if not safe:
410
+ logger.warning("Rejected unsafe code snippet: %s", reason)
411
+ else:
412
+ allowed_globals = {
413
+ "__builtins__": {"None": None, "True": True, "False": False, "len": len, "min": min, "max": max, "sum": sum, "sorted": sorted, "round": round},
414
+ "pd": pd, "np": np, "df": pre_df.copy(),
415
+ }
416
+ local_vars = {}
417
+ try:
418
+ exec(code_snippet, allowed_globals, local_vars)
419
+ result = None
420
+ if "result" in local_vars:
421
+ result = local_vars["result"]
422
+ elif "output" in local_vars:
423
+ result = local_vars["output"]
424
+ else:
425
+ single_expr = (("\n" not in code_snippet) and ("=" not in code_snippet) and ("return" not in code_snippet) and (not code_snippet.strip().startswith("def ")))
426
+ if single_expr:
427
+ try:
428
+ result = eval(code_snippet, allowed_globals, {})
429
+ except Exception:
430
+ result = None
431
+
432
+ result_json = None
433
+ if isinstance(result, pd.DataFrame):
434
+ result_json = [{k: to_json_serializable(v) for k, v in r.items()} for r in result.to_dict(orient="records")]
435
+ elif isinstance(result, list):
436
+ norm = []
437
+ for r in result:
438
+ if isinstance(r, dict):
439
+ norm.append({k: to_json_serializable(v) for k, v in r.items()})
440
+ else:
441
+ norm.append(to_json_serializable(r))
442
+ result_json = norm
443
+ elif isinstance(result, dict):
444
+ result_json = [{k: to_json_serializable(v) for k, v in result.items()}]
445
+ else:
446
+ result_json = None
447
+
448
+ if result_json is not None:
449
+ normalized = []
450
+ for item in result_json:
451
+ if isinstance(item, dict):
452
+ normalized.append(item)
453
+ else:
454
+ normalized.append({"value": to_json_serializable(item)})
455
+ if chart_type in final:
456
+ final[chart_type].extend(normalized)
457
+ else:
458
+ final.setdefault(chart_type, []).extend(normalized)
459
+ executed = True
460
+
461
+ if not executed:
462
+ execution_errors.append({"task_index": idx, "reason": "result not list-of-dicts or missing", "code": code_snippet})
463
+ except Exception as e:
464
+ logger.exception("Model code execution failed for task %s: %s", idx, str(e))
465
+ execution_errors.append({"task_index": idx, "reason": "exception during exec/eval", "exception": str(e), "code": code_snippet})
466
+
467
+ if not executed:
468
+ ct, res = generate_chart_data_from_spec(pre_df, task)
469
+ final.setdefault(ct, []).extend(res)
470
+
471
+ for k in final:
472
+ if isinstance(final[k], list):
473
+ final[k] = final[k][:200]
474
+
475
+ output_payload = {
476
+ "metadata": {
477
+ "source_file": os.path.basename(path),
478
+ "ingestion_meta": meta,
479
+ "preprocess_meta": preprocess_meta,
480
+ "classification": classification_output if isinstance(classification_output, dict) else {"raw": classification_output},
481
+ "planning_meta": planning_output if isinstance(planning_output, dict) else {"raw": planning_output},
482
+ "execution_errors": execution_errors,
483
+ },
484
+ "charts": final,
485
+ }
486
+ return output_payload