migratron / code_migration /train_grpo.py
amrithanandini's picture
integrated backend and frontend
1b35d41
"""
GRPO Training — Code Migration Environment
===========================================
Trains a model with Group Relative Policy Optimization (GRPO) using TRL.
The model learns to fix failing tests caused by Python dependency upgrades.
Reward:
+2.0 + efficiency bonus tests pass
-2.0 hit max steps without passing
Usage:
python code_migration/train_grpo.py
Environment variables:
MODEL_NAME (default: google/gemma-4-E4B-it)
DATASET_PATH (default: code_migration/data/train.jsonl)
EVAL_DATASET_PATH (default: code_migration/data/eval.jsonl)
OUTPUT_DIR (default: ./grpo_output)
LOG_DIR (default: ./train_logs)
MAX_STEPS_PER_TASK (default: 15)
MAX_TEST_EXEC (default: 5)
NUM_TRAIN_EPOCHS (default: 3)
LORA_R (default: 16)
LORA_ALPHA (default: 32)
DIFFICULTY_FILTER (default: Easy)
NUM_ROLLOUTS (default: 4)
NUM_TASKS (default: 10)
"""
from __future__ import annotations
import json
import logging
import os
import re
import sys
import time
from datetime import datetime
from pathlib import Path
from typing import Any, Dict, List
import torch
from datasets import Dataset
from peft import LoraConfig, get_peft_model
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
from trl import GRPOConfig, GRPOTrainer
sys.path.insert(0, os.path.join(os.path.dirname(__file__), ".."))
from code_migration.models import CodeMigrationAction, _TOOL_REQUIRED_ARGS
from code_migration.server.code_migration_environment import CodeMigrationEnvironment
from code_migration.research_agent import ResearchAgent
# ---------------------------------------------------------------------------
# Config
# ---------------------------------------------------------------------------
MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen3.5-4B")
DATASET_PATH = os.getenv("DATASET_PATH", None)
EVAL_DATASET_PATH = os.getenv("EVAL_DATASET_PATH", None)
OUTPUT_DIR = os.getenv("OUTPUT_DIR", "./grpo_output")
LOG_DIR = os.getenv("LOG_DIR", "./train_logs")
MAX_STEPS_PER_TASK = int(os.getenv("MAX_STEPS_PER_TASK", "15"))
MAX_TEST_EXEC = int(os.getenv("MAX_TEST_EXEC", "3"))
NUM_TRAIN_EPOCHS = int(os.getenv("NUM_TRAIN_EPOCHS", "3"))
LORA_R = int(os.getenv("LORA_R", "16"))
LORA_ALPHA = int(os.getenv("LORA_ALPHA", "32"))
DIFFICULTY_FILTER = os.getenv("DIFFICULTY_FILTER", "all")
NUM_ROLLOUTS = int(os.getenv("NUM_ROLLOUTS", "2"))
NUM_TASKS = int(os.getenv("NUM_TASKS", "20"))
PER_DEVICE_BATCH = int(os.getenv("PER_DEVICE_BATCH", "1"))
GRAD_ACCUM = int(os.getenv("GRAD_ACCUM", "8"))
MAX_COMPLETION_LENGTH = int(os.getenv("MAX_COMPLETION_LENGTH", "400"))
TEMPERATURE = 0.7 # higher for exploration during training
# ---------------------------------------------------------------------------
# Logging
# ---------------------------------------------------------------------------
run_id = datetime.now().strftime("%Y%m%d_%H%M%S")
log_dir = Path(LOG_DIR) / run_id
log_dir.mkdir(parents=True, exist_ok=True)
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
datefmt="%H:%M:%S",
handlers=[
logging.StreamHandler(sys.stdout),
logging.FileHandler(log_dir / "train.log"),
],
)
log = logging.getLogger("train")
# ---------------------------------------------------------------------------
# System prompt (same as inference)
# ---------------------------------------------------------------------------
SYSTEM_PROMPT = (
"You are an expert Python developer fixing failing tests after dependency upgrades.\n\n"
"Available tools:\n"
"- list_dir(dir_path?), search_dir(regex_pattern, dir_path?)\n"
"- search_file(regex_pattern, file_path), view_file(file_path, line_no)\n"
"- edit_file(file_path, start_line, end_line, replacement_text)\n"
"- replace_all_in_file(file_path, regex_pattern, replacement_string)\n"
"- revert_last(), execute_tests()\n"
"- search_last_log(regex_pattern), view_last_log(line_no)\n\n"
"Output EXACTLY ONE JSON tool call: {\"name\": \"...\", \"arguments\": {...}}\n"
"Be decisive: view error → find code → edit → test. 4-8 steps.\n"
"CRITICAL: Line numbers in test logs are TEST LOG line numbers, NOT source file line numbers.\n"
"Always use search_file or view_file to find the ACTUAL line number before editing.\n"
"Use replace_all_in_file when possible — it doesn't need line numbers and is safer.\n"
)
# ---------------------------------------------------------------------------
# Model family detection
# ---------------------------------------------------------------------------
def _detect_model_family(model_name: str) -> str:
"""Detect model family from model name string."""
name_lower = model_name.lower()
if "gemma" in name_lower:
return "gemma"
if "qwen3" in name_lower:
return "qwen3"
if "qwen" in name_lower:
return "qwen2"
return "unknown"
MODEL_FAMILY = _detect_model_family(MODEL_NAME)
def _strip_model_artifacts(raw_text: str, family: str) -> str:
"""Strip model-specific artifacts from generated text."""
clean = raw_text
if family == "gemma":
clean = re.sub(r"<\|channel>thought\n.*?<channel\|>", "", clean, flags=re.DOTALL)
for tok in ["<turn|>", "<|turn>", "<eos>", "</s>"]:
clean = clean.replace(tok, "")
elif family == "qwen3":
clean = re.sub(r"<think>.*?</think>", "", clean, flags=re.DOTALL)
im_end = clean.find("<|im_end|>")
if im_end != -1:
clean = clean[:im_end]
for tok in ["<|im_end|>", "<|endoftext|>", "<|im_start|>"]:
clean = clean.replace(tok, "")
elif family == "qwen2":
im_end = clean.find("<|im_end|>")
if im_end != -1:
clean = clean[:im_end]
for tok in ["<|im_end|>", "<|endoftext|>", "<|im_start|>"]:
clean = clean.replace(tok, "")
else:
for tok in ["<eos>", "</s>", "<|im_end|>", "<|endoftext|>", "<turn|>", "<|turn>"]:
clean = clean.replace(tok, "")
return clean.strip()
# ---------------------------------------------------------------------------
# Model loading with 4-bit + LoRA
# ---------------------------------------------------------------------------
def load_model_and_tokenizer():
"""Load model with 4-bit NF4 quantization and LoRA for training.
Supports Gemma 4 and Qwen 2.5 model families.
"""
family = _detect_model_family(MODEL_NAME)
log.info("Loading %s (family=%s) with 4-bit NF4 + LoRA...", MODEL_NAME, family)
t0 = time.time()
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
quantization_config=bnb_config,
device_map={"": 0},
trust_remote_code=True,
)
# Gemma 4 specific: unwrap ClippableLinear
if family == "gemma":
replacements = []
for name, module in model.named_modules():
if type(module).__name__ == "Gemma4ClippableLinear":
if hasattr(module, "linear"):
replacements.append((name, module.linear))
for name, inner in replacements:
parts = name.split(".")
parent = model.get_submodule(".".join(parts[:-1])) if len(parts) > 1 else model
setattr(parent, parts[-1], inner)
if replacements:
log.info("Unwrapped %d ClippableLinear modules", len(replacements))
# Apply LoRA — target modules differ by model family
if family == "qwen":
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj"]
else:
# Gemma and default
target_modules = ["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj"]
lora_config = LoraConfig(
r=LORA_R,
lora_alpha=LORA_ALPHA,
target_modules=target_modules,
lora_dropout=0.05,
bias="none",
task_type="CAUSAL_LM",
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
elapsed = time.time() - t0
mem_gb = torch.cuda.memory_allocated(0) / 1e9 if torch.cuda.is_available() else 0
log.info("Loaded in %.1fs | GPU: %.2f GB", elapsed, mem_gb)
return model, tokenizer
# ---------------------------------------------------------------------------
# Tool call parsing (same as inference)
# ---------------------------------------------------------------------------
def _parse_tool_call(text: str) -> Dict[str, Any]:
text = text.strip()
if text.startswith("```"):
lines = text.split("\n")
lines = [l for l in lines if not l.strip().startswith("```")]
text = "\n".join(lines).strip()
# Find first { and match its closing }
start = text.find("{")
if start != -1:
depth = 0
for i in range(start, len(text)):
if text[i] == "{":
depth += 1
elif text[i] == "}":
depth -= 1
if depth == 0:
try:
data = json.loads(text[start:i + 1])
if "tool_name" in data:
return {"tool_name": data["tool_name"], "tool_args": data.get("tool_args", {})}
if "name" in data:
return {"tool_name": data["name"], "tool_args": data.get("arguments", data.get("parameters", {}))}
if "action" in data:
a = data.pop("action")
return {"tool_name": a, "tool_args": data}
except json.JSONDecodeError:
pass
break
return {"tool_name": "list_dir", "tool_args": {}}
# ---------------------------------------------------------------------------
# Run one episode — returns prompt, completion, reward + detailed log
# ---------------------------------------------------------------------------
def run_episode(
model, tokenizer, env: CodeMigrationEnvironment, task_index: int,
) -> Dict[str, Any]:
"""Run one full episode. Returns dict with prompt, completion, reward, log."""
episode_log = {"task_index": task_index, "steps": []}
obs = env.reset(task_index=task_index)
repo_name = obs.metadata.get("repo_name", "unknown")
difficulty = obs.metadata.get("difficulty", "unknown")
episode_log["repo_name"] = repo_name
episode_log["difficulty"] = difficulty
if obs.done:
log.info(" [%s] reset failed", repo_name)
episode_log["success"] = False
episode_log["reward"] = -2.0
return {"prompt": "", "completion": "", "reward": -2.0,
"success": False, "repo_name": repo_name, "steps": 0,
"episode_log": episode_log}
# --- RESEARCH PHASE ---
research = ResearchAgent(model, tokenizer, max_steps=12, model_name=MODEL_NAME)
task_meta = env._current_task if hasattr(env, "_current_task") and env._current_task else None
old_py = task_meta.reproduction_target_version if task_meta else "3.6"
new_py = task_meta.migration_target_version if task_meta else "3.12"
related_mods = task_meta.related_modules if task_meta else "builtin"
research_context = research.research(
repo_name=repo_name,
old_python=old_py,
new_python=new_py,
related_modules=related_mods,
test_output=obs.tool_output,
)
episode_log["research_context"] = research_context
episode_log["research_steps"] = getattr(research, "last_research_steps", [])
log.info(" [%s] research done (%d chars, %d steps)",
repo_name, len(research_context), len(episode_log["research_steps"]))
# --- BUILD PROMPT with research + error logs ---
system_with_research = (
SYSTEM_PROMPT
+ "\n\n=== MIGRATION RESEARCH (gathered by research agent) ===\n"
+ research_context
+ "\n=== END RESEARCH ===\n\n"
"A research agent has already analyzed the error and found the relevant "
"breaking changes above. Use this information to make the fix directly. "
"Don't waste steps searching — act on the research.\n"
)
# Build initial prompt: system + research + error logs
initial_prompt = system_with_research + "\n\n" + obs.tool_output
messages = [
{"role": "system", "content": system_with_research},
{"role": "user", "content": obs.tool_output},
]
all_completions = []
total_steps = 0
success = False
for step_num in range(1, MAX_STEPS_PER_TASK + 1):
if obs.done:
break
torch.cuda.empty_cache()
# Generate
t0 = time.time()
try:
text = tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=True,
**({"enable_thinking": False} if MODEL_FAMILY == "qwen3" else {}),
)
# Gemma 4: strip thinking trigger
if MODEL_FAMILY == "gemma":
text = text.replace("<|think|>", "")
inputs = tokenizer(text, return_tensors="pt").to(model.device)
input_len = inputs["input_ids"].shape[-1]
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=MAX_COMPLETION_LENGTH,
temperature=TEMPERATURE,
top_p=0.95,
do_sample=True,
pad_token_id=tokenizer.pad_token_id,
)
raw_text = tokenizer.decode(outputs[0][input_len:], skip_special_tokens=False)
del inputs, outputs
torch.cuda.empty_cache()
clean = _strip_model_artifacts(raw_text, MODEL_FAMILY)
parsed = _parse_tool_call(clean)
tool_name = parsed["tool_name"]
tool_args = parsed["tool_args"]
gen_time = time.time() - t0
except Exception as e:
gen_time = time.time() - t0
log.warning(" Step %d gen failed: %s", step_num, e)
tool_name, tool_args = "list_dir", {}
raw_text, clean = str(e), ""
all_completions.append(clean or raw_text)
# Validate and execute
if tool_name not in _TOOL_REQUIRED_ARGS:
tool_name, tool_args = "list_dir", {}
try:
action = CodeMigrationAction(tool_name=tool_name, tool_args=tool_args)
except Exception:
action = CodeMigrationAction(tool_name="list_dir", tool_args={})
obs = env.step(action)
total_steps = step_num
if action.tool_name == "execute_tests" and obs.metadata.get("last_test_exit_code") == 0:
success = True
# Log step
episode_log["steps"].append({
"step": step_num, "gen_time": round(gen_time, 2),
"tool": action.tool_name, "args": action.tool_args,
"raw_output": raw_text[:1000], "clean_output": clean[:500],
"result": obs.tool_output[:1000], "reward": obs.reward, "done": obs.done,
})
log.info(" [%s] step %d/%.1fs %s → %s",
repo_name, step_num, gen_time, action.tool_name,
"PASS!" if success else obs.tool_output[:80].replace("\n", " "))
# Update conversation
messages.append({"role": "assistant", "content": clean or raw_text})
messages.append({"role": "user", "content": f"Tool result:\n{obs.tool_output}"})
if len(messages) > 22:
messages = messages[:2] + messages[-20:]
if obs.done:
break
# Episode reward: high positive for success, high negative for failure
if success:
reward = 5.0 + max(0, (MAX_STEPS_PER_TASK - total_steps) * 0.2)
else:
reward = -3.0
episode_log["success"] = success
episode_log["total_steps"] = total_steps
episode_log["reward"] = reward
log.info(" [%s] %s steps=%d reward=%.2f",
repo_name, "PASS" if success else "FAIL", total_steps, reward)
return {
"prompt": initial_prompt[:4000],
"completion": "\n".join(all_completions),
"reward": reward,
"success": success,
"repo_name": repo_name,
"steps": total_steps,
"episode_log": episode_log,
}
# ---------------------------------------------------------------------------
# Collect rollouts
# ---------------------------------------------------------------------------
def collect_rollouts(model, tokenizer, env, num_tasks, num_rollouts):
"""Run episodes and build training dataset."""
log.info("Collecting rollouts: %d tasks × %d rollouts = %d episodes",
num_tasks, num_rollouts, num_tasks * num_rollouts)
prompts, completions, rewards = [], [], []
all_logs = []
total_success = 0
for task_idx in range(num_tasks):
for rollout_idx in range(num_rollouts):
log.info("─ Task %d/%d, Rollout %d/%d",
task_idx + 1, num_tasks, rollout_idx + 1, num_rollouts)
result = run_episode(model, tokenizer, env, task_idx)
if result["prompt"]:
prompts.append(result["prompt"])
completions.append(result["completion"])
rewards.append(result["reward"])
all_logs.append(result.get("episode_log", {}))
if result["success"]:
total_success += 1
total = num_tasks * num_rollouts
log.info("Rollouts done: %d/%d succeeded (%.1f%%)",
total_success, total, 100 * total_success / max(total, 1))
# Save rollout logs
rollout_log_path = log_dir / "rollout_logs.json"
with open(rollout_log_path, "w") as f:
json.dump(all_logs, f, indent=2, default=str)
log.info("Rollout logs saved: %s", rollout_log_path)
return Dataset.from_dict({
"prompt": prompts,
"completion": completions,
"reward": rewards,
})
# ---------------------------------------------------------------------------
# Reward function for GRPOTrainer
# ---------------------------------------------------------------------------
def reward_from_env(completions, **kwargs):
"""Extract pre-computed rewards from kwargs."""
env_rewards = kwargs.get("env_reward", [])
if env_rewards:
return [float(r) for r in env_rewards]
return [0.0] * len(completions)
# ---------------------------------------------------------------------------
# Main
# ---------------------------------------------------------------------------
def main():
log.info("=" * 60)
log.info("GRPO Training — Code Migration")
log.info(" model: %s", MODEL_NAME)
log.info(" difficulty: %s", DIFFICULTY_FILTER)
log.info(" tasks: %d", NUM_TASKS)
log.info(" rollouts: %d per task", NUM_ROLLOUTS)
log.info(" max_steps: %d per episode", MAX_STEPS_PER_TASK)
log.info(" lora: r=%d alpha=%d", LORA_R, LORA_ALPHA)
log.info(" output: %s", OUTPUT_DIR)
log.info(" log_dir: %s", log_dir)
log.info("=" * 60)
# Load model
model, tokenizer = load_model_and_tokenizer()
# Create environment
dataset_path = DATASET_PATH or os.path.join(
os.path.dirname(__file__), "data", "train.jsonl"
)
log.info("Environment dataset: %s", dataset_path)
env = CodeMigrationEnvironment(
dataset_path=dataset_path,
max_steps=MAX_STEPS_PER_TASK,
max_test_executions=MAX_TEST_EXEC,
difficulty_filter=DIFFICULTY_FILTER if DIFFICULTY_FILTER != "all" else None,
)
num_tasks = min(NUM_TASKS, len(env._loader))
log.info("Training on %d tasks", num_tasks)
# Phase 1: Collect rollouts
log.info("=" * 40)
log.info("Phase 1: Collecting rollouts")
log.info("=" * 40)
rollout_dataset = collect_rollouts(
model, tokenizer, env, num_tasks, NUM_ROLLOUTS,
)
log.info("Dataset: %d episodes", len(rollout_dataset))
if rollout_dataset["reward"]:
log.info("Rewards: mean=%.2f min=%.2f max=%.2f",
sum(rollout_dataset["reward"]) / len(rollout_dataset),
min(rollout_dataset["reward"]),
max(rollout_dataset["reward"]))
# Phase 2: GRPO Training
log.info("=" * 40)
log.info("Phase 2: GRPO Training")
log.info("=" * 40)
os.makedirs(OUTPUT_DIR, exist_ok=True)
training_args = GRPOConfig(
output_dir=OUTPUT_DIR,
num_train_epochs=NUM_TRAIN_EPOCHS,
per_device_train_batch_size=PER_DEVICE_BATCH,
gradient_accumulation_steps=GRAD_ACCUM,
learning_rate=1e-5,
max_completion_length=MAX_COMPLETION_LENGTH,
num_generations=NUM_ROLLOUTS,
logging_steps=1,
save_steps=50,
save_total_limit=3,
bf16=True,
report_to="none",
remove_unused_columns=False,
)
trainer = GRPOTrainer(
model=model,
args=training_args,
train_dataset=rollout_dataset,
processing_class=tokenizer,
reward_funcs=reward_from_env,
)
log.info("Starting GRPO training...")
trainer.train()
# Save
log.info("Saving model to %s", OUTPUT_DIR)
trainer.save_model(OUTPUT_DIR)
tokenizer.save_pretrained(OUTPUT_DIR)
log.info("Model saved.")
# Phase 3: Eval
eval_path = EVAL_DATASET_PATH or os.path.join(
os.path.dirname(__file__), "data", "eval.jsonl"
)
if os.path.exists(eval_path):
log.info("=" * 40)
log.info("Phase 3: Evaluation")
log.info("=" * 40)
eval_env = CodeMigrationEnvironment(
dataset_path=eval_path,
max_steps=MAX_STEPS_PER_TASK,
max_test_executions=MAX_TEST_EXEC,
)
eval_tasks = min(len(eval_env._loader), 5)
eval_results = []
successes = 0
for i in range(eval_tasks):
result = run_episode(model, tokenizer, eval_env, i)
eval_results.append(result)
if result["success"]:
successes += 1
log.info("Eval: %d/%d passed (%.1f%%)",
successes, eval_tasks, 100 * successes / max(eval_tasks, 1))
# Save eval logs
eval_logs = [r.get("episode_log", {}) for r in eval_results]
eval_log_path = log_dir / "eval_logs.json"
with open(eval_log_path, "w") as f:
json.dump(eval_logs, f, indent=2, default=str)
log.info("Eval logs saved: %s", eval_log_path)
# Save training summary
summary = {
"run_id": run_id,
"model": MODEL_NAME,
"difficulty": DIFFICULTY_FILTER,
"num_tasks": num_tasks,
"num_rollouts": NUM_ROLLOUTS,
"lora_r": LORA_R,
"lora_alpha": LORA_ALPHA,
"output_dir": OUTPUT_DIR,
"dataset_size": len(rollout_dataset),
"reward_mean": sum(rollout_dataset["reward"]) / max(len(rollout_dataset), 1),
}
with open(log_dir / "training_summary.json", "w") as f:
json.dump(summary, f, indent=2)
log.info("Training complete! Logs at %s", log_dir)
if __name__ == "__main__":
main()