fsds_cleaning_env / agents.py
israaaML's picture
v3: benchmark results, final report, agent/eval improvements, smoke test fixes
b3fc5ee
"""Baseline agents for the FSDS Cleaning Environment.
Provides RandomAgent, HeuristicAgent, and extensible Agent interface for
benchmarking, evaluation, and RL training.
"""
from __future__ import annotations
import random
from typing import Any, Protocol, TypedDict
from fsds_cleaning_env.server.cleaning_environment import AVAILABLE_OPERATIONS
class ToolCall(TypedDict):
"""A tool invocation: tool_name and arguments for env.call_tool."""
tool: str
arguments: dict[str, Any]
# Scripted policies derived from required_ops. Format: list of (operation, column).
HEURISTIC_POLICIES: dict[str, list[tuple[str, str | None]]] = {
"ecommerce_mobile": [
("replace_invalid_with_null", "country"),
("replace_invalid_with_null", "items_in_cart"),
("replace_invalid_with_null", "device_os"),
("cast_numeric", "items_in_cart"),
("cast_numeric", "order_value"),
("impute_numeric", "items_in_cart"),
("impute_numeric", "order_value"),
("clip_outliers_iqr", "items_in_cart"),
("clip_outliers_iqr", "order_value"),
("normalize_categories", "device_os"),
("normalize_categories", "country"),
("impute_categorical", "device_os"),
("impute_categorical", "country"),
("cast_datetime", "event_date"),
("drop_duplicates", None),
],
"subscription_churn": [
("replace_invalid_with_null", "monthly_spend"),
("replace_invalid_with_null", "age"),
("replace_invalid_with_null", "tenure_months"),
("replace_invalid_with_null", "payment_method"),
("cast_numeric", "age"),
("cast_numeric", "monthly_spend"),
("cast_numeric", "tenure_months"),
("impute_numeric", "age"),
("impute_numeric", "monthly_spend"),
("impute_numeric", "tenure_months"),
("clip_outliers_iqr", "monthly_spend"),
("normalize_categories", "plan_type"),
("normalize_categories", "payment_method"),
("impute_categorical", "plan_type"),
("impute_categorical", "payment_method"),
("drop_duplicates", None),
],
"delivery_eta": [
("replace_invalid_with_null", "driver_rating"),
("replace_invalid_with_null", "late_deliveries_last_30d"),
("replace_invalid_with_null", "city"),
("replace_invalid_with_null", "vehicle_type"),
("cast_numeric", "driver_rating"),
("cast_numeric", "delivery_distance_km"),
("cast_numeric", "late_deliveries_last_30d"),
("impute_numeric", "driver_rating"),
("impute_numeric", "late_deliveries_last_30d"),
("impute_numeric", "delivery_distance_km"),
("clip_outliers_iqr", "delivery_distance_km"),
("normalize_categories", "city"),
("normalize_categories", "vehicle_type"),
("impute_categorical", "city"),
("impute_categorical", "vehicle_type"),
("drop_duplicates", None),
],
}
def _extract_reward(result: dict[str, Any]) -> float:
if "reward" in result:
return float(result["reward"])
if "final_reward" in result:
return float(result["final_reward"])
return 0.0
class Agent(Protocol):
"""Protocol for agents that run episodes in the FSDS Cleaning Environment."""
def run_episode(
self,
env: Any,
task_id: str,
max_steps: int = 18,
seed: int | None = None,
**reset_kwargs: Any,
) -> list[dict[str, Any]]:
"""Run one episode and return the trajectory (list of step dicts with reward, result, tool_name)."""
...
class AgentWithAct(Protocol):
"""Agent that supports per-step action selection for RL and step-by-step control.
Use act(observation, history) to get the next tool call; run_episode uses it in a loop.
"""
def act(self, observation: dict[str, Any], history: list[dict[str, Any]]) -> ToolCall | None:
"""Return the next tool call, or None to submit and end the episode."""
...
def run_episode(
self,
env: Any,
task_id: str,
max_steps: int = 18,
seed: int | None = None,
**reset_kwargs: Any,
) -> list[dict[str, Any]]:
"""Run one episode by repeatedly calling act() until done or max_steps."""
...
class RandomAgent:
"""Uniform random over valid tool calls. Serves as a lower bound for evaluation."""
def __init__(self, rng: random.Random | None = None) -> None:
self._rng = rng or random.Random()
def run_episode(
self,
env: Any,
task_id: str,
max_steps: int = 18,
seed: int | None = None,
**reset_kwargs: Any,
) -> list[dict[str, Any]]:
trajectory: list[dict[str, Any]] = []
reset_kwargs["seed"] = seed
env.reset(task_id=task_id, **reset_kwargs)
profile = env.call_tool("profile_data")
columns = list(profile.get("columns", []))
columns_no_target = [c for c in columns if c not in ("converted", "churned", "delivery_time_minutes")]
steps = 0
submitted = False
while steps < max_steps and not submitted:
action = self._rng.choice(["inspect", "clean", "gates", "submit"])
if action == "inspect":
tool_name = self._rng.choice(["profile_data", "preview_data", "get_task_brief"])
result = env.call_tool(tool_name)
elif action == "clean":
tool_name = "apply_cleaning_operation"
op = self._rng.choice(AVAILABLE_OPERATIONS)
if op == "drop_duplicates":
result = env.call_tool(tool_name, operation=op)
elif columns_no_target:
col = self._rng.choice(columns_no_target)
result = env.call_tool(tool_name, operation=op, column=col)
else:
result = env.call_tool(tool_name, operation="drop_duplicates")
elif action == "gates":
tool_name = "run_quality_gates"
result = env.call_tool(tool_name)
else:
tool_name = "submit_solution"
result = env.call_tool(tool_name)
submitted = result.get("done", False)
trajectory.append({
"tool_name": tool_name,
"reward": _extract_reward(result),
"result": result,
})
steps += 1
return trajectory
class HeuristicAgent:
"""Rule-based agent that follows the canonical cleaning policy for each task."""
def run_episode(
self,
env: Any,
task_id: str,
max_steps: int = 18,
seed: int | None = None,
**reset_kwargs: Any,
) -> list[dict[str, Any]]:
policy = HEURISTIC_POLICIES.get(task_id, HEURISTIC_POLICIES["ecommerce_mobile"])
reset_kwargs["seed"] = seed
env.reset(task_id=task_id, **reset_kwargs)
trajectory: list[dict[str, Any]] = []
for operation, column in policy:
if len(trajectory) >= max_steps:
break
kwargs: dict[str, Any] = {"operation": operation}
if column is not None:
kwargs["column"] = column
result = env.call_tool("apply_cleaning_operation", **kwargs)
trajectory.append({
"tool_name": "apply_cleaning_operation",
"reward": _extract_reward(result),
"result": result,
})
if len(trajectory) < max_steps:
result = env.call_tool("run_quality_gates")
trajectory.append({
"tool_name": "run_quality_gates",
"reward": _extract_reward(result),
"result": result,
})
if len(trajectory) < max_steps:
result = env.call_tool("submit_solution")
trajectory.append({
"tool_name": "submit_solution",
"reward": _extract_reward(result),
"result": result,
})
return trajectory
def _default_parse_llm_output(text: str) -> ToolCall:
"""Parse JSON tool call from LLM output. Fallback to profile_data."""
import json
import re
match = re.search(r"\{[^{}]*\"tool\"[^{}]*\}", text, re.DOTALL)
if match:
try:
d = json.loads(match.group())
tool = d.get("tool", "profile_data")
args = d.get("arguments", {})
return {"tool": str(tool), "arguments": dict(args)}
except (json.JSONDecodeError, TypeError):
pass
return {"tool": "profile_data", "arguments": {}}
class LLMAgentAdapter:
"""Adapter for HF/LLM-based agents. Wraps a callable that produces tool calls from context.
Usage:
def my_model_fn(observation, history) -> str:
# Build prompt, call model, return raw text
return model.generate(...)
agent = LLMAgentAdapter(generate_fn=my_model_fn, parse_fn=parse_json_tool_call)
"""
def __init__(
self,
generate_fn: Any = None,
parse_fn: Any = None,
) -> None:
self._generate_fn = generate_fn or (lambda obs, hist: '{"tool": "profile_data", "arguments": {}}')
self._parse_fn = parse_fn or _default_parse_llm_output
def act(self, observation: dict[str, Any], history: list[dict[str, Any]]) -> ToolCall | None:
text = self._generate_fn(observation, history)
return self._parse_fn(text)
def run_episode(
self,
env: Any,
task_id: str,
max_steps: int = 18,
seed: int | None = None,
**reset_kwargs: Any,
) -> list[dict[str, Any]]:
reset_kwargs["seed"] = seed
env.reset(task_id=task_id, **reset_kwargs)
trajectory: list[dict[str, Any]] = []
history: list[dict[str, Any]] = []
observation: dict[str, Any] = {}
for _ in range(max_steps):
tool_call = self.act(observation, history)
if tool_call is None:
result = env.call_tool("submit_solution")
trajectory.append({"tool_name": "submit_solution", "reward": result.get("final_reward", 0.0), "result": result})
break
tool_name = tool_call["tool"]
args = tool_call.get("arguments", {})
result = env.call_tool(tool_name, **args)
trajectory.append({"tool_name": tool_name, "reward": _extract_reward(result), "result": result})
history.append({"observation": observation, "tool_call": tool_call, "result": result})
observation = result
if result.get("done", False):
break
return trajectory
SYSTEM_PROMPT = """\
You are a Data Cleaning Agent working in a Medallion data pipeline (Bronze → Silver).
Your job: inspect a dirty dataset and clean it to Silver quality by choosing \
the right tools in the right order.
## Methodology (FSDS + VDS)
1. INSPECT first: profile_data, preview_data, get_task_brief
2. CLEAN systematically: fix dtypes, strip whitespace, handle missing values, \
remove duplicates, clip outliers
3. VALIDATE before submitting: run_quality_gates to check quality gate
4. SUBMIT: submit_solution when all tests pass
## Output Format
Each turn, output exactly one JSON action:
{"tool": "<tool_name>", "arguments": {"operation": "<op>", "column": "<col_or_omit>"}}
Top-level tools: profile_data, preview_data, get_task_brief, run_quality_gates, submit_solution
Cleaning tool: apply_cleaning_operation — requires an "operation" argument.
Available operations for apply_cleaning_operation:
drop_duplicates
replace_invalid_with_null (requires "column")
cast_numeric (requires "column")
cast_datetime (requires "column")
impute_numeric (requires "column"; optional "strategy": "median"|"mean")
impute_categorical (requires "column")
normalize_categories (requires "column")
clip_outliers_iqr (requires "column")
Examples:
{"tool": "profile_data", "arguments": {}}
{"tool": "apply_cleaning_operation", "arguments": {"operation": "drop_duplicates"}}
{"tool": "apply_cleaning_operation", "arguments": {"operation": "cast_numeric", "column": "amount"}}
{"tool": "run_quality_gates", "arguments": {}}
{"tool": "submit_solution", "arguments": {}}
Think step by step. Inspect before cleaning. Validate before submitting."""
class LLMAgent:
"""Agent powered by a fine-tuned LLM checkpoint (Unsloth/HuggingFace).
Loads the model once on first use and generates one JSON action per step
conditioned on the current observation and episode history.
Args:
model_path: Path to the saved model directory (e.g. ``./data-cleaning-grpo-final``).
max_new_tokens: Maximum tokens to generate per step.
temperature: Sampling temperature (0.0 = greedy).
"""
def __init__(
self,
model_path: str = "./data-cleaning-grpo-final",
max_new_tokens: int = 128,
temperature: float = 0.0,
) -> None:
self.model_path = model_path
self.max_new_tokens = max_new_tokens
self.temperature = temperature
self._model = None
self._tokenizer = None
def _load(self) -> None:
import json as _json
from unsloth import FastLanguageModel # type: ignore[import]
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=self.model_path,
max_seq_length=2048,
load_in_4bit=True,
)
FastLanguageModel.for_inference(model)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
self._model = model
self._tokenizer = tokenizer
self._json = _json
def _build_user_message(
self, observation: dict[str, Any], history: list[dict[str, Any]]
) -> str:
import json as _json
parts: list[str] = []
if not history:
parts.append("You just received a dirty Bronze-layer dataset. What is your first action?")
else:
last = history[-1]
obs_summary = _json.dumps(last["result"], ensure_ascii=False)[:400]
parts.append(f"Last action: {last['tool_call']['tool']}")
parts.append(f"Result (truncated): {obs_summary}")
parts.append("What is your next action?")
return "\n".join(parts)
def _generate(self, observation: dict[str, Any], history: list[dict[str, Any]]) -> str:
if self._model is None:
self._load()
import torch # type: ignore[import]
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": self._build_user_message(observation, history)},
]
text = self._tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True
)
inputs = self._tokenizer(text, return_tensors="pt").to(self._model.device)
with torch.no_grad():
output_ids = self._model.generate(
**inputs,
max_new_tokens=self.max_new_tokens,
temperature=self.temperature if self.temperature > 0 else None,
do_sample=self.temperature > 0,
pad_token_id=self._tokenizer.eos_token_id,
)
generated = output_ids[0][inputs["input_ids"].shape[-1]:]
return self._tokenizer.decode(generated, skip_special_tokens=True)
def run_episode(
self,
env: Any,
task_id: str,
max_steps: int = 18,
seed: int | None = None,
**reset_kwargs: Any,
) -> list[dict[str, Any]]:
reset_kwargs["seed"] = seed
env.reset(task_id=task_id, **reset_kwargs)
trajectory: list[dict[str, Any]] = []
history: list[dict[str, Any]] = []
observation: dict[str, Any] = {}
for _ in range(max_steps):
raw = self._generate(observation, history)
tool_call = _default_parse_llm_output(raw)
tool_name = tool_call["tool"]
args = tool_call.get("arguments", {})
result = env.call_tool(tool_name, **args)
trajectory.append({
"tool_name": tool_name,
"reward": _extract_reward(result),
"result": result,
"raw_output": raw,
})
history.append({"observation": observation, "tool_call": tool_call, "result": result})
observation = result
if result.get("done", False):
break
return trajectory
__all__ = [
"Agent",
"AgentWithAct",
"ToolCall",
"RandomAgent",
"HeuristicAgent",
"LLMAgent",
"LLMAgentAdapter",
"HEURISTIC_POLICIES",
"SYSTEM_PROMPT",
]