"""Run LLM baselines against the DataClean-Env server. Requires: 1. A running DataClean-Env server (e.g. `python -m dataclean_env.server`) 2. An LLM inference endpoint (vLLM, TGI, OpenAI-compatible, etc.) Environment variables: API_BASE_URL - DataClean-Env server URL (default: http://localhost:8000) MODEL_NAME - Model identifier for the LLM endpoint (e.g. "meta-llama/Llama-3-8B") HF_TOKEN - HuggingFace token (if needed for gated models) LLM_BASE_URL - LLM inference endpoint (default: http://localhost:8001/v1) Usage: API_BASE_URL=http://localhost:8000 MODEL_NAME=gpt-4 python3 scripts/run_baselines.py """ from __future__ import annotations import json import os import sys from typing import Any, Dict, List # --------------------------------------------------------------------------- # Bootstrap: install openenv mock if the real package is absent. # Note: the DataCleanEnv *client* needs a real running server at runtime, # but the mock lets the module import succeed for validation and --help. # --------------------------------------------------------------------------- def _ensure_openenv_mock() -> None: """Install a lightweight openenv mock into sys.modules if needed.""" try: import openenv.core.env_server # noqa: F401 return except ImportError: pass from types import ModuleType class _Base: def __init__(self, **kw: object) -> None: for k, v in kw.items(): setattr(self, k, v) class _Environment: def __init__(self) -> None: pass def __class_getitem__(cls, item): # type: ignore[override] return cls class _EnvClient: def __init__(self, *a: object, **kw: object) -> None: pass def __class_getitem__(cls, item): # type: ignore[override] return cls names = [ "openenv", "openenv.core", "openenv.core.env_server", "openenv.core.env_server.types", "openenv.core.env_client", "openenv.core.client_types", ] mods = {n: ModuleType(n) for n in names} for n, m in mods.items(): sys.modules[n] = m mods["openenv"].core = mods["openenv.core"] # type: ignore[attr-defined] mods["openenv.core"].env_server = mods["openenv.core.env_server"] # type: ignore[attr-defined] mods["openenv.core"].env_client = mods["openenv.core.env_client"] # type: ignore[attr-defined] mods["openenv.core"].client_types = mods["openenv.core.client_types"] # type: ignore[attr-defined] for attr in ("Action", "Observation", "State"): setattr(mods["openenv.core.env_server"], attr, type(attr, (_Base,), {})) setattr(mods["openenv.core.env_server"], "Environment", _Environment) setattr(mods["openenv.core.env_server.types"], "EnvironmentMetadata", _Base) setattr(mods["openenv.core.env_client"], "EnvClient", _EnvClient) setattr(mods["openenv.core.client_types"], "StepResult", _Base) _ensure_openenv_mock() # --------------------------------------------------------------------------- # Configuration from environment # --------------------------------------------------------------------------- API_BASE_URL = os.environ.get("API_BASE_URL", "http://localhost:8000") LLM_BASE_URL = os.environ.get("LLM_BASE_URL", "http://localhost:8001/v1") MODEL_NAME = os.environ.get("MODEL_NAME", "") HF_TOKEN = os.environ.get("HF_TOKEN", "") TASK_IDS = ["easy_contacts", "medium_employees", "hard_patients"] # --------------------------------------------------------------------------- # Validate prerequisites # --------------------------------------------------------------------------- def _check_prerequisites() -> bool: """Check that required config is available. Returns True if OK.""" ok = True if not MODEL_NAME: print("ERROR: MODEL_NAME env var is not set.") print(" Example: MODEL_NAME=gpt-4 python3 scripts/run_baselines.py") ok = False try: from dataclean_env.client import DataCleanEnv # noqa: F401 except ImportError as exc: print(f"ERROR: Cannot import DataCleanEnv client: {exc}") print(" Install the package: pip install -e .") ok = False try: import httpx # noqa: F401 except ImportError: print("WARNING: httpx not installed. Install with: pip install httpx") print(" The client depends on httpx for HTTP transport.") ok = False return ok # --------------------------------------------------------------------------- # LLM interaction (stub -- replace with your inference logic) # --------------------------------------------------------------------------- def call_llm(prompt: str) -> str: """Call the LLM endpoint and return the completion text. This is a stub. Replace the body with your preferred inference method: - OpenAI-compatible: POST to LLM_BASE_URL/chat/completions - HuggingFace TGI: POST to LLM_BASE_URL/generate - vLLM: POST to LLM_BASE_URL/chat/completions The prompt contains the observation as JSON. The LLM should return a JSON object with keys "action_type" and "params". """ try: import httpx except ImportError: raise RuntimeError("httpx is required. Install with: pip install httpx") headers: Dict[str, str] = {"Content-Type": "application/json"} if HF_TOKEN: headers["Authorization"] = f"Bearer {HF_TOKEN}" payload = { "model": MODEL_NAME, "messages": [ { "role": "system", "content": ( "You are a data cleaning agent. Given a dataset observation, " "return a JSON action with keys 'action_type' and 'params'. " "Available actions: fix_value, delete_row, fill_missing, " "standardize_format, merge_duplicates, flag_anomaly, " "split_column, rename_column, cast_type, escalate_to_human, " "mark_complete." ), }, {"role": "user", "content": prompt}, ], "temperature": 0.0, "max_tokens": 512, } resp = httpx.post( f"{LLM_BASE_URL}/chat/completions", json=payload, headers=headers, timeout=60.0, ) resp.raise_for_status() data = resp.json() return data["choices"][0]["message"]["content"] def parse_llm_action(text: str) -> Dict[str, Any]: """Parse an LLM response into an action dict. Expects JSON with "action_type" and "params" keys. Falls back to mark_complete if parsing fails. """ # Try to extract JSON from the response (handle markdown code blocks) cleaned = text.strip() if "```" in cleaned: # Extract content between first pair of triple backticks parts = cleaned.split("```") if len(parts) >= 3: cleaned = parts[1] # Remove optional language tag on first line if cleaned.startswith("json"): cleaned = cleaned[4:] cleaned = cleaned.strip() try: parsed = json.loads(cleaned) if "action_type" in parsed: return parsed except (json.JSONDecodeError, TypeError): pass # Fallback: mark_complete print(f" WARNING: Could not parse LLM response, falling back to mark_complete") return {"action_type": "mark_complete", "params": {}} # --------------------------------------------------------------------------- # Run one episode # --------------------------------------------------------------------------- def run_episode(task_id: str) -> float: """Run one LLM-driven episode. Returns the final score.""" from dataclean_env.client import DataCleanEnv from dataclean_env.models import DataCleanAction with DataCleanEnv(base_url=API_BASE_URL).sync() as env: result = env.reset(task_id=task_id) obs = result.observation step = 0 while not obs.done: # Build prompt from observation prompt_data = { "task_id": obs.task_id, "step": obs.step_number, "steps_remaining": obs.steps_remaining, "row_count": obs.row_count, "columns": obs.columns, "issues_remaining": obs.issues_remaining, "quality_issues": [ { "row_id": qi.row_id, "column": qi.column, "issue_type": qi.issue_type, "description": qi.description, "suggestion": qi.suggestion, } for qi in obs.quality_issues[:20] # Cap for context length ], "rows": obs.rows[:15], # Cap for context length } prompt = json.dumps(prompt_data, indent=2, default=str) # Get LLM action llm_text = call_llm(prompt) action_dict = parse_llm_action(llm_text) action = DataCleanAction( action_type=action_dict["action_type"], params=action_dict.get("params", {}), ) result = env.step(action) obs = result.observation step += 1 print(f" Step {step}: {action_dict['action_type']} -> reward={obs.reward}") return obs.reward # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- def main() -> None: if not _check_prerequisites(): print("\nFix the issues above and try again.") sys.exit(1) print(f"Model: {MODEL_NAME}") print(f"Server: {API_BASE_URL}") print(f"LLM API: {LLM_BASE_URL}") print() results: Dict[str, float] = {} for task_id in TASK_IDS: print(f"--- {task_id} ---") try: score = run_episode(task_id) results[task_id] = score print(f" Final score: {score:.4f}") except Exception as exc: print(f" ERROR: {exc}") results[task_id] = -1.0 # Print summary table print("\n" + "=" * 50) print(f"Baseline Results: {MODEL_NAME}") print("=" * 50) print(f"{'Task':<25} {'Score':>10}") print("-" * 50) for task_id in TASK_IDS: s = results.get(task_id, -1.0) score_str = f"{s:.4f}" if s >= 0 else "ERROR" print(f"{task_id:<25} {score_str:>10}") valid = [s for s in results.values() if s >= 0] if valid: mean = sum(valid) / len(valid) print("-" * 50) print(f"{'Mean':<25} {mean:>10.4f}") print() if __name__ == "__main__": main()