Spaces:
Sleeping
Sleeping
| """ | |
| Inference Script for Data Cleaning RL Environment | |
| =================================== | |
| MANDATORY | |
| - Before submitting, ensure the following variables are defined in your environment configuration: | |
| API_BASE_URL The API endpoint for the LLM. | |
| MODEL_NAME The model identifier to use for inference. | |
| HF_TOKEN Your Hugging Face / API key. | |
| LOCAL_IMAGE_NAME The name of the local image to use for the environment if you are using from_docker_image() | |
| - Defaults are set only for API_BASE_URL and MODEL_NAME: | |
| API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1") | |
| MODEL_NAME = os.getenv("MODEL_NAME", "gpt-4o-mini") | |
| - The inference script must be named `inference.py` and placed in the root directory of the project | |
| - Participants must use OpenAI Client for all LLM calls using above variables | |
| STDOUT FORMAT | |
| - The script must emit exactly three line types to stdout, in this order: | |
| [START] task=<task_name> env=<benchmark> model=<model_name> | |
| [STEP] step=<n> action=<action_str> reward=<0.00> done=<true|false> error=<msg|null> | |
| [END] success=<true|false> steps=<n> score=<score> rewards=<r1,r2,...,rn> | |
| """ | |
| import asyncio | |
| import json | |
| import os | |
| import subprocess | |
| import sys | |
| import time | |
| import traceback | |
| from typing import Any, Dict, List, Optional | |
| from openai import OpenAI | |
| from openenv import GenericEnvClient | |
| # --------------------------------------------------------------------------- | |
| # Configuration — from environment variables | |
| # --------------------------------------------------------------------------- | |
| IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME") or os.getenv("IMAGE_NAME") | |
| API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY") | |
| API_BASE_URL = os.getenv("API_BASE_URL") or "https://router.huggingface.co/v1" | |
| MODEL_NAME = os.getenv("MODEL_NAME") or "gpt-4o-mini" | |
| BENCHMARK = "data-cleaning-env" | |
| TASKS = ["easy", "medium", "hard", "expert"] | |
| MAX_STEPS_MAP = {"easy": 20, "medium": 40, "hard": 60, "expert": 80} | |
| # Track server subprocess for cleanup | |
| _server_proc: Optional[subprocess.Popen] = None | |
| # --------------------------------------------------------------------------- | |
| # OpenAI tool definitions for function-calling | |
| # --------------------------------------------------------------------------- | |
| TOOLS = [ | |
| {"type": "function", "function": {"name": "fill_missing", "description": "Fill missing (NaN) values in a column.", "parameters": {"type": "object", "properties": {"column": {"type": "string"}, "strategy": {"type": "string", "enum": ["mean", "median", "mode", "constant"]}}, "required": ["column", "strategy"]}}}, | |
| {"type": "function", "function": {"name": "drop_duplicates", "description": "Drop exact duplicate rows.", "parameters": {"type": "object", "properties": {}, "required": []}}}, | |
| {"type": "function", "function": {"name": "fix_type", "description": "Coerce a column to a target dtype.", "parameters": {"type": "object", "properties": {"column": {"type": "string"}, "dtype": {"type": "string", "enum": ["int", "float", "str"]}}, "required": ["column", "dtype"]}}}, | |
| {"type": "function", "function": {"name": "fix_schema_violation", "description": "Clamp values that violate constraints.", "parameters": {"type": "object", "properties": {"column": {"type": "string"}, "constraint": {"type": "string", "enum": ["non_negative", "clamp_range"]}}, "required": ["column", "constraint"]}}}, | |
| {"type": "function", "function": {"name": "standardize_categories", "description": "Lowercase, strip whitespace, collapse spaces.", "parameters": {"type": "object", "properties": {"column": {"type": "string"}}, "required": ["column"]}}}, | |
| {"type": "function", "function": {"name": "fix_format_regex", "description": "Regex substitution for formatting.", "parameters": {"type": "object", "properties": {"column": {"type": "string"}, "pattern": {"type": "string"}, "replacement": {"type": "string"}}, "required": ["column", "pattern", "replacement"]}}}, | |
| {"type": "function", "function": {"name": "deduplicate_fuzzy", "description": "Replace near-duplicate strings with canonical form.", "parameters": {"type": "object", "properties": {"column": {"type": "string"}, "threshold": {"type": "number"}}, "required": ["column"]}}}, | |
| {"type": "function", "function": {"name": "profile_column", "description": "Get extended stats for a column. Free.", "parameters": {"type": "object", "properties": {"column": {"type": "string"}}, "required": ["column"]}}}, | |
| {"type": "function", "function": {"name": "done", "description": "Signal cleaning is complete.", "parameters": {"type": "object", "properties": {}, "required": []}}}, | |
| ] | |
| # --------------------------------------------------------------------------- | |
| # System prompt for the LLM | |
| # --------------------------------------------------------------------------- | |
| SYSTEM_PROMPT = """\ | |
| You are an expert data-cleaning agent. Clean dirty tabular datasets by calling \ | |
| tool actions to maximize the composite quality score. | |
| GRADING: accuracy(30%) + completeness(25%) + consistency(25%) + format(20%). | |
| STRATEGY (in order): | |
| 1. fill_missing — 'median' for numeric, 'mode' for categorical | |
| 2. standardize_categories — for columns with semantic duplicates | |
| 3. fix_type — coerce columns with type errors to 'float' | |
| 4. fix_schema_violation — fix negatives with 'non_negative' | |
| 5. Call done() when no more improvements possible | |
| AVOID: normalize, drop_outliers. Focus on columns with most issues first.""" | |
| # --------------------------------------------------------------------------- | |
| # Logging helpers (required stdout format) | |
| # --------------------------------------------------------------------------- | |
| def log_start(task: str, model: str) -> None: | |
| print(f"[START] task={task} env={BENCHMARK} model={model}", flush=True) | |
| def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None: | |
| action_clean = action.replace("\n", " ").replace("\r", " ")[:120] | |
| error_str = "null" if error is None else error.replace("\n", " ") | |
| print(f"[STEP] step={step} action={action_clean} reward={reward:.2f} done={str(done).lower()} error={error_str}", flush=True) | |
| def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None: | |
| rewards_str = ",".join(f"{r:.2f}" for r in rewards) | |
| print(f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}", flush=True) | |
| # --------------------------------------------------------------------------- | |
| # Build observation summary for LLM | |
| # --------------------------------------------------------------------------- | |
| def build_user_message(obs: Dict[str, Any], task: str) -> str: | |
| cols = obs.get("columns", []) | |
| issues = obs.get("column_issues", {}) | |
| stats = obs.get("column_stats", {}) | |
| step = obs.get("step", 0) | |
| max_steps = obs.get("max_steps", 0) | |
| reward = obs.get("reward", 0.0) | |
| lines = [f"Task: {task} | Step: {step}/{max_steps} | Last reward: {reward:.2f}", "", "Columns:"] | |
| for col in cols: | |
| ci = issues.get(col, {}) | |
| cs = stats.get(col, {}) | |
| parts = [] | |
| if ci.get("missing_count", 0) > 0: | |
| parts.append(f"missing={ci['missing_count']}") | |
| if ci.get("type_errors", 0) > 0: | |
| parts.append(f"type_errors={ci['type_errors']}") | |
| if ci.get("semantic_duplicate_count", 0) > 0: | |
| parts.append(f"sem_dups={ci['semantic_duplicate_count']}") | |
| if ci.get("format_violation_count", 0) > 0: | |
| parts.append(f"format_violations={ci['format_violation_count']}") | |
| issue_str = ", ".join(parts) if parts else "clean" | |
| is_num = "numeric" if cs.get("mean") is not None else "categorical" | |
| lines.append(f" {col} ({is_num}): [{issue_str}]") | |
| budget = obs.get("budget_remaining") | |
| if budget is not None: | |
| lines.append(f"\nBudget: {budget:.2f}") | |
| lines.append("\nChoose the best next action. Call done() if all issues are resolved.") | |
| return "\n".join(lines) | |
| # --------------------------------------------------------------------------- | |
| # LLM action selection | |
| # --------------------------------------------------------------------------- | |
| def llm_choose_action(client: OpenAI, messages: List[Dict[str, Any]]) -> tuple: | |
| """Returns (action_dict, action_string, tool_call_obj).""" | |
| response = client.chat.completions.create( | |
| model=MODEL_NAME, | |
| messages=messages, | |
| tools=TOOLS, | |
| tool_choice="required", | |
| temperature=0.0, | |
| ) | |
| choice = response.choices[0] | |
| if not choice.message.tool_calls: | |
| raise ValueError("No tool calls in response") | |
| tc = choice.message.tool_calls[0] | |
| args = json.loads(tc.function.arguments or "{}") | |
| payload: Dict[str, Any] = {"action_type": tc.function.name} | |
| for field in ("column", "strategy", "dtype", "method", "constraint", | |
| "new_name", "datetime_format", "threshold", "delimiter", | |
| "column2", "merge_strategy", "pattern", "replacement"): | |
| if field in args: | |
| payload[field] = args[field] | |
| action_str = f"{tc.function.name}({tc.function.arguments})" | |
| return payload, action_str, tc | |
| # --------------------------------------------------------------------------- | |
| # Heuristic fallback (when no LLM key) | |
| # --------------------------------------------------------------------------- | |
| def heuristic_action(obs: Dict[str, Any]) -> Optional[Dict[str, Any]]: | |
| issues = obs.get("column_issues", {}) | |
| columns = obs.get("columns", []) | |
| stats = obs.get("column_stats", {}) | |
| for col in columns: | |
| if issues.get(col, {}).get("missing_count", 0) > 0: | |
| is_num = stats.get(col, {}).get("mean") is not None | |
| return {"action_type": "fill_missing", "column": col, "strategy": "median" if is_num else "mode"} | |
| for col in columns: | |
| if issues.get(col, {}).get("semantic_duplicate_count", 0) > 0: | |
| return {"action_type": "standardize_categories", "column": col} | |
| for col in columns: | |
| if issues.get(col, {}).get("type_errors", 0) > 0: | |
| return {"action_type": "fix_type", "column": col, "dtype": "float"} | |
| for col in columns: | |
| ci = issues.get(col, {}) | |
| if ci.get("format_violation_count", 0) > 0 and stats.get(col, {}).get("mean") is not None: | |
| return {"action_type": "fix_schema_violation", "column": col, "constraint": "non_negative"} | |
| return None | |
| # --------------------------------------------------------------------------- | |
| # Run one task episode | |
| # --------------------------------------------------------------------------- | |
| async def run_task(env: GenericEnvClient, client: Optional[OpenAI], task: str, use_llm: bool) -> tuple: | |
| """Run a single task. Returns (score, steps, rewards).""" | |
| max_steps = MAX_STEPS_MAP.get(task, 20) | |
| result = await env.reset(task=task) | |
| obs = result.observation | |
| rewards: List[float] = [] | |
| steps_taken = 0 | |
| messages: List[Dict[str, Any]] = [] | |
| if use_llm: | |
| messages = [ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": build_user_message(obs, task)}, | |
| ] | |
| log_start(task=task, model=MODEL_NAME) | |
| try: | |
| for step in range(1, max_steps + 1): | |
| if result.done: | |
| break | |
| action_payload: Dict[str, Any] | |
| action_str: str | |
| error: Optional[str] = None | |
| tc = None | |
| if use_llm: | |
| try: | |
| action_payload, action_str, tc = llm_choose_action(client, messages) | |
| except Exception as exc: | |
| error = f"LLM error: {exc}" | |
| ha = heuristic_action(obs) | |
| action_payload = ha if ha else {"action_type": "done"} | |
| action_str = json.dumps(action_payload, separators=(",", ":")) | |
| else: | |
| ha = heuristic_action(obs) | |
| action_payload = ha if ha else {"action_type": "done"} | |
| action_str = json.dumps(action_payload, separators=(",", ":")) | |
| result = await env.step(action_payload) | |
| obs = result.observation | |
| reward = result.reward or 0.0 | |
| done = result.done | |
| rewards.append(reward) | |
| steps_taken = step | |
| log_step(step=step, action=action_str, reward=reward, done=done, error=error) | |
| # Update LLM conversation | |
| if use_llm and error is None and tc is not None: | |
| messages.append({ | |
| "role": "assistant", | |
| "content": None, | |
| "tool_calls": [{ | |
| "id": tc.id, | |
| "type": "function", | |
| "function": {"name": tc.function.name, "arguments": tc.function.arguments}, | |
| }], | |
| }) | |
| messages.append({ | |
| "role": "tool", | |
| "tool_call_id": tc.id, | |
| "content": build_user_message(obs, task), | |
| }) | |
| if done: | |
| break | |
| except Exception as exc: | |
| log_step(step=steps_taken + 1, action="error", reward=0.0, done=True, error=str(exc)) | |
| rewards.append(0.0) | |
| steps_taken += 1 | |
| # Score = average reward normalized, clamped to [0, 1] | |
| total_reward = sum(rewards) | |
| score = min(max(total_reward / max(max_steps * 0.01, 0.01), 0.0), 1.0) | |
| return score, steps_taken, rewards | |
| # --------------------------------------------------------------------------- | |
| # Environment connection — try multiple strategies | |
| # --------------------------------------------------------------------------- | |
| async def connect_env() -> GenericEnvClient: | |
| """Connect to the environment. Tries multiple strategies in order.""" | |
| global _server_proc | |
| # Strategy 1: from_docker_image if IMAGE_NAME is set | |
| if IMAGE_NAME: | |
| print(f"[ENV] Connecting via from_docker_image({IMAGE_NAME})...", flush=True) | |
| try: | |
| env = await GenericEnvClient.from_docker_image(IMAGE_NAME) | |
| print("[ENV] Docker connection successful!", flush=True) | |
| return env | |
| except Exception as exc: | |
| print(f"[ENV] Docker connection failed: {exc}", flush=True) | |
| print("[ENV] Falling back to other strategies...", flush=True) | |
| # Strategy 2: Try connecting to common ports (validator may already have server running) | |
| for port in [7860, 8000, 8080]: | |
| try: | |
| import requests | |
| r = requests.get(f"http://localhost:{port}/health", timeout=3) | |
| if r.status_code == 200: | |
| print(f"[ENV] Found running server at localhost:{port}", flush=True) | |
| env = GenericEnvClient(base_url=f"http://localhost:{port}") | |
| await env.connect() | |
| print(f"[ENV] WebSocket connected to localhost:{port}!", flush=True) | |
| return env | |
| except Exception: | |
| pass | |
| # Strategy 3: Try HF Space | |
| hf_url = "https://yashmarathe-data-cleaning-openenv.hf.space" | |
| try: | |
| import requests | |
| r = requests.get(f"{hf_url}/health", timeout=10) | |
| if r.status_code == 200: | |
| print(f"[ENV] Connecting to HF Space...", flush=True) | |
| env = GenericEnvClient(base_url=hf_url) | |
| await env.connect() | |
| print("[ENV] HF Space WebSocket connected!", flush=True) | |
| return env | |
| except Exception as exc: | |
| print(f"[ENV] HF Space connection failed: {exc}", flush=True) | |
| # Strategy 4: Start local server | |
| print("[ENV] Starting local server...", flush=True) | |
| _server_proc = subprocess.Popen( | |
| [sys.executable, "-m", "uvicorn", | |
| "data_cleaning_env.server.app:app", | |
| "--host", "0.0.0.0", "--port", "8765"], | |
| stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, | |
| ) | |
| import requests | |
| for i in range(60): | |
| try: | |
| if requests.get("http://localhost:8765/health", timeout=2).status_code == 200: | |
| print(f"[ENV] Local server ready after {i+1}s", flush=True) | |
| break | |
| except Exception: | |
| pass | |
| time.sleep(1) | |
| else: | |
| raise RuntimeError("All connection strategies failed") | |
| env = GenericEnvClient(base_url="http://localhost:8765") | |
| await env.connect() | |
| print("[ENV] Local server WebSocket connected!", flush=True) | |
| return env | |
| def cleanup(): | |
| """Clean up server process if we started one.""" | |
| global _server_proc | |
| if _server_proc is not None: | |
| try: | |
| _server_proc.terminate() | |
| _server_proc.wait(timeout=5) | |
| except Exception: | |
| try: | |
| _server_proc.kill() | |
| except Exception: | |
| pass | |
| _server_proc = None | |
| # --------------------------------------------------------------------------- | |
| # Main — wrapped in try/except to ALWAYS emit [START]/[END] for every task | |
| # --------------------------------------------------------------------------- | |
| async def main() -> None: | |
| use_llm = bool(API_KEY) | |
| client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY) if use_llm else None | |
| print(f"[CONFIG] API_BASE_URL={API_BASE_URL} MODEL={MODEL_NAME} USE_LLM={use_llm} IMAGE={IMAGE_NAME}", flush=True) | |
| env = None | |
| try: | |
| env = await connect_env() | |
| scores: Dict[str, float] = {} | |
| for task in TASKS: | |
| try: | |
| score, steps, rewards = await run_task(env, client, task, use_llm) | |
| success = score > 0.0 | |
| log_end(success=success, steps=steps, score=score, rewards=rewards) | |
| scores[task] = round(score, 4) | |
| except Exception as exc: | |
| log_start(task=task, model=MODEL_NAME) | |
| log_end(success=False, steps=0, score=0.0, rewards=[]) | |
| print(f"ERROR in task {task}: {exc}", flush=True) | |
| scores[task] = 0.0 | |
| print(f"\nFinal scores:\n{json.dumps(scores, indent=2)}", flush=True) | |
| except Exception as exc: | |
| # Connection completely failed — emit START/END for all tasks | |
| print(f"FATAL: Could not connect to environment: {exc}", flush=True) | |
| traceback.print_exc() | |
| for task in TASKS: | |
| log_start(task=task, model=MODEL_NAME) | |
| log_end(success=False, steps=0, score=0.0, rewards=[]) | |
| finally: | |
| if env is not None: | |
| try: | |
| await env.close() | |
| except Exception as e: | |
| print(f"[DEBUG] env.close() error: {e}", flush=True) | |
| cleanup() | |
| if __name__ == "__main__": | |
| asyncio.run(main()) | |