100XZX001 commited on
Commit
6b4dcf0
·
verified ·
1 Parent(s): 778f292

Update training.py

Browse files
Files changed (1) hide show
  1. training.py +210 -100
training.py CHANGED
@@ -1,4 +1,4 @@
1
- # training.py – True PPO-based RL training with multi-step trajectories
2
 
3
  import json
4
  import torch
@@ -7,11 +7,15 @@ from torch.optim import AdamW
7
  from dataclasses import dataclass
8
  from typing import List, Dict, Tuple, Optional
9
  import numpy as np
 
 
10
 
11
  from unsloth import FastLanguageModel
12
  from transformers import TrainingArguments
 
 
13
 
14
- # Import your environment and actions
15
  from environment import CodeReviewEnv
16
  from models import (
17
  RunTests, RunLinter, Inspect,
@@ -20,7 +24,7 @@ from models import (
20
  )
21
 
22
  # ======================================================================
23
- # 1. ACTION PARSING (unchanged from original)
24
  # ======================================================================
25
  @dataclass
26
  class AgentAction:
@@ -28,6 +32,8 @@ class AgentAction:
28
  content: Optional[str] = None
29
 
30
  def parse_action(output: str) -> AgentAction:
 
 
31
  try:
32
  data = json.loads(output)
33
  return AgentAction(
@@ -35,7 +41,36 @@ def parse_action(output: str) -> AgentAction:
35
  content=data.get("content")
36
  )
37
  except:
38
- return AgentAction("invalid", output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
  def map_to_env(action: AgentAction):
41
  if action.action_type == "run_tests":
@@ -56,7 +91,7 @@ def map_to_env(action: AgentAction):
56
  return Skip()
57
 
58
  # ======================================================================
59
- # 2. MODEL SETUP
60
  # ======================================================================
61
  def load_model():
62
  model, tokenizer = FastLanguageModel.from_pretrained(
@@ -64,32 +99,139 @@ def load_model():
64
  max_seq_length=2048,
65
  load_in_4bit=True,
66
  )
 
67
  model = FastLanguageModel.get_peft_model(
68
  model,
69
- r=64,
70
  target_modules=[
71
  "q_proj", "k_proj", "v_proj", "o_proj",
72
  "gate_proj", "up_proj", "down_proj"
73
  ],
74
- lora_alpha=64,
75
- lora_dropout=0,
76
  )
 
 
 
77
  return model, tokenizer
78
 
79
  # ======================================================================
80
- # 3. ACTION GENERATION WITH LOGPROB TRACKING
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  # ======================================================================
82
  def generate_action_with_logprob(
83
  prompt: str,
84
  model,
85
  tokenizer,
86
- temperature: float = 0.8,
87
  max_retries: int = 2
88
  ) -> Tuple[str, float]:
89
- """
90
- Generate action and return (action_text, logprob)
91
- """
92
- formatted = f"<start_of_turn>user\n{prompt}<end_of_turn>\n<start_of_turn>model\n"
93
  inputs = tokenizer(formatted, return_tensors="pt").to("cuda")
94
 
95
  for attempt in range(max_retries):
@@ -97,25 +239,37 @@ def generate_action_with_logprob(
97
  outputs = model.generate(
98
  **inputs,
99
  max_new_tokens=128,
100
- do_sample=True,
101
- temperature=temperature,
 
102
  return_dict_in_generate=True,
103
  output_scores=True,
104
  )
105
 
106
- # Extract generated tokens
107
  generated_ids = outputs.sequences[0][inputs['input_ids'].shape[1]:]
108
  action_text = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
109
 
110
- # Compute logprob for the generated sequence
111
  logprobs = []
112
  for idx, token_id in enumerate(generated_ids):
113
  if idx < len(outputs.scores):
114
- token_logits = outputs.scores[idx][0] # [vocab_size]
115
  token_logprob = F.log_softmax(token_logits, dim=-1)[token_id].item()
116
  logprobs.append(token_logprob)
 
117
 
118
- total_logprob = sum(logprobs)
 
 
 
 
 
 
 
 
 
 
 
119
 
120
  # Validate JSON
121
  try:
@@ -123,14 +277,13 @@ def generate_action_with_logprob(
123
  return action_text, total_logprob
124
  except:
125
  if attempt == max_retries - 1:
126
- # Return skip action with penalty logprob
127
  return '{"action_type":"skip"}', -100.0
128
  continue
129
 
130
  return '{"action_type":"skip"}', -100.0
131
 
132
  # ======================================================================
133
- # 4. PROMPT BUILDER (unchanged from original)
134
  # ======================================================================
135
  def build_prompt(obs, history_lines: List[str]) -> str:
136
  prompt = f"""You are a code review agent.
@@ -153,7 +306,7 @@ Respond ONLY in JSON:
153
  return prompt
154
 
155
  # ======================================================================
156
- # 5. TRAJECTORY STORAGE
157
  # ======================================================================
158
  @dataclass
159
  class Trajectory:
@@ -176,18 +329,15 @@ class Trajectory:
176
  }
177
 
178
  # ======================================================================
179
- # 6. ROLLOUT COLLECTION
180
  # ======================================================================
181
  def collect_trajectory(
182
  env: CodeReviewEnv,
183
  model,
184
  tokenizer,
185
  max_steps: int = 10,
186
- temperature: float = 0.8
187
  ) -> Trajectory:
188
- """
189
- Collect a single trajectory with full RL data.
190
- """
191
  obs = env.reset()
192
  history_lines = []
193
 
@@ -198,18 +348,15 @@ def collect_trajectory(
198
  dones = []
199
 
200
  for step in range(max_steps):
201
- # Build prompt
202
  prompt = build_prompt(obs, history_lines)
203
  states.append(prompt)
204
 
205
- # Generate action with logprob
206
  action_text, logprob = generate_action_with_logprob(
207
  prompt, model, tokenizer, temperature
208
  )
209
  actions.append(action_text)
210
  logprobs.append(logprob)
211
 
212
- # Parse and execute
213
  action = parse_action(action_text)
214
  env_action = map_to_env(action)
215
  next_obs, reward, done, _ = env.step(env_action)
@@ -217,7 +364,6 @@ def collect_trajectory(
217
  rewards.append(reward.value)
218
  dones.append(done)
219
 
220
- # Update history
221
  history_lines.append(f"Agent: {action_text}")
222
  history_lines.append(f"Env: {next_obs.last_tool_output}")
223
 
@@ -234,7 +380,6 @@ def collect_trajectories(
234
  n_trajectories: int,
235
  max_steps: int = 10
236
  ) -> List[Trajectory]:
237
- """Collect multiple trajectories."""
238
  trajectories = []
239
  for i in range(n_trajectories):
240
  traj = collect_trajectory(env, model, tokenizer, max_steps)
@@ -245,7 +390,7 @@ def collect_trajectories(
245
  return trajectories
246
 
247
  # ======================================================================
248
- # 7. ADVANTAGE ESTIMATION (GAE)
249
  # ======================================================================
250
  def compute_gae(
251
  rewards: List[float],
@@ -254,13 +399,7 @@ def compute_gae(
254
  gamma: float = 0.99,
255
  lambda_: float = 0.95
256
  ) -> Tuple[List[float], List[float]]:
257
- """
258
- Compute Generalized Advantage Estimation.
259
- If no value function provided, use reward-to-go as returns.
260
- """
261
  n = len(rewards)
262
-
263
- # Compute returns (reward-to-go)
264
  returns = [0.0] * n
265
  running_return = 0.0
266
  for t in reversed(range(n)):
@@ -269,16 +408,13 @@ def compute_gae(
269
  running_return = rewards[t] + gamma * running_return
270
  returns[t] = running_return
271
 
272
- # If no value function, use returns as advantages (centered)
273
  if values is None:
274
  advantages = returns
275
- # Normalize advantages
276
  adv_mean = np.mean(advantages)
277
  adv_std = np.std(advantages) + 1e-8
278
  advantages = [(a - adv_mean) / adv_std for a in advantages]
279
  return advantages, returns
280
 
281
- # GAE with value function
282
  advantages = [0.0] * n
283
  gae = 0.0
284
  for t in reversed(range(n)):
@@ -289,44 +425,39 @@ def compute_gae(
289
  gae = delta + gamma * lambda_ * gae
290
  advantages[t] = gae
291
 
292
- # Normalize
293
  adv_mean = np.mean(advantages)
294
  adv_std = np.std(advantages) + 1e-8
295
  advantages = [(a - adv_mean) / adv_std for a in advantages]
296
-
297
  return advantages, returns
298
 
299
  # ======================================================================
300
- # 8. COMPUTE NEW LOGPROBS (for PPO ratio)
301
  # ======================================================================
302
  def compute_logprob(prompt: str, action: str, model, tokenizer) -> float:
303
- """
304
- Compute log probability of action given prompt.
305
- """
306
- formatted = f"<start_of_turn>user\n{prompt}<end_of_turn>\n<start_of_turn>model\n{action}"
307
- inputs = tokenizer(formatted, return_tensors="pt").to("cuda")
308
 
309
  with torch.no_grad():
310
  outputs = model(**inputs)
311
- logits = outputs.logits # [1, seq_len, vocab_size]
312
 
313
- # Get action tokens
314
  action_ids = tokenizer.encode(action, add_special_tokens=False)
315
- action_start = inputs['input_ids'].shape[1] - len(action_ids)
 
316
 
317
- # Compute logprob for action tokens
318
  logprobs = []
319
  for idx, token_id in enumerate(action_ids):
320
- position = action_start + idx - 1 # -1 because logits are shifted
321
- if position >= 0 and position < logits.shape[1]:
322
  token_logits = logits[0, position]
323
  token_logprob = F.log_softmax(token_logits, dim=-1)[token_id].item()
324
  logprobs.append(token_logprob)
325
-
326
  return sum(logprobs) if logprobs else -100.0
327
 
328
  # ======================================================================
329
- # 9. PPO UPDATE
330
  # ======================================================================
331
  def ppo_update(
332
  trajectories: List[Trajectory],
@@ -339,12 +470,8 @@ def ppo_update(
339
  gamma: float = 0.99,
340
  lambda_: float = 0.95,
341
  ) -> Dict[str, float]:
342
- """
343
- Perform PPO policy update.
344
- """
345
  model.train()
346
 
347
- # Flatten all trajectories into single dataset
348
  all_states = []
349
  all_actions = []
350
  all_old_logprobs = []
@@ -352,11 +479,9 @@ def ppo_update(
352
  all_returns = []
353
 
354
  for traj in trajectories:
355
- # Compute advantages for this trajectory
356
  advantages, returns = compute_gae(
357
  traj.rewards, traj.dones, values=None, gamma=gamma, lambda_=lambda_
358
  )
359
-
360
  all_states.extend(traj.states)
361
  all_actions.extend(traj.actions)
362
  all_old_logprobs.extend(traj.logprobs)
@@ -364,46 +489,42 @@ def ppo_update(
364
  all_returns.extend(returns)
365
 
366
  n_samples = len(all_states)
367
-
368
  total_loss = 0.0
369
  total_policy_loss = 0.0
370
  total_entropy = 0.0
371
  n_updates = 0
372
 
373
- # Multiple epochs over the data
374
  for epoch in range(n_epochs):
375
- # Shuffle data
376
  indices = np.random.permutation(n_samples)
377
-
378
  for i in indices:
379
  state = all_states[i]
380
  action = all_actions[i]
381
  old_logprob = all_old_logprobs[i]
382
  advantage = all_advantages[i]
383
 
384
- # Compute new logprob with gradient
385
- formatted = f"<start_of_turn>user\n{state}<end_of_turn>\n<start_of_turn>model\n{action}"
386
- inputs = tokenizer(formatted, return_tensors="pt").to("cuda")
 
 
387
 
388
  outputs = model(**inputs)
389
  logits = outputs.logits
390
 
391
- # Get action tokens
392
  action_ids = tokenizer.encode(action, add_special_tokens=False)
393
- action_start = inputs['input_ids'].shape[1] - len(action_ids)
 
394
 
395
- # Compute logprob for action
396
  logprobs = []
397
  entropy = 0.0
398
  for idx, token_id in enumerate(action_ids):
399
  position = action_start + idx - 1
400
- if position >= 0 and position < logits.shape[1]:
401
  token_logits = logits[0, position]
402
  log_probs = F.log_softmax(token_logits, dim=-1)
403
  token_logprob = log_probs[token_id]
404
  logprobs.append(token_logprob)
405
 
406
- # Entropy
407
  probs = F.softmax(token_logits, dim=-1)
408
  entropy += -(probs * log_probs).sum()
409
 
@@ -413,16 +534,12 @@ def ppo_update(
413
  new_logprob = sum(logprobs)
414
  avg_entropy = entropy / len(logprobs) if logprobs else 0.0
415
 
416
- # PPO objective
417
  ratio = torch.exp(new_logprob - old_logprob)
418
  surr1 = ratio * advantage
419
  surr2 = torch.clamp(ratio, 1 - clip_epsilon, 1 + clip_epsilon) * advantage
420
  policy_loss = -torch.min(surr1, surr2)
421
-
422
- # Total loss
423
  loss = policy_loss - entropy_coef * avg_entropy
424
 
425
- # Backprop
426
  optimizer.zero_grad()
427
  loss.backward()
428
  torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
@@ -440,7 +557,7 @@ def ppo_update(
440
  }
441
 
442
  # ======================================================================
443
- # 10. EVALUATION
444
  # ======================================================================
445
  def evaluate_policy(
446
  env: CodeReviewEnv,
@@ -449,22 +566,16 @@ def evaluate_policy(
449
  n_episodes: int = 10,
450
  max_steps: int = 10
451
  ) -> Dict[str, float]:
452
- """
453
- Evaluate policy over multiple episodes.
454
- """
455
  model.eval()
456
-
457
  total_rewards = []
458
  episode_lengths = []
459
  success_count = 0
460
 
461
  for _ in range(n_episodes):
462
- traj = collect_trajectory(env, model, tokenizer, max_steps, temperature=0.5)
463
  total_reward = sum(traj.rewards)
464
  total_rewards.append(total_reward)
465
  episode_lengths.append(len(traj))
466
-
467
- # Define success (e.g., reward > threshold)
468
  if total_reward > 0.5:
469
  success_count += 1
470
 
@@ -476,7 +587,7 @@ def evaluate_policy(
476
  }
477
 
478
  # ======================================================================
479
- # 11. MAIN TRAINING LOOP
480
  # ======================================================================
481
  def train_ppo(
482
  n_iterations: int = 50,
@@ -490,14 +601,18 @@ def train_ppo(
490
  lambda_: float = 0.95,
491
  eval_every: int = 5,
492
  ):
493
- """
494
- Main PPO training loop.
495
- """
496
  print("Loading model...")
497
  model, tokenizer = load_model()
498
 
499
- optimizer = AdamW(model.parameters(), lr=learning_rate)
 
 
 
 
 
 
500
 
 
501
  env = CodeReviewEnv()
502
 
503
  print(f"\n{'='*60}")
@@ -510,20 +625,17 @@ def train_ppo(
510
  for iteration in range(n_iterations):
511
  print(f"\n--- Iteration {iteration + 1}/{n_iterations} ---")
512
 
513
- # Collect trajectories
514
  print("Collecting trajectories...")
515
  trajectories = collect_trajectories(
516
  env, model, tokenizer, trajectories_per_iter, max_steps
517
  )
518
 
519
- # Compute statistics
520
  avg_reward = np.mean([sum(t.rewards) for t in trajectories])
521
  avg_length = np.mean([len(t) for t in trajectories])
522
 
523
  print(f"Avg reward: {avg_reward:.3f}")
524
  print(f"Avg length: {avg_length:.1f}")
525
 
526
- # PPO update
527
  print("Updating policy...")
528
  metrics = ppo_update(
529
  trajectories,
@@ -541,7 +653,6 @@ def train_ppo(
541
  print(f"Policy loss: {metrics['policy_loss']:.4f}")
542
  print(f"Entropy: {metrics['entropy']:.4f}")
543
 
544
- # Evaluation
545
  if (iteration + 1) % eval_every == 0:
546
  print("\nEvaluating policy...")
547
  eval_metrics = evaluate_policy(env, model, tokenizer, n_episodes=10)
@@ -549,7 +660,6 @@ def train_ppo(
549
  print(f"Eval success rate: {eval_metrics['success_rate']:.2%}")
550
  print(f"Eval avg length: {eval_metrics['avg_length']:.1f}")
551
 
552
- # Final save
553
  print("\n" + "="*60)
554
  print("Training complete. Saving model...")
555
  model.save_pretrained("ppo_final_model")
@@ -558,7 +668,7 @@ def train_ppo(
558
  print("="*60)
559
 
560
  # ======================================================================
561
- # 12. ENTRY POINT
562
  # ======================================================================
563
  if __name__ == "__main__":
564
  train_ppo(
 
1
+ # training.py – FIXED PPO training (no variable names changed)
2
 
3
  import json
4
  import torch
 
7
  from dataclasses import dataclass
8
  from typing import List, Dict, Tuple, Optional
9
  import numpy as np
10
+ import re
11
+ import random
12
 
13
  from unsloth import FastLanguageModel
14
  from transformers import TrainingArguments
15
+ from trl import SFTTrainer
16
+ from datasets import Dataset
17
 
18
+ # Import your environment and actions (unchanged)
19
  from environment import CodeReviewEnv
20
  from models import (
21
  RunTests, RunLinter, Inspect,
 
24
  )
25
 
26
  # ======================================================================
27
+ # 1. ACTION PARSING (improved with fallback)
28
  # ======================================================================
29
  @dataclass
30
  class AgentAction:
 
32
  content: Optional[str] = None
33
 
34
  def parse_action(output: str) -> AgentAction:
35
+ """Robust JSON parsing with regex fallback and keyword detection."""
36
+ # Try strict JSON first
37
  try:
38
  data = json.loads(output)
39
  return AgentAction(
 
41
  content=data.get("content")
42
  )
43
  except:
44
+ pass
45
+
46
+ # Try to extract JSON from markdown blocks
47
+ json_match = re.search(r'```(?:json)?\s*(\{.*?\})\s*```', output, re.DOTALL)
48
+ if json_match:
49
+ try:
50
+ data = json.loads(json_match.group(1))
51
+ return AgentAction(
52
+ action_type=data.get("action_type", "").lower(),
53
+ content=data.get("content")
54
+ )
55
+ except:
56
+ pass
57
+
58
+ # Try to find "action_type" field with regex
59
+ action_pattern = r'"action_type"\s*:\s*"(\w+)"'
60
+ match = re.search(action_pattern, output)
61
+ if match:
62
+ return AgentAction(action_type=match.group(1).lower())
63
+
64
+ # Keyword detection as last resort
65
+ output_lower = output.lower()
66
+ if "test" in output_lower:
67
+ return AgentAction("run_tests")
68
+ if "lint" in output_lower:
69
+ return AgentAction("run_linter")
70
+ if "inspect" in output_lower:
71
+ return AgentAction("inspect")
72
+
73
+ return AgentAction("invalid", output)
74
 
75
  def map_to_env(action: AgentAction):
76
  if action.action_type == "run_tests":
 
91
  return Skip()
92
 
93
  # ======================================================================
94
+ # 2. MODEL SETUP (stabilised LoRA)
95
  # ======================================================================
96
  def load_model():
97
  model, tokenizer = FastLanguageModel.from_pretrained(
 
99
  max_seq_length=2048,
100
  load_in_4bit=True,
101
  )
102
+ # FIXED: Lower rank (16), dropout=0 for stability
103
  model = FastLanguageModel.get_peft_model(
104
  model,
105
+ r=16, # was 64 → causes collapse
106
  target_modules=[
107
  "q_proj", "k_proj", "v_proj", "o_proj",
108
  "gate_proj", "up_proj", "down_proj"
109
  ],
110
+ lora_alpha=32, # adjusted for r=16
111
+ lora_dropout=0.0, # dropout can cause empty outputs
112
  )
113
+ # Ensure tokenizer has correct chat template for Gemma-2
114
+ if tokenizer.chat_template is None:
115
+ tokenizer.chat_template = "{{ bos_token }}{% for message in messages %}{% if message['role'] == 'user' %}<start_of_turn>user\n{{ message['content'] }}<end_of_turn>\n<start_of_turn>model\n{% elif message['role'] == 'assistant' %}{{ message['content'] }}<end_of_turn>\n{% endif %}{% endfor %}"
116
  return model, tokenizer
117
 
118
  # ======================================================================
119
+ # 3. MODEL SANITY CHECK (new – ensures model can generate text)
120
+ # ======================================================================
121
+ def test_model_sanity(model, tokenizer) -> bool:
122
+ print("\n" + "="*60)
123
+ print("SANITY CHECK: Testing base model generation")
124
+ print("="*60)
125
+ test_prompt = "Hello, how are you?"
126
+ messages = [{"role": "user", "content": test_prompt}]
127
+ formatted = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
128
+ inputs = tokenizer(formatted, return_tensors="pt").to("cuda")
129
+ with torch.no_grad():
130
+ outputs = model.generate(
131
+ **inputs,
132
+ max_new_tokens=30,
133
+ do_sample=True,
134
+ temperature=0.7,
135
+ min_new_tokens=1,
136
+ eos_token_id=tokenizer.eos_token_id,
137
+ pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id,
138
+ )
139
+ generated_ids = outputs[0][inputs['input_ids'].shape[1]:]
140
+ response = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
141
+ print(f"Prompt: {test_prompt}")
142
+ print(f"Response: {repr(response)}")
143
+ if len(response) == 0:
144
+ print("❌ Model produces empty output – cannot train.")
145
+ return False
146
+ print("✓ Model sanity check PASSED\n")
147
+ return True
148
+
149
+ # ======================================================================
150
+ # 4. SUPERVISED WARM-UP (teaches JSON output)
151
+ # ======================================================================
152
+ def supervised_warmup(model, tokenizer, n_examples=500, epochs=2):
153
+ print("\n" + "="*60)
154
+ print("SUPERVISED WARM-UP: Teaching JSON format")
155
+ print("="*60)
156
+
157
+ examples = []
158
+ action_templates = [
159
+ '{"action_type": "run_tests"}',
160
+ '{"action_type": "run_linter"}',
161
+ '{"action_type": "inspect"}',
162
+ '{"action_type": "fix", "content": "def corrected():\n pass"}',
163
+ '{"action_type": "comment", "content": "This looks good."}',
164
+ '{"action_type": "question", "content": "Why is this variable used?"}',
165
+ '{"action_type": "done"}',
166
+ ]
167
+
168
+ for i in range(n_examples):
169
+ code = f"def example_{i}():\n return {i % 10}"
170
+ last_outputs = [
171
+ "Tests passed: 2/3",
172
+ "Linter found 1 error",
173
+ "Inspection complete",
174
+ "No previous action",
175
+ ]
176
+ last_output = random.choice(last_outputs)
177
+ # Use same prompt structure as build_prompt
178
+ prompt = f"""You are a code review agent.
179
+
180
+ Code:
181
+ {code}
182
+
183
+ Last Output:
184
+ {last_output}
185
+
186
+ Available actions:
187
+ run_tests, run_linter, inspect, fix, comment, question, done
188
+
189
+ Respond ONLY in JSON:
190
+ {{"action_type": "...", "content": "..."}}"""
191
+
192
+ action_json = random.choice(action_templates)
193
+ messages = [
194
+ {"role": "user", "content": prompt},
195
+ {"role": "assistant", "content": action_json}
196
+ ]
197
+ full_text = tokenizer.apply_chat_template(messages, tokenize=False)
198
+ examples.append({"text": full_text})
199
+
200
+ dataset = Dataset.from_list(examples)
201
+ trainer = SFTTrainer(
202
+ model=model,
203
+ tokenizer=tokenizer,
204
+ train_dataset=dataset,
205
+ dataset_text_field="text",
206
+ max_seq_length=512,
207
+ args=TrainingArguments(
208
+ output_dir="warmup_output",
209
+ num_train_epochs=epochs,
210
+ per_device_train_batch_size=4,
211
+ gradient_accumulation_steps=2,
212
+ learning_rate=2e-5,
213
+ logging_steps=50,
214
+ save_strategy="no",
215
+ fp16=True,
216
+ ),
217
+ )
218
+ print(f"Training on {n_examples} examples for {epochs} epochs...")
219
+ trainer.train()
220
+ print("✓ Warm-up complete\n")
221
+
222
+ # ======================================================================
223
+ # 5. ACTION GENERATION WITH LOGPROB TRACKING (fixed)
224
  # ======================================================================
225
  def generate_action_with_logprob(
226
  prompt: str,
227
  model,
228
  tokenizer,
229
+ temperature: float = 0.0, # changed: greedy by default for stability
230
  max_retries: int = 2
231
  ) -> Tuple[str, float]:
232
+ """Generate action using correct chat template, with fallback."""
233
+ messages = [{"role": "user", "content": prompt}]
234
+ formatted = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
 
235
  inputs = tokenizer(formatted, return_tensors="pt").to("cuda")
236
 
237
  for attempt in range(max_retries):
 
239
  outputs = model.generate(
240
  **inputs,
241
  max_new_tokens=128,
242
+ do_sample=(temperature > 0),
243
+ temperature=max(temperature, 0.01) if temperature > 0 else 1.0,
244
+ min_new_tokens=1,
245
  return_dict_in_generate=True,
246
  output_scores=True,
247
  )
248
 
 
249
  generated_ids = outputs.sequences[0][inputs['input_ids'].shape[1]:]
250
  action_text = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
251
 
252
+ # Compute logprob
253
  logprobs = []
254
  for idx, token_id in enumerate(generated_ids):
255
  if idx < len(outputs.scores):
256
+ token_logits = outputs.scores[idx][0]
257
  token_logprob = F.log_softmax(token_logits, dim=-1)[token_id].item()
258
  logprobs.append(token_logprob)
259
+ total_logprob = sum(logprobs) if logprobs else -100.0
260
 
261
+ # If empty, use fallback
262
+ if not action_text:
263
+ fallback_actions = [
264
+ '{"action_type": "run_tests"}',
265
+ '{"action_type": "run_linter"}',
266
+ '{"action_type": "inspect"}',
267
+ '{"action_type": "skip"}',
268
+ ]
269
+ action_text = random.choice(fallback_actions)
270
+ total_logprob = -50.0
271
+ print(f"[WARN] Empty generation → using fallback: {action_text}")
272
+ return action_text, total_logprob
273
 
274
  # Validate JSON
275
  try:
 
277
  return action_text, total_logprob
278
  except:
279
  if attempt == max_retries - 1:
 
280
  return '{"action_type":"skip"}', -100.0
281
  continue
282
 
283
  return '{"action_type":"skip"}', -100.0
284
 
285
  # ======================================================================
286
+ # 6. PROMPT BUILDER (unchanged exactly as you wrote)
287
  # ======================================================================
288
  def build_prompt(obs, history_lines: List[str]) -> str:
289
  prompt = f"""You are a code review agent.
 
306
  return prompt
307
 
308
  # ======================================================================
309
+ # 7. TRAJECTORY STORAGE (unchanged)
310
  # ======================================================================
311
  @dataclass
312
  class Trajectory:
 
329
  }
330
 
331
  # ======================================================================
332
+ # 8. ROLLOUT COLLECTION (uses fixed generate)
333
  # ======================================================================
334
  def collect_trajectory(
335
  env: CodeReviewEnv,
336
  model,
337
  tokenizer,
338
  max_steps: int = 10,
339
+ temperature: float = 0.0 # changed to greedy
340
  ) -> Trajectory:
 
 
 
341
  obs = env.reset()
342
  history_lines = []
343
 
 
348
  dones = []
349
 
350
  for step in range(max_steps):
 
351
  prompt = build_prompt(obs, history_lines)
352
  states.append(prompt)
353
 
 
354
  action_text, logprob = generate_action_with_logprob(
355
  prompt, model, tokenizer, temperature
356
  )
357
  actions.append(action_text)
358
  logprobs.append(logprob)
359
 
 
360
  action = parse_action(action_text)
361
  env_action = map_to_env(action)
362
  next_obs, reward, done, _ = env.step(env_action)
 
364
  rewards.append(reward.value)
365
  dones.append(done)
366
 
 
367
  history_lines.append(f"Agent: {action_text}")
368
  history_lines.append(f"Env: {next_obs.last_tool_output}")
369
 
 
380
  n_trajectories: int,
381
  max_steps: int = 10
382
  ) -> List[Trajectory]:
 
383
  trajectories = []
384
  for i in range(n_trajectories):
385
  traj = collect_trajectory(env, model, tokenizer, max_steps)
 
390
  return trajectories
391
 
392
  # ======================================================================
393
+ # 9. ADVANTAGE ESTIMATION (unchanged)
394
  # ======================================================================
395
  def compute_gae(
396
  rewards: List[float],
 
399
  gamma: float = 0.99,
400
  lambda_: float = 0.95
401
  ) -> Tuple[List[float], List[float]]:
 
 
 
 
402
  n = len(rewards)
 
 
403
  returns = [0.0] * n
404
  running_return = 0.0
405
  for t in reversed(range(n)):
 
408
  running_return = rewards[t] + gamma * running_return
409
  returns[t] = running_return
410
 
 
411
  if values is None:
412
  advantages = returns
 
413
  adv_mean = np.mean(advantages)
414
  adv_std = np.std(advantages) + 1e-8
415
  advantages = [(a - adv_mean) / adv_std for a in advantages]
416
  return advantages, returns
417
 
 
418
  advantages = [0.0] * n
419
  gae = 0.0
420
  for t in reversed(range(n)):
 
425
  gae = delta + gamma * lambda_ * gae
426
  advantages[t] = gae
427
 
 
428
  adv_mean = np.mean(advantages)
429
  adv_std = np.std(advantages) + 1e-8
430
  advantages = [(a - adv_mean) / adv_std for a in advantages]
 
431
  return advantages, returns
432
 
433
  # ======================================================================
434
+ # 10. COMPUTE NEW LOGPROBS (unchanged)
435
  # ======================================================================
436
  def compute_logprob(prompt: str, action: str, model, tokenizer) -> float:
437
+ messages = [{"role": "user", "content": prompt}]
438
+ formatted = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
439
+ full_text = formatted + action
440
+ inputs = tokenizer(full_text, return_tensors="pt").to("cuda")
 
441
 
442
  with torch.no_grad():
443
  outputs = model(**inputs)
444
+ logits = outputs.logits
445
 
 
446
  action_ids = tokenizer.encode(action, add_special_tokens=False)
447
+ prefix_ids = tokenizer.encode(formatted, add_special_tokens=False)
448
+ action_start = len(prefix_ids)
449
 
 
450
  logprobs = []
451
  for idx, token_id in enumerate(action_ids):
452
+ position = action_start + idx - 1
453
+ if 0 <= position < logits.shape[1]:
454
  token_logits = logits[0, position]
455
  token_logprob = F.log_softmax(token_logits, dim=-1)[token_id].item()
456
  logprobs.append(token_logprob)
 
457
  return sum(logprobs) if logprobs else -100.0
458
 
459
  # ======================================================================
460
+ # 11. PPO UPDATE (unchanged except uses compute_logprob correctly)
461
  # ======================================================================
462
  def ppo_update(
463
  trajectories: List[Trajectory],
 
470
  gamma: float = 0.99,
471
  lambda_: float = 0.95,
472
  ) -> Dict[str, float]:
 
 
 
473
  model.train()
474
 
 
475
  all_states = []
476
  all_actions = []
477
  all_old_logprobs = []
 
479
  all_returns = []
480
 
481
  for traj in trajectories:
 
482
  advantages, returns = compute_gae(
483
  traj.rewards, traj.dones, values=None, gamma=gamma, lambda_=lambda_
484
  )
 
485
  all_states.extend(traj.states)
486
  all_actions.extend(traj.actions)
487
  all_old_logprobs.extend(traj.logprobs)
 
489
  all_returns.extend(returns)
490
 
491
  n_samples = len(all_states)
 
492
  total_loss = 0.0
493
  total_policy_loss = 0.0
494
  total_entropy = 0.0
495
  n_updates = 0
496
 
 
497
  for epoch in range(n_epochs):
 
498
  indices = np.random.permutation(n_samples)
 
499
  for i in indices:
500
  state = all_states[i]
501
  action = all_actions[i]
502
  old_logprob = all_old_logprobs[i]
503
  advantage = all_advantages[i]
504
 
505
+ # Use the same chat template for PPO update
506
+ messages = [{"role": "user", "content": state}]
507
+ formatted = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
508
+ full_text = formatted + action
509
+ inputs = tokenizer(full_text, return_tensors="pt").to("cuda")
510
 
511
  outputs = model(**inputs)
512
  logits = outputs.logits
513
 
 
514
  action_ids = tokenizer.encode(action, add_special_tokens=False)
515
+ prefix_ids = tokenizer.encode(formatted, add_special_tokens=False)
516
+ action_start = len(prefix_ids)
517
 
 
518
  logprobs = []
519
  entropy = 0.0
520
  for idx, token_id in enumerate(action_ids):
521
  position = action_start + idx - 1
522
+ if 0 <= position < logits.shape[1]:
523
  token_logits = logits[0, position]
524
  log_probs = F.log_softmax(token_logits, dim=-1)
525
  token_logprob = log_probs[token_id]
526
  logprobs.append(token_logprob)
527
 
 
528
  probs = F.softmax(token_logits, dim=-1)
529
  entropy += -(probs * log_probs).sum()
530
 
 
534
  new_logprob = sum(logprobs)
535
  avg_entropy = entropy / len(logprobs) if logprobs else 0.0
536
 
 
537
  ratio = torch.exp(new_logprob - old_logprob)
538
  surr1 = ratio * advantage
539
  surr2 = torch.clamp(ratio, 1 - clip_epsilon, 1 + clip_epsilon) * advantage
540
  policy_loss = -torch.min(surr1, surr2)
 
 
541
  loss = policy_loss - entropy_coef * avg_entropy
542
 
 
543
  optimizer.zero_grad()
544
  loss.backward()
545
  torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
 
557
  }
558
 
559
  # ======================================================================
560
+ # 12. EVALUATION (unchanged)
561
  # ======================================================================
562
  def evaluate_policy(
563
  env: CodeReviewEnv,
 
566
  n_episodes: int = 10,
567
  max_steps: int = 10
568
  ) -> Dict[str, float]:
 
 
 
569
  model.eval()
 
570
  total_rewards = []
571
  episode_lengths = []
572
  success_count = 0
573
 
574
  for _ in range(n_episodes):
575
+ traj = collect_trajectory(env, model, tokenizer, max_steps, temperature=0.0)
576
  total_reward = sum(traj.rewards)
577
  total_rewards.append(total_reward)
578
  episode_lengths.append(len(traj))
 
 
579
  if total_reward > 0.5:
580
  success_count += 1
581
 
 
587
  }
588
 
589
  # ======================================================================
590
+ # 13. MAIN TRAINING LOOP (added sanity check and warm-up)
591
  # ======================================================================
592
  def train_ppo(
593
  n_iterations: int = 50,
 
601
  lambda_: float = 0.95,
602
  eval_every: int = 5,
603
  ):
 
 
 
604
  print("Loading model...")
605
  model, tokenizer = load_model()
606
 
607
+ # NEW: Sanity check before any training
608
+ if not test_model_sanity(model, tokenizer):
609
+ print("\n❌ Model sanity check failed – cannot proceed.")
610
+ return
611
+
612
+ # NEW: Supervised warm-up to teach JSON format
613
+ supervised_warmup(model, tokenizer, n_examples=500, epochs=2)
614
 
615
+ optimizer = AdamW(model.parameters(), lr=learning_rate)
616
  env = CodeReviewEnv()
617
 
618
  print(f"\n{'='*60}")
 
625
  for iteration in range(n_iterations):
626
  print(f"\n--- Iteration {iteration + 1}/{n_iterations} ---")
627
 
 
628
  print("Collecting trajectories...")
629
  trajectories = collect_trajectories(
630
  env, model, tokenizer, trajectories_per_iter, max_steps
631
  )
632
 
 
633
  avg_reward = np.mean([sum(t.rewards) for t in trajectories])
634
  avg_length = np.mean([len(t) for t in trajectories])
635
 
636
  print(f"Avg reward: {avg_reward:.3f}")
637
  print(f"Avg length: {avg_length:.1f}")
638
 
 
639
  print("Updating policy...")
640
  metrics = ppo_update(
641
  trajectories,
 
653
  print(f"Policy loss: {metrics['policy_loss']:.4f}")
654
  print(f"Entropy: {metrics['entropy']:.4f}")
655
 
 
656
  if (iteration + 1) % eval_every == 0:
657
  print("\nEvaluating policy...")
658
  eval_metrics = evaluate_policy(env, model, tokenizer, n_episodes=10)
 
660
  print(f"Eval success rate: {eval_metrics['success_rate']:.2%}")
661
  print(f"Eval avg length: {eval_metrics['avg_length']:.1f}")
662
 
 
663
  print("\n" + "="*60)
664
  print("Training complete. Saving model...")
665
  model.save_pretrained("ppo_final_model")
 
668
  print("="*60)
669
 
670
  # ======================================================================
671
+ # 14. ENTRY POINT (unchanged)
672
  # ======================================================================
673
  if __name__ == "__main__":
674
  train_ppo(