#!/usr/bin/env python3 """LLM baseline agent using Google Gemini (via OpenAI-compatible SDK). Requires GEMINI_API_KEY environment variable (or pass via --api-key). Uses temperature=0.0 for near-deterministic behavior. Usage: GEMINI_API_KEY=... python baseline_inference.py python baseline_inference.py --api-key YOUR_KEY """ from __future__ import annotations import argparse import json import os import sys from pathlib import Path # Load .env file if present _env_path = Path(__file__).parent / ".env" if _env_path.exists(): for line in _env_path.read_text().splitlines(): line = line.strip() if line and not line.startswith("#") and "=" in line: key, _, value = line.partition("=") os.environ.setdefault(key.strip(), value.strip()) try: from openai import OpenAI except ImportError: print("Error: openai package not installed. Run: pip install openai") sys.exit(1) from ml_training_debugger.models import MLTrainingAction from server.environment import MLTrainingEnvironment ALL_TASKS = [ "task_001", "task_002", "task_003", "task_004", "task_005", "task_006", "task_007", ] SYSTEM_PROMPT = """You are an expert ML engineer debugging a PyTorch training run. You are interacting with an environment that simulates a broken training job. Available actions (respond with JSON only, no explanation): - {"action_type": "inspect_gradients"} - View gradient statistics per layer - {"action_type": "inspect_data_batch"} - View data batch statistics and confusion matrix - {"action_type": "inspect_model_modes"} - View model layer modes (train/eval) - {"action_type": "inspect_model_weights"} - View model weight statistics - {"action_type": "inspect_code"} - View PyTorch training code - {"action_type": "modify_config", "target": "", "value": } - Change a hyperparameter - {"action_type": "add_callback"} - Add gradient clipping/scheduler - {"action_type": "patch_data_loader"} - Fix data pipeline issues - {"action_type": "fix_model_mode"} - Call model.train() - {"action_type": "fix_code", "line": , "replacement": ""} - Fix a code line - {"action_type": "restart_run"} - Restart training (requires a fix first) - {"action_type": "mark_diagnosed", "diagnosis": ""} - Submit diagnosis Valid diagnoses: lr_too_high, vanishing_gradients, data_leakage, overfitting, batchnorm_eval_mode, code_bug, scheduler_misconfigured Strategy: 1. First investigate by inspecting gradients, data, model modes, and code 2. Form a hypothesis based on the evidence gathered 3. Apply the correct fix for the identified root cause 4. Restart training to verify the fix works 5. Submit your diagnosis IMPORTANT: Respond with ONLY a valid JSON action object. No explanation, no markdown, no code blocks.""" def run_llm_episode(task_id: str, client: OpenAI, model_name: str) -> float: """Run one LLM agent episode.""" env = MLTrainingEnvironment() obs = env.reset(seed=42, episode_id=f"llm_{task_id}", task_id=task_id) initial_obs = { "training_loss_history": obs.training_loss_history[:5], "val_accuracy_history": obs.val_accuracy_history[:5], "current_config": obs.current_config.model_dump(), "error_log": obs.error_log, "available_actions": obs.available_actions, "notes": obs.notes, "gpu_memory_used_gb": obs.gpu_memory_used_gb, } messages = [ {"role": "system", "content": SYSTEM_PROMPT}, { "role": "user", "content": f"New episode started for a broken PyTorch training run.\n\nInitial observation:\n{json.dumps(initial_obs, indent=2, default=str)}", }, ] for step in range(25): if obs.done: break try: response = client.chat.completions.create( model=model_name, messages=messages, temperature=0.0, max_tokens=300, ) action_text = response.choices[0].message.content.strip() except Exception as e: print(f" Step {step}: API error — {e}", file=sys.stderr) break # Clean up common LLM formatting issues action_text = action_text.strip("`").strip() if action_text.startswith("json"): action_text = action_text[4:].strip() messages.append({"role": "assistant", "content": action_text}) try: action_data = json.loads(action_text) action = MLTrainingAction(**action_data) except (json.JSONDecodeError, Exception) as e: messages.append( { "role": "user", "content": f"Invalid action format: {e}. Respond with ONLY valid JSON.", } ) continue obs = env.step(action) obs_summary: dict = { "reward": obs.reward, "done": obs.done, "step": obs.episode_state.step_count, "available_actions": obs.available_actions, } if obs.error_log: obs_summary["error_log"] = obs.error_log if obs.gradient_stats: obs_summary["gradient_stats"] = [ { "layer": g.layer_name, "mean_norm": round(g.mean_norm, 4), "exploding": g.is_exploding, "vanishing": g.is_vanishing, } for g in obs.gradient_stats ] if obs.data_batch_stats: obs_summary["data_overlap"] = obs.data_batch_stats.class_overlap_score obs_summary["duplicate_ratio"] = obs.data_batch_stats.duplicate_ratio if obs.model_mode_info: obs_summary["model_modes"] = obs.model_mode_info if obs.code_snippet: obs_summary["code"] = obs.code_snippet.code[:600] obs_summary["hint"] = obs.code_snippet.hint messages.append( { "role": "user", "content": f"Observation after your action:\n{json.dumps(obs_summary, indent=2, default=str)}", } ) session = env._get_session() return session.last_score if session and session.last_score is not None else 0.0 PROVIDERS = { "groq": { "env_key": "GROQ_API_KEY", "base_url": "https://api.groq.com/openai/v1", "default_model": "llama-3.3-70b-versatile", }, "cerebras": { "env_key": "CEREBRAS_API_KEY", "base_url": "https://api.cerebras.ai/v1", "default_model": "llama3.1-8b", }, "gemini": { "env_key": "GEMINI_API_KEY", "base_url": "https://generativelanguage.googleapis.com/v1beta/openai/", "default_model": "gemini-2.0-flash", }, "openai": { "env_key": "OPENAI_API_KEY", "base_url": None, "default_model": "gpt-4o", }, } def main() -> None: parser = argparse.ArgumentParser(description="LLM baseline agent") parser.add_argument("--url", default="http://localhost:7860") parser.add_argument("--api-key", default=None, help="API key") parser.add_argument( "--provider", default="groq", choices=list(PROVIDERS.keys()), help="LLM provider (default: groq)", ) parser.add_argument("--model", default=None, help="Model name (auto-detected from provider)") args = parser.parse_args() prov = PROVIDERS[args.provider] api_key = args.api_key or os.environ.get(prov["env_key"]) if not api_key: print(f"Error: Set {prov['env_key']} env var or pass --api-key") sys.exit(1) model_name = args.model or prov["default_model"] client_kwargs: dict = {"api_key": api_key} if prov["base_url"]: client_kwargs["base_url"] = prov["base_url"] client = OpenAI(**client_kwargs) scores: dict[str, float] = {} print(f"Running LLM baseline with {args.provider}/{model_name}...", file=sys.stderr) for task_id in ALL_TASKS: try: score = run_llm_episode(task_id, client, model_name) scores[task_id] = round(score, 4) print(f" {task_id}: {score:.4f}", file=sys.stderr) except Exception as e: print(f" {task_id}: ERROR — {e}", file=sys.stderr) scores[task_id] = 0.0 print(json.dumps(scores, indent=2)) if __name__ == "__main__": main()