""" Inference Script for Data Cleaning RL Environment =================================== MANDATORY - Before submitting, ensure the following variables are defined in your environment configuration: API_BASE_URL The API endpoint for the LLM. MODEL_NAME The model identifier to use for inference. HF_TOKEN Your Hugging Face / API key. LOCAL_IMAGE_NAME The name of the local image to use for the environment if you are using from_docker_image() - Defaults are set only for API_BASE_URL and MODEL_NAME: API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1") MODEL_NAME = os.getenv("MODEL_NAME", "gpt-4o-mini") - The inference script must be named `inference.py` and placed in the root directory of the project - Participants must use OpenAI Client for all LLM calls using above variables STDOUT FORMAT - The script must emit exactly three line types to stdout, in this order: [START] task= env= model= [STEP] step= action= reward=<0.00> done= error= [END] success= steps= score= rewards= """ import asyncio import json import os import subprocess import sys import time import traceback from typing import Any, Dict, List, Optional from openai import OpenAI from openenv import GenericEnvClient # --------------------------------------------------------------------------- # Configuration — from environment variables # --------------------------------------------------------------------------- IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME") or os.getenv("IMAGE_NAME") API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY") API_BASE_URL = os.getenv("API_BASE_URL") or "https://router.huggingface.co/v1" MODEL_NAME = os.getenv("MODEL_NAME") or "gpt-4o-mini" BENCHMARK = "data-cleaning-env" TASKS = ["easy", "medium", "hard", "expert"] MAX_STEPS_MAP = {"easy": 20, "medium": 40, "hard": 60, "expert": 80} # Track server subprocess for cleanup _server_proc: Optional[subprocess.Popen] = None # --------------------------------------------------------------------------- # OpenAI tool definitions for function-calling # --------------------------------------------------------------------------- TOOLS = [ {"type": "function", "function": {"name": "fill_missing", "description": "Fill missing (NaN) values in a column.", "parameters": {"type": "object", "properties": {"column": {"type": "string"}, "strategy": {"type": "string", "enum": ["mean", "median", "mode", "constant"]}}, "required": ["column", "strategy"]}}}, {"type": "function", "function": {"name": "drop_duplicates", "description": "Drop exact duplicate rows.", "parameters": {"type": "object", "properties": {}, "required": []}}}, {"type": "function", "function": {"name": "fix_type", "description": "Coerce a column to a target dtype.", "parameters": {"type": "object", "properties": {"column": {"type": "string"}, "dtype": {"type": "string", "enum": ["int", "float", "str"]}}, "required": ["column", "dtype"]}}}, {"type": "function", "function": {"name": "fix_schema_violation", "description": "Clamp values that violate constraints.", "parameters": {"type": "object", "properties": {"column": {"type": "string"}, "constraint": {"type": "string", "enum": ["non_negative", "clamp_range"]}}, "required": ["column", "constraint"]}}}, {"type": "function", "function": {"name": "standardize_categories", "description": "Lowercase, strip whitespace, collapse spaces.", "parameters": {"type": "object", "properties": {"column": {"type": "string"}}, "required": ["column"]}}}, {"type": "function", "function": {"name": "fix_format_regex", "description": "Regex substitution for formatting.", "parameters": {"type": "object", "properties": {"column": {"type": "string"}, "pattern": {"type": "string"}, "replacement": {"type": "string"}}, "required": ["column", "pattern", "replacement"]}}}, {"type": "function", "function": {"name": "deduplicate_fuzzy", "description": "Replace near-duplicate strings with canonical form.", "parameters": {"type": "object", "properties": {"column": {"type": "string"}, "threshold": {"type": "number"}}, "required": ["column"]}}}, {"type": "function", "function": {"name": "profile_column", "description": "Get extended stats for a column. Free.", "parameters": {"type": "object", "properties": {"column": {"type": "string"}}, "required": ["column"]}}}, {"type": "function", "function": {"name": "done", "description": "Signal cleaning is complete.", "parameters": {"type": "object", "properties": {}, "required": []}}}, ] # --------------------------------------------------------------------------- # System prompt for the LLM # --------------------------------------------------------------------------- SYSTEM_PROMPT = """\ You are an expert data-cleaning agent. Clean dirty tabular datasets by calling \ tool actions to maximize the composite quality score. GRADING: accuracy(30%) + completeness(25%) + consistency(25%) + format(20%). STRATEGY (in order): 1. fill_missing — 'median' for numeric, 'mode' for categorical 2. standardize_categories — for columns with semantic duplicates 3. fix_type — coerce columns with type errors to 'float' 4. fix_schema_violation — fix negatives with 'non_negative' 5. Call done() when no more improvements possible AVOID: normalize, drop_outliers. Focus on columns with most issues first.""" # --------------------------------------------------------------------------- # Logging helpers (required stdout format) # --------------------------------------------------------------------------- def log_start(task: str, model: str) -> None: print(f"[START] task={task} env={BENCHMARK} model={model}", flush=True) def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None: action_clean = action.replace("\n", " ").replace("\r", " ")[:120] error_str = "null" if error is None else error.replace("\n", " ") print(f"[STEP] step={step} action={action_clean} reward={reward:.2f} done={str(done).lower()} error={error_str}", flush=True) def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None: rewards_str = ",".join(f"{r:.2f}" for r in rewards) print(f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}", flush=True) # --------------------------------------------------------------------------- # Build observation summary for LLM # --------------------------------------------------------------------------- def build_user_message(obs: Dict[str, Any], task: str) -> str: cols = obs.get("columns", []) issues = obs.get("column_issues", {}) stats = obs.get("column_stats", {}) step = obs.get("step", 0) max_steps = obs.get("max_steps", 0) reward = obs.get("reward", 0.0) lines = [f"Task: {task} | Step: {step}/{max_steps} | Last reward: {reward:.2f}", "", "Columns:"] for col in cols: ci = issues.get(col, {}) cs = stats.get(col, {}) parts = [] if ci.get("missing_count", 0) > 0: parts.append(f"missing={ci['missing_count']}") if ci.get("type_errors", 0) > 0: parts.append(f"type_errors={ci['type_errors']}") if ci.get("semantic_duplicate_count", 0) > 0: parts.append(f"sem_dups={ci['semantic_duplicate_count']}") if ci.get("format_violation_count", 0) > 0: parts.append(f"format_violations={ci['format_violation_count']}") issue_str = ", ".join(parts) if parts else "clean" is_num = "numeric" if cs.get("mean") is not None else "categorical" lines.append(f" {col} ({is_num}): [{issue_str}]") budget = obs.get("budget_remaining") if budget is not None: lines.append(f"\nBudget: {budget:.2f}") lines.append("\nChoose the best next action. Call done() if all issues are resolved.") return "\n".join(lines) # --------------------------------------------------------------------------- # LLM action selection # --------------------------------------------------------------------------- def llm_choose_action(client: OpenAI, messages: List[Dict[str, Any]]) -> tuple: """Returns (action_dict, action_string, tool_call_obj).""" response = client.chat.completions.create( model=MODEL_NAME, messages=messages, tools=TOOLS, tool_choice="required", temperature=0.0, ) choice = response.choices[0] if not choice.message.tool_calls: raise ValueError("No tool calls in response") tc = choice.message.tool_calls[0] args = json.loads(tc.function.arguments or "{}") payload: Dict[str, Any] = {"action_type": tc.function.name} for field in ("column", "strategy", "dtype", "method", "constraint", "new_name", "datetime_format", "threshold", "delimiter", "column2", "merge_strategy", "pattern", "replacement"): if field in args: payload[field] = args[field] action_str = f"{tc.function.name}({tc.function.arguments})" return payload, action_str, tc # --------------------------------------------------------------------------- # Heuristic fallback (when no LLM key) # --------------------------------------------------------------------------- def heuristic_action(obs: Dict[str, Any]) -> Optional[Dict[str, Any]]: issues = obs.get("column_issues", {}) columns = obs.get("columns", []) stats = obs.get("column_stats", {}) for col in columns: if issues.get(col, {}).get("missing_count", 0) > 0: is_num = stats.get(col, {}).get("mean") is not None return {"action_type": "fill_missing", "column": col, "strategy": "median" if is_num else "mode"} for col in columns: if issues.get(col, {}).get("semantic_duplicate_count", 0) > 0: return {"action_type": "standardize_categories", "column": col} for col in columns: if issues.get(col, {}).get("type_errors", 0) > 0: return {"action_type": "fix_type", "column": col, "dtype": "float"} for col in columns: ci = issues.get(col, {}) if ci.get("format_violation_count", 0) > 0 and stats.get(col, {}).get("mean") is not None: return {"action_type": "fix_schema_violation", "column": col, "constraint": "non_negative"} return None # --------------------------------------------------------------------------- # Run one task episode # --------------------------------------------------------------------------- async def run_task(env: GenericEnvClient, client: Optional[OpenAI], task: str, use_llm: bool) -> tuple: """Run a single task. Returns (score, steps, rewards).""" max_steps = MAX_STEPS_MAP.get(task, 20) result = await env.reset(task=task) obs = result.observation rewards: List[float] = [] steps_taken = 0 messages: List[Dict[str, Any]] = [] if use_llm: messages = [ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": build_user_message(obs, task)}, ] log_start(task=task, model=MODEL_NAME) try: for step in range(1, max_steps + 1): if result.done: break action_payload: Dict[str, Any] action_str: str error: Optional[str] = None tc = None if use_llm: try: action_payload, action_str, tc = llm_choose_action(client, messages) except Exception as exc: error = f"LLM error: {exc}" ha = heuristic_action(obs) action_payload = ha if ha else {"action_type": "done"} action_str = json.dumps(action_payload, separators=(",", ":")) else: ha = heuristic_action(obs) action_payload = ha if ha else {"action_type": "done"} action_str = json.dumps(action_payload, separators=(",", ":")) result = await env.step(action_payload) obs = result.observation reward = result.reward or 0.0 done = result.done rewards.append(reward) steps_taken = step log_step(step=step, action=action_str, reward=reward, done=done, error=error) # Update LLM conversation if use_llm and error is None and tc is not None: messages.append({ "role": "assistant", "content": None, "tool_calls": [{ "id": tc.id, "type": "function", "function": {"name": tc.function.name, "arguments": tc.function.arguments}, }], }) messages.append({ "role": "tool", "tool_call_id": tc.id, "content": build_user_message(obs, task), }) if done: break except Exception as exc: log_step(step=steps_taken + 1, action="error", reward=0.0, done=True, error=str(exc)) rewards.append(0.0) steps_taken += 1 # Score = average reward normalized, clamped to [0, 1] total_reward = sum(rewards) score = min(max(total_reward / max(max_steps * 0.01, 0.01), 0.0), 1.0) return score, steps_taken, rewards # --------------------------------------------------------------------------- # Environment connection — try multiple strategies # --------------------------------------------------------------------------- async def connect_env() -> GenericEnvClient: """Connect to the environment. Tries multiple strategies in order.""" global _server_proc # Strategy 1: from_docker_image if IMAGE_NAME is set if IMAGE_NAME: print(f"[ENV] Connecting via from_docker_image({IMAGE_NAME})...", flush=True) try: env = await GenericEnvClient.from_docker_image(IMAGE_NAME) print("[ENV] Docker connection successful!", flush=True) return env except Exception as exc: print(f"[ENV] Docker connection failed: {exc}", flush=True) print("[ENV] Falling back to other strategies...", flush=True) # Strategy 2: Try connecting to common ports (validator may already have server running) for port in [7860, 8000, 8080]: try: import requests r = requests.get(f"http://localhost:{port}/health", timeout=3) if r.status_code == 200: print(f"[ENV] Found running server at localhost:{port}", flush=True) env = GenericEnvClient(base_url=f"http://localhost:{port}") await env.connect() print(f"[ENV] WebSocket connected to localhost:{port}!", flush=True) return env except Exception: pass # Strategy 3: Try HF Space hf_url = "https://yashmarathe-data-cleaning-openenv.hf.space" try: import requests r = requests.get(f"{hf_url}/health", timeout=10) if r.status_code == 200: print(f"[ENV] Connecting to HF Space...", flush=True) env = GenericEnvClient(base_url=hf_url) await env.connect() print("[ENV] HF Space WebSocket connected!", flush=True) return env except Exception as exc: print(f"[ENV] HF Space connection failed: {exc}", flush=True) # Strategy 4: Start local server print("[ENV] Starting local server...", flush=True) _server_proc = subprocess.Popen( [sys.executable, "-m", "uvicorn", "data_cleaning_env.server.app:app", "--host", "0.0.0.0", "--port", "8765"], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, ) import requests for i in range(60): try: if requests.get("http://localhost:8765/health", timeout=2).status_code == 200: print(f"[ENV] Local server ready after {i+1}s", flush=True) break except Exception: pass time.sleep(1) else: raise RuntimeError("All connection strategies failed") env = GenericEnvClient(base_url="http://localhost:8765") await env.connect() print("[ENV] Local server WebSocket connected!", flush=True) return env def cleanup(): """Clean up server process if we started one.""" global _server_proc if _server_proc is not None: try: _server_proc.terminate() _server_proc.wait(timeout=5) except Exception: try: _server_proc.kill() except Exception: pass _server_proc = None # --------------------------------------------------------------------------- # Main — wrapped in try/except to ALWAYS emit [START]/[END] for every task # --------------------------------------------------------------------------- async def main() -> None: use_llm = bool(API_KEY) client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY) if use_llm else None print(f"[CONFIG] API_BASE_URL={API_BASE_URL} MODEL={MODEL_NAME} USE_LLM={use_llm} IMAGE={IMAGE_NAME}", flush=True) env = None try: env = await connect_env() scores: Dict[str, float] = {} for task in TASKS: try: score, steps, rewards = await run_task(env, client, task, use_llm) success = score > 0.0 log_end(success=success, steps=steps, score=score, rewards=rewards) scores[task] = round(score, 4) except Exception as exc: log_start(task=task, model=MODEL_NAME) log_end(success=False, steps=0, score=0.0, rewards=[]) print(f"ERROR in task {task}: {exc}", flush=True) scores[task] = 0.0 print(f"\nFinal scores:\n{json.dumps(scores, indent=2)}", flush=True) except Exception as exc: # Connection completely failed — emit START/END for all tasks print(f"FATAL: Could not connect to environment: {exc}", flush=True) traceback.print_exc() for task in TASKS: log_start(task=task, model=MODEL_NAME) log_end(success=False, steps=0, score=0.0, rewards=[]) finally: if env is not None: try: await env.close() except Exception as e: print(f"[DEBUG] env.close() error: {e}", flush=True) cleanup() if __name__ == "__main__": asyncio.run(main())