dataclean-env / inference.py
Anuj424614's picture
fix: inference.py comply with hackathon spec (env vars, [END] format, HF_TOKEN validation)
2f5d42a verified
"""Baseline inference script for DataClean-Env hackathon.
Runs an LLM-based data-cleaning agent through all three tasks
(easy_contacts, medium_employees, hard_patients), collects scores,
and emits [START]/[STEP]/[END] lines to stdout per OpenEnv spec.
Environment variables required:
API_BASE_URL - LLM endpoint URL (default: https://api.openai.com/v1)
MODEL_NAME - model identifier (default: gpt-4.1-mini)
HF_TOKEN - Hugging Face API token (mandatory, no default)
Usage:
API_BASE_URL=... MODEL_NAME=... HF_TOKEN=... python inference.py
"""
from __future__ import annotations
import json
import os
import re
import signal
import sys
import time
import traceback
from typing import Any, Dict, List, Optional
from openai import OpenAI
from dataclean_env import DataCleanEnv, DataCleanAction
from dataclean_env.models import DataCleanObservation
# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------
# Read environment variables with defaults where required
API_BASE_URL: str = os.getenv("API_BASE_URL", "https://api.openai.com/v1")
MODEL_NAME: str = os.getenv("MODEL_NAME", "gpt-4.1-mini")
HF_TOKEN: Optional[str] = os.getenv("HF_TOKEN")
# Fallback: direct HTTP connection to environment server
ENV_BASE_URL: str = os.getenv("ENV_BASE_URL", "http://localhost:8000")
TASKS: List[str] = ["easy_contacts", "medium_employees", "hard_patients"]
BENCHMARK: str = "dataclean_env"
SEED: int = 42
TEMPERATURE: float = 0.0
MAX_TOKENS: int = 1024
GLOBAL_TIMEOUT_SECONDS: int = 1100 # 18.3 min safety margin
SUCCESS_SCORE_THRESHOLD: float = 0.1
# ---------------------------------------------------------------------------
# Timeout handler
# ---------------------------------------------------------------------------
class GlobalTimeoutError(Exception):
pass
def _timeout_handler(signum: int, frame: Any) -> None:
raise GlobalTimeoutError("Global timeout reached")
# ---------------------------------------------------------------------------
# System prompt
# ---------------------------------------------------------------------------
SYSTEM_PROMPT: str = """You are a data-cleaning agent. Your job is to fix quality issues in a tabular dataset.
## Available Actions
Respond with ONLY a JSON object (no markdown, no explanation) containing:
- "action_type": one of the types below
- "params": a dict of parameters for that action
### Action Types
1. **fix_value** - Fix an incorrect value in a cell.
params: {"row_id": <int>, "column": "<col_name>", "new_value": "<corrected_value>"}
2. **delete_row** - Delete a row (e.g. junk/duplicate that cannot be merged).
params: {"row_id": <int>}
3. **fill_missing** - Fill a missing (null) value.
params: {"row_id": <int>, "column": "<col_name>", "value": "<fill_value>"}
4. **standardize_format** - Standardize the format of ALL values in a column (column-level, not row-level).
params: {"column": "<col_name>", "format_type": "<format>"}
format_type options: "date:YYYY-MM-DD", "phone:US", "phone:E164", "name:title_case",
"email:lowercase", "zip:5digit", "currency:float", "state:abbreviation"
5. **merge_duplicates** - Merge two duplicate rows (keeps the first, deletes the second).
params: {"row_id1": <int>, "row_id2": <int>, "strategy": "<merge_strategy>"}
strategy options: "keep_first", "keep_second", "merge_prefer_nonnull", "merge_prefer_row1", "merge_prefer_row2"
6. **flag_anomaly** - Flag a suspicious value for review.
params: {"row_id": <int>, "column": "<col_name>", "reason": "<why>"}
7. **split_column** - Split a column into multiple columns.
params: {"column": "<col_name>", "delimiter": "<delim>", "new_names": ["<name1>", "<name2>"]}
8. **rename_column** - Rename a column.
params: {"old_name": "<old>", "new_name": "<new>"}
9. **cast_type** - Cast a column to a different type.
params: {"column": "<col_name>", "target_type": "<type>"}
target_type options: "int", "float", "str", "bool", "date"
10. **escalate_to_human** - Escalate an ambiguous cell to human review when you are uncertain.
params: {"row_id": <int>, "column": "<col_name>", "confidence": <float 0-1>, "reason": "<why>"}
11. **mark_complete** - Signal that you believe the dataset is clean.
params: {}
## Important Rules
- Always reference rows by their **row_id** (the first column shown), NOT by row index.
- Examine the quality issues carefully and fix the most impactful ones first.
- When all issues are resolved (or you cannot fix more), use mark_complete.
- Respond with ONLY the JSON object. No extra text."""
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _build_user_prompt(obs: DataCleanObservation) -> str:
"""Build the per-step user prompt from the observation."""
parts: List[str] = []
# Step info and budget
parts.append(
f"## Step {obs.step_number} / {obs.max_steps} "
f"({obs.steps_remaining} remaining)"
)
if hasattr(obs, 'budget_remaining') and obs.budget_remaining is not None:
parts.append(
f"- Budget: {obs.budget_remaining:.1f} / "
f"{obs.budget_remaining + (obs.budget_spent or 0):.1f} remaining"
)
# Data summary
ds = obs.data_summary
parts.append(
f"\n## Data Summary\n"
f"- Rows: {ds.row_count}, Columns: {ds.column_count}\n"
f"- Total cells: {ds.total_cells}, Null cells: {ds.null_count}\n"
f"- Quality issues: {ds.issue_count}\n"
f"- Columns: {', '.join(ds.columns)}"
)
# Quality issues (grouped, max 15)
if obs.issue_groups:
parts.append("\n## Quality Issues (grouped)")
shown = 0
for group in obs.issue_groups:
parts.append(f"\n### {group.issue_type} ({group.count} issues)")
for ex in group.examples:
if shown >= 15:
break
parts.append(
f" - Row {ex.row_id}, col '{ex.column}': "
f"{ex.description}"
+ (f" -> suggestion: {ex.suggestion}" if ex.suggestion else "")
)
shown += 1
if shown >= 15:
remaining = obs.issues_remaining - shown
if remaining > 0:
parts.append(f" ... and {remaining} more issues not shown")
break
# Last action result
if obs.last_action_result is not None:
ar = obs.last_action_result
parts.append(
f"\n## Last Action Result\n"
f"- Action: {ar.action}, Status: {ar.status}\n"
f"- Message: {ar.message}\n"
f"- Cells modified: {ar.cells_modified}"
)
# Current data table
if obs.rows:
parts.append("\n## Current Data (row_id is first column)")
header_cols = ["row_id"] + obs.columns
parts.append("| " + " | ".join(str(c) for c in header_cols) + " |")
parts.append("| " + " | ".join("---" for _ in header_cols) + " |")
for row in obs.rows:
parts.append("| " + " | ".join(str(v) for v in row) + " |")
return "\n".join(parts)
def _normalize_params(action_type: str, params: Dict[str, Any]) -> Dict[str, Any]:
"""Normalize param aliases to canonical names expected by environment.py.
Handles common LLM mistakes like using 'value' instead of 'new_value'
for fix_value, or 'row_id_1'/'row_id_2' instead of 'row_id1'/'row_id2'.
"""
p = dict(params)
# Universal aliases — LLMs commonly use these instead of canonical names
if "row" in p and "row_id" not in p:
p["row_id"] = p.pop("row")
if "col" in p and "column" not in p:
p["column"] = p.pop("col")
if action_type == "fix_value":
if "value" in p and "new_value" not in p:
p["new_value"] = p.pop("value")
elif action_type == "fill_missing":
# fill_missing uses "value" canonically, but some LLMs send "fill_value"
if "fill_value" in p and "value" not in p:
p["value"] = p.pop("fill_value")
elif action_type == "merge_duplicates":
if "row_id_1" in p and "row_id1" not in p:
p["row_id1"] = p.pop("row_id_1")
if "row_id_2" in p and "row_id2" not in p:
p["row_id2"] = p.pop("row_id_2")
if "row1" in p and "row_id1" not in p:
p["row_id1"] = p.pop("row1")
if "row2" in p and "row_id2" not in p:
p["row_id2"] = p.pop("row2")
return p
def _parse_action(response_text: str) -> DataCleanAction:
"""Parse the LLM response into a DataCleanAction.
Tries to extract a JSON object with an "action_type" key.
Falls back to mark_complete if parsing fails.
"""
text = response_text.strip()
# Strip markdown code fences if present
if text.startswith("```"):
text = re.sub(r"^```(?:json)?\s*", "", text)
text = re.sub(r"\s*```$", "", text)
text = text.strip()
# Try direct JSON parse
try:
data = json.loads(text)
if isinstance(data, dict) and "action_type" in data:
action_type = data["action_type"]
params = _normalize_params(action_type, data.get("params", {}))
return DataCleanAction(
action_type=action_type,
params=params,
)
except json.JSONDecodeError:
pass
# Try to find a JSON object in the text
match = re.search(r"\{[^{}]*\"action_type\"[^{}]*\}", text, re.DOTALL)
if match:
try:
data = json.loads(match.group())
action_type = data["action_type"]
params = _normalize_params(action_type, data.get("params", {}))
return DataCleanAction(
action_type=action_type,
params=params,
)
except (json.JSONDecodeError, KeyError):
pass
# Try nested braces (for params containing dicts)
match = re.search(r"\{.*\"action_type\".*\}", text, re.DOTALL)
if match:
try:
data = json.loads(match.group())
if isinstance(data, dict) and "action_type" in data:
action_type = data["action_type"]
params = _normalize_params(action_type, data.get("params", {}))
return DataCleanAction(
action_type=action_type,
params=params,
)
except (json.JSONDecodeError, KeyError):
pass
# Fallback
print(f" [WARN] Could not parse LLM response, falling back to mark_complete", file=sys.stderr)
print(f" [WARN] Raw response: {text[:200]}", file=sys.stderr)
return DataCleanAction(action_type="mark_complete", params={})
def _call_llm(
client: OpenAI,
messages: List[Dict[str, str]],
retry: bool = True,
) -> str:
"""Call the LLM and return the assistant message content.
Retries once on failure, then returns a fallback mark_complete JSON.
"""
try:
response = client.chat.completions.create(
model=MODEL_NAME,
messages=messages,
temperature=TEMPERATURE,
max_tokens=MAX_TOKENS,
seed=SEED,
)
content = response.choices[0].message.content
return content if content is not None else ""
except Exception as e:
print(f" [ERROR] LLM call failed: {e}", file=sys.stderr)
if retry:
print(" [INFO] Retrying once...", file=sys.stderr)
time.sleep(1)
return _call_llm(client, messages, retry=False)
print(" [WARN] Retry failed, using fallback action", file=sys.stderr)
return '{"action_type": "mark_complete", "params": {}}'
# ---------------------------------------------------------------------------
# Stdout logging (mandatory format)
# ---------------------------------------------------------------------------
def log_start(task: str, env: str, model: str) -> None:
print(f"[START] task={task} env={env} model={model}", flush=True)
def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
error_val = error if error else "null"
done_val = str(done).lower()
print(
f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}",
flush=True,
)
def log_end(success: bool, steps: int, rewards: List[float]) -> None:
rewards_str = ",".join(f"{r:.2f}" for r in rewards)
print(
f"[END] success={str(success).lower()} steps={steps} rewards={rewards_str}",
flush=True,
)
# ---------------------------------------------------------------------------
# Task runner
# ---------------------------------------------------------------------------
def run_task(
client: OpenAI,
env_base_url: str,
task_id: str,
) -> float:
"""Run a single data-cleaning task and return the final score."""
print(f"\n Task: {task_id}", file=sys.stderr)
rewards: List[float] = []
steps_taken = 0
score = 0.0
success = False
log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME)
try:
with DataCleanEnv(base_url=env_base_url).sync() as env:
result = env.reset(seed=SEED, task_id=task_id)
obs: DataCleanObservation = result.observation
print(f" Initial issues: {obs.data_summary.issue_count}", file=sys.stderr)
print(f" Max steps: {obs.max_steps}", file=sys.stderr)
messages: List[Dict[str, str]] = [
{"role": "system", "content": SYSTEM_PROMPT},
]
step = 0
done = False
while not done:
step += 1
user_msg = _build_user_prompt(obs)
messages.append({"role": "user", "content": user_msg})
# Keep conversation manageable: only system + last 6 exchanges
if len(messages) > 13:
messages = [messages[0]] + messages[-12:]
llm_response = _call_llm(client, messages)
messages.append({"role": "assistant", "content": llm_response})
action = _parse_action(llm_response)
action_str = f"{action.action_type}({json.dumps(action.params) if action.params else ''})"
error: Optional[str] = None
try:
result = env.step(action)
obs = result.observation
done = result.done
reward = result.reward if result.reward is not None else 0.0
except Exception as e:
print(f" [ERROR] Environment step failed: {e}", file=sys.stderr)
traceback.print_exc(file=sys.stderr)
reward = 0.0
error = str(e)
done = True
rewards.append(float(reward))
steps_taken = step
log_step(step=step, action=action_str, reward=float(reward), done=done, error=error)
# Extract final score from last reward
score = float(result.reward) if result is not None and result.reward is not None else 0.0
score = min(max(score, 0.0), 1.0) # clamp to [0, 1]
success = score >= SUCCESS_SCORE_THRESHOLD
except Exception as e:
print(f" [ERROR] Task failed: {e}", file=sys.stderr)
traceback.print_exc(file=sys.stderr)
log_end(success=success, steps=steps_taken, rewards=rewards)
print(f" Final score: {score:.4f}", file=sys.stderr)
return score
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
async def async_main() -> int:
"""Run all tasks and emit [START]/[STEP]/[END] to stdout."""
if HF_TOKEN is None:
raise ValueError("HF_TOKEN environment variable is required")
# Set global timeout (Unix only; no-op on Windows)
if hasattr(signal, "SIGALRM"):
signal.signal(signal.SIGALRM, _timeout_handler)
signal.alarm(GLOBAL_TIMEOUT_SECONDS)
# Initialize OpenAI client per hackathon spec
client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)
print(f"LLM endpoint: {API_BASE_URL}", file=sys.stderr)
print(f"Model: {MODEL_NAME}", file=sys.stderr)
print(f"Environment: {ENV_BASE_URL}", file=sys.stderr)
print(f"Tasks: {TASKS}", file=sys.stderr)
scores: Dict[str, float] = {}
for task_id in TASKS:
try:
score = run_task(client, ENV_BASE_URL, task_id)
scores[task_id] = score
except GlobalTimeoutError:
print(f"\n[TIMEOUT] Global timeout reached during task '{task_id}'", file=sys.stderr)
scores[task_id] = 0.0
break
except Exception as e:
print(f"\n[ERROR] Task '{task_id}' failed: {e}", file=sys.stderr)
traceback.print_exc(file=sys.stderr)
scores[task_id] = 0.0
# Cancel alarm if still active
if hasattr(signal, "SIGALRM"):
signal.alarm(0)
# Summary to stderr
avg_score = sum(scores.values()) / len(scores) if scores else 0.0
print(f"\n RESULTS", file=sys.stderr)
for task_id, score in scores.items():
print(f" {task_id}: {score:.4f}", file=sys.stderr)
print(f" Average: {avg_score:.4f}", file=sys.stderr)
return 0
def main() -> int:
"""Sync entry point — delegates to async_main."""
import asyncio
return asyncio.run(async_main())
if __name__ == "__main__":
sys.exit(main())