""" inference.py — OpenEnv submission file """ import os, json, sys from openai import OpenAI from data_cleaning_env import DataCleaningEnvironment, CleaningAction # ── Config ──────────────────────────────────────────────────────────────────── API_BASE_URL = os.getenv("API_BASE_URL", "https://api.groq.com/openai/v1") MODEL_NAME = os.getenv("MODEL_NAME", "openai/gpt-oss-120b") # Groq model name HF_TOKEN = os.getenv("HF_TOKEN") if HF_TOKEN is None: raise ValueError("HF_TOKEN environment variable is required") client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN) SYSTEM_PROMPT = ( "You are a data cleaning expert. " "Respond ONLY with a valid JSON object, no markdown, no explanation.\n" 'Format: {"action_type": "", "column": ""}' ) TASK_NAMES = {1: "remove_nulls", 2: "fix_dates", 3: "remove_outliers"} ENV_NAME = "data_cleaning" def parse_llm_response(text: str, task_id: int) -> CleaningAction: text = text.strip().replace("```json", "").replace("```", "").strip() try: data = json.loads(text) action_type = data.get("action_type", "remove_nulls") if action_type not in ["remove_nulls", "fix_dates", "remove_outliers"]: action_type = "remove_nulls" return CleaningAction( task_id=task_id, action_type=action_type, column=data.get("column") ) except Exception: if "date" in text.lower(): return CleaningAction(task_id=task_id, action_type="fix_dates", column="hire_date") elif "outlier" in text.lower(): return CleaningAction(task_id=task_id, action_type="remove_outliers", column="all") return CleaningAction(task_id=task_id, action_type="remove_nulls") def heuristic_action(task_id: int, obs) -> CleaningAction: if obs.null_count > 0: return CleaningAction(task_id=task_id, action_type="remove_nulls") elif obs.date_format_errors > 0: return CleaningAction(task_id=task_id, action_type="fix_dates", column="hire_date") else: return CleaningAction(task_id=task_id, action_type="remove_outliers", column="all") def run_episode(task_id: int, seed: int): env = DataCleaningEnvironment(task_id=task_id, seed=seed) obs = env.reset() error_str = "null" action = None user_msg = ( f"Task {task_id}: {obs.task_description}\n" f"Nulls: {obs.null_count}, Date errors: {obs.date_format_errors}, " f"Outliers: {obs.outlier_count}\n" f"Preview:\n{obs.dataset_preview}\n" f"Respond with JSON only." ) # ── Primary: LLM via OpenAI client ─────────────────────────────────────── try: resp = client.chat.completions.create( model=MODEL_NAME, messages=[ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": user_msg}, ], max_tokens=100, temperature=0.1, ) action = parse_llm_response(resp.choices[0].message.content, task_id) except Exception as e: error_str = str(e).replace("\n", " ") # ── Fallback: heuristic if LLM failed ──────────────────────────────────── if action is None: action = heuristic_action(task_id, obs) col = action.column if action.column else "null" action_str = f"{action.action_type}('{col}')" _, reward, done, _ = env.step(action) if hasattr(env, "close"): env.close() return float(reward), action_str, bool(done), error_str def main(): all_results = {} n_episodes = int(os.getenv("N_EPISODES", "10")) for task_id in [1, 2, 3]: task_name = TASK_NAMES[task_id] print(f"[START] task={task_name} env={ENV_NAME} model={MODEL_NAME}", flush=True) episode_rewards = [] success = False score = 0.0 try: for seed in range(n_episodes): reward, action_str, done, error_str = run_episode(task_id, seed) episode_rewards.append(reward) print( f"[STEP] step={seed + 1} action={action_str} " f"reward={reward:.2f} done={str(done).lower()} error={error_str}", flush=True, ) score = sum(episode_rewards) / len(episode_rewards) score = round(min(max(score, 0.0), 1.0), 2) all_results[task_id] = score success = score > 0.0 finally: rewards_str = ",".join(f"{r:.2f}" for r in episode_rewards) # ── [END] with score= field as required ────────────────────────── print( f"[END] success={str(success).lower()} " f"steps={len(episode_rewards)} " f"score={score:.2f} " f"rewards={rewards_str}", flush=True, ) overall = round(sum(all_results.values()) / max(len(all_results), 1), 4) with open("scores.json", "w") as f: json.dump({"tasks": all_results, "overall": overall}, f, indent=2) print(f"[SUMMARY] overall_score={overall} task_scores={all_results}", flush=True) if __name__ == "__main__": main()