"""Lightweight execution-grounded agent for the HF Space demo."""
from __future__ import annotations
import os
import re
import subprocess
import sys
import tempfile
import uuid
from pathlib import Path
from typing import Any, Optional
from config import AGENT_EXEC_TIMEOUT, AGENT_MAX_NEW_TOKENS, AGENT_MAX_STEPS, AGENT_TEMPERATURE, DONE_MARKERS
from prompts import SYSTEM_PROMPT
class ContextManager:
def __init__(self, system_prompt: str, max_tokens: int = 6000):
self.system_prompt = system_prompt
self.max_tokens = max_tokens
self.messages: list[dict] = []
self.pinned_first_msg: Optional[dict] = None
def add_user(self, content: str) -> None:
msg = {"role": "user", "content": content}
if self.pinned_first_msg is None:
self.pinned_first_msg = msg
self.messages.append(msg)
def add_assistant(self, content: str) -> None:
self.messages.append({"role": "assistant", "content": content})
def add_result(self, result: str) -> None:
self.messages.append({
"role": "user",
"content": f"\n[EXEC:real]\n{result[:2000]}\n",
})
def get_messages(self) -> list[dict]:
recent = self._trim_to_budget()
full: list[dict] = [{"role": "system", "content": self.system_prompt}]
if self.pinned_first_msg:
full.append(self.pinned_first_msg)
if recent and recent[0].get("content") == self.pinned_first_msg.get("content"):
recent = recent[1:]
full.extend(recent)
return full
def _trim_to_budget(self) -> list[dict]:
budget = self.max_tokens
trimmed: list[dict] = []
for msg in reversed(self.messages):
tokens = len(msg["content"].split()) * 1.3
if budget - tokens < 0:
break
trimmed.insert(0, msg)
budget -= tokens
return trimmed
def extract_code_blocks(text: str) -> list[str]:
blocks = re.findall(r"```python\n(.*?)```", text, re.DOTALL)
if not blocks:
blocks = re.findall(r"```\n(.*?)```", text, re.DOTALL)
return blocks
def detect_output_files(code: str) -> list[str]:
files: list[str] = []
for pattern in (
r'savefig\(["\']([^"\']+)["\']\)',
r'write_html\(["\']([^"\']+)["\']\)',
r'to_csv\(["\']([^"\']+)["\']\)',
):
files.extend(re.findall(pattern, code))
return files
def format_exec_result(result: dict) -> str:
if result["success"]:
out = result["stdout"] or "(no output)"
if result["files"]:
out += f"\nFiles saved: {list(result['files'].keys())}"
else:
out = result["stderr"] or result["stdout"] or "(execution failed)"
return out
def execute_python(code: str, working_dir: str, timeout: int = 30) -> dict:
os.makedirs(working_dir, exist_ok=True)
safe_dir = working_dir.replace("\\", "/").replace("'", "\\'")
preamble = (
f"import os\nos.chdir('{safe_dir}')\n"
"import matplotlib\nmatplotlib.use('Agg')\n"
"import warnings\nwarnings.filterwarnings('ignore')\n"
)
with tempfile.NamedTemporaryFile(
mode="w", suffix=".py", dir=working_dir, delete=False, encoding="utf-8",
) as f:
f.write(preamble + code)
tmp_path = f.name
try:
proc = subprocess.run(
[sys.executable, tmp_path],
capture_output=True,
text=True,
timeout=timeout,
cwd=working_dir,
)
return {
"stdout": (proc.stdout or "")[:3000],
"stderr": (proc.stderr or "")[:1500],
"files": {},
"success": proc.returncode == 0,
}
except subprocess.TimeoutExpired:
return {"stdout": "", "stderr": f"TimeoutError: exceeded {timeout}s", "files": {}, "success": False}
finally:
if os.path.exists(tmp_path):
os.unlink(tmp_path)
def _read_tabular(path: Path, nrows: int = 200):
import pandas as pd
suffix = path.suffix.lower()
if suffix in (".xlsx", ".xls"):
return pd.read_excel(path, nrows=nrows)
return pd.read_csv(path, nrows=nrows)
def inspect_data(path: Path) -> dict[str, str]:
df = _read_tabular(path, nrows=200)
schema = "\n".join(f" {c}: {df[c].dtype}" for c in df.columns)
sample = df.head(5).to_string(index=False)
kind = "excel" if path.suffix.lower() in (".xlsx", ".xls") else "csv"
return {
"type": kind,
"schema": schema,
"sample": sample,
"row_counts": f"preview_rows={len(df)} (file may be larger)",
}
def inspect_csv(path: Path) -> dict[str, str]:
return inspect_data(path)
def build_user_message(data_path: Path, task: str) -> str:
info = inspect_data(data_path)
filename = data_path.name
read_hint = (
f"pd.read_excel('{filename}')"
if info["type"] == "excel"
else f"pd.read_csv('{filename}')"
)
lines = [
f"Data source: {filename}",
f"Working directory contains: {filename}",
f"Type: {info['type']}",
"",
"Schema:",
info["schema"],
"",
"Sample rows:",
info["sample"],
"",
info["row_counts"],
"",
f"Task: {task}",
"",
f"Read the file with pandas: {read_hint}",
]
return "\n".join(lines)
DONE_MARKERS = ("**Summary:**", "**Finding:**", "**Conclusion:**", "**Results:**")
FINISH_MARKERS = DONE_MARKERS + (
"**Answer:**",
"**ANSWER:**",
"Final Answer:",
"final answer:",
)
_GEMMA_TOKEN_RE = re.compile(r"<(?:start_of_turn|end_of_turn|turn)[^>]*>|<\|[^|]+\|>")
_THINK_RE = re.compile(r".*?", re.DOTALL | re.IGNORECASE)
def _strip_model_noise(text: str) -> str:
text = _THINK_RE.sub("", text)
text = _GEMMA_TOKEN_RE.sub("", text)
return text.strip()
def _answer_from_stdout(stdout: str) -> str:
"""Best-effort answer from verified execution output."""
if not stdout:
return ""
label_patterns = [
r"(?:Product|product) with highest (?:total )?revenue:\s*(.+)",
r"(?:Top product|top product)(?:\s+by revenue)?:\s*(.+)",
r"(?:The answer is|Answer|Result|Final answer):\s*(.+)",
r"(?:Maximum|Max) revenue:\s*([\d.,]+)",
]
for line in stdout.splitlines():
line = line.strip()
if not line or line.startswith("Name:") or "dtype:" in line:
continue
for pat in label_patterns:
m = re.search(pat, line, re.IGNORECASE)
if m:
val = m.group(1).strip().strip(".")
if val and val.lower() not in ("nan", "none"):
return val
lines = [ln.strip() for ln in stdout.splitlines() if ln.strip() and "dtype:" not in ln]
return lines[-1] if lines else ""
def extract_answer(final_text: str, exec_outputs: list[str] | None = None) -> str:
"""Parse answer: **Answer:** / Final Answer: → execution stdout → last line."""
exec_outputs = exec_outputs or []
cleaned = _strip_model_noise(final_text)
tag_patterns = [
r"\*\*Answer:\*\*\s*(.+?)(?:\n|$)",
r"\*\*ANSWER:\*\*\s*(.+?)(?:\n|$)",
r"Final Answer:\s*(.+?)(?:\n|$)",
r"final answer:\s*(.+?)(?:\n|$)",
]
for pat in tag_patterns:
m = re.search(pat, cleaned, re.IGNORECASE)
if m:
ans = m.group(1).strip().strip("*").strip()
if ans and not ans.startswith("```"):
return ans
for stdout in reversed(exec_outputs):
from_exec = _answer_from_stdout(stdout)
if from_exec:
return from_exec
lines = [ln.strip() for ln in cleaned.splitlines() if ln.strip()]
if lines:
last = lines[-1]
if len(last) < 200 and not last.startswith("```"):
return last
return ""
def extract_summary(final_text: str) -> str:
cleaned = _strip_model_noise(final_text)
for prefix in ("**Summary:**", "**Finding:**", "**Conclusion:**", "**Results:**"):
if prefix in cleaned:
tail = cleaned.split(prefix, 1)[1].strip()
line = tail.split("\n")[0].strip()
if line:
return line[:1500]
return ""
def generate_response(messages: list, model, tokenizer) -> str:
import torch
input_ids = tokenizer.apply_chat_template(
messages,
tokenize=True,
add_generation_prompt=True,
return_tensors="pt",
).to(model.device)
with torch.no_grad():
output_ids = model.generate(
input_ids,
max_new_tokens=AGENT_MAX_NEW_TOKENS,
temperature=AGENT_TEMPERATURE,
do_sample=AGENT_TEMPERATURE > 0,
pad_token_id=tokenizer.eos_token_id,
)
return tokenizer.decode(output_ids[0][input_ids.shape[-1] :], skip_special_tokens=False)
def run_agent(
model,
tokenizer,
data_path: Path,
task: str,
*,
max_steps: int = AGENT_MAX_STEPS,
progress: Optional[Any] = None,
stream: bool = False,
) -> dict:
"""Run generate → execute loop. Returns steps log + final text."""
workspace = Path(tempfile.gettempdir()) / f"datasense_{uuid.uuid4().hex[:10]}"
workspace.mkdir(parents=True, exist_ok=True)
# Copy dataset into isolated workspace
dest = workspace / data_path.name
dest.write_bytes(data_path.read_bytes())
context = ContextManager(system_prompt=SYSTEM_PROMPT)
context.add_user(build_user_message(dest, task))
step_logs: list[str] = []
exec_outputs: list[str] = []
final_text = ""
for step in range(max_steps):
if progress is not None:
progress((step + 1) / max_steps, desc=f"Step {step + 1}/{max_steps}")
response = generate_response(context.get_messages(), model, tokenizer)
context.add_assistant(response)
final_text = response
preview = _strip_model_noise(response).replace("\n", " ")[:180]
step_logs.append(f"### Step {step + 1}\n{preview}...\n")
if any(m in response for m in FINISH_MARKERS):
step_logs.append("✅ Agent finished.\n")
if stream:
yield ("progress", step + 1, max_steps, "\n".join(step_logs))
break
code_blocks = extract_code_blocks(response)
if not code_blocks:
if exec_outputs:
step_logs.append("ℹ️ No more code — answer from execution output.\n")
else:
step_logs.append("ℹ️ No code block — stopping.\n")
if stream:
yield ("progress", step + 1, max_steps, "\n".join(step_logs))
break
result_str = ""
for code_block in code_blocks:
out_files = detect_output_files(code_block)
result = execute_python(
code=code_block,
working_dir=str(workspace),
timeout=AGENT_EXEC_TIMEOUT,
)
result_str = format_exec_result(result)
if result["success"] and result_str:
exec_outputs.append(result_str)
status = "✅" if result["success"] else "❌"
step_logs.append(f"{status} **Execution**\n```\n{result_str[:1200]}\n```\n")
context.add_result(result_str)
if stream:
yield ("progress", step + 1, max_steps, "\n".join(step_logs))
answer = extract_answer(final_text, exec_outputs)
summary = extract_summary(final_text)
result = {
"steps_markdown": "\n".join(step_logs),
"final_response": final_text,
"answer": answer,
"summary": summary,
"workspace": str(workspace),
}
if stream:
yield ("final", result)
else:
return result