""" GRPO Training — Code Migration Environment =========================================== Trains a model with Group Relative Policy Optimization (GRPO) using TRL. The model learns to fix failing tests caused by Python dependency upgrades. Reward: +2.0 + efficiency bonus tests pass -2.0 hit max steps without passing Usage: python code_migration/train_grpo.py Environment variables: MODEL_NAME (default: google/gemma-4-E4B-it) DATASET_PATH (default: code_migration/data/train.jsonl) EVAL_DATASET_PATH (default: code_migration/data/eval.jsonl) OUTPUT_DIR (default: ./grpo_output) LOG_DIR (default: ./train_logs) MAX_STEPS_PER_TASK (default: 15) MAX_TEST_EXEC (default: 5) NUM_TRAIN_EPOCHS (default: 3) LORA_R (default: 16) LORA_ALPHA (default: 32) DIFFICULTY_FILTER (default: Easy) NUM_ROLLOUTS (default: 4) NUM_TASKS (default: 10) """ from __future__ import annotations import json import logging import os import re import sys import time from datetime import datetime from pathlib import Path from typing import Any, Dict, List import torch from datasets import Dataset from peft import LoraConfig, get_peft_model from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig from trl import GRPOConfig, GRPOTrainer sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..")) from code_migration.models import CodeMigrationAction, _TOOL_REQUIRED_ARGS from code_migration.server.code_migration_environment import CodeMigrationEnvironment from code_migration.research_agent import ResearchAgent # --------------------------------------------------------------------------- # Config # --------------------------------------------------------------------------- MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen3.5-4B") DATASET_PATH = os.getenv("DATASET_PATH", None) EVAL_DATASET_PATH = os.getenv("EVAL_DATASET_PATH", None) OUTPUT_DIR = os.getenv("OUTPUT_DIR", "./grpo_output") LOG_DIR = os.getenv("LOG_DIR", "./train_logs") MAX_STEPS_PER_TASK = int(os.getenv("MAX_STEPS_PER_TASK", "15")) MAX_TEST_EXEC = int(os.getenv("MAX_TEST_EXEC", "3")) NUM_TRAIN_EPOCHS = int(os.getenv("NUM_TRAIN_EPOCHS", "3")) LORA_R = int(os.getenv("LORA_R", "16")) LORA_ALPHA = int(os.getenv("LORA_ALPHA", "32")) DIFFICULTY_FILTER = os.getenv("DIFFICULTY_FILTER", "all") NUM_ROLLOUTS = int(os.getenv("NUM_ROLLOUTS", "2")) NUM_TASKS = int(os.getenv("NUM_TASKS", "20")) PER_DEVICE_BATCH = int(os.getenv("PER_DEVICE_BATCH", "1")) GRAD_ACCUM = int(os.getenv("GRAD_ACCUM", "8")) MAX_COMPLETION_LENGTH = int(os.getenv("MAX_COMPLETION_LENGTH", "400")) TEMPERATURE = 0.7 # higher for exploration during training # --------------------------------------------------------------------------- # Logging # --------------------------------------------------------------------------- run_id = datetime.now().strftime("%Y%m%d_%H%M%S") log_dir = Path(LOG_DIR) / run_id log_dir.mkdir(parents=True, exist_ok=True) logging.basicConfig( level=logging.INFO, format="%(asctime)s [%(levelname)s] %(message)s", datefmt="%H:%M:%S", handlers=[ logging.StreamHandler(sys.stdout), logging.FileHandler(log_dir / "train.log"), ], ) log = logging.getLogger("train") # --------------------------------------------------------------------------- # System prompt (same as inference) # --------------------------------------------------------------------------- SYSTEM_PROMPT = ( "You are an expert Python developer fixing failing tests after dependency upgrades.\n\n" "Available tools:\n" "- list_dir(dir_path?), search_dir(regex_pattern, dir_path?)\n" "- search_file(regex_pattern, file_path), view_file(file_path, line_no)\n" "- edit_file(file_path, start_line, end_line, replacement_text)\n" "- replace_all_in_file(file_path, regex_pattern, replacement_string)\n" "- revert_last(), execute_tests()\n" "- search_last_log(regex_pattern), view_last_log(line_no)\n\n" "Output EXACTLY ONE JSON tool call: {\"name\": \"...\", \"arguments\": {...}}\n" "Be decisive: view error → find code → edit → test. 4-8 steps.\n" "CRITICAL: Line numbers in test logs are TEST LOG line numbers, NOT source file line numbers.\n" "Always use search_file or view_file to find the ACTUAL line number before editing.\n" "Use replace_all_in_file when possible — it doesn't need line numbers and is safer.\n" ) # --------------------------------------------------------------------------- # Model family detection # --------------------------------------------------------------------------- def _detect_model_family(model_name: str) -> str: """Detect model family from model name string.""" name_lower = model_name.lower() if "gemma" in name_lower: return "gemma" if "qwen3" in name_lower: return "qwen3" if "qwen" in name_lower: return "qwen2" return "unknown" MODEL_FAMILY = _detect_model_family(MODEL_NAME) def _strip_model_artifacts(raw_text: str, family: str) -> str: """Strip model-specific artifacts from generated text.""" clean = raw_text if family == "gemma": clean = re.sub(r"<\|channel>thought\n.*?", "", clean, flags=re.DOTALL) for tok in ["", "<|turn>", "", ""]: clean = clean.replace(tok, "") elif family == "qwen3": clean = re.sub(r".*?", "", clean, flags=re.DOTALL) im_end = clean.find("<|im_end|>") if im_end != -1: clean = clean[:im_end] for tok in ["<|im_end|>", "<|endoftext|>", "<|im_start|>"]: clean = clean.replace(tok, "") elif family == "qwen2": im_end = clean.find("<|im_end|>") if im_end != -1: clean = clean[:im_end] for tok in ["<|im_end|>", "<|endoftext|>", "<|im_start|>"]: clean = clean.replace(tok, "") else: for tok in ["", "", "<|im_end|>", "<|endoftext|>", "", "<|turn>"]: clean = clean.replace(tok, "") return clean.strip() # --------------------------------------------------------------------------- # Model loading with 4-bit + LoRA # --------------------------------------------------------------------------- def load_model_and_tokenizer(): """Load model with 4-bit NF4 quantization and LoRA for training. Supports Gemma 4 and Qwen 2.5 model families. """ family = _detect_model_family(MODEL_NAME) log.info("Loading %s (family=%s) with 4-bit NF4 + LoRA...", MODEL_NAME, family) t0 = time.time() tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=torch.bfloat16, bnb_4bit_use_double_quant=True, ) model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, quantization_config=bnb_config, device_map={"": 0}, trust_remote_code=True, ) # Gemma 4 specific: unwrap ClippableLinear if family == "gemma": replacements = [] for name, module in model.named_modules(): if type(module).__name__ == "Gemma4ClippableLinear": if hasattr(module, "linear"): replacements.append((name, module.linear)) for name, inner in replacements: parts = name.split(".") parent = model.get_submodule(".".join(parts[:-1])) if len(parts) > 1 else model setattr(parent, parts[-1], inner) if replacements: log.info("Unwrapped %d ClippableLinear modules", len(replacements)) # Apply LoRA — target modules differ by model family if family == "qwen": target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"] else: # Gemma and default target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"] lora_config = LoraConfig( r=LORA_R, lora_alpha=LORA_ALPHA, target_modules=target_modules, lora_dropout=0.05, bias="none", task_type="CAUSAL_LM", ) model = get_peft_model(model, lora_config) model.print_trainable_parameters() elapsed = time.time() - t0 mem_gb = torch.cuda.memory_allocated(0) / 1e9 if torch.cuda.is_available() else 0 log.info("Loaded in %.1fs | GPU: %.2f GB", elapsed, mem_gb) return model, tokenizer # --------------------------------------------------------------------------- # Tool call parsing (same as inference) # --------------------------------------------------------------------------- def _parse_tool_call(text: str) -> Dict[str, Any]: text = text.strip() if text.startswith("```"): lines = text.split("\n") lines = [l for l in lines if not l.strip().startswith("```")] text = "\n".join(lines).strip() # Find first { and match its closing } start = text.find("{") if start != -1: depth = 0 for i in range(start, len(text)): if text[i] == "{": depth += 1 elif text[i] == "}": depth -= 1 if depth == 0: try: data = json.loads(text[start:i + 1]) if "tool_name" in data: return {"tool_name": data["tool_name"], "tool_args": data.get("tool_args", {})} if "name" in data: return {"tool_name": data["name"], "tool_args": data.get("arguments", data.get("parameters", {}))} if "action" in data: a = data.pop("action") return {"tool_name": a, "tool_args": data} except json.JSONDecodeError: pass break return {"tool_name": "list_dir", "tool_args": {}} # --------------------------------------------------------------------------- # Run one episode — returns prompt, completion, reward + detailed log # --------------------------------------------------------------------------- def run_episode( model, tokenizer, env: CodeMigrationEnvironment, task_index: int, ) -> Dict[str, Any]: """Run one full episode. Returns dict with prompt, completion, reward, log.""" episode_log = {"task_index": task_index, "steps": []} obs = env.reset(task_index=task_index) repo_name = obs.metadata.get("repo_name", "unknown") difficulty = obs.metadata.get("difficulty", "unknown") episode_log["repo_name"] = repo_name episode_log["difficulty"] = difficulty if obs.done: log.info(" [%s] reset failed", repo_name) episode_log["success"] = False episode_log["reward"] = -2.0 return {"prompt": "", "completion": "", "reward": -2.0, "success": False, "repo_name": repo_name, "steps": 0, "episode_log": episode_log} # --- RESEARCH PHASE --- research = ResearchAgent(model, tokenizer, max_steps=12, model_name=MODEL_NAME) task_meta = env._current_task if hasattr(env, "_current_task") and env._current_task else None old_py = task_meta.reproduction_target_version if task_meta else "3.6" new_py = task_meta.migration_target_version if task_meta else "3.12" related_mods = task_meta.related_modules if task_meta else "builtin" research_context = research.research( repo_name=repo_name, old_python=old_py, new_python=new_py, related_modules=related_mods, test_output=obs.tool_output, ) episode_log["research_context"] = research_context episode_log["research_steps"] = getattr(research, "last_research_steps", []) log.info(" [%s] research done (%d chars, %d steps)", repo_name, len(research_context), len(episode_log["research_steps"])) # --- BUILD PROMPT with research + error logs --- system_with_research = ( SYSTEM_PROMPT + "\n\n=== MIGRATION RESEARCH (gathered by research agent) ===\n" + research_context + "\n=== END RESEARCH ===\n\n" "A research agent has already analyzed the error and found the relevant " "breaking changes above. Use this information to make the fix directly. " "Don't waste steps searching — act on the research.\n" ) # Build initial prompt: system + research + error logs initial_prompt = system_with_research + "\n\n" + obs.tool_output messages = [ {"role": "system", "content": system_with_research}, {"role": "user", "content": obs.tool_output}, ] all_completions = [] total_steps = 0 success = False for step_num in range(1, MAX_STEPS_PER_TASK + 1): if obs.done: break torch.cuda.empty_cache() # Generate t0 = time.time() try: text = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True, **({"enable_thinking": False} if MODEL_FAMILY == "qwen3" else {}), ) # Gemma 4: strip thinking trigger if MODEL_FAMILY == "gemma": text = text.replace("<|think|>", "") inputs = tokenizer(text, return_tensors="pt").to(model.device) input_len = inputs["input_ids"].shape[-1] with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=MAX_COMPLETION_LENGTH, temperature=TEMPERATURE, top_p=0.95, do_sample=True, pad_token_id=tokenizer.pad_token_id, ) raw_text = tokenizer.decode(outputs[0][input_len:], skip_special_tokens=False) del inputs, outputs torch.cuda.empty_cache() clean = _strip_model_artifacts(raw_text, MODEL_FAMILY) parsed = _parse_tool_call(clean) tool_name = parsed["tool_name"] tool_args = parsed["tool_args"] gen_time = time.time() - t0 except Exception as e: gen_time = time.time() - t0 log.warning(" Step %d gen failed: %s", step_num, e) tool_name, tool_args = "list_dir", {} raw_text, clean = str(e), "" all_completions.append(clean or raw_text) # Validate and execute if tool_name not in _TOOL_REQUIRED_ARGS: tool_name, tool_args = "list_dir", {} try: action = CodeMigrationAction(tool_name=tool_name, tool_args=tool_args) except Exception: action = CodeMigrationAction(tool_name="list_dir", tool_args={}) obs = env.step(action) total_steps = step_num if action.tool_name == "execute_tests" and obs.metadata.get("last_test_exit_code") == 0: success = True # Log step episode_log["steps"].append({ "step": step_num, "gen_time": round(gen_time, 2), "tool": action.tool_name, "args": action.tool_args, "raw_output": raw_text[:1000], "clean_output": clean[:500], "result": obs.tool_output[:1000], "reward": obs.reward, "done": obs.done, }) log.info(" [%s] step %d/%.1fs %s → %s", repo_name, step_num, gen_time, action.tool_name, "PASS!" if success else obs.tool_output[:80].replace("\n", " ")) # Update conversation messages.append({"role": "assistant", "content": clean or raw_text}) messages.append({"role": "user", "content": f"Tool result:\n{obs.tool_output}"}) if len(messages) > 22: messages = messages[:2] + messages[-20:] if obs.done: break # Episode reward: high positive for success, high negative for failure if success: reward = 5.0 + max(0, (MAX_STEPS_PER_TASK - total_steps) * 0.2) else: reward = -3.0 episode_log["success"] = success episode_log["total_steps"] = total_steps episode_log["reward"] = reward log.info(" [%s] %s steps=%d reward=%.2f", repo_name, "PASS" if success else "FAIL", total_steps, reward) return { "prompt": initial_prompt[:4000], "completion": "\n".join(all_completions), "reward": reward, "success": success, "repo_name": repo_name, "steps": total_steps, "episode_log": episode_log, } # --------------------------------------------------------------------------- # Collect rollouts # --------------------------------------------------------------------------- def collect_rollouts(model, tokenizer, env, num_tasks, num_rollouts): """Run episodes and build training dataset.""" log.info("Collecting rollouts: %d tasks × %d rollouts = %d episodes", num_tasks, num_rollouts, num_tasks * num_rollouts) prompts, completions, rewards = [], [], [] all_logs = [] total_success = 0 for task_idx in range(num_tasks): for rollout_idx in range(num_rollouts): log.info("─ Task %d/%d, Rollout %d/%d", task_idx + 1, num_tasks, rollout_idx + 1, num_rollouts) result = run_episode(model, tokenizer, env, task_idx) if result["prompt"]: prompts.append(result["prompt"]) completions.append(result["completion"]) rewards.append(result["reward"]) all_logs.append(result.get("episode_log", {})) if result["success"]: total_success += 1 total = num_tasks * num_rollouts log.info("Rollouts done: %d/%d succeeded (%.1f%%)", total_success, total, 100 * total_success / max(total, 1)) # Save rollout logs rollout_log_path = log_dir / "rollout_logs.json" with open(rollout_log_path, "w") as f: json.dump(all_logs, f, indent=2, default=str) log.info("Rollout logs saved: %s", rollout_log_path) return Dataset.from_dict({ "prompt": prompts, "completion": completions, "reward": rewards, }) # --------------------------------------------------------------------------- # Reward function for GRPOTrainer # --------------------------------------------------------------------------- def reward_from_env(completions, **kwargs): """Extract pre-computed rewards from kwargs.""" env_rewards = kwargs.get("env_reward", []) if env_rewards: return [float(r) for r in env_rewards] return [0.0] * len(completions) # --------------------------------------------------------------------------- # Main # --------------------------------------------------------------------------- def main(): log.info("=" * 60) log.info("GRPO Training — Code Migration") log.info(" model: %s", MODEL_NAME) log.info(" difficulty: %s", DIFFICULTY_FILTER) log.info(" tasks: %d", NUM_TASKS) log.info(" rollouts: %d per task", NUM_ROLLOUTS) log.info(" max_steps: %d per episode", MAX_STEPS_PER_TASK) log.info(" lora: r=%d alpha=%d", LORA_R, LORA_ALPHA) log.info(" output: %s", OUTPUT_DIR) log.info(" log_dir: %s", log_dir) log.info("=" * 60) # Load model model, tokenizer = load_model_and_tokenizer() # Create environment dataset_path = DATASET_PATH or os.path.join( os.path.dirname(__file__), "data", "train.jsonl" ) log.info("Environment dataset: %s", dataset_path) env = CodeMigrationEnvironment( dataset_path=dataset_path, max_steps=MAX_STEPS_PER_TASK, max_test_executions=MAX_TEST_EXEC, difficulty_filter=DIFFICULTY_FILTER if DIFFICULTY_FILTER != "all" else None, ) num_tasks = min(NUM_TASKS, len(env._loader)) log.info("Training on %d tasks", num_tasks) # Phase 1: Collect rollouts log.info("=" * 40) log.info("Phase 1: Collecting rollouts") log.info("=" * 40) rollout_dataset = collect_rollouts( model, tokenizer, env, num_tasks, NUM_ROLLOUTS, ) log.info("Dataset: %d episodes", len(rollout_dataset)) if rollout_dataset["reward"]: log.info("Rewards: mean=%.2f min=%.2f max=%.2f", sum(rollout_dataset["reward"]) / len(rollout_dataset), min(rollout_dataset["reward"]), max(rollout_dataset["reward"])) # Phase 2: GRPO Training log.info("=" * 40) log.info("Phase 2: GRPO Training") log.info("=" * 40) os.makedirs(OUTPUT_DIR, exist_ok=True) training_args = GRPOConfig( output_dir=OUTPUT_DIR, num_train_epochs=NUM_TRAIN_EPOCHS, per_device_train_batch_size=PER_DEVICE_BATCH, gradient_accumulation_steps=GRAD_ACCUM, learning_rate=1e-5, max_completion_length=MAX_COMPLETION_LENGTH, num_generations=NUM_ROLLOUTS, logging_steps=1, save_steps=50, save_total_limit=3, bf16=True, report_to="none", remove_unused_columns=False, ) trainer = GRPOTrainer( model=model, args=training_args, train_dataset=rollout_dataset, processing_class=tokenizer, reward_funcs=reward_from_env, ) log.info("Starting GRPO training...") trainer.train() # Save log.info("Saving model to %s", OUTPUT_DIR) trainer.save_model(OUTPUT_DIR) tokenizer.save_pretrained(OUTPUT_DIR) log.info("Model saved.") # Phase 3: Eval eval_path = EVAL_DATASET_PATH or os.path.join( os.path.dirname(__file__), "data", "eval.jsonl" ) if os.path.exists(eval_path): log.info("=" * 40) log.info("Phase 3: Evaluation") log.info("=" * 40) eval_env = CodeMigrationEnvironment( dataset_path=eval_path, max_steps=MAX_STEPS_PER_TASK, max_test_executions=MAX_TEST_EXEC, ) eval_tasks = min(len(eval_env._loader), 5) eval_results = [] successes = 0 for i in range(eval_tasks): result = run_episode(model, tokenizer, eval_env, i) eval_results.append(result) if result["success"]: successes += 1 log.info("Eval: %d/%d passed (%.1f%%)", successes, eval_tasks, 100 * successes / max(eval_tasks, 1)) # Save eval logs eval_logs = [r.get("episode_log", {}) for r in eval_results] eval_log_path = log_dir / "eval_logs.json" with open(eval_log_path, "w") as f: json.dump(eval_logs, f, indent=2, default=str) log.info("Eval logs saved: %s", eval_log_path) # Save training summary summary = { "run_id": run_id, "model": MODEL_NAME, "difficulty": DIFFICULTY_FILTER, "num_tasks": num_tasks, "num_rollouts": NUM_ROLLOUTS, "lora_r": LORA_R, "lora_alpha": LORA_ALPHA, "output_dir": OUTPUT_DIR, "dataset_size": len(rollout_dataset), "reward_mean": sum(rollout_dataset["reward"]) / max(len(rollout_dataset), 1), } with open(log_dir / "training_summary.json", "w") as f: json.dump(summary, f, indent=2) log.info("Training complete! Logs at %s", log_dir) if __name__ == "__main__": main()