triflix commited on
Commit
1035208
Β·
verified Β·
1 Parent(s): 17ab777

Update pipeline_with_agents.py

Browse files
Files changed (1) hide show
  1. pipeline_with_agents.py +106 -39
pipeline_with_agents.py CHANGED
@@ -1,17 +1,12 @@
1
- # (This file is verbatim copy of the pipeline you provided β€” unchanged.)
2
- # Save exactly as provided by you.
3
- # -------------------------
4
- # Paste the full content you provided earlier here.
5
- # -------------------------
6
  """
7
  Automated Data-Analysis Pipeline with Agent Prompts + Gemini (google-genai)
8
 
9
- Fixed:
10
- - Uses GEMINI_API_KEY from environment (no hardcoded key)
11
- - Enforces planning agent contract: model MUST assign `result` variable
12
- - Hardened execution: exec -> check for `result` or attempt safe eval for single-expression snippets
13
- - Rejects unsafe snippets via regex blacklist
14
- - Logs model code when execution fails and falls back to deterministic generator
15
  """
16
 
17
  import os
@@ -34,25 +29,60 @@ logging.basicConfig(level=logging.INFO)
34
  logger = logging.getLogger("pipeline")
35
 
36
  # ---------------------------
37
- # Agent system prompts
38
  # ---------------------------
39
 
40
  PROMPTS = {
41
- "file_ingestion": """You are a file ingestion agent. Detect file type, extract sheets if Excel, load into structured dataframe. Return metadata: file_type, sheet_names, selected_sheet.""",
42
- "preprocessing": """You are a preprocessing agent. Clean and normalize the dataset: handle nulls, infer types, encode categories, scale numerics. Output structured summary of cleaning steps and cleaned dataframe schema.""",
43
- "sampling": """You are a sampling agent. From the dataframe, generate head(5), tail(5), and a random sample. Output each in JSON structure.""",
44
- "classification": """You are a classification agent. Analyze given dataset samples. Identify domain (finance, sales, demographics, etc.) and suggest suitable visualization tasks (minimum 6). Return structured JSON listing tasks with fields: chart_type (one of pie, bar, line, scatter, histogram, boxplot), target_columns (array), aggregation (string or null), reasoning (short). Example output: { "domain": "sales", "tasks": [ { "chart_type": "pie", "target_columns": ["region"], "aggregation": "count", "reasoning": "distribution of region" }, ... ] }""",
45
- "planning": """You are a planning agent. From classification output, create at least 6 chart tasks. For each: chart_type (pie, bar, line, scatter, histogram, boxplot), target_columns, aggregation (if needed), and a Python/pandas code snippet that generates the chart-ready aggregated JSON (but code must be limited to pandas/numpy operations). IMPORTANT: Your code MUST assign the final chart data to a variable named `result` that is a list of dictionaries ready for JSON serialization. Example:
46
- result = df.groupby('anemia_label').size().reset_index(name='value').to_dict(orient='records')
47
- Return JSON with array 'tasks'. Example task entry: { "chart_type": "bar", "target_columns": ["month", "sales"], "aggregation": "sum", "code": "result = df.groupby('month')['sales'].sum().reset_index().to_dict(orient=\\'records\\')" }""",
48
- "execution": """You are an execution agent. Execute given Python code safely on provided dataset. Return structured JSON results in chart-ready format. Follow schema:
49
- Pie β†’ [{ "name": "...", "value": ... }]
50
- Bar β†’ [{ "label": "...", "metric1": ..., "metric2": ... }]
51
- Line β†’ [{ "x": "...", "y": ... }]
52
- Scatter β†’ [{ "x": ..., "y": ... }]
53
- Histogram β†’ [{ "bin": ..., "count": ... }]
54
- Boxplot β†’ [{ "category": "...", "q1": ..., "median": ..., "q3": ... }]""",
55
- "output": """You are an output agent. Aggregate all chart JSON objects into final structured response. Ensure at least 6 charts included. Output JSON with keys: "pie", "bar", "line", "scatter", "histogram", "boxplot"."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  }
57
 
58
  # ---------------------------
@@ -73,13 +103,24 @@ DISALLOWED_PATTERNS = [
73
  r"\bsystem\s*\(",
74
  r"\bPopen\b",
75
  r"\bsh\b",
 
 
 
 
 
 
 
 
 
 
 
76
  ]
77
 
78
 
79
  def code_is_safe(code: str) -> Tuple[bool, Optional[str]]:
80
  lowered = code
81
  for pat in DISALLOWED_PATTERNS:
82
- if re.search(pat, lowered):
83
  return False, f"disallowed pattern: {pat}"
84
  return True, None
85
 
@@ -203,8 +244,9 @@ def gemini_generate_json(model: str, system_instruction: str, user_content: str,
203
  """
204
  Calls genai generate_content_stream with given system prompt and user content.
205
  Expects the model to return JSON text. Joins chunks and returns parsed JSON or raw text.
 
206
  """
207
- api_key = "AIzaSyDfy0E-9b2XjoYHrHX2C1nVLHWyrWUFkMs"
208
  if not api_key:
209
  raise EnvironmentError("GEMINI_API_KEY not set in environment.")
210
  client = genai.Client(api_key=api_key)
@@ -503,12 +545,38 @@ def process_file(path: str, sheet: Optional[str] = None, model: str = "gemini-2.
503
  # If model didn't produce tasks array, use classification tasks
504
  tasks = classification_output.get("tasks", [])
505
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
506
  # Ensure at least 6 tasks
507
  tasks = ensure_six_tasks(tasks, pre_df)
508
 
509
  # Execute tasks
510
- final = {"pie": [], "bar": [], "line": [], "scatter": [], "histogram": [], "boxplot": []}
511
- execution_errors = []
512
 
513
  for idx, task in enumerate(tasks):
514
  chart_type = task.get("chart_type")
@@ -518,7 +586,7 @@ def process_file(path: str, sheet: Optional[str] = None, model: str = "gemini-2.
518
  if code_snippet:
519
  safe, reason = code_is_safe(code_snippet)
520
  if not safe:
521
- logger.warning("Rejected unsafe code snippet: %s", reason)
522
  else:
523
  # Controlled globals for exec/eval
524
  allowed_globals = {
@@ -537,7 +605,7 @@ def process_file(path: str, sheet: Optional[str] = None, model: str = "gemini-2.
537
  "np": np,
538
  "df": pre_df.copy(),
539
  }
540
- local_vars = {}
541
  try:
542
  # 1) Try exec (model should assign `result`)
543
  exec(code_snippet, allowed_globals, local_vars)
@@ -564,19 +632,18 @@ def process_file(path: str, sheet: Optional[str] = None, model: str = "gemini-2.
564
  # 3) Normalize result into list-of-dicts
565
  result_json = None
566
  if isinstance(result, pd.DataFrame):
567
- result_json = [ {k: to_json_serializable(v) for k,v in r.items()} for r in result.to_dict(orient="records") ]
568
  elif isinstance(result, list):
569
  norm = []
570
- valid = True
571
  for r in result:
572
  if isinstance(r, dict):
573
- norm.append({k: to_json_serializable(v) for k,v in r.items()})
574
  else:
575
  # allow primitive lists but wrap as dict with value key
576
- norm.append(to_json_serializable(r))
577
  result_json = norm
578
  elif isinstance(result, dict):
579
- result_json = [{k: to_json_serializable(v) for k,v in result.items()}]
580
  else:
581
  # primitive or None -> invalid for chart payload
582
  result_json = None
@@ -597,7 +664,7 @@ def process_file(path: str, sheet: Optional[str] = None, model: str = "gemini-2.
597
  final.setdefault(chart_type, []).extend(normalized)
598
  executed = True
599
  if not executed:
600
- execution_errors.append({"task_index": idx, "reason": "result not list-of-dicts or missing", "code": code_snippet})
601
  except Exception as e:
602
  logger.exception("Model code execution failed for task %s: %s", idx, str(e))
603
  execution_errors.append({"task_index": idx, "reason": "exception during exec/eval", "exception": str(e), "code": code_snippet})
 
 
 
 
 
 
1
  """
2
  Automated Data-Analysis Pipeline with Agent Prompts + Gemini (google-genai)
3
 
4
+ Changes applied:
5
+ - Use GEMINI_API_KEY from environment (no hardcoded key)
6
+ - Stronger, model-proof PROMPTS that forbid plotting and require `result` assignment
7
+ - Extended DISALLOWED_PATTERNS to block plotting libraries and plotting methods
8
+ - Validation step after planning: drop model-provided code that lacks `result` or uses plotting tokens; record execution_errors
9
+ - Execution still performs safety checks and falls back to deterministic generators when needed
10
  """
11
 
12
  import os
 
29
  logger = logging.getLogger("pipeline")
30
 
31
  # ---------------------------
32
+ # Agent system prompts (strict, plotting banned)
33
  # ---------------------------
34
 
35
  PROMPTS = {
36
+ "file_ingestion": (
37
+ "You are a file ingestion agent. Detect file type; if Excel enumerate sheets and pick the specified sheet or default to the first. "
38
+ "Load the chosen sheet into a pandas DataFrame and return only metadata (no narrative): "
39
+ '{"file_type":"<.csv|.xlsx|...>", "sheet_names":[...], "selected_sheet":"..."}'
40
+ ),
41
+ "preprocessing": (
42
+ "You are a preprocessing agent. Clean and normalize the dataset deterministically. "
43
+ "Operations allowed: trim strings, coerce numeric columns with pandas.to_numeric, fill numeric NaNs with median, fill object NaNs with mode, "
44
+ "generate one-line schema summary. RETURN JSON only: {\"actions\": [...], \"schema\": {\"columns\":[{\"name\":\"...\",\"dtype\":\"...\",\"n_unique\":N},...], \"n_rows\":N}}. "
45
+ "Do NOT print or return any code, diagrams, or explanations."
46
+ ),
47
+ "sampling": (
48
+ "You are a sampling agent. From the cleaned dataframe produce three JSON arrays: head(5), tail(5), random(5). "
49
+ "Return JSON: {\"head\": [...], \"tail\": [...], \"random\": [...]} where each array contains row dicts. Do NOT include extra fields."
50
+ ),
51
+ "classification": (
52
+ "You are a classification agent. Examine provided samples and schema. Identify dataset domain (one-word) and propose at least SIX visualization tasks. "
53
+ "Each task must be a JSON object: {\"task_id\":\"tN\",\"chart_type\":\"pie|bar|line|scatter|histogram|boxplot\",\"target_columns\":[...],"
54
+ "\"aggregation\": null|\"count\"|\"sum\"|\"mean\",\"reasoning\":\"one-sentence\"}. "
55
+ "Return JSON exactly: {\"domain\":\"...\",\"tasks\":[...]} and nothing else. Do NOT include code. Do NOT recommend plotting libraries."
56
+ ),
57
+ "planning": (
58
+ "You are a planning agent. Input: the classification JSON + schema + small samples. Produce at least SIX task entries. "
59
+ "For each task output a Python/pandas code snippet that uses ONLY pandas and numpy (and the dataframe variable `df`) and assigns the final result to a variable named `result`. "
60
+ "REQUIREMENTS for the code string: "
61
+ " - Must NOT import or reference matplotlib, seaborn, plotly, altair, bokeh, or any plotting functions. "
62
+ " - Must NOT call pandas plotting methods (e.g. .plot(), .hist() wrapper that uses matplotlib). "
63
+ " - Must NOT use eval/exec/compile or open(). "
64
+ " - Allowed names: df, pd, np, len, sum, min, max, round, sorted. "
65
+ " - The code must produce `result` as a list of dictionaries ready for JSON serialization (use .to_dict(orient='records') or list comprehension). "
66
+ " - Return JSON exactly: {\"tasks\":[ {\"task_id\":\"t1\",\"chart_type\":\"pie\",\"target_columns\":[\"colA\"],"
67
+ "\"aggregation\":\"count\",\"reasoning\":\"...\",\"code\":\"result = df.groupby('colA').size().reset_index(name=\\'value\\').to_dict(orient=\\'records\\')\" }, ... ] }"
68
+ ),
69
+ "execution": (
70
+ "You are an execution agent. You will run model-provided code in a restricted execution environment WITHOUT plotting libraries. "
71
+ "The executor expects the code to assign a variable named `result` containing a list of dicts. "
72
+ "Rules: do not rely on plotting functions. Use pandas/numpy for aggregation and numeric work only. "
73
+ "Schema expectations per chart type (examples only): "
74
+ " Pie β†’ [{\"name\":\"...\",\"value\":number}], "
75
+ " Bar β†’ [{\"label\":\"...\",\"metric1\":number, ...}], "
76
+ " Line β†’ [{\"x\":...,\"y\":...}] (x may be ISO string), "
77
+ " Scatter β†’ [{\"x\":number,\"y\":number}], "
78
+ " Histogram β†’ [{\"bin\":\"start-end\",\"count\":number}], "
79
+ " Boxplot β†’ [{\"category\":\"...\",\"q1\":number,\"median\":number,\"q3\":number}]. "
80
+ "Return nothing else; the pipeline will read `result` after execution. If you must provide example code show it only as a code string and follow the allowed-names rule."
81
+ ),
82
+ "output": (
83
+ "You are an output agent. Aggregate final chart JSON objects into a single JSON object with keys: "
84
+ '"pie","bar","line","scatter","histogram","boxplot". Each key maps to an array (may be empty). Output JSON only.'
85
+ )
86
  }
87
 
88
  # ---------------------------
 
103
  r"\bsystem\s*\(",
104
  r"\bPopen\b",
105
  r"\bsh\b",
106
+ # plotting libraries / functions
107
+ r"\bmatplotlib\b",
108
+ r"\bseaborn\b",
109
+ r"\bplotly\b",
110
+ r"\baltair\b",
111
+ r"\bbokeh\b",
112
+ r"\.plot\s*\(",
113
+ r"\.hist\s*\(",
114
+ r"\.boxplot\s*\(",
115
+ r"\bpyplot\b",
116
+ r"\bplt\b",
117
  ]
118
 
119
 
120
  def code_is_safe(code: str) -> Tuple[bool, Optional[str]]:
121
  lowered = code
122
  for pat in DISALLOWED_PATTERNS:
123
+ if re.search(pat, lowered, flags=re.I):
124
  return False, f"disallowed pattern: {pat}"
125
  return True, None
126
 
 
244
  """
245
  Calls genai generate_content_stream with given system prompt and user content.
246
  Expects the model to return JSON text. Joins chunks and returns parsed JSON or raw text.
247
+ Uses GEMINI_API_KEY from environment.
248
  """
249
+ api_key = os.environ.get("GEMINI_API_KEY")
250
  if not api_key:
251
  raise EnvironmentError("GEMINI_API_KEY not set in environment.")
252
  client = genai.Client(api_key=api_key)
 
545
  # If model didn't produce tasks array, use classification tasks
546
  tasks = classification_output.get("tasks", [])
547
 
548
+ # Execution errors list (populate during validation/execution)
549
+ execution_errors: List[Dict[str, Any]] = []
550
+
551
+ # Validate model-provided code before execution:
552
+ # - require 'result' assignment inside code
553
+ # - drop code that contains plotting tokens or disallowed patterns
554
+ plotting_disallowed_re = re.compile(r"(matplotlib|seaborn|plotly|altair|bokeh|\.plot\s*\(|\.hist\s*\(|\.boxplot\s*\(|plt\b|pyplot\b)", flags=re.I)
555
+ for i, t in enumerate(tasks):
556
+ code = t.get("code", "") or ""
557
+ if code:
558
+ # 1) basic presence of `result`
559
+ if "result" not in code:
560
+ t.pop("code", None)
561
+ execution_errors.append({"task_index": i, "task_id": t.get("task_id"), "reason": "missing 'result' assignment - code dropped"})
562
+ continue
563
+ # 2) plotting tokens check
564
+ if plotting_disallowed_re.search(code):
565
+ t.pop("code", None)
566
+ execution_errors.append({"task_index": i, "task_id": t.get("task_id"), "reason": "plotting functions not allowed - code dropped"})
567
+ continue
568
+ # 3) disallowed patterns check
569
+ safe, reason = code_is_safe(code)
570
+ if not safe:
571
+ t.pop("code", None)
572
+ execution_errors.append({"task_index": i, "task_id": t.get("task_id"), "reason": f"disallowed pattern - code dropped: {reason}"})
573
+ continue
574
+
575
  # Ensure at least 6 tasks
576
  tasks = ensure_six_tasks(tasks, pre_df)
577
 
578
  # Execute tasks
579
+ final: Dict[str, List[Dict[str, Any]]] = {"pie": [], "bar": [], "line": [], "scatter": [], "histogram": [], "boxplot": []}
 
580
 
581
  for idx, task in enumerate(tasks):
582
  chart_type = task.get("chart_type")
 
586
  if code_snippet:
587
  safe, reason = code_is_safe(code_snippet)
588
  if not safe:
589
+ logger.warning("Rejected unsafe code snippet at execution time: %s", reason)
590
  else:
591
  # Controlled globals for exec/eval
592
  allowed_globals = {
 
605
  "np": np,
606
  "df": pre_df.copy(),
607
  }
608
+ local_vars: Dict[str, Any] = {}
609
  try:
610
  # 1) Try exec (model should assign `result`)
611
  exec(code_snippet, allowed_globals, local_vars)
 
632
  # 3) Normalize result into list-of-dicts
633
  result_json = None
634
  if isinstance(result, pd.DataFrame):
635
+ result_json = [{k: to_json_serializable(v) for k, v in r.items()} for r in result.to_dict(orient="records")]
636
  elif isinstance(result, list):
637
  norm = []
 
638
  for r in result:
639
  if isinstance(r, dict):
640
+ norm.append({k: to_json_serializable(v) for k, v in r.items()})
641
  else:
642
  # allow primitive lists but wrap as dict with value key
643
+ norm.append({"value": to_json_serializable(r)})
644
  result_json = norm
645
  elif isinstance(result, dict):
646
+ result_json = [{k: to_json_serializable(v) for k, v in result.items()}]
647
  else:
648
  # primitive or None -> invalid for chart payload
649
  result_json = None
 
664
  final.setdefault(chart_type, []).extend(normalized)
665
  executed = True
666
  if not executed:
667
+ execution_errors.append({"task_index": idx, "reason": "result not list-of-dicts or missing after exec", "code": code_snippet})
668
  except Exception as e:
669
  logger.exception("Model code execution failed for task %s: %s", idx, str(e))
670
  execution_errors.append({"task_index": idx, "reason": "exception during exec/eval", "exception": str(e), "code": code_snippet})