100XZX001 commited on
Commit
a3960cd
Β·
verified Β·
1 Parent(s): a1d7a9c

Update training.py

Browse files
Files changed (1) hide show
  1. training.py +708 -328
training.py CHANGED
@@ -1,35 +1,96 @@
1
- # training.py – Clean PPO + QLoRA + Supervised Warm‑up (evidence‑driven RL)
2
- import os
 
 
 
 
3
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
4
 
5
- import json
 
 
 
 
6
  import torch
7
  import torch.nn.functional as F
8
  from torch.optim import AdamW
9
- from dataclasses import dataclass
10
- from typing import List, Optional
 
11
  import numpy as np
12
- import random
13
- import matplotlib
14
- matplotlib.use("Agg")
15
- import matplotlib.pyplot as plt
16
- from collections import Counter
17
 
18
- from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
19
- from peft import LoraConfig, get_peft_model, TaskType
20
 
21
  from environment import CodeReviewEnv
22
  from redteam import BUG_DB
23
- from models import map_to_env as model_map_to_env
24
 
25
- # =========================================================
26
- # DEVICE
27
- # =========================================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
29
 
30
- # =========================================================
31
  # DATA STRUCTURES
32
- # =========================================================
33
  @dataclass
34
  class AgentAction:
35
  action_type: str
@@ -37,305 +98,486 @@ class AgentAction:
37
 
38
  @dataclass
39
  class Trajectory:
40
- states: List[str]
41
- actions: List[str]
42
- rewards: List[float]
43
  logprobs: List[float]
44
- dones: List[bool]
 
45
 
46
- # =========================================================
 
 
 
 
 
 
 
 
 
 
47
  # ACTION PARSER
48
- # =========================================================
49
- def parse_action(output: str) -> AgentAction:
 
 
50
  try:
51
- data = json.loads(output)
52
- return AgentAction(
53
- action_type=data.get("action_type", "").lower(),
54
- content=data.get("content")
55
- )
56
- except:
57
- return AgentAction("skip", None)
 
 
 
 
 
 
 
58
 
59
  def map_to_env(action: AgentAction):
60
  return model_map_to_env(action.action_type, action.content)
61
 
62
- # =========================================================
63
- # MODEL
64
- # =========================================================
65
  def load_model():
66
- model_name = "microsoft/Phi-3-mini-4k-instruct"
67
-
68
- bnb = BitsAndBytesConfig(
69
- load_in_4bit=True,
70
- bnb_4bit_compute_dtype=torch.bfloat16,
71
- bnb_4bit_quant_type="nf4"
72
  )
73
-
74
- model = AutoModelForCausalLM.from_pretrained(
75
- model_name,
76
- quantization_config=bnb,
77
- device_map="auto",
78
- torch_dtype=torch.bfloat16
 
79
  )
80
-
81
- tokenizer = AutoTokenizer.from_pretrained(model_name)
82
  tokenizer.pad_token = tokenizer.eos_token
83
-
84
- lora = LoraConfig(
85
- r=16,
86
- lora_alpha=32,
87
- target_modules=["q_proj","k_proj","v_proj","o_proj",
88
- "gate_proj","up_proj","down_proj"],
89
- lora_dropout=0.0,
90
- bias="none",
91
- task_type=TaskType.CAUSAL_LM
92
- )
93
-
94
- model = get_peft_model(model, lora)
95
- model.gradient_checkpointing_enable()
96
  return model, tokenizer
97
 
98
- # =========================================================
99
- # PROMPT BUILDER (full environment context)
100
- # =========================================================
101
  def build_prompt(obs, history_lines: List[str]) -> str:
102
- author_msg = getattr(obs, "author_response", "") or ""
103
- tool_output = getattr(obs, "last_tool_output", "") or ""
104
- author_personality = getattr(obs, "author_personality", "defensive")
105
-
106
- prompt = f"""You are an AI code review agent. Your goal is to convince a simulated human developer to accept your proposed fix and name your proposed fix function fix.
107
-
108
- The developer has a **{author_personality}** personality and will only accept if you provide solid evidence:
109
- - Tests pass (high pass ratio)
110
- - Lint is clean (zero errors)
111
- - Documentation or references are provided
112
- - Your reasoning is clear, uses words like "because" or "therefore", and is detailed (over 30 words if needed)
113
-
114
- Workflow:
115
- 1. Use `inspect` to understand the code.
116
- 2. Use `run_tests` and `run_linter` to gather evidence.
117
- 3. Use `query_docs` when you need references or language‑specific guidance.
118
- 4. Propose a fix (`fix`) and explain why it works (`comment` or `question`).
119
- 5. If the developer pushes back, read their response carefully and address their specific concern.
120
- 6. Once convinced, use `done` to finish.
121
-
122
- Code:
123
- {obs.code_snippet}
124
-
125
- Author says:
126
- {author_msg if author_msg else "(no response yet – start with inspection)"}
127
-
128
- Last tool output:
129
- {tool_output if tool_output else "(none)"}
130
-
131
- Available actions:
132
- run_tests, run_linter, inspect, query_docs, fix, comment, question, done
133
-
134
- Respond ONLY in JSON:
135
- {{"action_type": "...", "content": "..."}}"""
136
-
137
  if history_lines:
138
- history = "\n".join(history_lines[-6:])
139
- prompt += f"\n\nPrevious steps:\n{history}"
140
- return prompt
141
-
142
- # =========================================================
143
- # GENERATION
144
- # =========================================================
145
- def generate_action(prompt, model, tokenizer, temperature):
146
- messages = [{"role": "user", "content": prompt}]
147
- formatted = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
148
- inputs = tokenizer(formatted, return_tensors="pt", truncation=True).to(DEVICE)
149
-
150
- outputs = model.generate(
151
- **inputs,
152
- max_new_tokens=128,
153
- do_sample=temperature > 0,
154
- temperature=temperature if temperature > 0 else None,
155
- return_dict_in_generate=True,
156
- output_scores=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  )
158
 
159
- gen_ids = outputs.sequences[0][inputs["input_ids"].shape[1]:]
160
- text = tokenizer.decode(gen_ids, skip_special_tokens=True)
161
-
162
- logprobs = []
163
- for i, token_id in enumerate(gen_ids):
164
- if i < len(outputs.scores):
165
- logits = outputs.scores[i][0]
166
- lp = F.log_softmax(logits, dim=-1)[token_id]
167
- logprobs.append(lp)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
 
169
- if not logprobs:
170
- return '{"action_type":"skip"}', -100.0
171
 
172
- return text, torch.stack(logprobs).sum().item()
173
 
174
- # =========================================================
175
  # TRAJECTORY COLLECTION
176
- # =========================================================
177
- def collect_trajectory(env, model, tokenizer, max_steps, temperature):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
178
  obs = env.reset()
179
- history_lines = []
180
-
181
- states, actions, rewards, logprobs, dones = [], [], [], [], []
182
- metrics = {"test_score": [], "actions": []}
183
-
184
- for step in range(max_steps):
185
- prompt = build_prompt(obs, history_lines)
186
- states.append(prompt)
187
-
188
- action_text, lp = generate_action(prompt, model, tokenizer, temperature)
189
- actions.append(action_text)
190
- logprobs.append(lp)
191
-
192
- action = parse_action(action_text)
193
- env_action = map_to_env(action)
194
- next_obs, reward, done, _ = env.step(env_action)
195
-
196
- rewards.append(float(np.clip(reward.value, -1, 1)))
197
- dones.append(done)
198
-
199
- history_lines.append(f"Agent: {action_text}")
200
- history_lines.append(f"Env: {next_obs.last_tool_output}")
201
-
202
- metrics["test_score"].append(getattr(next_obs, "current_test_score", 0.0))
203
- metrics["actions"].append(action.action_type)
204
-
205
- obs = next_obs
206
- if done:
 
 
 
 
 
 
 
 
 
207
  break
208
 
209
- return Trajectory(states, actions, rewards, logprobs, dones), metrics
210
 
211
- # =========================================================
212
- # SUPERVISED WARM‑UP
213
- # =========================================================
214
- def supervised_warmup(model, tokenizer, data_path="training_data.json", epochs=3):
215
- print("\n=== SUPERVISED WARMUP ===")
 
 
216
 
217
- with open(data_path) as f:
218
  data = json.load(f)
219
 
220
- optimizer = AdamW(model.parameters(), lr=2e-5)
221
  model.train()
 
222
 
223
- for epoch in range(epochs):
224
  random.shuffle(data)
225
- total_loss = 0
 
 
 
 
 
 
 
 
 
 
 
 
 
226
 
227
- for ex in data:
228
- prompt = ex["prompt"]
229
- action = ex["action"]
 
 
 
 
 
230
 
231
- messages = [
232
- {"role": "user", "content": prompt},
233
- {"role": "assistant", "content": action},
234
- ]
235
- text = tokenizer.apply_chat_template(messages, tokenize=False)
236
- inputs = tokenizer(text, return_tensors="pt", truncation=True).to(DEVICE)
237
 
238
- outputs = model(**inputs, labels=inputs["input_ids"])
239
- loss = outputs.loss
240
 
241
- optimizer.zero_grad()
 
242
  loss.backward()
243
- optimizer.step()
244
-
245
- total_loss += loss.item()
246
 
247
- print(f"Epoch {epoch+1} Loss: {total_loss/len(data):.4f}")
248
-
249
- print("βœ“ Warmup done\n")
250
-
251
- # =========================================================
252
- # PPO UPDATE (FIXED advantage = return – baseline)
253
- # =========================================================
254
- def ppo_update(trajectories, model, tokenizer, optimizer, clip=0.2, gamma=0.99):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
255
  model.train()
 
256
 
257
- losses = []
258
- kls = []
 
 
 
 
 
259
 
260
- # =========================
261
- # Compute returns + baseline
262
- # =========================
263
  all_returns = []
264
-
265
  traj_returns = []
266
  for traj in trajectories:
267
- returns = []
268
- running = 0.0
269
-
270
- for r in reversed(traj.rewards):
271
- running = r + gamma * running
272
- returns.insert(0, running)
273
-
274
- returns = torch.tensor(returns, dtype=torch.float32, device=DEVICE)
275
- traj_returns.append(returns)
276
- all_returns.extend(returns.tolist())
277
-
278
- baseline = torch.tensor(np.mean(all_returns), device=DEVICE) if all_returns else torch.tensor(0.0, device=DEVICE)
279
-
280
- # =========================
281
- # PPO update
282
- # =========================
283
- for traj, returns in zip(trajectories, traj_returns):
284
-
 
 
 
 
 
 
 
 
 
 
 
 
285
  for i in range(len(traj.states)):
286
- state = traj.states[i]
287
  action = traj.actions[i]
 
 
288
 
289
- old_lp = torch.tensor(traj.logprobs[i], device=DEVICE)
290
-
291
- # Advantage (detached)
292
- adv = (returns[i] - baseline).detach()
293
-
294
- messages = [{"role": "user", "content": state}]
295
- formatted = tokenizer.apply_chat_template(
296
- messages, tokenize=False, add_generation_prompt=True
297
  )
 
298
 
299
- full = formatted + action
300
- inputs = tokenizer(full, return_tensors="pt", truncation=True).to(DEVICE)
301
-
302
- logits = model(**inputs).logits
303
-
304
- action_ids = tokenizer.encode(action, add_special_tokens=False)
305
- prefix_len = len(tokenizer.encode(formatted, add_special_tokens=False))
306
 
307
- logps = []
308
- entropy = 0.0
 
 
 
 
309
 
310
- for idx in range(len(action_ids)):
311
- pos = prefix_len + idx
312
- if pos == 0 or pos >= logits.shape[1]:
313
- continue
314
 
315
- token_logits = logits[0, pos - 1]
316
- log_probs = F.log_softmax(token_logits, dim=-1)
317
-
318
- lp = log_probs[action_ids[idx]]
319
- logps.append(lp)
320
-
321
- probs = torch.exp(log_probs)
322
- entropy += (-(probs * log_probs).sum()).detach()
323
-
324
- if not logps:
325
  continue
326
 
327
- new_lp = torch.stack(logps).sum()
328
-
329
- # PPO ratio
330
- ratio = torch.exp(new_lp - old_lp)
 
 
331
 
332
- s1 = ratio * adv
333
- s2 = torch.clamp(ratio, 1 - clip, 1 + clip) * adv
 
 
 
334
 
335
  policy_loss = -torch.min(s1, s2)
336
- loss = policy_loss - 0.01 * (entropy / len(logps))
337
 
338
- if torch.isnan(loss):
339
  continue
340
 
341
  optimizer.zero_grad()
@@ -343,89 +585,227 @@ def ppo_update(trajectories, model, tokenizer, optimizer, clip=0.2, gamma=0.99):
343
  torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
344
  optimizer.step()
345
 
346
- kl = (old_lp - new_lp).detach().cpu().item()
347
- kls.append(kl)
348
  losses.append(loss.item())
349
-
350
- return (
351
- float(np.mean(losses)) if losses else 0.0,
352
- float(np.mean(kls)) if kls else 0.0,
 
 
 
 
353
  )
354
- # =========================================================
355
- # MAIN TRAINING LOOP
356
- # =========================================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
357
  def train():
358
  model, tokenizer = load_model()
359
  env = CodeReviewEnv()
360
 
361
- # ---------- Supervised warm‑up ----------
362
- supervised_warmup(model, tokenizer, data_path="training_data.json", epochs=3)
 
 
 
363
 
364
- optimizer = AdamW(model.parameters(), lr=3e-5)
365
- reward_hist, success_hist, kl_hist = [], [], []
366
- task_levels = list(BUG_DB.keys())
367
 
368
- # Baseline evaluation (after warm‑up, before PPO)
369
- baseline_rewards = []
370
- for _ in range(5):
371
- env.set_task(random.choice(task_levels))
372
- traj, _ = collect_trajectory(env, model, tokenizer, 6, 0.0)
373
- baseline_rewards.append(sum(traj.rewards))
374
- baseline_reward = np.mean(baseline_rewards)
375
- print(f"Baseline reward: {baseline_reward:+.4f}")
376
 
377
- # PPO iterations
378
- for it in range(15):
379
- print(f"\nIteration {it+1}")
380
- temperature = max(0.7 * (1 - it/15), 0.1)
381
 
382
- trajectories = []
383
- successes = 0
384
- action_counter = Counter()
385
-
386
- for _ in range(6):
387
- env.set_task(random.choice(task_levels))
388
- traj, metrics = collect_trajectory(env, model, tokenizer, 6, temperature)
389
- trajectories.append(traj)
390
 
391
- for a in metrics["actions"]:
392
- action_counter[a] += 1
393
- if sum(traj.rewards) > 0:
394
- successes += 1
 
 
395
 
396
- avg_reward = np.mean([sum(t.rewards) for t in trajectories])
397
- success_rate = successes / len(trajectories)
398
-
399
- loss, kl = ppo_update(trajectories, model, tokenizer, optimizer)
400
-
401
- reward_hist.append(avg_reward)
402
- success_hist.append(success_rate)
403
- kl_hist.append(kl)
404
-
405
- print(f"Reward: {avg_reward:+.4f} Success: {success_rate:.2%} KL: {kl:.4f}")
406
- print(f"Actions: {dict(action_counter)}")
407
-
408
- # ===================== Plots =====================
409
- iters = list(range(1, len(reward_hist)+1))
410
-
411
- plt.figure()
412
- plt.plot(iters, reward_hist)
413
- plt.axhline(y=baseline_reward, linestyle="--", color="gray")
414
- plt.title("PPO Reward Curve")
415
- plt.savefig("reward_curve.png")
416
-
417
- plt.figure()
418
- plt.plot(iters, success_hist)
419
- plt.title("Success Rate")
420
- plt.savefig("success_rate.png")
421
 
422
- plt.figure()
423
- plt.plot(iters, kl_hist)
424
- plt.title("KL Divergence")
425
- plt.savefig("kl_divergence.png")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
426
 
427
- print(f"\nTraining complete. Plots saved.")
428
- print(f"Final reward: {np.mean(reward_hist[-3:]):+.4f}")
429
 
430
  if __name__ == "__main__":
431
  train()
 
1
+ # training.py – PPO + QLoRA + Supervised Warm-up
2
+ # Model : Qwen/Qwen2.5-1.5B-Instruct (via Unsloth – 2Γ— faster, fits Colab T4)
3
+ # Fixed : label-masking, BPE-boundary alignment, log-ratio clamping, OOM guards
4
+ # Evidence: reward curves, before/after traces, per-difficulty breakdown, KL, entropy
5
+ # ============================================================
6
+ import os, json, random, re
7
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
8
 
9
+ import matplotlib
10
+ matplotlib.use("Agg")
11
+ import matplotlib.pyplot as plt
12
+ import matplotlib.gridspec as gridspec
13
+
14
  import torch
15
  import torch.nn.functional as F
16
  from torch.optim import AdamW
17
+ from dataclasses import dataclass, field
18
+ from typing import List, Optional, Dict
19
+ from collections import Counter, defaultdict
20
  import numpy as np
 
 
 
 
 
21
 
22
+ # ── Unsloth gives 2Γ— throughput with identical outputs ────────────────────────
23
+ from unsloth import FastLanguageModel
24
 
25
  from environment import CodeReviewEnv
26
  from redteam import BUG_DB
 
27
 
28
+ # Graceful import: use project map_to_env if available, else inline fallback.
29
+ try:
30
+ from models import map_to_env as model_map_to_env
31
+ _HAVE_MODEL_MAP = True
32
+ except (ImportError, AttributeError):
33
+ _HAVE_MODEL_MAP = False
34
+
35
+ if not _HAVE_MODEL_MAP:
36
+ try:
37
+ from models import (RunTests, RunLinter, Inspect, ProposeFix,
38
+ WriteComment, AskQuestion, Done, Skip, QueryDocs)
39
+ def model_map_to_env(action_type: str, content=None):
40
+ return {
41
+ "run_tests": RunTests(),
42
+ "run_linter": RunLinter(),
43
+ "inspect": Inspect(),
44
+ "query_docs": QueryDocs(content or "python bug fix"),
45
+ "fix": ProposeFix(content or ""),
46
+ "comment": WriteComment(content or ""),
47
+ "question": AskQuestion(content or ""),
48
+ "done": Done(),
49
+ }.get(action_type, Skip())
50
+ except ImportError:
51
+ # Last resort: duck-typed object the env can introspect.
52
+ class _EnvAction:
53
+ def __init__(self, **kw): self.__dict__.update(kw)
54
+ def model_map_to_env(action_type: str, content=None):
55
+ return _EnvAction(action_type=action_type, content=content)
56
+
57
+ # ══════════════════════════════════════════════════════════════════════════════
58
+ # CONFIG
59
+ # ══════════════════════════════════════════════════════════════════════════════
60
+ CFG = dict(
61
+ model_name = "unsloth/Qwen2.5-1.5B-Instruct-bnb-4bit",
62
+ max_seq_len = 512, # hard cap; prevents OOM on T4
63
+ lora_r = 16,
64
+ lora_alpha = 32,
65
+
66
+ # Warm-up
67
+ warmup_data = "training_data.json",
68
+ warmup_epochs = 2,
69
+ warmup_lr = 2e-5,
70
+ warmup_grad_acc = 4, # effective batch = 4 examples
71
+
72
+ # PPO
73
+ ppo_iters = 15,
74
+ trajs_per_iter = 6,
75
+ max_steps = 7,
76
+ ppo_lr = 3e-5,
77
+ clip_eps = 0.2,
78
+ entropy_coef = 0.01,
79
+ gamma = 0.99,
80
+ log_ratio_clamp = 5.0, # ← prevents exp-explosion / NaN loss
81
+ temp_start = 0.8,
82
+ temp_end = 0.1,
83
+
84
+ # Eval
85
+ eval_episodes = 10, # episodes per evaluation snapshot
86
+ )
87
+
88
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
89
+ TASK_LEVELS = list(BUG_DB.keys()) # [easy, medium, hard, harder, hardest]
90
 
91
+ # ══════════════════════════════════════════════════════════════════════════════
92
  # DATA STRUCTURES
93
+ # ══════════════════════════════════════════════════════════════════════════════
94
  @dataclass
95
  class AgentAction:
96
  action_type: str
 
98
 
99
  @dataclass
100
  class Trajectory:
101
+ states: List[str]
102
+ actions: List[str]
103
+ rewards: List[float]
104
  logprobs: List[float]
105
+ dones: List[bool]
106
+ task: str = ""
107
 
108
+ @dataclass
109
+ class EvalSnapshot:
110
+ """Captures full agent behaviour for before/after comparison."""
111
+ avg_reward: float
112
+ per_task: Dict[str, float] = field(default_factory=dict)
113
+ action_dist: Dict[str, float] = field(default_factory=dict)
114
+ success_rate: float = 0.0
115
+ avg_steps: float = 0.0
116
+ traces: List[dict] = field(default_factory=list)
117
+
118
+ # ══════════════════════════════════════════════════════════════════════════════
119
  # ACTION PARSER
120
+ # ══════════════════════════════════════════════════════════════════════════════
121
+ def parse_action(text: str) -> AgentAction:
122
+ """Robust parser: tries strict JSON, then regex, then keyword heuristic."""
123
+ text = text.strip()
124
  try:
125
+ d = json.loads(text)
126
+ return AgentAction(d.get("action_type","skip").lower(), d.get("content"))
127
+ except json.JSONDecodeError:
128
+ pass
129
+ m = re.search(r'"action_type"\s*:\s*"(\w+)"', text)
130
+ if m:
131
+ cm = re.search(r'"content"\s*:\s*"(.*?)"', text, re.DOTALL)
132
+ return AgentAction(m.group(1).lower(), cm.group(1) if cm else None)
133
+ tl = text.lower()
134
+ for kw in ("run_tests","run_linter","inspect","query_docs","fix",
135
+ "comment","question","done"):
136
+ if kw in tl:
137
+ return AgentAction(kw)
138
+ return AgentAction("skip")
139
 
140
  def map_to_env(action: AgentAction):
141
  return model_map_to_env(action.action_type, action.content)
142
 
143
+ # ══════════════════════════════════════════════════════════════════════════════
144
+ # MODEL (Qwen2.5-1.5B via Unsloth)
145
+ # ══════════════════════════════════════════════════════════════════════════════
146
  def load_model():
147
+ print(f"Loading {CFG['model_name']} …")
148
+ model, tokenizer = FastLanguageModel.from_pretrained(
149
+ model_name = CFG["model_name"],
150
+ max_seq_length = CFG["max_seq_len"],
151
+ load_in_4bit = True,
 
152
  )
153
+ model = FastLanguageModel.get_peft_model(
154
+ model,
155
+ r = CFG["lora_r"],
156
+ lora_alpha = CFG["lora_alpha"],
157
+ target_modules = ["q_proj","k_proj","v_proj","o_proj",
158
+ "gate_proj","up_proj","down_proj"],
159
+ lora_dropout = 0.0,
160
  )
 
 
161
  tokenizer.pad_token = tokenizer.eos_token
162
+ print(f" trainable params: "
163
+ f"{sum(p.numel() for p in model.parameters() if p.requires_grad)/1e6:.1f}M")
 
 
 
 
 
 
 
 
 
 
 
164
  return model, tokenizer
165
 
166
+ # ══════════════════════════════════════════════════════════════════════════════
167
+ # PROMPT BUILDER
168
+ # ═════════════���════════════════════════════════════════════════════════════════
169
  def build_prompt(obs, history_lines: List[str]) -> str:
170
+ author_msg = getattr(obs, "author_response", "") or ""
171
+ tool_output = getattr(obs, "last_tool_output", "") or ""
172
+ personality = getattr(obs, "author_personality","defensive")
173
+
174
+ # Trim tool output to avoid context explosion
175
+ if len(tool_output) > 600:
176
+ tool_output = tool_output[:600] + " …[truncated]"
177
+
178
+ p = (
179
+ f"You are an AI code review agent. Convince the developer (personality: "
180
+ f"**{personality}**) to accept your fix. Name your fix function `fix`.\n\n"
181
+ "Evidence required: tests pass, lint clean, docs cited, reasoning uses "
182
+ "'because'/'therefore' (>30 words).\n\n"
183
+ "Workflow: inspect β†’ run_tests β†’ run_linter β†’ query_docs β†’ fix β†’ "
184
+ "comment/question β†’ done.\n\n"
185
+ f"Code:\n{obs.code_snippet}\n\n"
186
+ f"Author: {author_msg or '(no response yet – start with inspect)'}\n\n"
187
+ f"Last tool: {tool_output or '(none)'}\n\n"
188
+ "Actions: run_tests, run_linter, inspect, query_docs, fix, comment, question, done\n\n"
189
+ 'Respond ONLY in JSON: {"action_type": "...", "content": "..."}'
190
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
191
  if history_lines:
192
+ p += "\n\nRecent steps:\n" + "\n".join(history_lines[-4:])
193
+ return p
194
+
195
+ # ══════════════════════════════════════════════════════════════════════════════
196
+ # BUG FIX 1 – label masking in supervised warmup
197
+ # (original: labels=inputs["input_ids"] trains on ALL tokens, including prompt)
198
+ # ══════════════════════════════════════════════════════════════════════════════
199
+ def _masked_labels(input_ids: torch.Tensor, prompt_len: int) -> torch.Tensor:
200
+ """Return labels with prompt positions set to -100 (ignored by CE loss)."""
201
+ labels = input_ids.clone()
202
+ labels[0, :prompt_len] = -100
203
+ return labels
204
+
205
+ # ══════════════════════════════════════════════════════════════════════════════
206
+ # BUG FIX 2 – BPE-boundary-safe logprob computation
207
+ # (original: tokenize(prompt) + tokenize(action) β‰  tokenize(prompt+action))
208
+ # ══════════════════════════════════════���═══════════════════════════════════════
209
+ def _compute_action_logprob(
210
+ logits: torch.Tensor, # [1, seq_len, vocab]
211
+ input_ids: torch.Tensor, # [1, seq_len]
212
+ prompt_len: int, # #tokens in the prompt part of the joint sequence
213
+ ) -> tuple:
214
+ """
215
+ Compute sum of log-probs for *action* tokens only, using the jointly
216
+ tokenised sequence so BPE boundaries are respected.
217
+
218
+ Returns (total_logprob, avg_entropy, n_tokens).
219
+ """
220
+ action_len = input_ids.shape[1] - prompt_len
221
+ if action_len <= 0:
222
+ return torch.tensor(0.0, device=DEVICE), torch.tensor(0.0, device=DEVICE), 0
223
+
224
+ total_lp = torch.tensor(0.0, device=DEVICE)
225
+ total_ent = torch.tensor(0.0, device=DEVICE)
226
+
227
+ for k in range(action_len):
228
+ pos = prompt_len + k # position of the k-th action token
229
+ pred_pos = pos - 1 # logit at pred_pos predicts token at pos
230
+ if pred_pos < 0 or pred_pos >= logits.shape[1]:
231
+ continue
232
+ token_id = input_ids[0, pos]
233
+ lp_dist = F.log_softmax(logits[0, pred_pos], dim=-1)
234
+ total_lp = total_lp + lp_dist[token_id]
235
+ probs = torch.exp(lp_dist)
236
+ total_ent = total_ent + (-(probs * lp_dist).sum()).detach()
237
+
238
+ n = action_len
239
+ return total_lp, total_ent / max(n, 1), n
240
+
241
+ # ══════════════════════════════════════════════════════════════════════════════
242
+ # GENERATION (returns text + joint-sequence logprob)
243
+ # ══════════════════════════════════════════════════════════════════════════════
244
+ @torch.no_grad()
245
+ def generate_action(prompt: str, model, tokenizer,
246
+ temperature: float) -> tuple:
247
+ messages = [{"role": "user", "content": prompt}]
248
+ formatted = tokenizer.apply_chat_template(
249
+ messages, tokenize=False, add_generation_prompt=True
250
  )
251
 
252
+ inputs = tokenizer(
253
+ formatted, return_tensors="pt",
254
+ max_length=CFG["max_seq_len"] - 128, # leave room for response
255
+ truncation=True
256
+ ).to(DEVICE)
257
+ prompt_len = inputs["input_ids"].shape[1]
258
+
259
+ gen_kwargs = dict(
260
+ max_new_tokens = 128,
261
+ do_sample = temperature > 0,
262
+ return_dict_in_generate = True,
263
+ output_scores = True,
264
+ pad_token_id = tokenizer.eos_token_id,
265
+ eos_token_id = tokenizer.eos_token_id,
266
+ )
267
+ if temperature > 0:
268
+ gen_kwargs["temperature"] = temperature
269
+
270
+ out = model.generate(**inputs, **gen_kwargs)
271
+ gen_ids = out.sequences[0][prompt_len:]
272
+ text = tokenizer.decode(gen_ids, skip_special_tokens=True).strip()
273
+
274
+ if not text:
275
+ fallback = random.choice([
276
+ '{"action_type":"inspect"}',
277
+ '{"action_type":"run_tests"}',
278
+ '{"action_type":"run_linter"}',
279
+ ])
280
+ print(f" [WARN] empty generation β†’ fallback {fallback}")
281
+ # BUG FIX 3: don't use -100 sentinel; use a mildly negative logprob
282
+ # so that PPO ratio = exp(new - old) stays finite when re-evaluated
283
+ return fallback, -10.0
284
+
285
+ # Recompute logprob from the full joint sequence (BPE-safe)
286
+ joint_ids = torch.cat(
287
+ [inputs["input_ids"], gen_ids.unsqueeze(0).to(DEVICE)], dim=1
288
+ )
289
+ joint_ids = joint_ids[:, :CFG["max_seq_len"]]
290
 
291
+ logits = model(input_ids=joint_ids).logits
292
+ lp, _, _ = _compute_action_logprob(logits, joint_ids, prompt_len)
293
 
294
+ return text, lp.item()
295
 
296
+ # ══════════════════════════════════════════════════════════════════════════════
297
  # TRAJECTORY COLLECTION
298
+ # ══════════════════════════════════════════════════════════════════════════════
299
+ # Per-action shaped rewards. These create reward variance so that
300
+ # trajectories with meaningful tool use beat inspect-only episodes.
301
+ _STEP_REWARD = {
302
+ "run_tests": +0.08,
303
+ "run_linter": +0.05,
304
+ "fix": +0.15,
305
+ "comment": +0.08,
306
+ "query_docs": +0.05,
307
+ "question": +0.04,
308
+ "inspect": 0.00, # neutral – observe before acting
309
+ "done": 0.00, # env handles the terminal reward
310
+ "skip": -0.10, # penalise doing nothing
311
+ }
312
+
313
+ def collect_trajectory(env, model, tokenizer,
314
+ max_steps: int, temperature: float,
315
+ task: str) -> tuple:
316
+ """
317
+ FIX 4 – Override env done/reward for non-terminal actions.
318
+
319
+ Root cause of the degenerate policy:
320
+ β€’ env.step(Inspect()) returns done=True, reward=+0.002
321
+ β€’ agent discovers inspect β†’ tiny reward β†’ done is the easiest path
322
+ β€’ every trajectory is identical β†’ zero advantage β†’ PPO does nothing
323
+
324
+ Fix: only accept env's done+reward when the agent explicitly emits
325
+ {"action_type": "done"}. For every other action, use a shaped step
326
+ reward and force the episode to continue.
327
+ """
328
+ env.set_task(task)
329
  obs = env.reset()
330
+ history: List[str] = []
331
+ traj = Trajectory([], [], [], [], [], task=task)
332
+ action_seq = []
333
+
334
+ for step_num in range(max_steps):
335
+ prompt = build_prompt(obs, history)
336
+ traj.states.append(prompt)
337
+
338
+ text, lp = generate_action(prompt, model, tokenizer, temperature)
339
+ traj.actions.append(text)
340
+ traj.logprobs.append(lp)
341
+
342
+ action = parse_action(text)
343
+ action_seq.append(action.action_type)
344
+
345
+ obs, reward, env_done, _ = env.step(map_to_env(action))
346
+ raw_r = float(reward.value)
347
+
348
+ if action.action_type == "done":
349
+ # Agent explicitly chose to terminate β†’ honour env reward
350
+ shaped_r = raw_r
351
+ effective_done = True
352
+ else:
353
+ # Intermediate step: use shaped reward, ignore env's done signal.
354
+ # Also keep a fraction of any large env reward (e.g. test pass).
355
+ shaped_r = _STEP_REWARD.get(action.action_type, 0.0)
356
+ if raw_r > 0.1: # env signalling meaningful progress
357
+ shaped_r += raw_r * 0.3
358
+ effective_done = False # ← key: don't let env short-circuit
359
+
360
+ traj.rewards.append(float(np.clip(shaped_r, -1.0, 1.0)))
361
+ traj.dones.append(effective_done)
362
+
363
+ history.append(f"Agent: {text[:120]}")
364
+ history.append(f"Env: {(obs.last_tool_output or '')[:120]}")
365
+
366
+ if effective_done:
367
  break
368
 
369
+ return traj, action_seq
370
 
371
+ # ══════════════════════════════════════════════════════════════════════════════
372
+ # SUPERVISED WARM-UP (BUG FIX 1: action-only label masking)
373
+ # ══════════════════════════════════════════════════════════════════════════════
374
+ def supervised_warmup(model, tokenizer):
375
+ print("\n" + "="*60)
376
+ print("SUPERVISED WARM-UP")
377
+ print("="*60)
378
 
379
+ with open(CFG["warmup_data"], encoding="utf-8") as f:
380
  data = json.load(f)
381
 
382
+ opt = AdamW(model.parameters(), lr=CFG["warmup_lr"])
383
  model.train()
384
+ loss_history = []
385
 
386
+ for epoch in range(CFG["warmup_epochs"]):
387
  random.shuffle(data)
388
+ epoch_loss, n_valid = 0.0, 0
389
+ opt.zero_grad()
390
+
391
+ for step, ex in enumerate(data):
392
+ # ── Tokenise prompt and full sequence jointly ────────────────
393
+ prompt_chat = tokenizer.apply_chat_template(
394
+ [{"role": "user", "content": ex["prompt"]}],
395
+ tokenize=False, add_generation_prompt=True
396
+ )
397
+ full_chat = tokenizer.apply_chat_template(
398
+ [{"role": "user", "content": ex["prompt"]},
399
+ {"role": "assistant", "content": ex["action"]}],
400
+ tokenize=False
401
+ )
402
 
403
+ prompt_ids = tokenizer(
404
+ prompt_chat, return_tensors="pt",
405
+ max_length=CFG["max_seq_len"], truncation=True
406
+ )["input_ids"]
407
+ full_inputs = tokenizer(
408
+ full_chat, return_tensors="pt",
409
+ max_length=CFG["max_seq_len"], truncation=True
410
+ ).to(DEVICE)
411
 
412
+ prompt_len = prompt_ids.shape[1]
413
+ if prompt_len >= full_inputs["input_ids"].shape[1]:
414
+ continue # action got truncated away
 
 
 
415
 
416
+ # BUG FIX 1 ── mask prompt tokens so loss is action-only
417
+ labels = _masked_labels(full_inputs["input_ids"], prompt_len)
418
 
419
+ out = model(**full_inputs, labels=labels)
420
+ loss = out.loss / CFG["warmup_grad_acc"]
421
  loss.backward()
 
 
 
422
 
423
+ if (step + 1) % CFG["warmup_grad_acc"] == 0:
424
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
425
+ opt.step()
426
+ opt.zero_grad()
427
+
428
+ epoch_loss += loss.item() * CFG["warmup_grad_acc"]
429
+ n_valid += 1
430
+
431
+ if (step + 1) % 50 == 0:
432
+ print(f" epoch {epoch+1} step {step+1}/{len(data)}"
433
+ f" loss={epoch_loss/n_valid:.4f}")
434
+
435
+ avg = epoch_loss / max(n_valid, 1)
436
+ loss_history.append(avg)
437
+ print(f" Epoch {epoch+1} complete: avg_loss={avg:.4f}")
438
+
439
+ torch.cuda.empty_cache()
440
+ print(f"βœ“ Warm-up done. Loss: {' β†’ '.join(f'{l:.4f}' for l in loss_history)}\n")
441
+ return loss_history
442
+
443
+ # ══════════════════════════════════════════════════════════════════════════════
444
+ # EVALUATION (produces rich EvalSnapshot for comparison plots)
445
+ # ══════════════════════════════════════════════════════════════════════════════
446
+ @torch.no_grad()
447
+ def evaluate(env, model, tokenizer, label: str = "") -> EvalSnapshot:
448
+ model.eval()
449
+ per_task: Dict[str, List[float]] = defaultdict(list)
450
+ action_counter: Counter = Counter()
451
+ all_steps, all_success = [], []
452
+ traces = []
453
+
454
+ for ep in range(CFG["eval_episodes"]):
455
+ task = TASK_LEVELS[ep % len(TASK_LEVELS)]
456
+ traj, actions = collect_trajectory(
457
+ env, model, tokenizer, CFG["max_steps"], 0.0, task
458
+ )
459
+ ep_r = sum(traj.rewards)
460
+ per_task[task].append(ep_r)
461
+ action_counter.update(actions)
462
+ all_steps.append(len(traj.actions))
463
+ # FIX 6 – meaningful success = agent explicitly called "done".
464
+ # ep_r > 0 is misleading: even a single inspect returns +0.002.
465
+ all_success.append(1 if "done" in actions else 0)
466
+ traces.append({"task": task, "reward": round(ep_r, 4),
467
+ "steps": len(traj.actions), "actions": actions})
468
+
469
+ total_actions = max(sum(action_counter.values()), 1)
470
+ snap = EvalSnapshot(
471
+ avg_reward = float(np.mean([r for rs in per_task.values() for r in rs])),
472
+ per_task = {t: float(np.mean(rs)) for t, rs in per_task.items()},
473
+ action_dist = {a: c/total_actions for a, c in action_counter.most_common()},
474
+ success_rate = float(np.mean(all_success)),
475
+ avg_steps = float(np.mean(all_steps)),
476
+ traces = traces,
477
+ )
478
+ if label:
479
+ print(f"\n── {label} ──")
480
+ print(f" avg_reward={snap.avg_reward:+.4f} "
481
+ f"success={snap.success_rate:.0%} steps={snap.avg_steps:.1f}")
482
+ print(f" per-task: " +
483
+ " ".join(f"{t}={v:+.3f}" for t,v in snap.per_task.items()))
484
+ print(f" top actions: " +
485
+ " ".join(f"{a}={p:.0%}" for a,p in list(snap.action_dist.items())[:5]))
486
  model.train()
487
+ return snap
488
 
489
+ # ══════════════════════════════════════════════════════════════════════════════
490
+ # PPO UPDATE (BUG FIX 2 + 3: BPE-safe logprob + log-ratio clamping)
491
+ # ══════════════════════════════════════════════════════════════════════════════
492
+ def ppo_update(trajectories: List[Trajectory],
493
+ model, tokenizer, optimizer) -> dict:
494
+ model.train()
495
+ losses, kls, entropies = [], [], []
496
 
497
+ # ── Compute discounted returns and a global mean baseline ────────────────
 
 
498
  all_returns = []
 
499
  traj_returns = []
500
  for traj in trajectories:
501
+ ret, running = [], 0.0
502
+ for r, done in zip(reversed(traj.rewards), reversed(traj.dones)):
503
+ running = r + CFG["gamma"] * (0.0 if done else running)
504
+ ret.insert(0, running)
505
+ traj_returns.append(ret)
506
+ all_returns.extend(ret)
507
+
508
+ # FIX 5 – Normalise advantages to zero mean / unit std.
509
+ # When all returns are identical (e.g. every episode returns 0.002),
510
+ # baseline = mean = every return, so adv = 0 for all steps, the
511
+ # policy loss is 0, and PPO never updates. Normalising creates real
512
+ # signal: better-than-average trajectories get positive advantage,
513
+ # worse-than-average get negative, even if the absolute spread is tiny.
514
+ ret_arr = np.array(all_returns) if all_returns else np.array([0.0])
515
+ ret_mean = float(ret_arr.mean())
516
+ ret_std = float(ret_arr.std())
517
+
518
+ if ret_std < 1e-6:
519
+ # Truly zero variance – nothing to learn this iteration.
520
+ print(" [PPO] Zero return variance – skipping gradient update.")
521
+ return dict(loss=0.0, kl=0.0, entropy=0.0)
522
+
523
+ # Build a lookup so we can retrieve the normalised advantage by
524
+ # (trajectory index, step index) during the update loop below.
525
+ norm_returns: List[List[float]] = [
526
+ [(r - ret_mean) / (ret_std + 1e-8) for r in ret_list]
527
+ for ret_list in traj_returns
528
+ ]
529
+
530
+ for traj_idx, (traj, returns) in enumerate(zip(trajectories, traj_returns)):
531
  for i in range(len(traj.states)):
532
+ state = traj.states[i]
533
  action = traj.actions[i]
534
+ old_lp = traj.logprobs[i]
535
+ adv = norm_returns[traj_idx][i] # ← normalised advantage
536
 
537
+ # ── Tokenise jointly (BPE FIX 2) ────────────────────────────────
538
+ prompt_chat = tokenizer.apply_chat_template(
539
+ [{"role": "user", "content": state}],
540
+ tokenize=False, add_generation_prompt=True
 
 
 
 
541
  )
542
+ full_text = prompt_chat + action
543
 
544
+ full_ids = tokenizer(
545
+ full_text, return_tensors="pt",
546
+ max_length=CFG["max_seq_len"], truncation=True
547
+ ).to(DEVICE)
 
 
 
548
 
549
+ # Count prompt tokens IN THE JOINT SEQUENCE (not separately)
550
+ prompt_ids = tokenizer(
551
+ prompt_chat, return_tensors="pt",
552
+ max_length=CFG["max_seq_len"] - 10, truncation=True
553
+ )["input_ids"]
554
+ prompt_len = min(prompt_ids.shape[1], full_ids["input_ids"].shape[1] - 1)
555
 
556
+ logits = model(**full_ids).logits
 
 
 
557
 
558
+ new_lp, avg_ent, n_tokens = _compute_action_logprob(
559
+ logits, full_ids["input_ids"], prompt_len
560
+ )
561
+ if n_tokens == 0:
 
 
 
 
 
 
562
  continue
563
 
564
+ # BUG FIX 3 ── clamp log-ratio before exp to prevent NaN
565
+ old_lp_t = torch.tensor(old_lp, dtype=torch.float32, device=DEVICE)
566
+ log_ratio = torch.clamp(new_lp - old_lp_t,
567
+ -CFG["log_ratio_clamp"],
568
+ CFG["log_ratio_clamp"])
569
+ ratio = torch.exp(log_ratio)
570
 
571
+ adv_t = torch.tensor(adv, dtype=torch.float32, device=DEVICE)
572
+ s1 = ratio * adv_t
573
+ s2 = torch.clamp(ratio,
574
+ 1.0 - CFG["clip_eps"],
575
+ 1.0 + CFG["clip_eps"]) * adv_t
576
 
577
  policy_loss = -torch.min(s1, s2)
578
+ loss = policy_loss - CFG["entropy_coef"] * avg_ent
579
 
580
+ if torch.isnan(loss) or torch.isinf(loss):
581
  continue
582
 
583
  optimizer.zero_grad()
 
585
  torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
586
  optimizer.step()
587
 
 
 
588
  losses.append(loss.item())
589
+ kls.append((old_lp_t - new_lp).detach().cpu().item())
590
+ entropies.append(avg_ent.item())
591
+
592
+ torch.cuda.empty_cache()
593
+ return dict(
594
+ loss = float(np.mean(losses)) if losses else 0.0,
595
+ kl = float(np.mean(kls)) if kls else 0.0,
596
+ entropy = float(np.mean(entropies)) if entropies else 0.0,
597
  )
598
+
599
+ # ══════════════════════════════════════════════════════════════════════════════
600
+ # PLOTTING (rich evidence panel)
601
+ # ══════════════════════════════════════════════════════════════════════════════
602
+ def plot_all(warmup_losses, reward_hist, success_hist, kl_hist, entropy_hist,
603
+ baseline_snap: EvalSnapshot,
604
+ postwarmup_snap: EvalSnapshot,
605
+ final_snap: EvalSnapshot):
606
+
607
+ iters = list(range(1, len(reward_hist) + 1))
608
+
609
+ # ── Figure 1: training curves (2Γ—3 grid) ─────────────────────────────────
610
+ fig = plt.figure(figsize=(18, 10))
611
+ gs = gridspec.GridSpec(2, 3, figure=fig, hspace=0.45, wspace=0.35)
612
+
613
+ # (0,0) Warm-up loss
614
+ ax = fig.add_subplot(gs[0, 0])
615
+ ax.plot(range(1, len(warmup_losses)+1), warmup_losses,
616
+ marker="o", color="mediumpurple", linewidth=2)
617
+ ax.set_title("A. Warm-up CE Loss ↓", fontweight="bold")
618
+ ax.set_xlabel("Epoch"); ax.set_ylabel("Loss"); ax.grid(alpha=0.3)
619
+
620
+ # (0,1) PPO reward
621
+ ax = fig.add_subplot(gs[0, 1])
622
+ smooth = np.convolve(reward_hist, np.ones(3)/3, mode="same")
623
+ ax.plot(iters, reward_hist, alpha=0.35, color="steelblue", linewidth=1)
624
+ ax.plot(iters, smooth, color="steelblue", linewidth=2.5, label="reward (smoothed)")
625
+ ax.axhline(baseline_snap.avg_reward, color="gray", linestyle=":",
626
+ label=f"pre-warmup ({baseline_snap.avg_reward:+.3f})")
627
+ ax.axhline(postwarmup_snap.avg_reward, color="mediumpurple", linestyle="--",
628
+ label=f"post-warmup ({postwarmup_snap.avg_reward:+.3f})")
629
+ ax.axhline(final_snap.avg_reward, color="forestgreen", linestyle="-.",
630
+ label=f"final ({final_snap.avg_reward:+.3f})")
631
+ ax.set_title("B. PPO Reward ↑", fontweight="bold")
632
+ ax.set_xlabel("Iteration"); ax.set_ylabel("Avg Reward")
633
+ ax.legend(fontsize=7); ax.grid(alpha=0.3)
634
+
635
+ # (0,2) Success rate
636
+ ax = fig.add_subplot(gs[0, 2])
637
+ ax.plot(iters, success_hist, marker="s", color="seagreen", linewidth=2)
638
+ ax.set_ylim(0, 1)
639
+ ax.set_title("C. Episode Success Rate ↑", fontweight="bold")
640
+ ax.set_xlabel("Iteration"); ax.set_ylabel("Fraction")
641
+ ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda y,_: f"{y:.0%}"))
642
+ ax.grid(alpha=0.3)
643
+
644
+ # (1,0) KL divergence
645
+ ax = fig.add_subplot(gs[1, 0])
646
+ ax.plot(iters, kl_hist, marker="^", color="tomato", linewidth=2)
647
+ ax.axhline(0, color="gray", linewidth=0.8)
648
+ ax.set_title("D. KL Divergence", fontweight="bold")
649
+ ax.set_xlabel("Iteration"); ax.set_ylabel("KL"); ax.grid(alpha=0.3)
650
+
651
+ # (1,1) Entropy
652
+ ax = fig.add_subplot(gs[1, 1])
653
+ ax.plot(iters, entropy_hist, marker="D", color="darkorange", linewidth=2)
654
+ ax.set_title("E. Policy Entropy", fontweight="bold")
655
+ ax.set_xlabel("Iteration"); ax.set_ylabel("Entropy"); ax.grid(alpha=0.3)
656
+
657
+ # (1,2) Per-difficulty final reward
658
+ ax = fig.add_subplot(gs[1, 2])
659
+ tasks = TASK_LEVELS
660
+ vals_base = [baseline_snap.per_task.get(t, 0) for t in tasks]
661
+ vals_final = [final_snap.per_task.get(t, 0) for t in tasks]
662
+ x = np.arange(len(tasks))
663
+ ax.bar(x - 0.2, vals_base, 0.35, label="baseline",color="lightcoral", alpha=0.8)
664
+ ax.bar(x + 0.2, vals_final, 0.35, label="final", color="steelblue", alpha=0.8)
665
+ ax.set_xticks(x); ax.set_xticklabels(tasks, fontsize=8)
666
+ ax.set_title("F. Per-Difficulty Reward", fontweight="bold")
667
+ ax.set_ylabel("Avg Reward"); ax.legend(fontsize=8); ax.grid(alpha=0.3, axis="y")
668
+ ax.axhline(0, color="gray", linewidth=0.8)
669
+
670
+ fig.suptitle(f"Code-Review Agent – Full Training Evidence "
671
+ f"(Qwen2.5-1.5B, PPO + QLoRA)",
672
+ fontsize=13, fontweight="bold")
673
+ fig.savefig("training_summary.png", dpi=150, bbox_inches="tight")
674
+ plt.close(fig)
675
+ print(" Saved: training_summary.png")
676
+
677
+ # ── Figure 2: before / after action distribution ─────────────────────────
678
+ fig, axes = plt.subplots(1, 3, figsize=(16, 4), sharey=False)
679
+ for ax, snap, title in zip(
680
+ axes,
681
+ [baseline_snap, postwarmup_snap, final_snap],
682
+ ["Before (baseline)", "After warm-up", "After PPO (final)"]
683
+ ):
684
+ if snap.action_dist:
685
+ labels = list(snap.action_dist.keys())
686
+ vals = [snap.action_dist[l]*100 for l in labels]
687
+ bars = ax.barh(labels, vals,
688
+ color=plt.cm.tab10(np.linspace(0, 0.8, len(labels))))
689
+ ax.bar_label(bars, fmt="%.0f%%", padding=3, fontsize=8)
690
+ ax.set_xlim(0, 105)
691
+ ax.set_title(title, fontweight="bold")
692
+ ax.set_xlabel("% of actions")
693
+ ax.grid(alpha=0.3, axis="x")
694
+
695
+ fig.suptitle("Action Distribution: Before vs After Training",
696
+ fontsize=12, fontweight="bold")
697
+ plt.tight_layout()
698
+ fig.savefig("action_distribution.png", dpi=150, bbox_inches="tight")
699
+ plt.close(fig)
700
+ print(" Saved: action_distribution.png")
701
+
702
+ # ══════════════════════════════════════════════════════════════════════════════
703
+ # MAIN
704
+ # ══════════════════════════════════════════════════════════════════════════════
705
  def train():
706
  model, tokenizer = load_model()
707
  env = CodeReviewEnv()
708
 
709
+ # ── PHASE 0: pre-warmup baseline ────────────────────────────────────────
710
+ print("\n" + "="*60)
711
+ print("PHASE 0 – BASELINE (untrained)")
712
+ print("="*60)
713
+ baseline_snap = evaluate(env, model, tokenizer, "Baseline")
714
 
715
+ # ── PHASE 1: supervised warm-up ─────────────────────────────────────────
716
+ warmup_losses = supervised_warmup(model, tokenizer)
 
717
 
718
+ postwarmup_snap = evaluate(env, model, tokenizer, "Post-Warmup")
 
 
 
 
 
 
 
719
 
720
+ # ── PHASE 2: PPO ────────────────────────────────────────────────────────
721
+ optimizer = AdamW(model.parameters(), lr=CFG["ppo_lr"])
722
+ reward_hist, success_hist, kl_hist, entropy_hist = [], [], [], []
 
723
 
724
+ print("\n" + "="*60)
725
+ print(f"PHASE 2 – PPO ({CFG['ppo_iters']} iterations Γ— "
726
+ f"{CFG['trajs_per_iter']} trajectories)")
727
+ print("="*60)
 
 
 
 
728
 
729
+ for it in range(CFG["ppo_iters"]):
730
+ # Linearly anneal exploration temperature
731
+ # FIX 7 – exponential decay with a floor (never below 0.35).
732
+ # Linear annealing to 0.1 collapses exploration before we learn
733
+ # anything; keeping >= 0.35 ensures trajectory diversity.
734
+ t = max(CFG["temp_start"] * (0.93 ** it), 0.35)
735
 
736
+ print(f"\n── Iteration {it+1}/{CFG['ppo_iters']} temp={t:.2f} ──")
737
+ trajectories, action_counts = [], Counter()
738
+ successes = 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
739
 
740
+ for j in range(CFG["trajs_per_iter"]):
741
+ task = TASK_LEVELS[j % len(TASK_LEVELS)]
742
+ traj, actions = collect_trajectory(
743
+ env, model, tokenizer, CFG["max_steps"], t, task
744
+ )
745
+ trajectories.append(traj)
746
+ action_counts.update(actions)
747
+ ep_r = sum(traj.rewards)
748
+ # FIX 6b – consistent with evaluate(): only explicit done counts
749
+ successes += int("done" in actions)
750
+ print(f" traj {j+1}/{CFG['trajs_per_iter']} task={task}"
751
+ f" steps={len(traj.actions)} reward={ep_r:+.3f}")
752
+
753
+ avg_r = float(np.mean([sum(t.rewards) for t in trajectories]))
754
+ success_r = successes / CFG["trajs_per_iter"]
755
+
756
+ m = ppo_update(trajectories, model, tokenizer, optimizer)
757
+
758
+ reward_hist.append(avg_r)
759
+ success_hist.append(success_r)
760
+ kl_hist.append(m["kl"])
761
+ entropy_hist.append(m["entropy"])
762
+
763
+ delta = avg_r - baseline_snap.avg_reward
764
+ print(f" β†’ avg_reward={avg_r:+.4f} Ξ”baseline={delta:+.4f}"
765
+ f" success={success_r:.0%}"
766
+ f" loss={m['loss']:.4f} kl={m['kl']:.4f} ent={m['entropy']:.4f}")
767
+ print(f" actions: {dict(action_counts.most_common(5))}")
768
+
769
+ # ── PHASE 3: final evaluation ───────────────────────────────────────────
770
+ print("\n" + "="*60)
771
+ print("PHASE 3 – FINAL EVALUATION")
772
+ print("="*60)
773
+ final_snap = evaluate(env, model, tokenizer, "Final")
774
+
775
+ # ── Summary table ───────────────────────────────────────────────────────
776
+ print("\n" + "="*60)
777
+ print("TRAINING SUMMARY")
778
+ print("="*60)
779
+ print(f" {'Stage':<20} {'Reward':>10} {'Success':>10} {'Ξ” baseline':>12}")
780
+ print(f" {'-'*54}")
781
+ for label, snap in [("Baseline", baseline_snap),
782
+ ("Post-warmup", postwarmup_snap),
783
+ ("Final (PPO)", final_snap)]:
784
+ delta = snap.avg_reward - baseline_snap.avg_reward
785
+ print(f" {label:<20} {snap.avg_reward:>+10.4f}"
786
+ f" {snap.success_rate:>10.0%} {delta:>+11.4f}")
787
+
788
+ improve = final_snap.avg_reward - baseline_snap.avg_reward
789
+ verdict = "βœ“ LEARNED" if improve > 0 else "βœ— NO IMPROVEMENT"
790
+ print(f"\n {verdict} (total Ξ” = {improve:+.4f})")
791
+
792
+ print("\nBefore β†’ After traces (one per difficulty):")
793
+ btask = {t["task"]: t for t in baseline_snap.traces}
794
+ ftask = {t["task"]: t for t in final_snap.traces}
795
+ for task in TASK_LEVELS:
796
+ b = btask.get(task, {})
797
+ f = ftask.get(task, {})
798
+ print(f" {task:8s} baseline actions={b.get('actions',[])} "
799
+ f"reward={b.get('reward',0):+.3f}"
800
+ f" β”‚ final actions={f.get('actions',[])} "
801
+ f"reward={f.get('reward',0):+.3f}")
802
+
803
+ # ── Plots ───────────────────────────────────────────────────────────────
804
+ plot_all(warmup_losses, reward_hist, success_hist, kl_hist, entropy_hist,
805
+ baseline_snap, postwarmup_snap, final_snap)
806
+
807
+ print("\nAll done. Saved: training_summary.png action_distribution.png")
808
 
 
 
809
 
810
  if __name__ == "__main__":
811
  train()