# training.py – Multi‑step DPO training with trajectory serialisation import json import torch from datasets import Dataset from dataclasses import dataclass from typing import Optional, List from unsloth import FastLanguageModel from trl import DPOTrainer from transformers import TrainingArguments # Import your environment and actions from environment import CodeReviewEnv from models import ( RunTests, RunLinter, Inspect, ProposeFix, WriteComment, AskQuestion, Done, Skip ) # ---------------------------------------------------------------------- # 1. Structured action parsing (with retry) # ---------------------------------------------------------------------- @dataclass class AgentAction: action_type: str content: Optional[str] = None def parse_action(output: str) -> AgentAction: try: data = json.loads(output) return AgentAction( action_type=data.get("action_type", "").lower(), content=data.get("content") ) except: return AgentAction("invalid", output) def safe_generate(prompt: str, model, tokenizer, max_retries=2) -> str: for _ in range(max_retries): raw = generate_action(prompt, model, tokenizer) try: json.loads(raw) return raw except: continue return '{"action_type":"skip"}' def map_to_env(action: AgentAction): if action.action_type == "run_tests": return RunTests() elif action.action_type == "run_linter": return RunLinter() elif action.action_type == "inspect": return Inspect() elif action.action_type == "fix": return ProposeFix(fix_code=action.content or "") elif action.action_type == "comment": return WriteComment(comment_text=action.content or "") elif action.action_type == "question": return AskQuestion(question=action.content or "") elif action.action_type == "done": return Done() else: return Skip() # ---------------------------------------------------------------------- # 2. Model loading # ---------------------------------------------------------------------- model, tokenizer = FastLanguageModel.from_pretrained( model_name="unsloth/gemma-2-2b-it-bnb-4bit", max_seq_length=2048, load_in_4bit=True, ) model = FastLanguageModel.get_peft_model( model, r=64, target_modules=[ "q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj" ], lora_alpha=64, lora_dropout=0, ) # ---------------------------------------------------------------------- # 3. Generation helper # ---------------------------------------------------------------------- def generate_action(prompt: str, model, tokenizer) -> str: formatted = f"user\n{prompt}\nmodel\n" inputs = tokenizer(formatted, return_tensors="pt").to("cuda") outputs = model.generate( **inputs, max_new_tokens=128, do_sample=True, temperature=0.8, ) return tokenizer.decode( outputs[0][inputs['input_ids'].shape[1]:], skip_special_tokens=True ).strip() # ---------------------------------------------------------------------- # 4. Prompt builder (initial state + sliding history) # ---------------------------------------------------------------------- def build_prompt(obs, history_lines: List[str]) -> str: prompt = f""" You are a code review agent. Code: {obs.code_snippet} Last Output: {obs.last_tool_output} Available actions: run_tests, run_linter, inspect, fix, comment, question, done Respond ONLY in JSON: {{"action_type": "...", "content": "..."}} """ if history_lines: history = "\n".join(history_lines[-6:]) # last 6 lines (3 exchanges) prompt += f"\n\nPrevious steps:\n{history}" return prompt # ---------------------------------------------------------------------- # 5. Multi‑step rollout # ---------------------------------------------------------------------- def rollout_episode(env, max_steps=8): obs = env.reset() history_lines = [] trajectory = [] for step in range(max_steps): prompt = build_prompt(obs, history_lines) raw = safe_generate(prompt, model, tokenizer) action = parse_action(raw) env_action = map_to_env(action) next_obs, reward, done, _ = env.step(env_action) trajectory.append({ "state": prompt, "action": raw, "reward": reward.value }) # Update history (for next turn) history_lines.append(f"Agent: {raw}") history_lines.append(f"Env: {next_obs.last_tool_output}") obs = next_obs if done: break total_reward = sum(step["reward"] for step in trajectory) return trajectory, total_reward # ---------------------------------------------------------------------- # 6. Collect trajectories # ---------------------------------------------------------------------- def collect_trajectories(env, n=30): data = [] for i in range(n): traj, reward = rollout_episode(env) data.append((traj, reward)) print(f"Episode {i+1}: total reward = {reward:.3f}") return data # ---------------------------------------------------------------------- # 7. Build DPO dataset (serialise full trajectory) # ---------------------------------------------------------------------- def serialize_trajectory(traj): return "\n".join([step["action"] for step in traj]) def build_dpo_dataset(trajectories): dataset = [] for i in range(len(trajectories)): for j in range(i+1, len(trajectories)): t1, r1 = trajectories[i] t2, r2 = trajectories[j] if abs(r1 - r2) < 0.2: continue chosen_traj = t1 if r1 > r2 else t2 rejected_traj = t2 if r1 > r2 else t1 dataset.append({ "prompt": chosen_traj[0]["state"], # initial state only "chosen": serialize_trajectory(chosen_traj), "rejected": serialize_trajectory(rejected_traj), }) return dataset # ---------------------------------------------------------------------- # 8. Main training pipeline # ---------------------------------------------------------------------- if __name__ == "__main__": env = CodeReviewEnv() print("Collecting trajectories...") trajectories = collect_trajectories(env, n=30) print("Building DPO dataset...") dpo_data = build_dpo_dataset(trajectories) if not dpo_data: raise RuntimeError("No training data generated.") dataset = Dataset.from_list(dpo_data) trainer = DPOTrainer( model=model, ref_model=None, args=TrainingArguments( per_device_train_batch_size=2, gradient_accumulation_steps=4, max_steps=100, learning_rate=5e-5, logging_steps=5, fp16=True, output_dir="dpo_output", ), train_dataset=dataset, tokenizer=tokenizer, ) print("Starting DPO training...") trainer.train() print("Saving model...") model.save_pretrained("dpo_final_model") tokenizer.save_pretrained("dpo_final_model") print("Done.")