| |
| """LLM baseline agent using Google Gemini (via OpenAI-compatible SDK). |
| |
| Requires GEMINI_API_KEY environment variable (or pass via --api-key). |
| Uses temperature=0.0 for near-deterministic behavior. |
| Usage: |
| GEMINI_API_KEY=... python baseline_inference.py |
| python baseline_inference.py --api-key YOUR_KEY |
| """ |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import json |
| import os |
| import sys |
| from pathlib import Path |
|
|
| |
| _env_path = Path(__file__).parent / ".env" |
| if _env_path.exists(): |
| for line in _env_path.read_text().splitlines(): |
| line = line.strip() |
| if line and not line.startswith("#") and "=" in line: |
| key, _, value = line.partition("=") |
| os.environ.setdefault(key.strip(), value.strip()) |
|
|
| try: |
| from openai import OpenAI |
| except ImportError: |
| print("Error: openai package not installed. Run: pip install openai") |
| sys.exit(1) |
|
|
| from ml_training_debugger.models import MLTrainingAction |
| from server.environment import MLTrainingEnvironment |
|
|
| ALL_TASKS = [ |
| "task_001", |
| "task_002", |
| "task_003", |
| "task_004", |
| "task_005", |
| "task_006", |
| "task_007", |
| ] |
|
|
| SYSTEM_PROMPT = """You are an expert ML engineer debugging a PyTorch training run. |
| You are interacting with an environment that simulates a broken training job. |
| |
| Available actions (respond with JSON only, no explanation): |
| - {"action_type": "inspect_gradients"} - View gradient statistics per layer |
| - {"action_type": "inspect_data_batch"} - View data batch statistics and confusion matrix |
| - {"action_type": "inspect_model_modes"} - View model layer modes (train/eval) |
| - {"action_type": "inspect_model_weights"} - View model weight statistics |
| - {"action_type": "inspect_code"} - View PyTorch training code |
| - {"action_type": "modify_config", "target": "<field>", "value": <val>} - Change a hyperparameter |
| - {"action_type": "add_callback"} - Add gradient clipping/scheduler |
| - {"action_type": "patch_data_loader"} - Fix data pipeline issues |
| - {"action_type": "fix_model_mode"} - Call model.train() |
| - {"action_type": "fix_code", "line": <int>, "replacement": "<code>"} - Fix a code line |
| - {"action_type": "restart_run"} - Restart training (requires a fix first) |
| - {"action_type": "mark_diagnosed", "diagnosis": "<cause>"} - Submit diagnosis |
| |
| Valid diagnoses: lr_too_high, vanishing_gradients, data_leakage, overfitting, batchnorm_eval_mode, code_bug, scheduler_misconfigured |
| |
| Strategy: |
| 1. First investigate by inspecting gradients, data, model modes, and code |
| 2. Form a hypothesis based on the evidence gathered |
| 3. Apply the correct fix for the identified root cause |
| 4. Restart training to verify the fix works |
| 5. Submit your diagnosis |
| |
| IMPORTANT: Respond with ONLY a valid JSON action object. No explanation, no markdown, no code blocks.""" |
|
|
|
|
| def run_llm_episode(task_id: str, client: OpenAI, model_name: str) -> float: |
| """Run one LLM agent episode.""" |
| env = MLTrainingEnvironment() |
| obs = env.reset(seed=42, episode_id=f"llm_{task_id}", task_id=task_id) |
|
|
| initial_obs = { |
| "training_loss_history": obs.training_loss_history[:5], |
| "val_accuracy_history": obs.val_accuracy_history[:5], |
| "current_config": obs.current_config.model_dump(), |
| "error_log": obs.error_log, |
| "available_actions": obs.available_actions, |
| "notes": obs.notes, |
| "gpu_memory_used_gb": obs.gpu_memory_used_gb, |
| } |
|
|
| messages = [ |
| {"role": "system", "content": SYSTEM_PROMPT}, |
| { |
| "role": "user", |
| "content": f"New episode started for a broken PyTorch training run.\n\nInitial observation:\n{json.dumps(initial_obs, indent=2, default=str)}", |
| }, |
| ] |
|
|
| for step in range(25): |
| if obs.done: |
| break |
|
|
| try: |
| response = client.chat.completions.create( |
| model=model_name, |
| messages=messages, |
| temperature=0.0, |
| max_tokens=300, |
| ) |
| action_text = response.choices[0].message.content.strip() |
| except Exception as e: |
| print(f" Step {step}: API error — {e}", file=sys.stderr) |
| break |
|
|
| |
| action_text = action_text.strip("`").strip() |
| if action_text.startswith("json"): |
| action_text = action_text[4:].strip() |
|
|
| messages.append({"role": "assistant", "content": action_text}) |
|
|
| try: |
| action_data = json.loads(action_text) |
| action = MLTrainingAction(**action_data) |
| except (json.JSONDecodeError, Exception) as e: |
| messages.append( |
| { |
| "role": "user", |
| "content": f"Invalid action format: {e}. Respond with ONLY valid JSON.", |
| } |
| ) |
| continue |
|
|
| obs = env.step(action) |
|
|
| obs_summary: dict = { |
| "reward": obs.reward, |
| "done": obs.done, |
| "step": obs.episode_state.step_count, |
| "available_actions": obs.available_actions, |
| } |
| if obs.error_log: |
| obs_summary["error_log"] = obs.error_log |
| if obs.gradient_stats: |
| obs_summary["gradient_stats"] = [ |
| { |
| "layer": g.layer_name, |
| "mean_norm": round(g.mean_norm, 4), |
| "exploding": g.is_exploding, |
| "vanishing": g.is_vanishing, |
| } |
| for g in obs.gradient_stats |
| ] |
| if obs.data_batch_stats: |
| obs_summary["data_overlap"] = obs.data_batch_stats.class_overlap_score |
| obs_summary["duplicate_ratio"] = obs.data_batch_stats.duplicate_ratio |
| if obs.model_mode_info: |
| obs_summary["model_modes"] = obs.model_mode_info |
| if obs.code_snippet: |
| obs_summary["code"] = obs.code_snippet.code[:600] |
| obs_summary["hint"] = obs.code_snippet.hint |
|
|
| messages.append( |
| { |
| "role": "user", |
| "content": f"Observation after your action:\n{json.dumps(obs_summary, indent=2, default=str)}", |
| } |
| ) |
|
|
| session = env._get_session() |
| return session.last_score if session and session.last_score is not None else 0.0 |
|
|
|
|
| PROVIDERS = { |
| "groq": { |
| "env_key": "GROQ_API_KEY", |
| "base_url": "https://api.groq.com/openai/v1", |
| "default_model": "llama-3.3-70b-versatile", |
| }, |
| "cerebras": { |
| "env_key": "CEREBRAS_API_KEY", |
| "base_url": "https://api.cerebras.ai/v1", |
| "default_model": "llama3.1-8b", |
| }, |
| "gemini": { |
| "env_key": "GEMINI_API_KEY", |
| "base_url": "https://generativelanguage.googleapis.com/v1beta/openai/", |
| "default_model": "gemini-2.0-flash", |
| }, |
| "openai": { |
| "env_key": "OPENAI_API_KEY", |
| "base_url": None, |
| "default_model": "gpt-4o", |
| }, |
| } |
|
|
|
|
| def main() -> None: |
| parser = argparse.ArgumentParser(description="LLM baseline agent") |
| parser.add_argument("--url", default="http://localhost:7860") |
| parser.add_argument("--api-key", default=None, help="API key") |
| parser.add_argument( |
| "--provider", |
| default="groq", |
| choices=list(PROVIDERS.keys()), |
| help="LLM provider (default: groq)", |
| ) |
| parser.add_argument("--model", default=None, help="Model name (auto-detected from provider)") |
| args = parser.parse_args() |
|
|
| prov = PROVIDERS[args.provider] |
| api_key = args.api_key or os.environ.get(prov["env_key"]) |
| if not api_key: |
| print(f"Error: Set {prov['env_key']} env var or pass --api-key") |
| sys.exit(1) |
|
|
| model_name = args.model or prov["default_model"] |
| client_kwargs: dict = {"api_key": api_key} |
| if prov["base_url"]: |
| client_kwargs["base_url"] = prov["base_url"] |
| client = OpenAI(**client_kwargs) |
|
|
| scores: dict[str, float] = {} |
| print(f"Running LLM baseline with {args.provider}/{model_name}...", file=sys.stderr) |
|
|
| for task_id in ALL_TASKS: |
| try: |
| score = run_llm_episode(task_id, client, model_name) |
| scores[task_id] = round(score, 4) |
| print(f" {task_id}: {score:.4f}", file=sys.stderr) |
| except Exception as e: |
| print(f" {task_id}: ERROR — {e}", file=sys.stderr) |
| scores[task_id] = 0.0 |
|
|
| print(json.dumps(scores, indent=2)) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|