medibill / scripts /run_baselines.py
Anuj424614's picture
Upload folder using huggingface_hub
a09b1f5 verified
"""Run LLM baselines against the DataClean-Env server.
Requires:
1. A running DataClean-Env server (e.g. `python -m dataclean_env.server`)
2. An LLM inference endpoint (vLLM, TGI, OpenAI-compatible, etc.)
Environment variables:
API_BASE_URL - DataClean-Env server URL (default: http://localhost:8000)
MODEL_NAME - Model identifier for the LLM endpoint (e.g. "meta-llama/Llama-3-8B")
HF_TOKEN - HuggingFace token (if needed for gated models)
LLM_BASE_URL - LLM inference endpoint (default: http://localhost:8001/v1)
Usage:
API_BASE_URL=http://localhost:8000 MODEL_NAME=gpt-4 python3 scripts/run_baselines.py
"""
from __future__ import annotations
import json
import os
import sys
from typing import Any, Dict, List
# ---------------------------------------------------------------------------
# Bootstrap: install openenv mock if the real package is absent.
# Note: the DataCleanEnv *client* needs a real running server at runtime,
# but the mock lets the module import succeed for validation and --help.
# ---------------------------------------------------------------------------
def _ensure_openenv_mock() -> None:
"""Install a lightweight openenv mock into sys.modules if needed."""
try:
import openenv.core.env_server # noqa: F401
return
except ImportError:
pass
from types import ModuleType
class _Base:
def __init__(self, **kw: object) -> None:
for k, v in kw.items():
setattr(self, k, v)
class _Environment:
def __init__(self) -> None:
pass
def __class_getitem__(cls, item): # type: ignore[override]
return cls
class _EnvClient:
def __init__(self, *a: object, **kw: object) -> None:
pass
def __class_getitem__(cls, item): # type: ignore[override]
return cls
names = [
"openenv", "openenv.core", "openenv.core.env_server",
"openenv.core.env_server.types", "openenv.core.env_client",
"openenv.core.client_types",
]
mods = {n: ModuleType(n) for n in names}
for n, m in mods.items():
sys.modules[n] = m
mods["openenv"].core = mods["openenv.core"] # type: ignore[attr-defined]
mods["openenv.core"].env_server = mods["openenv.core.env_server"] # type: ignore[attr-defined]
mods["openenv.core"].env_client = mods["openenv.core.env_client"] # type: ignore[attr-defined]
mods["openenv.core"].client_types = mods["openenv.core.client_types"] # type: ignore[attr-defined]
for attr in ("Action", "Observation", "State"):
setattr(mods["openenv.core.env_server"], attr, type(attr, (_Base,), {}))
setattr(mods["openenv.core.env_server"], "Environment", _Environment)
setattr(mods["openenv.core.env_server.types"], "EnvironmentMetadata", _Base)
setattr(mods["openenv.core.env_client"], "EnvClient", _EnvClient)
setattr(mods["openenv.core.client_types"], "StepResult", _Base)
_ensure_openenv_mock()
# ---------------------------------------------------------------------------
# Configuration from environment
# ---------------------------------------------------------------------------
API_BASE_URL = os.environ.get("API_BASE_URL", "http://localhost:8000")
LLM_BASE_URL = os.environ.get("LLM_BASE_URL", "http://localhost:8001/v1")
MODEL_NAME = os.environ.get("MODEL_NAME", "")
HF_TOKEN = os.environ.get("HF_TOKEN", "")
TASK_IDS = ["easy_contacts", "medium_employees", "hard_patients"]
# ---------------------------------------------------------------------------
# Validate prerequisites
# ---------------------------------------------------------------------------
def _check_prerequisites() -> bool:
"""Check that required config is available. Returns True if OK."""
ok = True
if not MODEL_NAME:
print("ERROR: MODEL_NAME env var is not set.")
print(" Example: MODEL_NAME=gpt-4 python3 scripts/run_baselines.py")
ok = False
try:
from dataclean_env.client import DataCleanEnv # noqa: F401
except ImportError as exc:
print(f"ERROR: Cannot import DataCleanEnv client: {exc}")
print(" Install the package: pip install -e .")
ok = False
try:
import httpx # noqa: F401
except ImportError:
print("WARNING: httpx not installed. Install with: pip install httpx")
print(" The client depends on httpx for HTTP transport.")
ok = False
return ok
# ---------------------------------------------------------------------------
# LLM interaction (stub -- replace with your inference logic)
# ---------------------------------------------------------------------------
def call_llm(prompt: str) -> str:
"""Call the LLM endpoint and return the completion text.
This is a stub. Replace the body with your preferred inference method:
- OpenAI-compatible: POST to LLM_BASE_URL/chat/completions
- HuggingFace TGI: POST to LLM_BASE_URL/generate
- vLLM: POST to LLM_BASE_URL/chat/completions
The prompt contains the observation as JSON. The LLM should return a
JSON object with keys "action_type" and "params".
"""
try:
import httpx
except ImportError:
raise RuntimeError("httpx is required. Install with: pip install httpx")
headers: Dict[str, str] = {"Content-Type": "application/json"}
if HF_TOKEN:
headers["Authorization"] = f"Bearer {HF_TOKEN}"
payload = {
"model": MODEL_NAME,
"messages": [
{
"role": "system",
"content": (
"You are a data cleaning agent. Given a dataset observation, "
"return a JSON action with keys 'action_type' and 'params'. "
"Available actions: fix_value, delete_row, fill_missing, "
"standardize_format, merge_duplicates, flag_anomaly, "
"split_column, rename_column, cast_type, escalate_to_human, "
"mark_complete."
),
},
{"role": "user", "content": prompt},
],
"temperature": 0.0,
"max_tokens": 512,
}
resp = httpx.post(
f"{LLM_BASE_URL}/chat/completions",
json=payload,
headers=headers,
timeout=60.0,
)
resp.raise_for_status()
data = resp.json()
return data["choices"][0]["message"]["content"]
def parse_llm_action(text: str) -> Dict[str, Any]:
"""Parse an LLM response into an action dict.
Expects JSON with "action_type" and "params" keys.
Falls back to mark_complete if parsing fails.
"""
# Try to extract JSON from the response (handle markdown code blocks)
cleaned = text.strip()
if "```" in cleaned:
# Extract content between first pair of triple backticks
parts = cleaned.split("```")
if len(parts) >= 3:
cleaned = parts[1]
# Remove optional language tag on first line
if cleaned.startswith("json"):
cleaned = cleaned[4:]
cleaned = cleaned.strip()
try:
parsed = json.loads(cleaned)
if "action_type" in parsed:
return parsed
except (json.JSONDecodeError, TypeError):
pass
# Fallback: mark_complete
print(f" WARNING: Could not parse LLM response, falling back to mark_complete")
return {"action_type": "mark_complete", "params": {}}
# ---------------------------------------------------------------------------
# Run one episode
# ---------------------------------------------------------------------------
def run_episode(task_id: str) -> float:
"""Run one LLM-driven episode. Returns the final score."""
from dataclean_env.client import DataCleanEnv
from dataclean_env.models import DataCleanAction
with DataCleanEnv(base_url=API_BASE_URL).sync() as env:
result = env.reset(task_id=task_id)
obs = result.observation
step = 0
while not obs.done:
# Build prompt from observation
prompt_data = {
"task_id": obs.task_id,
"step": obs.step_number,
"steps_remaining": obs.steps_remaining,
"row_count": obs.row_count,
"columns": obs.columns,
"issues_remaining": obs.issues_remaining,
"quality_issues": [
{
"row_id": qi.row_id,
"column": qi.column,
"issue_type": qi.issue_type,
"description": qi.description,
"suggestion": qi.suggestion,
}
for qi in obs.quality_issues[:20] # Cap for context length
],
"rows": obs.rows[:15], # Cap for context length
}
prompt = json.dumps(prompt_data, indent=2, default=str)
# Get LLM action
llm_text = call_llm(prompt)
action_dict = parse_llm_action(llm_text)
action = DataCleanAction(
action_type=action_dict["action_type"],
params=action_dict.get("params", {}),
)
result = env.step(action)
obs = result.observation
step += 1
print(f" Step {step}: {action_dict['action_type']} -> reward={obs.reward}")
return obs.reward
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main() -> None:
if not _check_prerequisites():
print("\nFix the issues above and try again.")
sys.exit(1)
print(f"Model: {MODEL_NAME}")
print(f"Server: {API_BASE_URL}")
print(f"LLM API: {LLM_BASE_URL}")
print()
results: Dict[str, float] = {}
for task_id in TASK_IDS:
print(f"--- {task_id} ---")
try:
score = run_episode(task_id)
results[task_id] = score
print(f" Final score: {score:.4f}")
except Exception as exc:
print(f" ERROR: {exc}")
results[task_id] = -1.0
# Print summary table
print("\n" + "=" * 50)
print(f"Baseline Results: {MODEL_NAME}")
print("=" * 50)
print(f"{'Task':<25} {'Score':>10}")
print("-" * 50)
for task_id in TASK_IDS:
s = results.get(task_id, -1.0)
score_str = f"{s:.4f}" if s >= 0 else "ERROR"
print(f"{task_id:<25} {score_str:>10}")
valid = [s for s in results.values() if s >= 0]
if valid:
mean = sum(valid) / len(valid)
print("-" * 50)
print(f"{'Mean':<25} {mean:>10.4f}")
print()
if __name__ == "__main__":
main()