data-cleaning-openenv / inference.py
Manas281's picture
Update inference.py
b871adb verified
"""
inference.py β€” OpenEnv submission file
"""
import os, json, sys
from openai import OpenAI
from data_cleaning_env import DataCleaningEnvironment, CleaningAction
# ── Config ────────────────────────────────────────────────────────────────────
API_BASE_URL = os.getenv("API_BASE_URL", "https://api.groq.com/openai/v1")
MODEL_NAME = os.getenv("MODEL_NAME", "openai/gpt-oss-120b") # Groq model name
HF_TOKEN = os.getenv("HF_TOKEN")
if HF_TOKEN is None:
raise ValueError("HF_TOKEN environment variable is required")
client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)
SYSTEM_PROMPT = (
"You are a data cleaning expert. "
"Respond ONLY with a valid JSON object, no markdown, no explanation.\n"
'Format: {"action_type": "<remove_nulls|fix_dates|remove_outliers>", "column": "<col_or_null>"}'
)
TASK_NAMES = {1: "remove_nulls", 2: "fix_dates", 3: "remove_outliers"}
ENV_NAME = "data_cleaning"
def parse_llm_response(text: str, task_id: int) -> CleaningAction:
text = text.strip().replace("```json", "").replace("```", "").strip()
try:
data = json.loads(text)
action_type = data.get("action_type", "remove_nulls")
if action_type not in ["remove_nulls", "fix_dates", "remove_outliers"]:
action_type = "remove_nulls"
return CleaningAction(
task_id=task_id,
action_type=action_type,
column=data.get("column")
)
except Exception:
if "date" in text.lower():
return CleaningAction(task_id=task_id, action_type="fix_dates", column="hire_date")
elif "outlier" in text.lower():
return CleaningAction(task_id=task_id, action_type="remove_outliers", column="all")
return CleaningAction(task_id=task_id, action_type="remove_nulls")
def heuristic_action(task_id: int, obs) -> CleaningAction:
if obs.null_count > 0:
return CleaningAction(task_id=task_id, action_type="remove_nulls")
elif obs.date_format_errors > 0:
return CleaningAction(task_id=task_id, action_type="fix_dates", column="hire_date")
else:
return CleaningAction(task_id=task_id, action_type="remove_outliers", column="all")
def run_episode(task_id: int, seed: int):
env = DataCleaningEnvironment(task_id=task_id, seed=seed)
obs = env.reset()
error_str = "null"
action = None
user_msg = (
f"Task {task_id}: {obs.task_description}\n"
f"Nulls: {obs.null_count}, Date errors: {obs.date_format_errors}, "
f"Outliers: {obs.outlier_count}\n"
f"Preview:\n{obs.dataset_preview}\n"
f"Respond with JSON only."
)
# ── Primary: LLM via OpenAI client ───────────────────────────────────────
try:
resp = client.chat.completions.create(
model=MODEL_NAME,
messages=[
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": user_msg},
],
max_tokens=100,
temperature=0.1,
)
action = parse_llm_response(resp.choices[0].message.content, task_id)
except Exception as e:
error_str = str(e).replace("\n", " ")
# ── Fallback: heuristic if LLM failed ────────────────────────────────────
if action is None:
action = heuristic_action(task_id, obs)
col = action.column if action.column else "null"
action_str = f"{action.action_type}('{col}')"
_, reward, done, _ = env.step(action)
if hasattr(env, "close"):
env.close()
return float(reward), action_str, bool(done), error_str
def main():
all_results = {}
n_episodes = int(os.getenv("N_EPISODES", "10"))
for task_id in [1, 2, 3]:
task_name = TASK_NAMES[task_id]
print(f"[START] task={task_name} env={ENV_NAME} model={MODEL_NAME}", flush=True)
episode_rewards = []
success = False
score = 0.0
try:
for seed in range(n_episodes):
reward, action_str, done, error_str = run_episode(task_id, seed)
episode_rewards.append(reward)
print(
f"[STEP] step={seed + 1} action={action_str} "
f"reward={reward:.2f} done={str(done).lower()} error={error_str}",
flush=True,
)
score = sum(episode_rewards) / len(episode_rewards)
score = round(min(max(score, 0.0), 1.0), 2)
all_results[task_id] = score
success = score > 0.0
finally:
rewards_str = ",".join(f"{r:.2f}" for r in episode_rewards)
# ── [END] with score= field as required ──────────────────────────
print(
f"[END] success={str(success).lower()} "
f"steps={len(episode_rewards)} "
f"score={score:.2f} "
f"rewards={rewards_str}",
flush=True,
)
overall = round(sum(all_results.values()) / max(len(all_results), 1), 4)
with open("scores.json", "w") as f:
json.dump({"tasks": all_results, "overall": overall}, f, indent=2)
print(f"[SUMMARY] overall_score={overall} task_scores={all_results}", flush=True)
if __name__ == "__main__":
main()