Spaces:
Sleeping
Sleeping
| """ | |
| GRPO Training — Code Migration Environment | |
| =========================================== | |
| Trains a model with Group Relative Policy Optimization (GRPO) using TRL. | |
| The model learns to fix failing tests caused by Python dependency upgrades. | |
| Reward: | |
| +2.0 + efficiency bonus tests pass | |
| -2.0 hit max steps without passing | |
| Usage: | |
| python code_migration/train_grpo.py | |
| Environment variables: | |
| MODEL_NAME (default: google/gemma-4-E4B-it) | |
| DATASET_PATH (default: code_migration/data/train.jsonl) | |
| EVAL_DATASET_PATH (default: code_migration/data/eval.jsonl) | |
| OUTPUT_DIR (default: ./grpo_output) | |
| LOG_DIR (default: ./train_logs) | |
| MAX_STEPS_PER_TASK (default: 15) | |
| MAX_TEST_EXEC (default: 5) | |
| NUM_TRAIN_EPOCHS (default: 3) | |
| LORA_R (default: 16) | |
| LORA_ALPHA (default: 32) | |
| DIFFICULTY_FILTER (default: Easy) | |
| NUM_ROLLOUTS (default: 4) | |
| NUM_TASKS (default: 10) | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import logging | |
| import os | |
| import re | |
| import sys | |
| import time | |
| from datetime import datetime | |
| from pathlib import Path | |
| from typing import Any, Dict, List | |
| import torch | |
| from datasets import Dataset | |
| from peft import LoraConfig, get_peft_model | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig | |
| from trl import GRPOConfig, GRPOTrainer | |
| sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) | |
| from code_migration.models import CodeMigrationAction, _TOOL_REQUIRED_ARGS | |
| from code_migration.server.code_migration_environment import CodeMigrationEnvironment | |
| from code_migration.research_agent import ResearchAgent | |
| # --------------------------------------------------------------------------- | |
| # Config | |
| # --------------------------------------------------------------------------- | |
| MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen3.5-4B") | |
| DATASET_PATH = os.getenv("DATASET_PATH", None) | |
| EVAL_DATASET_PATH = os.getenv("EVAL_DATASET_PATH", None) | |
| OUTPUT_DIR = os.getenv("OUTPUT_DIR", "./grpo_output") | |
| LOG_DIR = os.getenv("LOG_DIR", "./train_logs") | |
| MAX_STEPS_PER_TASK = int(os.getenv("MAX_STEPS_PER_TASK", "15")) | |
| MAX_TEST_EXEC = int(os.getenv("MAX_TEST_EXEC", "3")) | |
| NUM_TRAIN_EPOCHS = int(os.getenv("NUM_TRAIN_EPOCHS", "3")) | |
| LORA_R = int(os.getenv("LORA_R", "16")) | |
| LORA_ALPHA = int(os.getenv("LORA_ALPHA", "32")) | |
| DIFFICULTY_FILTER = os.getenv("DIFFICULTY_FILTER", "all") | |
| NUM_ROLLOUTS = int(os.getenv("NUM_ROLLOUTS", "2")) | |
| NUM_TASKS = int(os.getenv("NUM_TASKS", "20")) | |
| PER_DEVICE_BATCH = int(os.getenv("PER_DEVICE_BATCH", "1")) | |
| GRAD_ACCUM = int(os.getenv("GRAD_ACCUM", "8")) | |
| MAX_COMPLETION_LENGTH = int(os.getenv("MAX_COMPLETION_LENGTH", "400")) | |
| TEMPERATURE = 0.7 # higher for exploration during training | |
| # --------------------------------------------------------------------------- | |
| # Logging | |
| # --------------------------------------------------------------------------- | |
| run_id = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| log_dir = Path(LOG_DIR) / run_id | |
| log_dir.mkdir(parents=True, exist_ok=True) | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format="%(asctime)s [%(levelname)s] %(message)s", | |
| datefmt="%H:%M:%S", | |
| handlers=[ | |
| logging.StreamHandler(sys.stdout), | |
| logging.FileHandler(log_dir / "train.log"), | |
| ], | |
| ) | |
| log = logging.getLogger("train") | |
| # --------------------------------------------------------------------------- | |
| # System prompt (same as inference) | |
| # --------------------------------------------------------------------------- | |
| SYSTEM_PROMPT = ( | |
| "You are an expert Python developer fixing failing tests after dependency upgrades.\n\n" | |
| "Available tools:\n" | |
| "- list_dir(dir_path?), search_dir(regex_pattern, dir_path?)\n" | |
| "- search_file(regex_pattern, file_path), view_file(file_path, line_no)\n" | |
| "- edit_file(file_path, start_line, end_line, replacement_text)\n" | |
| "- replace_all_in_file(file_path, regex_pattern, replacement_string)\n" | |
| "- revert_last(), execute_tests()\n" | |
| "- search_last_log(regex_pattern), view_last_log(line_no)\n\n" | |
| "Output EXACTLY ONE JSON tool call: {\"name\": \"...\", \"arguments\": {...}}\n" | |
| "Be decisive: view error → find code → edit → test. 4-8 steps.\n" | |
| "CRITICAL: Line numbers in test logs are TEST LOG line numbers, NOT source file line numbers.\n" | |
| "Always use search_file or view_file to find the ACTUAL line number before editing.\n" | |
| "Use replace_all_in_file when possible — it doesn't need line numbers and is safer.\n" | |
| ) | |
| # --------------------------------------------------------------------------- | |
| # Model family detection | |
| # --------------------------------------------------------------------------- | |
| def _detect_model_family(model_name: str) -> str: | |
| """Detect model family from model name string.""" | |
| name_lower = model_name.lower() | |
| if "gemma" in name_lower: | |
| return "gemma" | |
| if "qwen3" in name_lower: | |
| return "qwen3" | |
| if "qwen" in name_lower: | |
| return "qwen2" | |
| return "unknown" | |
| MODEL_FAMILY = _detect_model_family(MODEL_NAME) | |
| def _strip_model_artifacts(raw_text: str, family: str) -> str: | |
| """Strip model-specific artifacts from generated text.""" | |
| clean = raw_text | |
| if family == "gemma": | |
| clean = re.sub(r"<\|channel>thought\n.*?<channel\|>", "", clean, flags=re.DOTALL) | |
| for tok in ["<turn|>", "<|turn>", "<eos>", "</s>"]: | |
| clean = clean.replace(tok, "") | |
| elif family == "qwen3": | |
| clean = re.sub(r"<think>.*?</think>", "", clean, flags=re.DOTALL) | |
| im_end = clean.find("<|im_end|>") | |
| if im_end != -1: | |
| clean = clean[:im_end] | |
| for tok in ["<|im_end|>", "<|endoftext|>", "<|im_start|>"]: | |
| clean = clean.replace(tok, "") | |
| elif family == "qwen2": | |
| im_end = clean.find("<|im_end|>") | |
| if im_end != -1: | |
| clean = clean[:im_end] | |
| for tok in ["<|im_end|>", "<|endoftext|>", "<|im_start|>"]: | |
| clean = clean.replace(tok, "") | |
| else: | |
| for tok in ["<eos>", "</s>", "<|im_end|>", "<|endoftext|>", "<turn|>", "<|turn>"]: | |
| clean = clean.replace(tok, "") | |
| return clean.strip() | |
| # --------------------------------------------------------------------------- | |
| # Model loading with 4-bit + LoRA | |
| # --------------------------------------------------------------------------- | |
| def load_model_and_tokenizer(): | |
| """Load model with 4-bit NF4 quantization and LoRA for training. | |
| Supports Gemma 4 and Qwen 2.5 model families. | |
| """ | |
| family = _detect_model_family(MODEL_NAME) | |
| log.info("Loading %s (family=%s) with 4-bit NF4 + LoRA...", MODEL_NAME, family) | |
| t0 = time.time() | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) | |
| if tokenizer.pad_token is None: | |
| tokenizer.pad_token = tokenizer.eos_token | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.bfloat16, | |
| bnb_4bit_use_double_quant=True, | |
| ) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_NAME, | |
| quantization_config=bnb_config, | |
| device_map={"": 0}, | |
| trust_remote_code=True, | |
| ) | |
| # Gemma 4 specific: unwrap ClippableLinear | |
| if family == "gemma": | |
| replacements = [] | |
| for name, module in model.named_modules(): | |
| if type(module).__name__ == "Gemma4ClippableLinear": | |
| if hasattr(module, "linear"): | |
| replacements.append((name, module.linear)) | |
| for name, inner in replacements: | |
| parts = name.split(".") | |
| parent = model.get_submodule(".".join(parts[:-1])) if len(parts) > 1 else model | |
| setattr(parent, parts[-1], inner) | |
| if replacements: | |
| log.info("Unwrapped %d ClippableLinear modules", len(replacements)) | |
| # Apply LoRA — target modules differ by model family | |
| if family == "qwen": | |
| target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", | |
| "gate_proj", "up_proj", "down_proj"] | |
| else: | |
| # Gemma and default | |
| target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", | |
| "gate_proj", "up_proj", "down_proj"] | |
| lora_config = LoraConfig( | |
| r=LORA_R, | |
| lora_alpha=LORA_ALPHA, | |
| target_modules=target_modules, | |
| lora_dropout=0.05, | |
| bias="none", | |
| task_type="CAUSAL_LM", | |
| ) | |
| model = get_peft_model(model, lora_config) | |
| model.print_trainable_parameters() | |
| elapsed = time.time() - t0 | |
| mem_gb = torch.cuda.memory_allocated(0) / 1e9 if torch.cuda.is_available() else 0 | |
| log.info("Loaded in %.1fs | GPU: %.2f GB", elapsed, mem_gb) | |
| return model, tokenizer | |
| # --------------------------------------------------------------------------- | |
| # Tool call parsing (same as inference) | |
| # --------------------------------------------------------------------------- | |
| def _parse_tool_call(text: str) -> Dict[str, Any]: | |
| text = text.strip() | |
| if text.startswith("```"): | |
| lines = text.split("\n") | |
| lines = [l for l in lines if not l.strip().startswith("```")] | |
| text = "\n".join(lines).strip() | |
| # Find first { and match its closing } | |
| start = text.find("{") | |
| if start != -1: | |
| depth = 0 | |
| for i in range(start, len(text)): | |
| if text[i] == "{": | |
| depth += 1 | |
| elif text[i] == "}": | |
| depth -= 1 | |
| if depth == 0: | |
| try: | |
| data = json.loads(text[start:i + 1]) | |
| if "tool_name" in data: | |
| return {"tool_name": data["tool_name"], "tool_args": data.get("tool_args", {})} | |
| if "name" in data: | |
| return {"tool_name": data["name"], "tool_args": data.get("arguments", data.get("parameters", {}))} | |
| if "action" in data: | |
| a = data.pop("action") | |
| return {"tool_name": a, "tool_args": data} | |
| except json.JSONDecodeError: | |
| pass | |
| break | |
| return {"tool_name": "list_dir", "tool_args": {}} | |
| # --------------------------------------------------------------------------- | |
| # Run one episode — returns prompt, completion, reward + detailed log | |
| # --------------------------------------------------------------------------- | |
| def run_episode( | |
| model, tokenizer, env: CodeMigrationEnvironment, task_index: int, | |
| ) -> Dict[str, Any]: | |
| """Run one full episode. Returns dict with prompt, completion, reward, log.""" | |
| episode_log = {"task_index": task_index, "steps": []} | |
| obs = env.reset(task_index=task_index) | |
| repo_name = obs.metadata.get("repo_name", "unknown") | |
| difficulty = obs.metadata.get("difficulty", "unknown") | |
| episode_log["repo_name"] = repo_name | |
| episode_log["difficulty"] = difficulty | |
| if obs.done: | |
| log.info(" [%s] reset failed", repo_name) | |
| episode_log["success"] = False | |
| episode_log["reward"] = -2.0 | |
| return {"prompt": "", "completion": "", "reward": -2.0, | |
| "success": False, "repo_name": repo_name, "steps": 0, | |
| "episode_log": episode_log} | |
| # --- RESEARCH PHASE --- | |
| research = ResearchAgent(model, tokenizer, max_steps=12, model_name=MODEL_NAME) | |
| task_meta = env._current_task if hasattr(env, "_current_task") and env._current_task else None | |
| old_py = task_meta.reproduction_target_version if task_meta else "3.6" | |
| new_py = task_meta.migration_target_version if task_meta else "3.12" | |
| related_mods = task_meta.related_modules if task_meta else "builtin" | |
| research_context = research.research( | |
| repo_name=repo_name, | |
| old_python=old_py, | |
| new_python=new_py, | |
| related_modules=related_mods, | |
| test_output=obs.tool_output, | |
| ) | |
| episode_log["research_context"] = research_context | |
| episode_log["research_steps"] = getattr(research, "last_research_steps", []) | |
| log.info(" [%s] research done (%d chars, %d steps)", | |
| repo_name, len(research_context), len(episode_log["research_steps"])) | |
| # --- BUILD PROMPT with research + error logs --- | |
| system_with_research = ( | |
| SYSTEM_PROMPT | |
| + "\n\n=== MIGRATION RESEARCH (gathered by research agent) ===\n" | |
| + research_context | |
| + "\n=== END RESEARCH ===\n\n" | |
| "A research agent has already analyzed the error and found the relevant " | |
| "breaking changes above. Use this information to make the fix directly. " | |
| "Don't waste steps searching — act on the research.\n" | |
| ) | |
| # Build initial prompt: system + research + error logs | |
| initial_prompt = system_with_research + "\n\n" + obs.tool_output | |
| messages = [ | |
| {"role": "system", "content": system_with_research}, | |
| {"role": "user", "content": obs.tool_output}, | |
| ] | |
| all_completions = [] | |
| total_steps = 0 | |
| success = False | |
| for step_num in range(1, MAX_STEPS_PER_TASK + 1): | |
| if obs.done: | |
| break | |
| torch.cuda.empty_cache() | |
| # Generate | |
| t0 = time.time() | |
| try: | |
| text = tokenizer.apply_chat_template( | |
| messages, tokenize=False, add_generation_prompt=True, | |
| **({"enable_thinking": False} if MODEL_FAMILY == "qwen3" else {}), | |
| ) | |
| # Gemma 4: strip thinking trigger | |
| if MODEL_FAMILY == "gemma": | |
| text = text.replace("<|think|>", "") | |
| inputs = tokenizer(text, return_tensors="pt").to(model.device) | |
| input_len = inputs["input_ids"].shape[-1] | |
| with torch.no_grad(): | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=MAX_COMPLETION_LENGTH, | |
| temperature=TEMPERATURE, | |
| top_p=0.95, | |
| do_sample=True, | |
| pad_token_id=tokenizer.pad_token_id, | |
| ) | |
| raw_text = tokenizer.decode(outputs[0][input_len:], skip_special_tokens=False) | |
| del inputs, outputs | |
| torch.cuda.empty_cache() | |
| clean = _strip_model_artifacts(raw_text, MODEL_FAMILY) | |
| parsed = _parse_tool_call(clean) | |
| tool_name = parsed["tool_name"] | |
| tool_args = parsed["tool_args"] | |
| gen_time = time.time() - t0 | |
| except Exception as e: | |
| gen_time = time.time() - t0 | |
| log.warning(" Step %d gen failed: %s", step_num, e) | |
| tool_name, tool_args = "list_dir", {} | |
| raw_text, clean = str(e), "" | |
| all_completions.append(clean or raw_text) | |
| # Validate and execute | |
| if tool_name not in _TOOL_REQUIRED_ARGS: | |
| tool_name, tool_args = "list_dir", {} | |
| try: | |
| action = CodeMigrationAction(tool_name=tool_name, tool_args=tool_args) | |
| except Exception: | |
| action = CodeMigrationAction(tool_name="list_dir", tool_args={}) | |
| obs = env.step(action) | |
| total_steps = step_num | |
| if action.tool_name == "execute_tests" and obs.metadata.get("last_test_exit_code") == 0: | |
| success = True | |
| # Log step | |
| episode_log["steps"].append({ | |
| "step": step_num, "gen_time": round(gen_time, 2), | |
| "tool": action.tool_name, "args": action.tool_args, | |
| "raw_output": raw_text[:1000], "clean_output": clean[:500], | |
| "result": obs.tool_output[:1000], "reward": obs.reward, "done": obs.done, | |
| }) | |
| log.info(" [%s] step %d/%.1fs %s → %s", | |
| repo_name, step_num, gen_time, action.tool_name, | |
| "PASS!" if success else obs.tool_output[:80].replace("\n", " ")) | |
| # Update conversation | |
| messages.append({"role": "assistant", "content": clean or raw_text}) | |
| messages.append({"role": "user", "content": f"Tool result:\n{obs.tool_output}"}) | |
| if len(messages) > 22: | |
| messages = messages[:2] + messages[-20:] | |
| if obs.done: | |
| break | |
| # Episode reward: high positive for success, high negative for failure | |
| if success: | |
| reward = 5.0 + max(0, (MAX_STEPS_PER_TASK - total_steps) * 0.2) | |
| else: | |
| reward = -3.0 | |
| episode_log["success"] = success | |
| episode_log["total_steps"] = total_steps | |
| episode_log["reward"] = reward | |
| log.info(" [%s] %s steps=%d reward=%.2f", | |
| repo_name, "PASS" if success else "FAIL", total_steps, reward) | |
| return { | |
| "prompt": initial_prompt[:4000], | |
| "completion": "\n".join(all_completions), | |
| "reward": reward, | |
| "success": success, | |
| "repo_name": repo_name, | |
| "steps": total_steps, | |
| "episode_log": episode_log, | |
| } | |
| # --------------------------------------------------------------------------- | |
| # Collect rollouts | |
| # --------------------------------------------------------------------------- | |
| def collect_rollouts(model, tokenizer, env, num_tasks, num_rollouts): | |
| """Run episodes and build training dataset.""" | |
| log.info("Collecting rollouts: %d tasks × %d rollouts = %d episodes", | |
| num_tasks, num_rollouts, num_tasks * num_rollouts) | |
| prompts, completions, rewards = [], [], [] | |
| all_logs = [] | |
| total_success = 0 | |
| for task_idx in range(num_tasks): | |
| for rollout_idx in range(num_rollouts): | |
| log.info("─ Task %d/%d, Rollout %d/%d", | |
| task_idx + 1, num_tasks, rollout_idx + 1, num_rollouts) | |
| result = run_episode(model, tokenizer, env, task_idx) | |
| if result["prompt"]: | |
| prompts.append(result["prompt"]) | |
| completions.append(result["completion"]) | |
| rewards.append(result["reward"]) | |
| all_logs.append(result.get("episode_log", {})) | |
| if result["success"]: | |
| total_success += 1 | |
| total = num_tasks * num_rollouts | |
| log.info("Rollouts done: %d/%d succeeded (%.1f%%)", | |
| total_success, total, 100 * total_success / max(total, 1)) | |
| # Save rollout logs | |
| rollout_log_path = log_dir / "rollout_logs.json" | |
| with open(rollout_log_path, "w") as f: | |
| json.dump(all_logs, f, indent=2, default=str) | |
| log.info("Rollout logs saved: %s", rollout_log_path) | |
| return Dataset.from_dict({ | |
| "prompt": prompts, | |
| "completion": completions, | |
| "reward": rewards, | |
| }) | |
| # --------------------------------------------------------------------------- | |
| # Reward function for GRPOTrainer | |
| # --------------------------------------------------------------------------- | |
| def reward_from_env(completions, **kwargs): | |
| """Extract pre-computed rewards from kwargs.""" | |
| env_rewards = kwargs.get("env_reward", []) | |
| if env_rewards: | |
| return [float(r) for r in env_rewards] | |
| return [0.0] * len(completions) | |
| # --------------------------------------------------------------------------- | |
| # Main | |
| # --------------------------------------------------------------------------- | |
| def main(): | |
| log.info("=" * 60) | |
| log.info("GRPO Training — Code Migration") | |
| log.info(" model: %s", MODEL_NAME) | |
| log.info(" difficulty: %s", DIFFICULTY_FILTER) | |
| log.info(" tasks: %d", NUM_TASKS) | |
| log.info(" rollouts: %d per task", NUM_ROLLOUTS) | |
| log.info(" max_steps: %d per episode", MAX_STEPS_PER_TASK) | |
| log.info(" lora: r=%d alpha=%d", LORA_R, LORA_ALPHA) | |
| log.info(" output: %s", OUTPUT_DIR) | |
| log.info(" log_dir: %s", log_dir) | |
| log.info("=" * 60) | |
| # Load model | |
| model, tokenizer = load_model_and_tokenizer() | |
| # Create environment | |
| dataset_path = DATASET_PATH or os.path.join( | |
| os.path.dirname(__file__), "data", "train.jsonl" | |
| ) | |
| log.info("Environment dataset: %s", dataset_path) | |
| env = CodeMigrationEnvironment( | |
| dataset_path=dataset_path, | |
| max_steps=MAX_STEPS_PER_TASK, | |
| max_test_executions=MAX_TEST_EXEC, | |
| difficulty_filter=DIFFICULTY_FILTER if DIFFICULTY_FILTER != "all" else None, | |
| ) | |
| num_tasks = min(NUM_TASKS, len(env._loader)) | |
| log.info("Training on %d tasks", num_tasks) | |
| # Phase 1: Collect rollouts | |
| log.info("=" * 40) | |
| log.info("Phase 1: Collecting rollouts") | |
| log.info("=" * 40) | |
| rollout_dataset = collect_rollouts( | |
| model, tokenizer, env, num_tasks, NUM_ROLLOUTS, | |
| ) | |
| log.info("Dataset: %d episodes", len(rollout_dataset)) | |
| if rollout_dataset["reward"]: | |
| log.info("Rewards: mean=%.2f min=%.2f max=%.2f", | |
| sum(rollout_dataset["reward"]) / len(rollout_dataset), | |
| min(rollout_dataset["reward"]), | |
| max(rollout_dataset["reward"])) | |
| # Phase 2: GRPO Training | |
| log.info("=" * 40) | |
| log.info("Phase 2: GRPO Training") | |
| log.info("=" * 40) | |
| os.makedirs(OUTPUT_DIR, exist_ok=True) | |
| training_args = GRPOConfig( | |
| output_dir=OUTPUT_DIR, | |
| num_train_epochs=NUM_TRAIN_EPOCHS, | |
| per_device_train_batch_size=PER_DEVICE_BATCH, | |
| gradient_accumulation_steps=GRAD_ACCUM, | |
| learning_rate=1e-5, | |
| max_completion_length=MAX_COMPLETION_LENGTH, | |
| num_generations=NUM_ROLLOUTS, | |
| logging_steps=1, | |
| save_steps=50, | |
| save_total_limit=3, | |
| bf16=True, | |
| report_to="none", | |
| remove_unused_columns=False, | |
| ) | |
| trainer = GRPOTrainer( | |
| model=model, | |
| args=training_args, | |
| train_dataset=rollout_dataset, | |
| processing_class=tokenizer, | |
| reward_funcs=reward_from_env, | |
| ) | |
| log.info("Starting GRPO training...") | |
| trainer.train() | |
| # Save | |
| log.info("Saving model to %s", OUTPUT_DIR) | |
| trainer.save_model(OUTPUT_DIR) | |
| tokenizer.save_pretrained(OUTPUT_DIR) | |
| log.info("Model saved.") | |
| # Phase 3: Eval | |
| eval_path = EVAL_DATASET_PATH or os.path.join( | |
| os.path.dirname(__file__), "data", "eval.jsonl" | |
| ) | |
| if os.path.exists(eval_path): | |
| log.info("=" * 40) | |
| log.info("Phase 3: Evaluation") | |
| log.info("=" * 40) | |
| eval_env = CodeMigrationEnvironment( | |
| dataset_path=eval_path, | |
| max_steps=MAX_STEPS_PER_TASK, | |
| max_test_executions=MAX_TEST_EXEC, | |
| ) | |
| eval_tasks = min(len(eval_env._loader), 5) | |
| eval_results = [] | |
| successes = 0 | |
| for i in range(eval_tasks): | |
| result = run_episode(model, tokenizer, eval_env, i) | |
| eval_results.append(result) | |
| if result["success"]: | |
| successes += 1 | |
| log.info("Eval: %d/%d passed (%.1f%%)", | |
| successes, eval_tasks, 100 * successes / max(eval_tasks, 1)) | |
| # Save eval logs | |
| eval_logs = [r.get("episode_log", {}) for r in eval_results] | |
| eval_log_path = log_dir / "eval_logs.json" | |
| with open(eval_log_path, "w") as f: | |
| json.dump(eval_logs, f, indent=2, default=str) | |
| log.info("Eval logs saved: %s", eval_log_path) | |
| # Save training summary | |
| summary = { | |
| "run_id": run_id, | |
| "model": MODEL_NAME, | |
| "difficulty": DIFFICULTY_FILTER, | |
| "num_tasks": num_tasks, | |
| "num_rollouts": NUM_ROLLOUTS, | |
| "lora_r": LORA_R, | |
| "lora_alpha": LORA_ALPHA, | |
| "output_dir": OUTPUT_DIR, | |
| "dataset_size": len(rollout_dataset), | |
| "reward_mean": sum(rollout_dataset["reward"]) / max(len(rollout_dataset), 1), | |
| } | |
| with open(log_dir / "training_summary.json", "w") as f: | |
| json.dump(summary, f, indent=2) | |
| log.info("Training complete! Logs at %s", log_dir) | |
| if __name__ == "__main__": | |
| main() | |