data-cleaning-openenv / inference.py
yashmarathe's picture
fix: make inference.py crash-proof with multi-strategy env connection
6d01bb5
"""
Inference Script for Data Cleaning RL Environment
===================================
MANDATORY
- Before submitting, ensure the following variables are defined in your environment configuration:
API_BASE_URL The API endpoint for the LLM.
MODEL_NAME The model identifier to use for inference.
HF_TOKEN Your Hugging Face / API key.
LOCAL_IMAGE_NAME The name of the local image to use for the environment if you are using from_docker_image()
- Defaults are set only for API_BASE_URL and MODEL_NAME:
API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
MODEL_NAME = os.getenv("MODEL_NAME", "gpt-4o-mini")
- The inference script must be named `inference.py` and placed in the root directory of the project
- Participants must use OpenAI Client for all LLM calls using above variables
STDOUT FORMAT
- The script must emit exactly three line types to stdout, in this order:
[START] task=<task_name> env=<benchmark> model=<model_name>
[STEP] step=<n> action=<action_str> reward=<0.00> done=<true|false> error=<msg|null>
[END] success=<true|false> steps=<n> score=<score> rewards=<r1,r2,...,rn>
"""
import asyncio
import json
import os
import subprocess
import sys
import time
import traceback
from typing import Any, Dict, List, Optional
from openai import OpenAI
from openenv import GenericEnvClient
# ---------------------------------------------------------------------------
# Configuration — from environment variables
# ---------------------------------------------------------------------------
IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME") or os.getenv("IMAGE_NAME")
API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
API_BASE_URL = os.getenv("API_BASE_URL") or "https://router.huggingface.co/v1"
MODEL_NAME = os.getenv("MODEL_NAME") or "gpt-4o-mini"
BENCHMARK = "data-cleaning-env"
TASKS = ["easy", "medium", "hard", "expert"]
MAX_STEPS_MAP = {"easy": 20, "medium": 40, "hard": 60, "expert": 80}
# Track server subprocess for cleanup
_server_proc: Optional[subprocess.Popen] = None
# ---------------------------------------------------------------------------
# OpenAI tool definitions for function-calling
# ---------------------------------------------------------------------------
TOOLS = [
{"type": "function", "function": {"name": "fill_missing", "description": "Fill missing (NaN) values in a column.", "parameters": {"type": "object", "properties": {"column": {"type": "string"}, "strategy": {"type": "string", "enum": ["mean", "median", "mode", "constant"]}}, "required": ["column", "strategy"]}}},
{"type": "function", "function": {"name": "drop_duplicates", "description": "Drop exact duplicate rows.", "parameters": {"type": "object", "properties": {}, "required": []}}},
{"type": "function", "function": {"name": "fix_type", "description": "Coerce a column to a target dtype.", "parameters": {"type": "object", "properties": {"column": {"type": "string"}, "dtype": {"type": "string", "enum": ["int", "float", "str"]}}, "required": ["column", "dtype"]}}},
{"type": "function", "function": {"name": "fix_schema_violation", "description": "Clamp values that violate constraints.", "parameters": {"type": "object", "properties": {"column": {"type": "string"}, "constraint": {"type": "string", "enum": ["non_negative", "clamp_range"]}}, "required": ["column", "constraint"]}}},
{"type": "function", "function": {"name": "standardize_categories", "description": "Lowercase, strip whitespace, collapse spaces.", "parameters": {"type": "object", "properties": {"column": {"type": "string"}}, "required": ["column"]}}},
{"type": "function", "function": {"name": "fix_format_regex", "description": "Regex substitution for formatting.", "parameters": {"type": "object", "properties": {"column": {"type": "string"}, "pattern": {"type": "string"}, "replacement": {"type": "string"}}, "required": ["column", "pattern", "replacement"]}}},
{"type": "function", "function": {"name": "deduplicate_fuzzy", "description": "Replace near-duplicate strings with canonical form.", "parameters": {"type": "object", "properties": {"column": {"type": "string"}, "threshold": {"type": "number"}}, "required": ["column"]}}},
{"type": "function", "function": {"name": "profile_column", "description": "Get extended stats for a column. Free.", "parameters": {"type": "object", "properties": {"column": {"type": "string"}}, "required": ["column"]}}},
{"type": "function", "function": {"name": "done", "description": "Signal cleaning is complete.", "parameters": {"type": "object", "properties": {}, "required": []}}},
]
# ---------------------------------------------------------------------------
# System prompt for the LLM
# ---------------------------------------------------------------------------
SYSTEM_PROMPT = """\
You are an expert data-cleaning agent. Clean dirty tabular datasets by calling \
tool actions to maximize the composite quality score.
GRADING: accuracy(30%) + completeness(25%) + consistency(25%) + format(20%).
STRATEGY (in order):
1. fill_missing — 'median' for numeric, 'mode' for categorical
2. standardize_categories — for columns with semantic duplicates
3. fix_type — coerce columns with type errors to 'float'
4. fix_schema_violation — fix negatives with 'non_negative'
5. Call done() when no more improvements possible
AVOID: normalize, drop_outliers. Focus on columns with most issues first."""
# ---------------------------------------------------------------------------
# Logging helpers (required stdout format)
# ---------------------------------------------------------------------------
def log_start(task: str, model: str) -> None:
print(f"[START] task={task} env={BENCHMARK} model={model}", flush=True)
def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
action_clean = action.replace("\n", " ").replace("\r", " ")[:120]
error_str = "null" if error is None else error.replace("\n", " ")
print(f"[STEP] step={step} action={action_clean} reward={reward:.2f} done={str(done).lower()} error={error_str}", flush=True)
def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
rewards_str = ",".join(f"{r:.2f}" for r in rewards)
print(f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}", flush=True)
# ---------------------------------------------------------------------------
# Build observation summary for LLM
# ---------------------------------------------------------------------------
def build_user_message(obs: Dict[str, Any], task: str) -> str:
cols = obs.get("columns", [])
issues = obs.get("column_issues", {})
stats = obs.get("column_stats", {})
step = obs.get("step", 0)
max_steps = obs.get("max_steps", 0)
reward = obs.get("reward", 0.0)
lines = [f"Task: {task} | Step: {step}/{max_steps} | Last reward: {reward:.2f}", "", "Columns:"]
for col in cols:
ci = issues.get(col, {})
cs = stats.get(col, {})
parts = []
if ci.get("missing_count", 0) > 0:
parts.append(f"missing={ci['missing_count']}")
if ci.get("type_errors", 0) > 0:
parts.append(f"type_errors={ci['type_errors']}")
if ci.get("semantic_duplicate_count", 0) > 0:
parts.append(f"sem_dups={ci['semantic_duplicate_count']}")
if ci.get("format_violation_count", 0) > 0:
parts.append(f"format_violations={ci['format_violation_count']}")
issue_str = ", ".join(parts) if parts else "clean"
is_num = "numeric" if cs.get("mean") is not None else "categorical"
lines.append(f" {col} ({is_num}): [{issue_str}]")
budget = obs.get("budget_remaining")
if budget is not None:
lines.append(f"\nBudget: {budget:.2f}")
lines.append("\nChoose the best next action. Call done() if all issues are resolved.")
return "\n".join(lines)
# ---------------------------------------------------------------------------
# LLM action selection
# ---------------------------------------------------------------------------
def llm_choose_action(client: OpenAI, messages: List[Dict[str, Any]]) -> tuple:
"""Returns (action_dict, action_string, tool_call_obj)."""
response = client.chat.completions.create(
model=MODEL_NAME,
messages=messages,
tools=TOOLS,
tool_choice="required",
temperature=0.0,
)
choice = response.choices[0]
if not choice.message.tool_calls:
raise ValueError("No tool calls in response")
tc = choice.message.tool_calls[0]
args = json.loads(tc.function.arguments or "{}")
payload: Dict[str, Any] = {"action_type": tc.function.name}
for field in ("column", "strategy", "dtype", "method", "constraint",
"new_name", "datetime_format", "threshold", "delimiter",
"column2", "merge_strategy", "pattern", "replacement"):
if field in args:
payload[field] = args[field]
action_str = f"{tc.function.name}({tc.function.arguments})"
return payload, action_str, tc
# ---------------------------------------------------------------------------
# Heuristic fallback (when no LLM key)
# ---------------------------------------------------------------------------
def heuristic_action(obs: Dict[str, Any]) -> Optional[Dict[str, Any]]:
issues = obs.get("column_issues", {})
columns = obs.get("columns", [])
stats = obs.get("column_stats", {})
for col in columns:
if issues.get(col, {}).get("missing_count", 0) > 0:
is_num = stats.get(col, {}).get("mean") is not None
return {"action_type": "fill_missing", "column": col, "strategy": "median" if is_num else "mode"}
for col in columns:
if issues.get(col, {}).get("semantic_duplicate_count", 0) > 0:
return {"action_type": "standardize_categories", "column": col}
for col in columns:
if issues.get(col, {}).get("type_errors", 0) > 0:
return {"action_type": "fix_type", "column": col, "dtype": "float"}
for col in columns:
ci = issues.get(col, {})
if ci.get("format_violation_count", 0) > 0 and stats.get(col, {}).get("mean") is not None:
return {"action_type": "fix_schema_violation", "column": col, "constraint": "non_negative"}
return None
# ---------------------------------------------------------------------------
# Run one task episode
# ---------------------------------------------------------------------------
async def run_task(env: GenericEnvClient, client: Optional[OpenAI], task: str, use_llm: bool) -> tuple:
"""Run a single task. Returns (score, steps, rewards)."""
max_steps = MAX_STEPS_MAP.get(task, 20)
result = await env.reset(task=task)
obs = result.observation
rewards: List[float] = []
steps_taken = 0
messages: List[Dict[str, Any]] = []
if use_llm:
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": build_user_message(obs, task)},
]
log_start(task=task, model=MODEL_NAME)
try:
for step in range(1, max_steps + 1):
if result.done:
break
action_payload: Dict[str, Any]
action_str: str
error: Optional[str] = None
tc = None
if use_llm:
try:
action_payload, action_str, tc = llm_choose_action(client, messages)
except Exception as exc:
error = f"LLM error: {exc}"
ha = heuristic_action(obs)
action_payload = ha if ha else {"action_type": "done"}
action_str = json.dumps(action_payload, separators=(",", ":"))
else:
ha = heuristic_action(obs)
action_payload = ha if ha else {"action_type": "done"}
action_str = json.dumps(action_payload, separators=(",", ":"))
result = await env.step(action_payload)
obs = result.observation
reward = result.reward or 0.0
done = result.done
rewards.append(reward)
steps_taken = step
log_step(step=step, action=action_str, reward=reward, done=done, error=error)
# Update LLM conversation
if use_llm and error is None and tc is not None:
messages.append({
"role": "assistant",
"content": None,
"tool_calls": [{
"id": tc.id,
"type": "function",
"function": {"name": tc.function.name, "arguments": tc.function.arguments},
}],
})
messages.append({
"role": "tool",
"tool_call_id": tc.id,
"content": build_user_message(obs, task),
})
if done:
break
except Exception as exc:
log_step(step=steps_taken + 1, action="error", reward=0.0, done=True, error=str(exc))
rewards.append(0.0)
steps_taken += 1
# Score = average reward normalized, clamped to [0, 1]
total_reward = sum(rewards)
score = min(max(total_reward / max(max_steps * 0.01, 0.01), 0.0), 1.0)
return score, steps_taken, rewards
# ---------------------------------------------------------------------------
# Environment connection — try multiple strategies
# ---------------------------------------------------------------------------
async def connect_env() -> GenericEnvClient:
"""Connect to the environment. Tries multiple strategies in order."""
global _server_proc
# Strategy 1: from_docker_image if IMAGE_NAME is set
if IMAGE_NAME:
print(f"[ENV] Connecting via from_docker_image({IMAGE_NAME})...", flush=True)
try:
env = await GenericEnvClient.from_docker_image(IMAGE_NAME)
print("[ENV] Docker connection successful!", flush=True)
return env
except Exception as exc:
print(f"[ENV] Docker connection failed: {exc}", flush=True)
print("[ENV] Falling back to other strategies...", flush=True)
# Strategy 2: Try connecting to common ports (validator may already have server running)
for port in [7860, 8000, 8080]:
try:
import requests
r = requests.get(f"http://localhost:{port}/health", timeout=3)
if r.status_code == 200:
print(f"[ENV] Found running server at localhost:{port}", flush=True)
env = GenericEnvClient(base_url=f"http://localhost:{port}")
await env.connect()
print(f"[ENV] WebSocket connected to localhost:{port}!", flush=True)
return env
except Exception:
pass
# Strategy 3: Try HF Space
hf_url = "https://yashmarathe-data-cleaning-openenv.hf.space"
try:
import requests
r = requests.get(f"{hf_url}/health", timeout=10)
if r.status_code == 200:
print(f"[ENV] Connecting to HF Space...", flush=True)
env = GenericEnvClient(base_url=hf_url)
await env.connect()
print("[ENV] HF Space WebSocket connected!", flush=True)
return env
except Exception as exc:
print(f"[ENV] HF Space connection failed: {exc}", flush=True)
# Strategy 4: Start local server
print("[ENV] Starting local server...", flush=True)
_server_proc = subprocess.Popen(
[sys.executable, "-m", "uvicorn",
"data_cleaning_env.server.app:app",
"--host", "0.0.0.0", "--port", "8765"],
stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL,
)
import requests
for i in range(60):
try:
if requests.get("http://localhost:8765/health", timeout=2).status_code == 200:
print(f"[ENV] Local server ready after {i+1}s", flush=True)
break
except Exception:
pass
time.sleep(1)
else:
raise RuntimeError("All connection strategies failed")
env = GenericEnvClient(base_url="http://localhost:8765")
await env.connect()
print("[ENV] Local server WebSocket connected!", flush=True)
return env
def cleanup():
"""Clean up server process if we started one."""
global _server_proc
if _server_proc is not None:
try:
_server_proc.terminate()
_server_proc.wait(timeout=5)
except Exception:
try:
_server_proc.kill()
except Exception:
pass
_server_proc = None
# ---------------------------------------------------------------------------
# Main — wrapped in try/except to ALWAYS emit [START]/[END] for every task
# ---------------------------------------------------------------------------
async def main() -> None:
use_llm = bool(API_KEY)
client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY) if use_llm else None
print(f"[CONFIG] API_BASE_URL={API_BASE_URL} MODEL={MODEL_NAME} USE_LLM={use_llm} IMAGE={IMAGE_NAME}", flush=True)
env = None
try:
env = await connect_env()
scores: Dict[str, float] = {}
for task in TASKS:
try:
score, steps, rewards = await run_task(env, client, task, use_llm)
success = score > 0.0
log_end(success=success, steps=steps, score=score, rewards=rewards)
scores[task] = round(score, 4)
except Exception as exc:
log_start(task=task, model=MODEL_NAME)
log_end(success=False, steps=0, score=0.0, rewards=[])
print(f"ERROR in task {task}: {exc}", flush=True)
scores[task] = 0.0
print(f"\nFinal scores:\n{json.dumps(scores, indent=2)}", flush=True)
except Exception as exc:
# Connection completely failed — emit START/END for all tasks
print(f"FATAL: Could not connect to environment: {exc}", flush=True)
traceback.print_exc()
for task in TASKS:
log_start(task=task, model=MODEL_NAME)
log_end(success=False, steps=0, score=0.0, rewards=[])
finally:
if env is not None:
try:
await env.close()
except Exception as e:
print(f"[DEBUG] env.close() error: {e}", flush=True)
cleanup()
if __name__ == "__main__":
asyncio.run(main())