"""Lightweight execution-grounded agent for the HF Space demo.""" from __future__ import annotations import os import re import subprocess import sys import tempfile import uuid from pathlib import Path from typing import Any, Optional from config import AGENT_EXEC_TIMEOUT, AGENT_MAX_NEW_TOKENS, AGENT_MAX_STEPS, AGENT_TEMPERATURE, DONE_MARKERS from prompts import SYSTEM_PROMPT class ContextManager: def __init__(self, system_prompt: str, max_tokens: int = 6000): self.system_prompt = system_prompt self.max_tokens = max_tokens self.messages: list[dict] = [] self.pinned_first_msg: Optional[dict] = None def add_user(self, content: str) -> None: msg = {"role": "user", "content": content} if self.pinned_first_msg is None: self.pinned_first_msg = msg self.messages.append(msg) def add_assistant(self, content: str) -> None: self.messages.append({"role": "assistant", "content": content}) def add_result(self, result: str) -> None: self.messages.append({ "role": "user", "content": f"\n[EXEC:real]\n{result[:2000]}\n", }) def get_messages(self) -> list[dict]: recent = self._trim_to_budget() full: list[dict] = [{"role": "system", "content": self.system_prompt}] if self.pinned_first_msg: full.append(self.pinned_first_msg) if recent and recent[0].get("content") == self.pinned_first_msg.get("content"): recent = recent[1:] full.extend(recent) return full def _trim_to_budget(self) -> list[dict]: budget = self.max_tokens trimmed: list[dict] = [] for msg in reversed(self.messages): tokens = len(msg["content"].split()) * 1.3 if budget - tokens < 0: break trimmed.insert(0, msg) budget -= tokens return trimmed def extract_code_blocks(text: str) -> list[str]: blocks = re.findall(r"```python\n(.*?)```", text, re.DOTALL) if not blocks: blocks = re.findall(r"```\n(.*?)```", text, re.DOTALL) return blocks def detect_output_files(code: str) -> list[str]: files: list[str] = [] for pattern in ( r'savefig\(["\']([^"\']+)["\']\)', r'write_html\(["\']([^"\']+)["\']\)', r'to_csv\(["\']([^"\']+)["\']\)', ): files.extend(re.findall(pattern, code)) return files def format_exec_result(result: dict) -> str: if result["success"]: out = result["stdout"] or "(no output)" if result["files"]: out += f"\nFiles saved: {list(result['files'].keys())}" else: out = result["stderr"] or result["stdout"] or "(execution failed)" return out def execute_python(code: str, working_dir: str, timeout: int = 30) -> dict: os.makedirs(working_dir, exist_ok=True) safe_dir = working_dir.replace("\\", "/").replace("'", "\\'") preamble = ( f"import os\nos.chdir('{safe_dir}')\n" "import matplotlib\nmatplotlib.use('Agg')\n" "import warnings\nwarnings.filterwarnings('ignore')\n" ) with tempfile.NamedTemporaryFile( mode="w", suffix=".py", dir=working_dir, delete=False, encoding="utf-8", ) as f: f.write(preamble + code) tmp_path = f.name try: proc = subprocess.run( [sys.executable, tmp_path], capture_output=True, text=True, timeout=timeout, cwd=working_dir, ) return { "stdout": (proc.stdout or "")[:3000], "stderr": (proc.stderr or "")[:1500], "files": {}, "success": proc.returncode == 0, } except subprocess.TimeoutExpired: return {"stdout": "", "stderr": f"TimeoutError: exceeded {timeout}s", "files": {}, "success": False} finally: if os.path.exists(tmp_path): os.unlink(tmp_path) def _read_tabular(path: Path, nrows: int = 200): import pandas as pd suffix = path.suffix.lower() if suffix in (".xlsx", ".xls"): return pd.read_excel(path, nrows=nrows) return pd.read_csv(path, nrows=nrows) def inspect_data(path: Path) -> dict[str, str]: df = _read_tabular(path, nrows=200) schema = "\n".join(f" {c}: {df[c].dtype}" for c in df.columns) sample = df.head(5).to_string(index=False) kind = "excel" if path.suffix.lower() in (".xlsx", ".xls") else "csv" return { "type": kind, "schema": schema, "sample": sample, "row_counts": f"preview_rows={len(df)} (file may be larger)", } def inspect_csv(path: Path) -> dict[str, str]: return inspect_data(path) def build_user_message(data_path: Path, task: str) -> str: info = inspect_data(data_path) filename = data_path.name read_hint = ( f"pd.read_excel('{filename}')" if info["type"] == "excel" else f"pd.read_csv('{filename}')" ) lines = [ f"Data source: {filename}", f"Working directory contains: {filename}", f"Type: {info['type']}", "", "Schema:", info["schema"], "", "Sample rows:", info["sample"], "", info["row_counts"], "", f"Task: {task}", "", f"Read the file with pandas: {read_hint}", ] return "\n".join(lines) DONE_MARKERS = ("**Summary:**", "**Finding:**", "**Conclusion:**", "**Results:**") FINISH_MARKERS = DONE_MARKERS + ( "**Answer:**", "**ANSWER:**", "Final Answer:", "final answer:", ) _GEMMA_TOKEN_RE = re.compile(r"<(?:start_of_turn|end_of_turn|turn)[^>]*>|<\|[^|]+\|>") _THINK_RE = re.compile(r".*?", re.DOTALL | re.IGNORECASE) def _strip_model_noise(text: str) -> str: text = _THINK_RE.sub("", text) text = _GEMMA_TOKEN_RE.sub("", text) return text.strip() def _answer_from_stdout(stdout: str) -> str: """Best-effort answer from verified execution output.""" if not stdout: return "" label_patterns = [ r"(?:Product|product) with highest (?:total )?revenue:\s*(.+)", r"(?:Top product|top product)(?:\s+by revenue)?:\s*(.+)", r"(?:The answer is|Answer|Result|Final answer):\s*(.+)", r"(?:Maximum|Max) revenue:\s*([\d.,]+)", ] for line in stdout.splitlines(): line = line.strip() if not line or line.startswith("Name:") or "dtype:" in line: continue for pat in label_patterns: m = re.search(pat, line, re.IGNORECASE) if m: val = m.group(1).strip().strip(".") if val and val.lower() not in ("nan", "none"): return val lines = [ln.strip() for ln in stdout.splitlines() if ln.strip() and "dtype:" not in ln] return lines[-1] if lines else "" def extract_answer(final_text: str, exec_outputs: list[str] | None = None) -> str: """Parse answer: **Answer:** / Final Answer: → execution stdout → last line.""" exec_outputs = exec_outputs or [] cleaned = _strip_model_noise(final_text) tag_patterns = [ r"\*\*Answer:\*\*\s*(.+?)(?:\n|$)", r"\*\*ANSWER:\*\*\s*(.+?)(?:\n|$)", r"Final Answer:\s*(.+?)(?:\n|$)", r"final answer:\s*(.+?)(?:\n|$)", ] for pat in tag_patterns: m = re.search(pat, cleaned, re.IGNORECASE) if m: ans = m.group(1).strip().strip("*").strip() if ans and not ans.startswith("```"): return ans for stdout in reversed(exec_outputs): from_exec = _answer_from_stdout(stdout) if from_exec: return from_exec lines = [ln.strip() for ln in cleaned.splitlines() if ln.strip()] if lines: last = lines[-1] if len(last) < 200 and not last.startswith("```"): return last return "" def extract_summary(final_text: str) -> str: cleaned = _strip_model_noise(final_text) for prefix in ("**Summary:**", "**Finding:**", "**Conclusion:**", "**Results:**"): if prefix in cleaned: tail = cleaned.split(prefix, 1)[1].strip() line = tail.split("\n")[0].strip() if line: return line[:1500] return "" def generate_response(messages: list, model, tokenizer) -> str: import torch input_ids = tokenizer.apply_chat_template( messages, tokenize=True, add_generation_prompt=True, return_tensors="pt", ).to(model.device) with torch.no_grad(): output_ids = model.generate( input_ids, max_new_tokens=AGENT_MAX_NEW_TOKENS, temperature=AGENT_TEMPERATURE, do_sample=AGENT_TEMPERATURE > 0, pad_token_id=tokenizer.eos_token_id, ) return tokenizer.decode(output_ids[0][input_ids.shape[-1] :], skip_special_tokens=False) def run_agent( model, tokenizer, data_path: Path, task: str, *, max_steps: int = AGENT_MAX_STEPS, progress: Optional[Any] = None, stream: bool = False, ) -> dict: """Run generate → execute loop. Returns steps log + final text.""" workspace = Path(tempfile.gettempdir()) / f"datasense_{uuid.uuid4().hex[:10]}" workspace.mkdir(parents=True, exist_ok=True) # Copy dataset into isolated workspace dest = workspace / data_path.name dest.write_bytes(data_path.read_bytes()) context = ContextManager(system_prompt=SYSTEM_PROMPT) context.add_user(build_user_message(dest, task)) step_logs: list[str] = [] exec_outputs: list[str] = [] final_text = "" for step in range(max_steps): if progress is not None: progress((step + 1) / max_steps, desc=f"Step {step + 1}/{max_steps}") response = generate_response(context.get_messages(), model, tokenizer) context.add_assistant(response) final_text = response preview = _strip_model_noise(response).replace("\n", " ")[:180] step_logs.append(f"### Step {step + 1}\n{preview}...\n") if any(m in response for m in FINISH_MARKERS): step_logs.append("✅ Agent finished.\n") if stream: yield ("progress", step + 1, max_steps, "\n".join(step_logs)) break code_blocks = extract_code_blocks(response) if not code_blocks: if exec_outputs: step_logs.append("ℹ️ No more code — answer from execution output.\n") else: step_logs.append("ℹ️ No code block — stopping.\n") if stream: yield ("progress", step + 1, max_steps, "\n".join(step_logs)) break result_str = "" for code_block in code_blocks: out_files = detect_output_files(code_block) result = execute_python( code=code_block, working_dir=str(workspace), timeout=AGENT_EXEC_TIMEOUT, ) result_str = format_exec_result(result) if result["success"] and result_str: exec_outputs.append(result_str) status = "✅" if result["success"] else "❌" step_logs.append(f"{status} **Execution**\n```\n{result_str[:1200]}\n```\n") context.add_result(result_str) if stream: yield ("progress", step + 1, max_steps, "\n".join(step_logs)) answer = extract_answer(final_text, exec_outputs) summary = extract_summary(final_text) result = { "steps_markdown": "\n".join(step_logs), "final_response": final_text, "answer": answer, "summary": summary, "workspace": str(workspace), } if stream: yield ("final", result) else: return result