Spaces:
Running on Zero
Running on Zero
| """Lightweight execution-grounded agent for the HF Space demo.""" | |
| from __future__ import annotations | |
| import os | |
| import re | |
| import subprocess | |
| import sys | |
| import tempfile | |
| import uuid | |
| from pathlib import Path | |
| from typing import Any, Optional | |
| from config import AGENT_EXEC_TIMEOUT, AGENT_MAX_NEW_TOKENS, AGENT_MAX_STEPS, AGENT_TEMPERATURE, DONE_MARKERS | |
| from prompts import SYSTEM_PROMPT | |
| class ContextManager: | |
| def __init__(self, system_prompt: str, max_tokens: int = 6000): | |
| self.system_prompt = system_prompt | |
| self.max_tokens = max_tokens | |
| self.messages: list[dict] = [] | |
| self.pinned_first_msg: Optional[dict] = None | |
| def add_user(self, content: str) -> None: | |
| msg = {"role": "user", "content": content} | |
| if self.pinned_first_msg is None: | |
| self.pinned_first_msg = msg | |
| self.messages.append(msg) | |
| def add_assistant(self, content: str) -> None: | |
| self.messages.append({"role": "assistant", "content": content}) | |
| def add_result(self, result: str) -> None: | |
| self.messages.append({ | |
| "role": "user", | |
| "content": f"<result>\n[EXEC:real]\n{result[:2000]}\n</result>", | |
| }) | |
| def get_messages(self) -> list[dict]: | |
| recent = self._trim_to_budget() | |
| full: list[dict] = [{"role": "system", "content": self.system_prompt}] | |
| if self.pinned_first_msg: | |
| full.append(self.pinned_first_msg) | |
| if recent and recent[0].get("content") == self.pinned_first_msg.get("content"): | |
| recent = recent[1:] | |
| full.extend(recent) | |
| return full | |
| def _trim_to_budget(self) -> list[dict]: | |
| budget = self.max_tokens | |
| trimmed: list[dict] = [] | |
| for msg in reversed(self.messages): | |
| tokens = len(msg["content"].split()) * 1.3 | |
| if budget - tokens < 0: | |
| break | |
| trimmed.insert(0, msg) | |
| budget -= tokens | |
| return trimmed | |
| def extract_code_blocks(text: str) -> list[str]: | |
| blocks = re.findall(r"```python\n(.*?)```", text, re.DOTALL) | |
| if not blocks: | |
| blocks = re.findall(r"```\n(.*?)```", text, re.DOTALL) | |
| return blocks | |
| def detect_output_files(code: str) -> list[str]: | |
| files: list[str] = [] | |
| for pattern in ( | |
| r'savefig\(["\']([^"\']+)["\']\)', | |
| r'write_html\(["\']([^"\']+)["\']\)', | |
| r'to_csv\(["\']([^"\']+)["\']\)', | |
| ): | |
| files.extend(re.findall(pattern, code)) | |
| return files | |
| def format_exec_result(result: dict) -> str: | |
| if result["success"]: | |
| out = result["stdout"] or "(no output)" | |
| if result["files"]: | |
| out += f"\nFiles saved: {list(result['files'].keys())}" | |
| else: | |
| out = result["stderr"] or result["stdout"] or "(execution failed)" | |
| return out | |
| def execute_python(code: str, working_dir: str, timeout: int = 30) -> dict: | |
| os.makedirs(working_dir, exist_ok=True) | |
| safe_dir = working_dir.replace("\\", "/").replace("'", "\\'") | |
| preamble = ( | |
| f"import os\nos.chdir('{safe_dir}')\n" | |
| "import matplotlib\nmatplotlib.use('Agg')\n" | |
| "import warnings\nwarnings.filterwarnings('ignore')\n" | |
| ) | |
| with tempfile.NamedTemporaryFile( | |
| mode="w", suffix=".py", dir=working_dir, delete=False, encoding="utf-8", | |
| ) as f: | |
| f.write(preamble + code) | |
| tmp_path = f.name | |
| try: | |
| proc = subprocess.run( | |
| [sys.executable, tmp_path], | |
| capture_output=True, | |
| text=True, | |
| timeout=timeout, | |
| cwd=working_dir, | |
| ) | |
| return { | |
| "stdout": (proc.stdout or "")[:3000], | |
| "stderr": (proc.stderr or "")[:1500], | |
| "files": {}, | |
| "success": proc.returncode == 0, | |
| } | |
| except subprocess.TimeoutExpired: | |
| return {"stdout": "", "stderr": f"TimeoutError: exceeded {timeout}s", "files": {}, "success": False} | |
| finally: | |
| if os.path.exists(tmp_path): | |
| os.unlink(tmp_path) | |
| def _read_tabular(path: Path, nrows: int = 200): | |
| import pandas as pd | |
| suffix = path.suffix.lower() | |
| if suffix in (".xlsx", ".xls"): | |
| return pd.read_excel(path, nrows=nrows) | |
| return pd.read_csv(path, nrows=nrows) | |
| def inspect_data(path: Path) -> dict[str, str]: | |
| df = _read_tabular(path, nrows=200) | |
| schema = "\n".join(f" {c}: {df[c].dtype}" for c in df.columns) | |
| sample = df.head(5).to_string(index=False) | |
| kind = "excel" if path.suffix.lower() in (".xlsx", ".xls") else "csv" | |
| return { | |
| "type": kind, | |
| "schema": schema, | |
| "sample": sample, | |
| "row_counts": f"preview_rows={len(df)} (file may be larger)", | |
| } | |
| def inspect_csv(path: Path) -> dict[str, str]: | |
| return inspect_data(path) | |
| def build_user_message(data_path: Path, task: str) -> str: | |
| info = inspect_data(data_path) | |
| filename = data_path.name | |
| read_hint = ( | |
| f"pd.read_excel('{filename}')" | |
| if info["type"] == "excel" | |
| else f"pd.read_csv('{filename}')" | |
| ) | |
| lines = [ | |
| f"Data source: {filename}", | |
| f"Working directory contains: {filename}", | |
| f"Type: {info['type']}", | |
| "", | |
| "Schema:", | |
| info["schema"], | |
| "", | |
| "Sample rows:", | |
| info["sample"], | |
| "", | |
| info["row_counts"], | |
| "", | |
| f"Task: {task}", | |
| "", | |
| f"Read the file with pandas: {read_hint}", | |
| ] | |
| return "\n".join(lines) | |
| DONE_MARKERS = ("**Summary:**", "**Finding:**", "**Conclusion:**", "**Results:**") | |
| FINISH_MARKERS = DONE_MARKERS + ( | |
| "**Answer:**", | |
| "**ANSWER:**", | |
| "Final Answer:", | |
| "final answer:", | |
| ) | |
| _GEMMA_TOKEN_RE = re.compile(r"<(?:start_of_turn|end_of_turn|turn)[^>]*>|<\|[^|]+\|>") | |
| _THINK_RE = re.compile(r"<think>.*?</think>", re.DOTALL | re.IGNORECASE) | |
| def _strip_model_noise(text: str) -> str: | |
| text = _THINK_RE.sub("", text) | |
| text = _GEMMA_TOKEN_RE.sub("", text) | |
| return text.strip() | |
| def _answer_from_stdout(stdout: str) -> str: | |
| """Best-effort answer from verified execution output.""" | |
| if not stdout: | |
| return "" | |
| label_patterns = [ | |
| r"(?:Product|product) with highest (?:total )?revenue:\s*(.+)", | |
| r"(?:Top product|top product)(?:\s+by revenue)?:\s*(.+)", | |
| r"(?:The answer is|Answer|Result|Final answer):\s*(.+)", | |
| r"(?:Maximum|Max) revenue:\s*([\d.,]+)", | |
| ] | |
| for line in stdout.splitlines(): | |
| line = line.strip() | |
| if not line or line.startswith("Name:") or "dtype:" in line: | |
| continue | |
| for pat in label_patterns: | |
| m = re.search(pat, line, re.IGNORECASE) | |
| if m: | |
| val = m.group(1).strip().strip(".") | |
| if val and val.lower() not in ("nan", "none"): | |
| return val | |
| lines = [ln.strip() for ln in stdout.splitlines() if ln.strip() and "dtype:" not in ln] | |
| return lines[-1] if lines else "" | |
| def extract_answer(final_text: str, exec_outputs: list[str] | None = None) -> str: | |
| """Parse answer: **Answer:** / Final Answer: → execution stdout → last line.""" | |
| exec_outputs = exec_outputs or [] | |
| cleaned = _strip_model_noise(final_text) | |
| tag_patterns = [ | |
| r"\*\*Answer:\*\*\s*(.+?)(?:\n|$)", | |
| r"\*\*ANSWER:\*\*\s*(.+?)(?:\n|$)", | |
| r"Final Answer:\s*(.+?)(?:\n|$)", | |
| r"final answer:\s*(.+?)(?:\n|$)", | |
| ] | |
| for pat in tag_patterns: | |
| m = re.search(pat, cleaned, re.IGNORECASE) | |
| if m: | |
| ans = m.group(1).strip().strip("*").strip() | |
| if ans and not ans.startswith("```"): | |
| return ans | |
| for stdout in reversed(exec_outputs): | |
| from_exec = _answer_from_stdout(stdout) | |
| if from_exec: | |
| return from_exec | |
| lines = [ln.strip() for ln in cleaned.splitlines() if ln.strip()] | |
| if lines: | |
| last = lines[-1] | |
| if len(last) < 200 and not last.startswith("```"): | |
| return last | |
| return "" | |
| def extract_summary(final_text: str) -> str: | |
| cleaned = _strip_model_noise(final_text) | |
| for prefix in ("**Summary:**", "**Finding:**", "**Conclusion:**", "**Results:**"): | |
| if prefix in cleaned: | |
| tail = cleaned.split(prefix, 1)[1].strip() | |
| line = tail.split("\n")[0].strip() | |
| if line: | |
| return line[:1500] | |
| return "" | |
| def generate_response(messages: list, model, tokenizer) -> str: | |
| import torch | |
| input_ids = tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=True, | |
| add_generation_prompt=True, | |
| return_tensors="pt", | |
| ).to(model.device) | |
| with torch.no_grad(): | |
| output_ids = model.generate( | |
| input_ids, | |
| max_new_tokens=AGENT_MAX_NEW_TOKENS, | |
| temperature=AGENT_TEMPERATURE, | |
| do_sample=AGENT_TEMPERATURE > 0, | |
| pad_token_id=tokenizer.eos_token_id, | |
| ) | |
| return tokenizer.decode(output_ids[0][input_ids.shape[-1] :], skip_special_tokens=False) | |
| def run_agent( | |
| model, | |
| tokenizer, | |
| data_path: Path, | |
| task: str, | |
| *, | |
| max_steps: int = AGENT_MAX_STEPS, | |
| progress: Optional[Any] = None, | |
| stream: bool = False, | |
| ) -> dict: | |
| """Run generate → execute loop. Returns steps log + final text.""" | |
| workspace = Path(tempfile.gettempdir()) / f"datasense_{uuid.uuid4().hex[:10]}" | |
| workspace.mkdir(parents=True, exist_ok=True) | |
| # Copy dataset into isolated workspace | |
| dest = workspace / data_path.name | |
| dest.write_bytes(data_path.read_bytes()) | |
| context = ContextManager(system_prompt=SYSTEM_PROMPT) | |
| context.add_user(build_user_message(dest, task)) | |
| step_logs: list[str] = [] | |
| exec_outputs: list[str] = [] | |
| final_text = "" | |
| for step in range(max_steps): | |
| if progress is not None: | |
| progress((step + 1) / max_steps, desc=f"Step {step + 1}/{max_steps}") | |
| response = generate_response(context.get_messages(), model, tokenizer) | |
| context.add_assistant(response) | |
| final_text = response | |
| preview = _strip_model_noise(response).replace("\n", " ")[:180] | |
| step_logs.append(f"### Step {step + 1}\n{preview}...\n") | |
| if any(m in response for m in FINISH_MARKERS): | |
| step_logs.append("✅ Agent finished.\n") | |
| if stream: | |
| yield ("progress", step + 1, max_steps, "\n".join(step_logs)) | |
| break | |
| code_blocks = extract_code_blocks(response) | |
| if not code_blocks: | |
| if exec_outputs: | |
| step_logs.append("ℹ️ No more code — answer from execution output.\n") | |
| else: | |
| step_logs.append("ℹ️ No code block — stopping.\n") | |
| if stream: | |
| yield ("progress", step + 1, max_steps, "\n".join(step_logs)) | |
| break | |
| result_str = "" | |
| for code_block in code_blocks: | |
| out_files = detect_output_files(code_block) | |
| result = execute_python( | |
| code=code_block, | |
| working_dir=str(workspace), | |
| timeout=AGENT_EXEC_TIMEOUT, | |
| ) | |
| result_str = format_exec_result(result) | |
| if result["success"] and result_str: | |
| exec_outputs.append(result_str) | |
| status = "✅" if result["success"] else "❌" | |
| step_logs.append(f"{status} **Execution**\n```\n{result_str[:1200]}\n```\n") | |
| context.add_result(result_str) | |
| if stream: | |
| yield ("progress", step + 1, max_steps, "\n".join(step_logs)) | |
| answer = extract_answer(final_text, exec_outputs) | |
| summary = extract_summary(final_text) | |
| result = { | |
| "steps_markdown": "\n".join(step_logs), | |
| "final_response": final_text, | |
| "answer": answer, | |
| "summary": summary, | |
| "workspace": str(workspace), | |
| } | |
| if stream: | |
| yield ("final", result) | |
| else: | |
| return result | |