"""Baseline agents for the FSDS Cleaning Environment. Provides RandomAgent, HeuristicAgent, and extensible Agent interface for benchmarking, evaluation, and RL training. """ from __future__ import annotations import random from typing import Any, Protocol, TypedDict from fsds_cleaning_env.server.cleaning_environment import AVAILABLE_OPERATIONS class ToolCall(TypedDict): """A tool invocation: tool_name and arguments for env.call_tool.""" tool: str arguments: dict[str, Any] # Scripted policies derived from required_ops. Format: list of (operation, column). HEURISTIC_POLICIES: dict[str, list[tuple[str, str | None]]] = { "ecommerce_mobile": [ ("replace_invalid_with_null", "country"), ("replace_invalid_with_null", "items_in_cart"), ("replace_invalid_with_null", "device_os"), ("cast_numeric", "items_in_cart"), ("cast_numeric", "order_value"), ("impute_numeric", "items_in_cart"), ("impute_numeric", "order_value"), ("clip_outliers_iqr", "items_in_cart"), ("clip_outliers_iqr", "order_value"), ("normalize_categories", "device_os"), ("normalize_categories", "country"), ("impute_categorical", "device_os"), ("impute_categorical", "country"), ("cast_datetime", "event_date"), ("drop_duplicates", None), ], "subscription_churn": [ ("replace_invalid_with_null", "monthly_spend"), ("replace_invalid_with_null", "age"), ("replace_invalid_with_null", "tenure_months"), ("replace_invalid_with_null", "payment_method"), ("cast_numeric", "age"), ("cast_numeric", "monthly_spend"), ("cast_numeric", "tenure_months"), ("impute_numeric", "age"), ("impute_numeric", "monthly_spend"), ("impute_numeric", "tenure_months"), ("clip_outliers_iqr", "monthly_spend"), ("normalize_categories", "plan_type"), ("normalize_categories", "payment_method"), ("impute_categorical", "plan_type"), ("impute_categorical", "payment_method"), ("drop_duplicates", None), ], "delivery_eta": [ ("replace_invalid_with_null", "driver_rating"), ("replace_invalid_with_null", "late_deliveries_last_30d"), ("replace_invalid_with_null", "city"), ("replace_invalid_with_null", "vehicle_type"), ("cast_numeric", "driver_rating"), ("cast_numeric", "delivery_distance_km"), ("cast_numeric", "late_deliveries_last_30d"), ("impute_numeric", "driver_rating"), ("impute_numeric", "late_deliveries_last_30d"), ("impute_numeric", "delivery_distance_km"), ("clip_outliers_iqr", "delivery_distance_km"), ("normalize_categories", "city"), ("normalize_categories", "vehicle_type"), ("impute_categorical", "city"), ("impute_categorical", "vehicle_type"), ("drop_duplicates", None), ], } def _extract_reward(result: dict[str, Any]) -> float: if "reward" in result: return float(result["reward"]) if "final_reward" in result: return float(result["final_reward"]) return 0.0 class Agent(Protocol): """Protocol for agents that run episodes in the FSDS Cleaning Environment.""" def run_episode( self, env: Any, task_id: str, max_steps: int = 18, seed: int | None = None, **reset_kwargs: Any, ) -> list[dict[str, Any]]: """Run one episode and return the trajectory (list of step dicts with reward, result, tool_name).""" ... class AgentWithAct(Protocol): """Agent that supports per-step action selection for RL and step-by-step control. Use act(observation, history) to get the next tool call; run_episode uses it in a loop. """ def act(self, observation: dict[str, Any], history: list[dict[str, Any]]) -> ToolCall | None: """Return the next tool call, or None to submit and end the episode.""" ... def run_episode( self, env: Any, task_id: str, max_steps: int = 18, seed: int | None = None, **reset_kwargs: Any, ) -> list[dict[str, Any]]: """Run one episode by repeatedly calling act() until done or max_steps.""" ... class RandomAgent: """Uniform random over valid tool calls. Serves as a lower bound for evaluation.""" def __init__(self, rng: random.Random | None = None) -> None: self._rng = rng or random.Random() def run_episode( self, env: Any, task_id: str, max_steps: int = 18, seed: int | None = None, **reset_kwargs: Any, ) -> list[dict[str, Any]]: trajectory: list[dict[str, Any]] = [] reset_kwargs["seed"] = seed env.reset(task_id=task_id, **reset_kwargs) profile = env.call_tool("profile_data") columns = list(profile.get("columns", [])) columns_no_target = [c for c in columns if c not in ("converted", "churned", "delivery_time_minutes")] steps = 0 submitted = False while steps < max_steps and not submitted: action = self._rng.choice(["inspect", "clean", "gates", "submit"]) if action == "inspect": tool_name = self._rng.choice(["profile_data", "preview_data", "get_task_brief"]) result = env.call_tool(tool_name) elif action == "clean": tool_name = "apply_cleaning_operation" op = self._rng.choice(AVAILABLE_OPERATIONS) if op == "drop_duplicates": result = env.call_tool(tool_name, operation=op) elif columns_no_target: col = self._rng.choice(columns_no_target) result = env.call_tool(tool_name, operation=op, column=col) else: result = env.call_tool(tool_name, operation="drop_duplicates") elif action == "gates": tool_name = "run_quality_gates" result = env.call_tool(tool_name) else: tool_name = "submit_solution" result = env.call_tool(tool_name) submitted = result.get("done", False) trajectory.append({ "tool_name": tool_name, "reward": _extract_reward(result), "result": result, }) steps += 1 return trajectory class HeuristicAgent: """Rule-based agent that follows the canonical cleaning policy for each task.""" def run_episode( self, env: Any, task_id: str, max_steps: int = 18, seed: int | None = None, **reset_kwargs: Any, ) -> list[dict[str, Any]]: policy = HEURISTIC_POLICIES.get(task_id, HEURISTIC_POLICIES["ecommerce_mobile"]) reset_kwargs["seed"] = seed env.reset(task_id=task_id, **reset_kwargs) trajectory: list[dict[str, Any]] = [] for operation, column in policy: if len(trajectory) >= max_steps: break kwargs: dict[str, Any] = {"operation": operation} if column is not None: kwargs["column"] = column result = env.call_tool("apply_cleaning_operation", **kwargs) trajectory.append({ "tool_name": "apply_cleaning_operation", "reward": _extract_reward(result), "result": result, }) if len(trajectory) < max_steps: result = env.call_tool("run_quality_gates") trajectory.append({ "tool_name": "run_quality_gates", "reward": _extract_reward(result), "result": result, }) if len(trajectory) < max_steps: result = env.call_tool("submit_solution") trajectory.append({ "tool_name": "submit_solution", "reward": _extract_reward(result), "result": result, }) return trajectory def _default_parse_llm_output(text: str) -> ToolCall: """Parse JSON tool call from LLM output. Fallback to profile_data.""" import json import re match = re.search(r"\{[^{}]*\"tool\"[^{}]*\}", text, re.DOTALL) if match: try: d = json.loads(match.group()) tool = d.get("tool", "profile_data") args = d.get("arguments", {}) return {"tool": str(tool), "arguments": dict(args)} except (json.JSONDecodeError, TypeError): pass return {"tool": "profile_data", "arguments": {}} class LLMAgentAdapter: """Adapter for HF/LLM-based agents. Wraps a callable that produces tool calls from context. Usage: def my_model_fn(observation, history) -> str: # Build prompt, call model, return raw text return model.generate(...) agent = LLMAgentAdapter(generate_fn=my_model_fn, parse_fn=parse_json_tool_call) """ def __init__( self, generate_fn: Any = None, parse_fn: Any = None, ) -> None: self._generate_fn = generate_fn or (lambda obs, hist: '{"tool": "profile_data", "arguments": {}}') self._parse_fn = parse_fn or _default_parse_llm_output def act(self, observation: dict[str, Any], history: list[dict[str, Any]]) -> ToolCall | None: text = self._generate_fn(observation, history) return self._parse_fn(text) def run_episode( self, env: Any, task_id: str, max_steps: int = 18, seed: int | None = None, **reset_kwargs: Any, ) -> list[dict[str, Any]]: reset_kwargs["seed"] = seed env.reset(task_id=task_id, **reset_kwargs) trajectory: list[dict[str, Any]] = [] history: list[dict[str, Any]] = [] observation: dict[str, Any] = {} for _ in range(max_steps): tool_call = self.act(observation, history) if tool_call is None: result = env.call_tool("submit_solution") trajectory.append({"tool_name": "submit_solution", "reward": result.get("final_reward", 0.0), "result": result}) break tool_name = tool_call["tool"] args = tool_call.get("arguments", {}) result = env.call_tool(tool_name, **args) trajectory.append({"tool_name": tool_name, "reward": _extract_reward(result), "result": result}) history.append({"observation": observation, "tool_call": tool_call, "result": result}) observation = result if result.get("done", False): break return trajectory SYSTEM_PROMPT = """\ You are a Data Cleaning Agent working in a Medallion data pipeline (Bronze → Silver). Your job: inspect a dirty dataset and clean it to Silver quality by choosing \ the right tools in the right order. ## Methodology (FSDS + VDS) 1. INSPECT first: profile_data, preview_data, get_task_brief 2. CLEAN systematically: fix dtypes, strip whitespace, handle missing values, \ remove duplicates, clip outliers 3. VALIDATE before submitting: run_quality_gates to check quality gate 4. SUBMIT: submit_solution when all tests pass ## Output Format Each turn, output exactly one JSON action: {"tool": "", "arguments": {"operation": "", "column": ""}} Top-level tools: profile_data, preview_data, get_task_brief, run_quality_gates, submit_solution Cleaning tool: apply_cleaning_operation — requires an "operation" argument. Available operations for apply_cleaning_operation: drop_duplicates replace_invalid_with_null (requires "column") cast_numeric (requires "column") cast_datetime (requires "column") impute_numeric (requires "column"; optional "strategy": "median"|"mean") impute_categorical (requires "column") normalize_categories (requires "column") clip_outliers_iqr (requires "column") Examples: {"tool": "profile_data", "arguments": {}} {"tool": "apply_cleaning_operation", "arguments": {"operation": "drop_duplicates"}} {"tool": "apply_cleaning_operation", "arguments": {"operation": "cast_numeric", "column": "amount"}} {"tool": "run_quality_gates", "arguments": {}} {"tool": "submit_solution", "arguments": {}} Think step by step. Inspect before cleaning. Validate before submitting.""" class LLMAgent: """Agent powered by a fine-tuned LLM checkpoint (Unsloth/HuggingFace). Loads the model once on first use and generates one JSON action per step conditioned on the current observation and episode history. Args: model_path: Path to the saved model directory (e.g. ``./data-cleaning-grpo-final``). max_new_tokens: Maximum tokens to generate per step. temperature: Sampling temperature (0.0 = greedy). """ def __init__( self, model_path: str = "./data-cleaning-grpo-final", max_new_tokens: int = 128, temperature: float = 0.0, ) -> None: self.model_path = model_path self.max_new_tokens = max_new_tokens self.temperature = temperature self._model = None self._tokenizer = None def _load(self) -> None: import json as _json from unsloth import FastLanguageModel # type: ignore[import] model, tokenizer = FastLanguageModel.from_pretrained( model_name=self.model_path, max_seq_length=2048, load_in_4bit=True, ) FastLanguageModel.for_inference(model) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token self._model = model self._tokenizer = tokenizer self._json = _json def _build_user_message( self, observation: dict[str, Any], history: list[dict[str, Any]] ) -> str: import json as _json parts: list[str] = [] if not history: parts.append("You just received a dirty Bronze-layer dataset. What is your first action?") else: last = history[-1] obs_summary = _json.dumps(last["result"], ensure_ascii=False)[:400] parts.append(f"Last action: {last['tool_call']['tool']}") parts.append(f"Result (truncated): {obs_summary}") parts.append("What is your next action?") return "\n".join(parts) def _generate(self, observation: dict[str, Any], history: list[dict[str, Any]]) -> str: if self._model is None: self._load() import torch # type: ignore[import] messages = [ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": self._build_user_message(observation, history)}, ] text = self._tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) inputs = self._tokenizer(text, return_tensors="pt").to(self._model.device) with torch.no_grad(): output_ids = self._model.generate( **inputs, max_new_tokens=self.max_new_tokens, temperature=self.temperature if self.temperature > 0 else None, do_sample=self.temperature > 0, pad_token_id=self._tokenizer.eos_token_id, ) generated = output_ids[0][inputs["input_ids"].shape[-1]:] return self._tokenizer.decode(generated, skip_special_tokens=True) def run_episode( self, env: Any, task_id: str, max_steps: int = 18, seed: int | None = None, **reset_kwargs: Any, ) -> list[dict[str, Any]]: reset_kwargs["seed"] = seed env.reset(task_id=task_id, **reset_kwargs) trajectory: list[dict[str, Any]] = [] history: list[dict[str, Any]] = [] observation: dict[str, Any] = {} for _ in range(max_steps): raw = self._generate(observation, history) tool_call = _default_parse_llm_output(raw) tool_name = tool_call["tool"] args = tool_call.get("arguments", {}) result = env.call_tool(tool_name, **args) trajectory.append({ "tool_name": tool_name, "reward": _extract_reward(result), "result": result, "raw_output": raw, }) history.append({"observation": observation, "tool_call": tool_call, "result": result}) observation = result if result.get("done", False): break return trajectory __all__ = [ "Agent", "AgentWithAct", "ToolCall", "RandomAgent", "HeuristicAgent", "LLMAgent", "LLMAgentAdapter", "HEURISTIC_POLICIES", "SYSTEM_PROMPT", ]