100XZX001's picture
Upload 13 files
29c6586 verified
# training.py – Multi‑step DPO training with trajectory serialisation
import json
import torch
from datasets import Dataset
from dataclasses import dataclass
from typing import Optional, List
from unsloth import FastLanguageModel
from trl import DPOTrainer
from transformers import TrainingArguments
# Import your environment and actions
from environment import CodeReviewEnv
from models import (
RunTests, RunLinter, Inspect,
ProposeFix, WriteComment, AskQuestion,
Done, Skip
)
# ----------------------------------------------------------------------
# 1. Structured action parsing (with retry)
# ----------------------------------------------------------------------
@dataclass
class AgentAction:
action_type: str
content: Optional[str] = None
def parse_action(output: str) -> AgentAction:
try:
data = json.loads(output)
return AgentAction(
action_type=data.get("action_type", "").lower(),
content=data.get("content")
)
except:
return AgentAction("invalid", output)
def safe_generate(prompt: str, model, tokenizer, max_retries=2) -> str:
for _ in range(max_retries):
raw = generate_action(prompt, model, tokenizer)
try:
json.loads(raw)
return raw
except:
continue
return '{"action_type":"skip"}'
def map_to_env(action: AgentAction):
if action.action_type == "run_tests":
return RunTests()
elif action.action_type == "run_linter":
return RunLinter()
elif action.action_type == "inspect":
return Inspect()
elif action.action_type == "fix":
return ProposeFix(fix_code=action.content or "")
elif action.action_type == "comment":
return WriteComment(comment_text=action.content or "")
elif action.action_type == "question":
return AskQuestion(question=action.content or "")
elif action.action_type == "done":
return Done()
else:
return Skip()
# ----------------------------------------------------------------------
# 2. Model loading
# ----------------------------------------------------------------------
model, tokenizer = FastLanguageModel.from_pretrained(
model_name="unsloth/gemma-2-2b-it-bnb-4bit",
max_seq_length=2048,
load_in_4bit=True,
)
model = FastLanguageModel.get_peft_model(
model,
r=64,
target_modules=[
"q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj"
],
lora_alpha=64,
lora_dropout=0,
)
# ----------------------------------------------------------------------
# 3. Generation helper
# ----------------------------------------------------------------------
def generate_action(prompt: str, model, tokenizer) -> str:
formatted = f"<start_of_turn>user\n{prompt}<end_of_turn>\n<start_of_turn>model\n"
inputs = tokenizer(formatted, return_tensors="pt").to("cuda")
outputs = model.generate(
**inputs,
max_new_tokens=128,
do_sample=True,
temperature=0.8,
)
return tokenizer.decode(
outputs[0][inputs['input_ids'].shape[1]:],
skip_special_tokens=True
).strip()
# ----------------------------------------------------------------------
# 4. Prompt builder (initial state + sliding history)
# ----------------------------------------------------------------------
def build_prompt(obs, history_lines: List[str]) -> str:
prompt = f"""
You are a code review agent.
Code:
{obs.code_snippet}
Last Output:
{obs.last_tool_output}
Available actions:
run_tests, run_linter, inspect, fix, comment, question, done
Respond ONLY in JSON:
{{"action_type": "...", "content": "..."}}
"""
if history_lines:
history = "\n".join(history_lines[-6:]) # last 6 lines (3 exchanges)
prompt += f"\n\nPrevious steps:\n{history}"
return prompt
# ----------------------------------------------------------------------
# 5. Multi‑step rollout
# ----------------------------------------------------------------------
def rollout_episode(env, max_steps=8):
obs = env.reset()
history_lines = []
trajectory = []
for step in range(max_steps):
prompt = build_prompt(obs, history_lines)
raw = safe_generate(prompt, model, tokenizer)
action = parse_action(raw)
env_action = map_to_env(action)
next_obs, reward, done, _ = env.step(env_action)
trajectory.append({
"state": prompt,
"action": raw,
"reward": reward.value
})
# Update history (for next turn)
history_lines.append(f"Agent: {raw}")
history_lines.append(f"Env: {next_obs.last_tool_output}")
obs = next_obs
if done:
break
total_reward = sum(step["reward"] for step in trajectory)
return trajectory, total_reward
# ----------------------------------------------------------------------
# 6. Collect trajectories
# ----------------------------------------------------------------------
def collect_trajectories(env, n=30):
data = []
for i in range(n):
traj, reward = rollout_episode(env)
data.append((traj, reward))
print(f"Episode {i+1}: total reward = {reward:.3f}")
return data
# ----------------------------------------------------------------------
# 7. Build DPO dataset (serialise full trajectory)
# ----------------------------------------------------------------------
def serialize_trajectory(traj):
return "\n".join([step["action"] for step in traj])
def build_dpo_dataset(trajectories):
dataset = []
for i in range(len(trajectories)):
for j in range(i+1, len(trajectories)):
t1, r1 = trajectories[i]
t2, r2 = trajectories[j]
if abs(r1 - r2) < 0.2:
continue
chosen_traj = t1 if r1 > r2 else t2
rejected_traj = t2 if r1 > r2 else t1
dataset.append({
"prompt": chosen_traj[0]["state"], # initial state only
"chosen": serialize_trajectory(chosen_traj),
"rejected": serialize_trajectory(rejected_traj),
})
return dataset
# ----------------------------------------------------------------------
# 8. Main training pipeline
# ----------------------------------------------------------------------
if __name__ == "__main__":
env = CodeReviewEnv()
print("Collecting trajectories...")
trajectories = collect_trajectories(env, n=30)
print("Building DPO dataset...")
dpo_data = build_dpo_dataset(trajectories)
if not dpo_data:
raise RuntimeError("No training data generated.")
dataset = Dataset.from_list(dpo_data)
trainer = DPOTrainer(
model=model,
ref_model=None,
args=TrainingArguments(
per_device_train_batch_size=2,
gradient_accumulation_steps=4,
max_steps=100,
learning_rate=5e-5,
logging_steps=5,
fp16=True,
output_dir="dpo_output",
),
train_dataset=dataset,
tokenizer=tokenizer,
)
print("Starting DPO training...")
trainer.train()
print("Saving model...")
model.save_pretrained("dpo_final_model")
tokenizer.save_pretrained("dpo_final_model")
print("Done.")