#!/usr/bin/env python3 """Inference script for the PyTorch Training Run Debugger. Required environment variables (injected by evaluator): API_BASE_URL — LLM API endpoint (must have default) MODEL_NAME — Model identifier (must have default) HF_TOKEN — API token (mandatory, no default) """ from __future__ import annotations import asyncio import json import os import sys from typing import List, Optional from openai import OpenAI from openenv.core import GenericAction, GenericEnvClient # --------------------------------------------------------------------------- # Configuration — EXACTLY per hackathon spec # --------------------------------------------------------------------------- API_BASE_URL = os.getenv("API_BASE_URL", "https://api.openai.com/v1") MODEL_NAME = os.getenv("MODEL_NAME", "gpt-4o") HF_TOKEN = os.getenv("HF_TOKEN") IMAGE_NAME = os.getenv("IMAGE_NAME") or os.getenv("LOCAL_IMAGE_NAME") ENV_URL = os.getenv("ENV_URL", "https://ujjwalpardeshi-pytorch-training-debugger.hf.space") BENCHMARK = "pytorch-training-debugger" MAX_STEPS = 25 SUCCESS_SCORE_THRESHOLD = 0.5 TEMPERATURE = 0.0 MAX_TOKENS = 300 # All tasks to run ALL_TASK_IDS = ["task_001", "task_002", "task_003", "task_004", "task_005", "task_006", "task_007"] # --------------------------------------------------------------------------- # Structured logging — EXACTLY per hackathon spec # --------------------------------------------------------------------------- def log_start(task: str, env: str, model: str) -> None: print(f"[START] task={task} env={env} model={model}", flush=True) def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None: error_val = error if error else "null" done_val = str(done).lower() clean_action = action.replace("\n", " ").replace("\r", " ") print( f"[STEP] step={step} action={clean_action} reward={reward:.2f} done={done_val} error={error_val}", flush=True, ) def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None: rewards_str = ",".join(f"{r:.2f}" for r in rewards) print( f"[END] success={str(success).lower()} steps={steps} score={score:.2f} rewards={rewards_str}", flush=True, ) # --------------------------------------------------------------------------- # System prompt # --------------------------------------------------------------------------- SYSTEM_PROMPT = """You are an expert ML engineer debugging a PyTorch training run. You are interacting with an environment that simulates a broken training job. Available actions (respond with JSON only, no explanation): - {"action_type": "inspect_gradients"} - View gradient statistics per layer - {"action_type": "inspect_data_batch"} - View data batch statistics - {"action_type": "inspect_model_modes"} - View model layer modes (train/eval) - {"action_type": "inspect_model_weights"} - View model weight statistics - {"action_type": "inspect_code"} - View PyTorch training code - {"action_type": "modify_config", "target": "", "value": } - {"action_type": "add_callback"} - Add gradient clipping/scheduler - {"action_type": "patch_data_loader"} - Fix data pipeline issues - {"action_type": "fix_model_mode"} - Call model.train() - {"action_type": "fix_code", "line": , "replacement": ""} - {"action_type": "restart_run"} - Restart training (requires a fix first) - {"action_type": "mark_diagnosed", "diagnosis": ""} - Submit diagnosis Valid diagnoses: lr_too_high, vanishing_gradients, data_leakage, \ overfitting, batchnorm_eval_mode, code_bug, scheduler_misconfigured IMPORTANT: Respond with ONLY a valid JSON action object.""" def _build_obs_summary(obs: dict) -> dict: """Build a compact observation summary for the LLM context.""" summary: dict = {"available_actions": obs.get("available_actions", [])} if obs.get("error_log"): summary["error_log"] = obs["error_log"] if obs.get("training_loss_history"): summary["loss_trend"] = obs["training_loss_history"][:5] if obs.get("val_accuracy_history"): summary["val_acc_trend"] = obs["val_accuracy_history"][:5] if obs.get("gradient_stats"): summary["gradient_stats"] = [ { "layer": g.get("layer_name", ""), "mean_norm": round(g.get("mean_norm", 0), 4), "exploding": g.get("is_exploding", False), "vanishing": g.get("is_vanishing", False), } for g in obs["gradient_stats"] ] if obs.get("data_batch_stats"): dbs = obs["data_batch_stats"] summary["data_overlap"] = dbs.get("class_overlap_score", 0) summary["duplicate_ratio"] = dbs.get("duplicate_ratio", 0) if obs.get("model_mode_info"): summary["model_modes"] = obs["model_mode_info"] if obs.get("model_weight_stats"): summary["weight_stats"] = [ { "layer": w.get("layer_name", ""), "norm": round(w.get("weight_norm", 0), 4), } for w in obs["model_weight_stats"] ] if obs.get("code_snippet"): cs = obs["code_snippet"] summary["code"] = cs.get("code", "")[:600] summary["hint"] = cs.get("hint", "") if obs.get("notes"): summary["notes"] = obs["notes"] return summary def get_model_message( client: OpenAI, step: int, last_obs_summary: dict, last_reward: float, history: List[str], ) -> str: """Get next action from the LLM with retry logic.""" history_ctx = "\n".join(history[-5:]) if history else "No previous steps." user_content = ( f"Step {step}. Last reward: {last_reward:+.2f}\n" f"Recent history:\n{history_ctx}\n\n" f"Current observation:\n" f"{json.dumps(last_obs_summary, indent=2, default=str)}\n\n" "What action should you take next? Respond with JSON only." ) max_retries = 3 for attempt in range(max_retries): try: completion = client.chat.completions.create( model=MODEL_NAME, messages=[ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": user_content}, ], temperature=TEMPERATURE, max_tokens=MAX_TOKENS, ) text = (completion.choices[0].message.content or "").strip() if text: return text except Exception as exc: print(f"[DEBUG] Model request failed (attempt {attempt+1}): {exc}", flush=True) if attempt < max_retries - 1: import time time.sleep((attempt + 1) * 2) else: raise return '{"action_type": "inspect_gradients"}' def parse_action(raw: str) -> str: """Clean up LLM output to extract JSON action string.""" text = raw.strip().strip("`").strip() if text.startswith("json"): text = text[4:].strip() try: json.loads(text) return text except json.JSONDecodeError: return '{"action_type": "inspect_gradients"}' async def run_task(env: GenericEnvClient, client: OpenAI, task_id: str) -> None: """Run a single task episode with [START]/[END] logging.""" history: List[str] = [] rewards: List[float] = [] steps_taken = 0 score = 0.01 success = False log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME) try: result = await env.reset(task_id=task_id, seed=42) obs = result.observation last_reward = 0.0 for step in range(1, MAX_STEPS + 1): if result.done: break obs_summary = _build_obs_summary(obs) raw = get_model_message(client, step, obs_summary, last_reward, history) action_str = parse_action(raw) action = GenericAction(**json.loads(action_str)) result = await env.step(action) obs = result.observation reward = result.reward or 0.0 done = result.done error = ( obs.get("notes") if "invalid" in str(obs.get("notes", "")).lower() else None ) rewards.append(reward) steps_taken = step last_reward = reward log_step(step=step, action=action_str, reward=reward, done=done, error=error) history.append(f"Step {step}: {action_str!r} -> reward {reward:+.2f}") if done: break # Score: clamp strictly between 0 and 1 (evaluator rejects 0.0 and 1.0) total_reward = sum(rewards) score = round(min(max(total_reward, 0.01), 0.99), 2) success = score >= SUCCESS_SCORE_THRESHOLD except Exception as exc: print(f"[DEBUG] Task {task_id} error: {exc}", flush=True) score = 0.01 finally: log_end(success=success, steps=steps_taken, score=score, rewards=rewards) async def main() -> None: # Optional: run specific task or all tasks target_task = os.getenv("TASK_NAME") tasks_to_run = [target_task] if target_task else ALL_TASK_IDS # Initialize client EXACTLY as spec: api_key=HF_TOKEN client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN) print(f"[DEBUG] API_BASE_URL={API_BASE_URL}", flush=True) print(f"[DEBUG] HF_TOKEN={'set' if HF_TOKEN else 'NOT SET'}", flush=True) print(f"[DEBUG] MODEL_NAME={MODEL_NAME}", flush=True) print(f"[DEBUG] Tasks to run: {tasks_to_run}", flush=True) # Mandatory LLM proxy call — ensures at least one call goes through try: test_resp = client.chat.completions.create( model=MODEL_NAME, messages=[{"role": "user", "content": "Say OK"}], max_tokens=5, ) print(f"[DEBUG] LLM proxy test OK: {test_resp.choices[0].message.content}", flush=True) except Exception as exc: print(f"[DEBUG] LLM proxy test failed: {exc}", flush=True) completed_tasks: set = set() env = None try: if IMAGE_NAME: env = await GenericEnvClient.from_docker_image(IMAGE_NAME) else: env = GenericEnvClient( base_url=ENV_URL, message_timeout_s=120.0, ) await env.connect() for task_id in tasks_to_run: await run_task(env, client, task_id) completed_tasks.add(task_id) except Exception as exc: print(f"[DEBUG] Fatal error: {exc}", flush=True) finally: # Emit [START]/[END] for any tasks that didn't run for task_id in tasks_to_run: if task_id not in completed_tasks: log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME) log_end(success=False, steps=0, score=0.01, rewards=[]) if env is not None: try: await env.close() except Exception: pass if __name__ == "__main__": asyncio.run(main())