Spaces:
No application file
No application file
| # training.py – Multi‑step DPO training with trajectory serialisation | |
| import json | |
| import torch | |
| from datasets import Dataset | |
| from dataclasses import dataclass | |
| from typing import Optional, List | |
| from unsloth import FastLanguageModel | |
| from trl import DPOTrainer | |
| from transformers import TrainingArguments | |
| # Import your environment and actions | |
| from environment import CodeReviewEnv | |
| from models import ( | |
| RunTests, RunLinter, Inspect, | |
| ProposeFix, WriteComment, AskQuestion, | |
| Done, Skip | |
| ) | |
| # ---------------------------------------------------------------------- | |
| # 1. Structured action parsing (with retry) | |
| # ---------------------------------------------------------------------- | |
| class AgentAction: | |
| action_type: str | |
| content: Optional[str] = None | |
| def parse_action(output: str) -> AgentAction: | |
| try: | |
| data = json.loads(output) | |
| return AgentAction( | |
| action_type=data.get("action_type", "").lower(), | |
| content=data.get("content") | |
| ) | |
| except: | |
| return AgentAction("invalid", output) | |
| def safe_generate(prompt: str, model, tokenizer, max_retries=2) -> str: | |
| for _ in range(max_retries): | |
| raw = generate_action(prompt, model, tokenizer) | |
| try: | |
| json.loads(raw) | |
| return raw | |
| except: | |
| continue | |
| return '{"action_type":"skip"}' | |
| def map_to_env(action: AgentAction): | |
| if action.action_type == "run_tests": | |
| return RunTests() | |
| elif action.action_type == "run_linter": | |
| return RunLinter() | |
| elif action.action_type == "inspect": | |
| return Inspect() | |
| elif action.action_type == "fix": | |
| return ProposeFix(fix_code=action.content or "") | |
| elif action.action_type == "comment": | |
| return WriteComment(comment_text=action.content or "") | |
| elif action.action_type == "question": | |
| return AskQuestion(question=action.content or "") | |
| elif action.action_type == "done": | |
| return Done() | |
| else: | |
| return Skip() | |
| # ---------------------------------------------------------------------- | |
| # 2. Model loading | |
| # ---------------------------------------------------------------------- | |
| model, tokenizer = FastLanguageModel.from_pretrained( | |
| model_name="unsloth/gemma-2-2b-it-bnb-4bit", | |
| max_seq_length=2048, | |
| load_in_4bit=True, | |
| ) | |
| model = FastLanguageModel.get_peft_model( | |
| model, | |
| r=64, | |
| target_modules=[ | |
| "q_proj", "k_proj", "v_proj", "o_proj", | |
| "gate_proj", "up_proj", "down_proj" | |
| ], | |
| lora_alpha=64, | |
| lora_dropout=0, | |
| ) | |
| # ---------------------------------------------------------------------- | |
| # 3. Generation helper | |
| # ---------------------------------------------------------------------- | |
| def generate_action(prompt: str, model, tokenizer) -> str: | |
| formatted = f"<start_of_turn>user\n{prompt}<end_of_turn>\n<start_of_turn>model\n" | |
| inputs = tokenizer(formatted, return_tensors="pt").to("cuda") | |
| outputs = model.generate( | |
| **inputs, | |
| max_new_tokens=128, | |
| do_sample=True, | |
| temperature=0.8, | |
| ) | |
| return tokenizer.decode( | |
| outputs[0][inputs['input_ids'].shape[1]:], | |
| skip_special_tokens=True | |
| ).strip() | |
| # ---------------------------------------------------------------------- | |
| # 4. Prompt builder (initial state + sliding history) | |
| # ---------------------------------------------------------------------- | |
| def build_prompt(obs, history_lines: List[str]) -> str: | |
| prompt = f""" | |
| You are a code review agent. | |
| Code: | |
| {obs.code_snippet} | |
| Last Output: | |
| {obs.last_tool_output} | |
| Available actions: | |
| run_tests, run_linter, inspect, fix, comment, question, done | |
| Respond ONLY in JSON: | |
| {{"action_type": "...", "content": "..."}} | |
| """ | |
| if history_lines: | |
| history = "\n".join(history_lines[-6:]) # last 6 lines (3 exchanges) | |
| prompt += f"\n\nPrevious steps:\n{history}" | |
| return prompt | |
| # ---------------------------------------------------------------------- | |
| # 5. Multi‑step rollout | |
| # ---------------------------------------------------------------------- | |
| def rollout_episode(env, max_steps=8): | |
| obs = env.reset() | |
| history_lines = [] | |
| trajectory = [] | |
| for step in range(max_steps): | |
| prompt = build_prompt(obs, history_lines) | |
| raw = safe_generate(prompt, model, tokenizer) | |
| action = parse_action(raw) | |
| env_action = map_to_env(action) | |
| next_obs, reward, done, _ = env.step(env_action) | |
| trajectory.append({ | |
| "state": prompt, | |
| "action": raw, | |
| "reward": reward.value | |
| }) | |
| # Update history (for next turn) | |
| history_lines.append(f"Agent: {raw}") | |
| history_lines.append(f"Env: {next_obs.last_tool_output}") | |
| obs = next_obs | |
| if done: | |
| break | |
| total_reward = sum(step["reward"] for step in trajectory) | |
| return trajectory, total_reward | |
| # ---------------------------------------------------------------------- | |
| # 6. Collect trajectories | |
| # ---------------------------------------------------------------------- | |
| def collect_trajectories(env, n=30): | |
| data = [] | |
| for i in range(n): | |
| traj, reward = rollout_episode(env) | |
| data.append((traj, reward)) | |
| print(f"Episode {i+1}: total reward = {reward:.3f}") | |
| return data | |
| # ---------------------------------------------------------------------- | |
| # 7. Build DPO dataset (serialise full trajectory) | |
| # ---------------------------------------------------------------------- | |
| def serialize_trajectory(traj): | |
| return "\n".join([step["action"] for step in traj]) | |
| def build_dpo_dataset(trajectories): | |
| dataset = [] | |
| for i in range(len(trajectories)): | |
| for j in range(i+1, len(trajectories)): | |
| t1, r1 = trajectories[i] | |
| t2, r2 = trajectories[j] | |
| if abs(r1 - r2) < 0.2: | |
| continue | |
| chosen_traj = t1 if r1 > r2 else t2 | |
| rejected_traj = t2 if r1 > r2 else t1 | |
| dataset.append({ | |
| "prompt": chosen_traj[0]["state"], # initial state only | |
| "chosen": serialize_trajectory(chosen_traj), | |
| "rejected": serialize_trajectory(rejected_traj), | |
| }) | |
| return dataset | |
| # ---------------------------------------------------------------------- | |
| # 8. Main training pipeline | |
| # ---------------------------------------------------------------------- | |
| if __name__ == "__main__": | |
| env = CodeReviewEnv() | |
| print("Collecting trajectories...") | |
| trajectories = collect_trajectories(env, n=30) | |
| print("Building DPO dataset...") | |
| dpo_data = build_dpo_dataset(trajectories) | |
| if not dpo_data: | |
| raise RuntimeError("No training data generated.") | |
| dataset = Dataset.from_list(dpo_data) | |
| trainer = DPOTrainer( | |
| model=model, | |
| ref_model=None, | |
| args=TrainingArguments( | |
| per_device_train_batch_size=2, | |
| gradient_accumulation_steps=4, | |
| max_steps=100, | |
| learning_rate=5e-5, | |
| logging_steps=5, | |
| fp16=True, | |
| output_dir="dpo_output", | |
| ), | |
| train_dataset=dataset, | |
| tokenizer=tokenizer, | |
| ) | |
| print("Starting DPO training...") | |
| trainer.train() | |
| print("Saving model...") | |
| model.save_pretrained("dpo_final_model") | |
| tokenizer.save_pretrained("dpo_final_model") | |
| print("Done.") |