Spaces:
Sleeping
Sleeping
| """ | |
| inference.py β OpenEnv submission file | |
| """ | |
| import os, json, sys | |
| from openai import OpenAI | |
| from data_cleaning_env import DataCleaningEnvironment, CleaningAction | |
| # ββ Config ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| API_BASE_URL = os.getenv("API_BASE_URL", "https://api.groq.com/openai/v1") | |
| MODEL_NAME = os.getenv("MODEL_NAME", "openai/gpt-oss-120b") # Groq model name | |
| HF_TOKEN = os.getenv("HF_TOKEN") | |
| if HF_TOKEN is None: | |
| raise ValueError("HF_TOKEN environment variable is required") | |
| client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN) | |
| SYSTEM_PROMPT = ( | |
| "You are a data cleaning expert. " | |
| "Respond ONLY with a valid JSON object, no markdown, no explanation.\n" | |
| 'Format: {"action_type": "<remove_nulls|fix_dates|remove_outliers>", "column": "<col_or_null>"}' | |
| ) | |
| TASK_NAMES = {1: "remove_nulls", 2: "fix_dates", 3: "remove_outliers"} | |
| ENV_NAME = "data_cleaning" | |
| def parse_llm_response(text: str, task_id: int) -> CleaningAction: | |
| text = text.strip().replace("```json", "").replace("```", "").strip() | |
| try: | |
| data = json.loads(text) | |
| action_type = data.get("action_type", "remove_nulls") | |
| if action_type not in ["remove_nulls", "fix_dates", "remove_outliers"]: | |
| action_type = "remove_nulls" | |
| return CleaningAction( | |
| task_id=task_id, | |
| action_type=action_type, | |
| column=data.get("column") | |
| ) | |
| except Exception: | |
| if "date" in text.lower(): | |
| return CleaningAction(task_id=task_id, action_type="fix_dates", column="hire_date") | |
| elif "outlier" in text.lower(): | |
| return CleaningAction(task_id=task_id, action_type="remove_outliers", column="all") | |
| return CleaningAction(task_id=task_id, action_type="remove_nulls") | |
| def heuristic_action(task_id: int, obs) -> CleaningAction: | |
| if obs.null_count > 0: | |
| return CleaningAction(task_id=task_id, action_type="remove_nulls") | |
| elif obs.date_format_errors > 0: | |
| return CleaningAction(task_id=task_id, action_type="fix_dates", column="hire_date") | |
| else: | |
| return CleaningAction(task_id=task_id, action_type="remove_outliers", column="all") | |
| def run_episode(task_id: int, seed: int): | |
| env = DataCleaningEnvironment(task_id=task_id, seed=seed) | |
| obs = env.reset() | |
| error_str = "null" | |
| action = None | |
| user_msg = ( | |
| f"Task {task_id}: {obs.task_description}\n" | |
| f"Nulls: {obs.null_count}, Date errors: {obs.date_format_errors}, " | |
| f"Outliers: {obs.outlier_count}\n" | |
| f"Preview:\n{obs.dataset_preview}\n" | |
| f"Respond with JSON only." | |
| ) | |
| # ββ Primary: LLM via OpenAI client βββββββββββββββββββββββββββββββββββββββ | |
| try: | |
| resp = client.chat.completions.create( | |
| model=MODEL_NAME, | |
| messages=[ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": user_msg}, | |
| ], | |
| max_tokens=100, | |
| temperature=0.1, | |
| ) | |
| action = parse_llm_response(resp.choices[0].message.content, task_id) | |
| except Exception as e: | |
| error_str = str(e).replace("\n", " ") | |
| # ββ Fallback: heuristic if LLM failed ββββββββββββββββββββββββββββββββββββ | |
| if action is None: | |
| action = heuristic_action(task_id, obs) | |
| col = action.column if action.column else "null" | |
| action_str = f"{action.action_type}('{col}')" | |
| _, reward, done, _ = env.step(action) | |
| if hasattr(env, "close"): | |
| env.close() | |
| return float(reward), action_str, bool(done), error_str | |
| def main(): | |
| all_results = {} | |
| n_episodes = int(os.getenv("N_EPISODES", "10")) | |
| for task_id in [1, 2, 3]: | |
| task_name = TASK_NAMES[task_id] | |
| print(f"[START] task={task_name} env={ENV_NAME} model={MODEL_NAME}", flush=True) | |
| episode_rewards = [] | |
| success = False | |
| score = 0.0 | |
| try: | |
| for seed in range(n_episodes): | |
| reward, action_str, done, error_str = run_episode(task_id, seed) | |
| episode_rewards.append(reward) | |
| print( | |
| f"[STEP] step={seed + 1} action={action_str} " | |
| f"reward={reward:.2f} done={str(done).lower()} error={error_str}", | |
| flush=True, | |
| ) | |
| score = sum(episode_rewards) / len(episode_rewards) | |
| score = round(min(max(score, 0.0), 1.0), 2) | |
| all_results[task_id] = score | |
| success = score > 0.0 | |
| finally: | |
| rewards_str = ",".join(f"{r:.2f}" for r in episode_rewards) | |
| # ββ [END] with score= field as required ββββββββββββββββββββββββββ | |
| print( | |
| f"[END] success={str(success).lower()} " | |
| f"steps={len(episode_rewards)} " | |
| f"score={score:.2f} " | |
| f"rewards={rewards_str}", | |
| flush=True, | |
| ) | |
| overall = round(sum(all_results.values()) / max(len(all_results), 1), 4) | |
| with open("scores.json", "w") as f: | |
| json.dump({"tasks": all_results, "overall": overall}, f, indent=2) | |
| print(f"[SUMMARY] overall_score={overall} task_scores={all_results}", flush=True) | |
| if __name__ == "__main__": | |
| main() |