dataclean-openenv / inference.py
GlitchGhost's picture
Handle from_docker_image failure gracefully, fall back to HF Space URL
5419ba4
"""
Inference Script — DataClean Environment
=========================================
MANDATORY environment variables:
API_BASE_URL The API endpoint for the LLM.
MODEL_NAME The model identifier to use for inference.
HF_TOKEN Your Hugging Face / API key.
IMAGE_NAME The name of the local Docker image (if using from_docker_image()).
Defaults are set only for API_BASE_URL and MODEL_NAME.
Uses OpenAI Client for all LLM calls.
STDOUT FORMAT:
[START] task=<task_name> env=<benchmark> model=<model_name>
[STEP] step=<n> action=<action_str> reward=<0.00> done=<true|false> error=<msg|null>
[END] success=<true|false> steps=<n> score=<score> rewards=<r1,r2,...,rn>
"""
import json
import os
import re
import sys
import textwrap
import time
from typing import List, Optional
from openai import OpenAI
# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------
API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/novita/v3/openai")
API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY")
MODEL_NAME = os.getenv("MODEL_NAME", "meta-llama/llama-3.3-70b-instruct")
IMAGE_NAME = os.getenv("IMAGE_NAME")
ENV_BASE_URL = os.getenv("ENV_BASE_URL", "https://glitchghost-dataclean-openenv.hf.space")
BENCHMARK = "dataclean_env"
MAX_STEPS_PER_TASK = {"easy": 12, "medium": 20, "hard": 30}
TEMPERATURE = 0.1
MAX_TOKENS = 400
SUCCESS_SCORE_THRESHOLD = 0.1
SYSTEM_PROMPT = textwrap.dedent("""\
You are an expert data-quality analyst. You are interacting with a data-cleaning
environment. Your goal is to identify and fix all data-quality issues.
After reviewing the data and quality report, respond with EXACTLY ONE action in
valid JSON format. Available actions:
1. Fix a cell value:
{"action_type": "fix_value", "row_index": <int>, "column_name": "<col>", "new_value": "<corrected>"}
2. Delete a duplicate/invalid row:
{"action_type": "delete_row", "row_index": <int>}
3. Fill a missing value:
{"action_type": "fill_missing", "row_index": <int>, "column_name": "<col>", "new_value": "<value>"}
4. Flag a suspicious cell (partial credit):
{"action_type": "flag_anomaly", "row_index": <int>, "column_name": "<col>"}
5. Submit your work (ends the episode):
{"action_type": "submit"}
6. Do nothing this step:
{"action_type": "noop"}
RULES:
- row_index is 0-based and refers to the ORIGINAL row number shown in the table.
- Respond ONLY with the JSON action. No explanations, no markdown, no extra text.
- Fix the most obvious/critical issues first.
- When all issues appear resolved, use submit.
- Dates should be in YYYY-MM-DD format.
- Prices should be plain numbers without $ or commas.
- Product codes should NOT have dashes (e.g., P102 not P-102).
- Emails should be lowercase.
""").strip()
# ---------------------------------------------------------------------------
# Structured logging (must match validator format exactly)
# ---------------------------------------------------------------------------
def log_start(task: str, env: str, model: str) -> None:
print(f"[START] task={task} env={env} model={model}", flush=True)
def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
error_val = error if error else "null"
done_val = str(done).lower()
print(
f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}",
flush=True,
)
def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
rewards_str = ",".join(f"{r:.2f}" for r in rewards)
print(
f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}",
flush=True,
)
# ---------------------------------------------------------------------------
# Environment connection
# ---------------------------------------------------------------------------
def _connect_env():
"""Connect to the DataClean environment."""
from dataclean_env.client import DataCleanEnv
if IMAGE_NAME:
try:
print(f" Starting environment from Docker image: {IMAGE_NAME}", flush=True)
return DataCleanEnv.from_docker_image(image=IMAGE_NAME)
except Exception as exc:
print(f" Docker launch failed ({exc}), falling back to URL", flush=True)
print(f" Connecting to environment at: {ENV_BASE_URL}", flush=True)
return DataCleanEnv(base_url=ENV_BASE_URL)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
ACTION_JSON_RE = re.compile(r"\{[^{}]*\}", re.DOTALL)
def parse_action(text: str) -> dict:
"""Extract the first JSON object from the model response."""
if not text:
return {"action_type": "noop"}
cleaned = re.sub(r"```(?:json)?\s*", "", text)
cleaned = re.sub(r"```", "", cleaned).strip()
try:
obj = json.loads(cleaned)
if isinstance(obj, dict) and "action_type" in obj:
return obj
except (json.JSONDecodeError, ValueError):
pass
for m in ACTION_JSON_RE.finditer(cleaned):
try:
obj = json.loads(m.group(0))
if isinstance(obj, dict) and "action_type" in obj:
return obj
except (json.JSONDecodeError, ValueError):
continue
return {"action_type": "noop"}
def action_to_str(action_dict: dict) -> str:
"""Format an action dict as a compact string for the [STEP] log line."""
at = action_dict.get("action_type", "noop")
ri = action_dict.get("row_index")
col = action_dict.get("column_name")
val = action_dict.get("new_value")
if at in ("fix_value", "fill_missing") and ri is not None and col and val:
return f"{at}(row={ri},col={col},val={val})"
if at == "delete_row" and ri is not None:
return f"{at}(row={ri})"
if at == "flag_anomaly" and ri is not None and col:
return f"{at}(row={ri},col={col})"
return at
def build_user_prompt(obs, step_num: int) -> str:
"""Build the user prompt from the observation."""
if hasattr(obs, "task_description"):
parts = [
f"TASK: {obs.task_description}",
f"DIFFICULTY: {obs.difficulty}",
f"STEP: {step_num}/{obs.max_steps}",
f"CURRENT SCORE: {obs.current_score}",
"",
"CURRENT DATA:",
obs.data_preview or "(no data)",
"",
obs.quality_report or "",
]
history = obs.action_history or []
else:
parts = [
f"TASK: {obs.get('task_description', '')}",
f"DIFFICULTY: {obs.get('difficulty', '')}",
f"STEP: {step_num}/{obs.get('max_steps', '?')}",
f"CURRENT SCORE: {obs.get('current_score', 0.0)}",
"",
"CURRENT DATA:",
obs.get("data_preview", "(no data)"),
"",
obs.get("quality_report", ""),
]
history = obs.get("action_history", [])
if history:
parts.append("")
parts.append("RECENT ACTIONS:")
for h in history[-5:]:
parts.append(f" {h}")
parts.append("")
parts.append("Respond with exactly one JSON action.")
return "\n".join(parts)
# ---------------------------------------------------------------------------
# Run one task
# ---------------------------------------------------------------------------
def run_task(
llm_client: OpenAI,
env_client,
task_name: str,
max_steps: int,
) -> float:
"""Run a single task and return the final score."""
from dataclean_env.models import DataCleanAction
rewards: List[float] = []
steps_taken = 0
score = 0.0
success = False
log_start(task=task_name, env=BENCHMARK, model=MODEL_NAME)
try:
result = env_client.reset(task_name)
obs = result.observation
for step in range(1, max_steps + 1):
if result.done:
break
user_prompt = build_user_prompt(obs, step)
messages = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_prompt},
]
response_text = '{"action_type": "noop"}'
for _attempt in range(3):
try:
completion = llm_client.chat.completions.create(
model=MODEL_NAME,
messages=messages,
temperature=TEMPERATURE,
max_tokens=MAX_TOKENS,
stream=False,
)
response_text = completion.choices[0].message.content or ""
break
except Exception as exc:
if "429" in str(exc) and _attempt < 2:
wait = 5 * (2 ** _attempt)
time.sleep(wait)
continue
response_text = '{"action_type": "noop"}'
break
action_dict = parse_action(response_text)
action = DataCleanAction(**action_dict)
result = env_client.step(action)
obs = result.observation
reward = result.reward or 0.0
done = result.done
rewards.append(reward)
steps_taken = step
log_step(
step=step,
action=action_to_str(action_dict),
reward=reward,
done=done,
error=None,
)
if done:
break
# If agent never submitted, force submit
if not result.done:
steps_taken += 1
result = env_client.step(DataCleanAction(action_type="submit"))
reward = result.reward or 0.0
rewards.append(reward)
log_step(
step=steps_taken,
action="submit",
reward=reward,
done=True,
error=None,
)
score = rewards[-1] if rewards else 0.0
score = min(max(score, 0.0), 1.0)
success = score >= SUCCESS_SCORE_THRESHOLD
finally:
log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
return score
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main() -> None:
if not API_KEY:
print("ERROR: HF_TOKEN or API_KEY environment variable not set", flush=True)
sys.exit(1)
print("DataClean Environment - Inference", flush=True)
print(f" API: {API_BASE_URL}", flush=True)
print(f" Model: {MODEL_NAME}", flush=True)
llm_client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
env_client = _connect_env()
scores = {}
try:
for task_name in ["easy", "medium", "hard"]:
max_steps = MAX_STEPS_PER_TASK[task_name]
score = run_task(llm_client, env_client, task_name, max_steps)
scores[task_name] = score
finally:
env_client.close()
print(f"\n{'='*60}", flush=True)
print(" FINAL RESULTS", flush=True)
print(f"{'='*60}", flush=True)
for name, score in scores.items():
print(f" {name:8s}: {score:.3f}", flush=True)
avg = sum(scores.values()) / len(scores) if scores else 0.0
print(f" {'AVERAGE':8s}: {avg:.3f}", flush=True)
print(f"{'='*60}", flush=True)
if __name__ == "__main__":
main()