100XZX001 commited on
Commit
868dd5a
Β·
verified Β·
1 Parent(s): a3960cd

Upload 17 files

Browse files
Files changed (2) hide show
  1. environment.py +1 -122
  2. training.py +850 -726
environment.py CHANGED
@@ -30,8 +30,6 @@ from rubrics import (
30
  # ======================================================================
31
  # FULLY MARKOV OBSERVATION (NOTHING HIDDEN)
32
  # ======================================================================
33
-
34
-
35
  @dataclass
36
  class EnhancedObservation:
37
  code_snippet: str
@@ -77,7 +75,6 @@ def execute_code(code: str, timeout_sec: int = 5) -> Tuple[bool, str, str]:
77
  f.write(code)
78
  tmp_path = f.name
79
 
80
-
81
  try:
82
  result = subprocess.run(
83
  [sys.executable, tmp_path],
@@ -205,124 +202,6 @@ class CodeReviewEnv:
205
  ExplorationRubric(penalty=-0.05, bonus=self.diversity_bonus * 0.7),
206
  AntiHackingRubric(),
207
  core_rubrics[-1],
208
-
209
-
210
-
211
-
212
-
213
-
214
-
215
-
216
-
217
-
218
-
219
-
220
-
221
-
222
-
223
-
224
-
225
-
226
-
227
-
228
-
229
-
230
-
231
-
232
-
233
-
234
-
235
-
236
-
237
-
238
-
239
-
240
-
241
-
242
-
243
-
244
-
245
-
246
-
247
-
248
-
249
-
250
-
251
-
252
-
253
-
254
-
255
-
256
-
257
-
258
-
259
-
260
-
261
-
262
-
263
-
264
-
265
-
266
-
267
-
268
-
269
-
270
-
271
-
272
-
273
-
274
-
275
-
276
-
277
-
278
-
279
-
280
-
281
-
282
-
283
-
284
-
285
-
286
-
287
-
288
-
289
-
290
-
291
-
292
-
293
-
294
-
295
-
296
-
297
-
298
-
299
-
300
-
301
-
302
-
303
-
304
-
305
-
306
-
307
-
308
-
309
-
310
-
311
-
312
-
313
-
314
-
315
-
316
-
317
-
318
-
319
-
320
-
321
-
322
-
323
-
324
-
325
-
326
  ]
327
  raise ValueError(f"Unknown reward_profile: {self.reward_profile}")
328
 
@@ -746,4 +625,4 @@ class CodeReviewEnv:
746
  test_results=self._test_results,
747
  step=self._step_count,
748
  done=self._done
749
- )
 
30
  # ======================================================================
31
  # FULLY MARKOV OBSERVATION (NOTHING HIDDEN)
32
  # ======================================================================
 
 
33
  @dataclass
34
  class EnhancedObservation:
35
  code_snippet: str
 
75
  f.write(code)
76
  tmp_path = f.name
77
 
 
78
  try:
79
  result = subprocess.run(
80
  [sys.executable, tmp_path],
 
202
  ExplorationRubric(penalty=-0.05, bonus=self.diversity_bonus * 0.7),
203
  AntiHackingRubric(),
204
  core_rubrics[-1],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
  ]
206
  raise ValueError(f"Unknown reward_profile: {self.reward_profile}")
207
 
 
625
  test_results=self._test_results,
626
  step=self._step_count,
627
  done=self._done
628
+ )
training.py CHANGED
@@ -1,811 +1,935 @@
1
- # training.py – PPO + QLoRA + Supervised Warm-up
2
- # Model : Qwen/Qwen2.5-1.5B-Instruct (via Unsloth – 2Γ— faster, fits Colab T4)
3
- # Fixed : label-masking, BPE-boundary alignment, log-ratio clamping, OOM guards
4
- # Evidence: reward curves, before/after traces, per-difficulty breakdown, KL, entropy
5
- # ============================================================
6
- import os, json, random, re
7
- os.environ["TOKENIZERS_PARALLELISM"] = "false"
8
-
9
- import matplotlib
10
- matplotlib.use("Agg")
11
- import matplotlib.pyplot as plt
12
- import matplotlib.gridspec as gridspec
13
-
14
  import torch
15
  import torch.nn.functional as F
16
  from torch.optim import AdamW
17
- from dataclasses import dataclass, field
18
- from typing import List, Optional, Dict
19
- from collections import Counter, defaultdict
20
  import numpy as np
 
 
 
21
 
22
- # ── Unsloth gives 2Γ— throughput with identical outputs ────────────────────────
23
  from unsloth import FastLanguageModel
 
 
 
24
 
25
  from environment import CodeReviewEnv
26
  from redteam import BUG_DB
27
-
28
- # Graceful import: use project map_to_env if available, else inline fallback.
29
- try:
30
- from models import map_to_env as model_map_to_env
31
- _HAVE_MODEL_MAP = True
32
- except (ImportError, AttributeError):
33
- _HAVE_MODEL_MAP = False
34
-
35
- if not _HAVE_MODEL_MAP:
36
- try:
37
- from models import (RunTests, RunLinter, Inspect, ProposeFix,
38
- WriteComment, AskQuestion, Done, Skip, QueryDocs)
39
- def model_map_to_env(action_type: str, content=None):
40
- return {
41
- "run_tests": RunTests(),
42
- "run_linter": RunLinter(),
43
- "inspect": Inspect(),
44
- "query_docs": QueryDocs(content or "python bug fix"),
45
- "fix": ProposeFix(content or ""),
46
- "comment": WriteComment(content or ""),
47
- "question": AskQuestion(content or ""),
48
- "done": Done(),
49
- }.get(action_type, Skip())
50
- except ImportError:
51
- # Last resort: duck-typed object the env can introspect.
52
- class _EnvAction:
53
- def __init__(self, **kw): self.__dict__.update(kw)
54
- def model_map_to_env(action_type: str, content=None):
55
- return _EnvAction(action_type=action_type, content=content)
56
-
57
- # ══════════════════════════════════════════════════════════════════════════════
58
- # CONFIG
59
- # ══════════════════════════════════════════════════════════════════════════════
60
- CFG = dict(
61
- model_name = "unsloth/Qwen2.5-1.5B-Instruct-bnb-4bit",
62
- max_seq_len = 512, # hard cap; prevents OOM on T4
63
- lora_r = 16,
64
- lora_alpha = 32,
65
-
66
- # Warm-up
67
- warmup_data = "training_data.json",
68
- warmup_epochs = 2,
69
- warmup_lr = 2e-5,
70
- warmup_grad_acc = 4, # effective batch = 4 examples
71
-
72
- # PPO
73
- ppo_iters = 15,
74
- trajs_per_iter = 6,
75
- max_steps = 7,
76
- ppo_lr = 3e-5,
77
- clip_eps = 0.2,
78
- entropy_coef = 0.01,
79
- gamma = 0.99,
80
- log_ratio_clamp = 5.0, # ← prevents exp-explosion / NaN loss
81
- temp_start = 0.8,
82
- temp_end = 0.1,
83
-
84
- # Eval
85
- eval_episodes = 10, # episodes per evaluation snapshot
86
  )
87
 
88
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
89
- TASK_LEVELS = list(BUG_DB.keys()) # [easy, medium, hard, harder, hardest]
90
-
91
- # ══════════════════════════════════════════════════════════════════════════════
92
- # DATA STRUCTURES
93
- # ══════════════════════════════════════════════════════════════════════════════
94
  @dataclass
95
  class AgentAction:
96
  action_type: str
97
  content: Optional[str] = None
98
 
99
- @dataclass
100
- class Trajectory:
101
- states: List[str]
102
- actions: List[str]
103
- rewards: List[float]
104
- logprobs: List[float]
105
- dones: List[bool]
106
- task: str = ""
107
-
108
- @dataclass
109
- class EvalSnapshot:
110
- """Captures full agent behaviour for before/after comparison."""
111
- avg_reward: float
112
- per_task: Dict[str, float] = field(default_factory=dict)
113
- action_dist: Dict[str, float] = field(default_factory=dict)
114
- success_rate: float = 0.0
115
- avg_steps: float = 0.0
116
- traces: List[dict] = field(default_factory=list)
117
-
118
- # ══════════════════════════════════════════════════════════════════════════════
119
- # ACTION PARSER
120
- # ══════════════════════════════════════════════════════════════════════════════
121
- def parse_action(text: str) -> AgentAction:
122
- """Robust parser: tries strict JSON, then regex, then keyword heuristic."""
123
- text = text.strip()
124
  try:
125
- d = json.loads(text)
126
- return AgentAction(d.get("action_type","skip").lower(), d.get("content"))
127
- except json.JSONDecodeError:
 
 
 
128
  pass
129
- m = re.search(r'"action_type"\s*:\s*"(\w+)"', text)
130
- if m:
131
- cm = re.search(r'"content"\s*:\s*"(.*?)"', text, re.DOTALL)
132
- return AgentAction(m.group(1).lower(), cm.group(1) if cm else None)
133
- tl = text.lower()
134
- for kw in ("run_tests","run_linter","inspect","query_docs","fix",
135
- "comment","question","done"):
136
- if kw in tl:
137
- return AgentAction(kw)
138
- return AgentAction("skip")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
 
140
  def map_to_env(action: AgentAction):
141
  return model_map_to_env(action.action_type, action.content)
142
 
143
- # ══════════════════════════════════════════════════════════════════════════════
144
- # MODEL (Qwen2.5-1.5B via Unsloth)
145
- # ══════════════════════════════════════════════════════════════════════════════
146
  def load_model():
147
- print(f"Loading {CFG['model_name']} …")
148
  model, tokenizer = FastLanguageModel.from_pretrained(
149
- model_name = CFG["model_name"],
150
- max_seq_length = CFG["max_seq_len"],
151
- load_in_4bit = True,
152
  )
153
  model = FastLanguageModel.get_peft_model(
154
  model,
155
- r = CFG["lora_r"],
156
- lora_alpha = CFG["lora_alpha"],
157
- target_modules = ["q_proj","k_proj","v_proj","o_proj",
158
- "gate_proj","up_proj","down_proj"],
159
- lora_dropout = 0.0,
 
 
160
  )
161
- tokenizer.pad_token = tokenizer.eos_token
162
- print(f" trainable params: "
163
- f"{sum(p.numel() for p in model.parameters() if p.requires_grad)/1e6:.1f}M")
164
  return model, tokenizer
165
 
166
- # ═══════════════════════════════════════════════════════════════��══════════════
167
- # PROMPT BUILDER
168
- # ══════════════════════════════════════════════════════════════════════════════
169
- def build_prompt(obs, history_lines: List[str]) -> str:
170
- author_msg = getattr(obs, "author_response", "") or ""
171
- tool_output = getattr(obs, "last_tool_output", "") or ""
172
- personality = getattr(obs, "author_personality","defensive")
173
-
174
- # Trim tool output to avoid context explosion
175
- if len(tool_output) > 600:
176
- tool_output = tool_output[:600] + " …[truncated]"
177
-
178
- p = (
179
- f"You are an AI code review agent. Convince the developer (personality: "
180
- f"**{personality}**) to accept your fix. Name your fix function `fix`.\n\n"
181
- "Evidence required: tests pass, lint clean, docs cited, reasoning uses "
182
- "'because'/'therefore' (>30 words).\n\n"
183
- "Workflow: inspect β†’ run_tests β†’ run_linter β†’ query_docs β†’ fix β†’ "
184
- "comment/question β†’ done.\n\n"
185
- f"Code:\n{obs.code_snippet}\n\n"
186
- f"Author: {author_msg or '(no response yet – start with inspect)'}\n\n"
187
- f"Last tool: {tool_output or '(none)'}\n\n"
188
- "Actions: run_tests, run_linter, inspect, query_docs, fix, comment, question, done\n\n"
189
- 'Respond ONLY in JSON: {"action_type": "...", "content": "..."}'
190
- )
191
- if history_lines:
192
- p += "\n\nRecent steps:\n" + "\n".join(history_lines[-4:])
193
- return p
194
-
195
- # ══════════════════════════════════════════════════════════════════════════════
196
- # BUG FIX 1 – label masking in supervised warmup
197
- # (original: labels=inputs["input_ids"] trains on ALL tokens, including prompt)
198
- # ══════════════════════════════════════════════════════════════════════════════
199
- def _masked_labels(input_ids: torch.Tensor, prompt_len: int) -> torch.Tensor:
200
- """Return labels with prompt positions set to -100 (ignored by CE loss)."""
201
- labels = input_ids.clone()
202
- labels[0, :prompt_len] = -100
203
- return labels
204
-
205
- # ══════════════════════════════════════════════════════════════════════════════
206
- # BUG FIX 2 – BPE-boundary-safe logprob computation
207
- # (original: tokenize(prompt) + tokenize(action) β‰  tokenize(prompt+action))
208
- # ══════════════════════════════════════════════════════════════════════════════
209
- def _compute_action_logprob(
210
- logits: torch.Tensor, # [1, seq_len, vocab]
211
- input_ids: torch.Tensor, # [1, seq_len]
212
- prompt_len: int, # #tokens in the prompt part of the joint sequence
213
- ) -> tuple:
214
  """
215
- Compute sum of log-probs for *action* tokens only, using the jointly
216
- tokenised sequence so BPE boundaries are respected.
217
-
218
- Returns (total_logprob, avg_entropy, n_tokens).
219
  """
220
- action_len = input_ids.shape[1] - prompt_len
221
- if action_len <= 0:
222
- return torch.tensor(0.0, device=DEVICE), torch.tensor(0.0, device=DEVICE), 0
 
 
 
 
 
 
 
223
 
224
- total_lp = torch.tensor(0.0, device=DEVICE)
225
- total_ent = torch.tensor(0.0, device=DEVICE)
 
 
 
226
 
227
- for k in range(action_len):
228
- pos = prompt_len + k # position of the k-th action token
229
- pred_pos = pos - 1 # logit at pred_pos predicts token at pos
230
- if pred_pos < 0 or pred_pos >= logits.shape[1]:
231
- continue
232
- token_id = input_ids[0, pos]
233
- lp_dist = F.log_softmax(logits[0, pred_pos], dim=-1)
234
- total_lp = total_lp + lp_dist[token_id]
235
- probs = torch.exp(lp_dist)
236
- total_ent = total_ent + (-(probs * lp_dist).sum()).detach()
237
-
238
- n = action_len
239
- return total_lp, total_ent / max(n, 1), n
240
-
241
- # ══════════════════════════════════════════════════════════════════════════════
242
- # GENERATION (returns text + joint-sequence logprob)
243
- # ══════════════════════════════════════════════════════════════════════════════
244
- @torch.no_grad()
245
- def generate_action(prompt: str, model, tokenizer,
246
- temperature: float) -> tuple:
247
- messages = [{"role": "user", "content": prompt}]
248
- formatted = tokenizer.apply_chat_template(
249
- messages, tokenize=False, add_generation_prompt=True
250
- )
251
 
252
- inputs = tokenizer(
253
- formatted, return_tensors="pt",
254
- max_length=CFG["max_seq_len"] - 128, # leave room for response
255
- truncation=True
256
- ).to(DEVICE)
257
- prompt_len = inputs["input_ids"].shape[1]
258
-
259
- gen_kwargs = dict(
260
- max_new_tokens = 128,
261
- do_sample = temperature > 0,
262
- return_dict_in_generate = True,
263
- output_scores = True,
264
- pad_token_id = tokenizer.eos_token_id,
265
- eos_token_id = tokenizer.eos_token_id,
266
- )
267
- if temperature > 0:
268
- gen_kwargs["temperature"] = temperature
269
-
270
- out = model.generate(**inputs, **gen_kwargs)
271
- gen_ids = out.sequences[0][prompt_len:]
272
- text = tokenizer.decode(gen_ids, skip_special_tokens=True).strip()
273
-
274
- if not text:
275
- fallback = random.choice([
276
- '{"action_type":"inspect"}',
277
- '{"action_type":"run_tests"}',
278
- '{"action_type":"run_linter"}',
279
- ])
280
- print(f" [WARN] empty generation β†’ fallback {fallback}")
281
- # BUG FIX 3: don't use -100 sentinel; use a mildly negative logprob
282
- # so that PPO ratio = exp(new - old) stays finite when re-evaluated
283
- return fallback, -10.0
284
-
285
- # Recompute logprob from the full joint sequence (BPE-safe)
286
- joint_ids = torch.cat(
287
- [inputs["input_ids"], gen_ids.unsqueeze(0).to(DEVICE)], dim=1
288
- )
289
- joint_ids = joint_ids[:, :CFG["max_seq_len"]]
290
-
291
- logits = model(input_ids=joint_ids).logits
292
- lp, _, _ = _compute_action_logprob(logits, joint_ids, prompt_len)
293
-
294
- return text, lp.item()
295
-
296
- # ══════════════════════════════════════════════════════════════════════════════
297
- # TRAJECTORY COLLECTION
298
- # ══════════════════════════════════════════════════════════════════════════════
299
- # Per-action shaped rewards. These create reward variance so that
300
- # trajectories with meaningful tool use beat inspect-only episodes.
301
- _STEP_REWARD = {
302
- "run_tests": +0.08,
303
- "run_linter": +0.05,
304
- "fix": +0.15,
305
- "comment": +0.08,
306
- "query_docs": +0.05,
307
- "question": +0.04,
308
- "inspect": 0.00, # neutral – observe before acting
309
- "done": 0.00, # env handles the terminal reward
310
- "skip": -0.10, # penalise doing nothing
311
- }
312
-
313
- def collect_trajectory(env, model, tokenizer,
314
- max_steps: int, temperature: float,
315
- task: str) -> tuple:
316
- """
317
- FIX 4 – Override env done/reward for non-terminal actions.
318
 
319
- Root cause of the degenerate policy:
320
- β€’ env.step(Inspect()) returns done=True, reward=+0.002
321
- β€’ agent discovers inspect β†’ tiny reward β†’ done is the easiest path
322
- β€’ every trajectory is identical β†’ zero advantage β†’ PPO does nothing
 
 
 
 
 
 
 
 
323
 
324
- Fix: only accept env's done+reward when the agent explicitly emits
325
- {"action_type": "done"}. For every other action, use a shaped step
326
- reward and force the episode to continue.
327
- """
328
- env.set_task(task)
329
- obs = env.reset()
330
- history: List[str] = []
331
- traj = Trajectory([], [], [], [], [], task=task)
332
- action_seq = []
333
-
334
- for step_num in range(max_steps):
335
- prompt = build_prompt(obs, history)
336
- traj.states.append(prompt)
337
-
338
- text, lp = generate_action(prompt, model, tokenizer, temperature)
339
- traj.actions.append(text)
340
- traj.logprobs.append(lp)
341
-
342
- action = parse_action(text)
343
- action_seq.append(action.action_type)
344
-
345
- obs, reward, env_done, _ = env.step(map_to_env(action))
346
- raw_r = float(reward.value)
347
-
348
- if action.action_type == "done":
349
- # Agent explicitly chose to terminate β†’ honour env reward
350
- shaped_r = raw_r
351
- effective_done = True
352
- else:
353
- # Intermediate step: use shaped reward, ignore env's done signal.
354
- # Also keep a fraction of any large env reward (e.g. test pass).
355
- shaped_r = _STEP_REWARD.get(action.action_type, 0.0)
356
- if raw_r > 0.1: # env signalling meaningful progress
357
- shaped_r += raw_r * 0.3
358
- effective_done = False # ← key: don't let env short-circuit
359
-
360
- traj.rewards.append(float(np.clip(shaped_r, -1.0, 1.0)))
361
- traj.dones.append(effective_done)
362
-
363
- history.append(f"Agent: {text[:120]}")
364
- history.append(f"Env: {(obs.last_tool_output or '')[:120]}")
365
-
366
- if effective_done:
367
- break
368
-
369
- return traj, action_seq
370
-
371
- # ══════════════════════════════════════════════════════════════════════════════
372
- # SUPERVISED WARM-UP (BUG FIX 1: action-only label masking)
373
- # ══════════════════════════════════════════════════════════════════════════════
374
- def supervised_warmup(model, tokenizer):
375
- print("\n" + "="*60)
376
- print("SUPERVISED WARM-UP")
377
- print("="*60)
378
 
379
- with open(CFG["warmup_data"], encoding="utf-8") as f:
380
- data = json.load(f)
 
 
 
 
 
 
 
 
381
 
382
- opt = AdamW(model.parameters(), lr=CFG["warmup_lr"])
383
- model.train()
384
- loss_history = []
 
 
 
 
385
 
386
- for epoch in range(CFG["warmup_epochs"]):
387
- random.shuffle(data)
388
- epoch_loss, n_valid = 0.0, 0
389
- opt.zero_grad()
390
 
391
- for step, ex in enumerate(data):
392
- # ── Tokenise prompt and full sequence jointly ────────────────
393
- prompt_chat = tokenizer.apply_chat_template(
394
- [{"role": "user", "content": ex["prompt"]}],
395
- tokenize=False, add_generation_prompt=True
396
- )
397
- full_chat = tokenizer.apply_chat_template(
398
- [{"role": "user", "content": ex["prompt"]},
399
- {"role": "assistant", "content": ex["action"]}],
400
- tokenize=False
401
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
402
 
403
- prompt_ids = tokenizer(
404
- prompt_chat, return_tensors="pt",
405
- max_length=CFG["max_seq_len"], truncation=True
406
- )["input_ids"]
407
- full_inputs = tokenizer(
408
- full_chat, return_tensors="pt",
409
- max_length=CFG["max_seq_len"], truncation=True
410
- ).to(DEVICE)
 
 
 
 
411
 
412
- prompt_len = prompt_ids.shape[1]
413
- if prompt_len >= full_inputs["input_ids"].shape[1]:
414
- continue # action got truncated away
 
 
 
415
 
416
- # BUG FIX 1 ── mask prompt tokens so loss is action-only
417
- labels = _masked_labels(full_inputs["input_ids"], prompt_len)
 
 
418
 
419
- out = model(**full_inputs, labels=labels)
420
- loss = out.loss / CFG["warmup_grad_acc"]
421
- loss.backward()
 
 
422
 
423
- if (step + 1) % CFG["warmup_grad_acc"] == 0:
424
- torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
425
- opt.step()
426
- opt.zero_grad()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
427
 
428
- epoch_loss += loss.item() * CFG["warmup_grad_acc"]
429
- n_valid += 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
430
 
431
- if (step + 1) % 50 == 0:
432
- print(f" epoch {epoch+1} step {step+1}/{len(data)}"
433
- f" loss={epoch_loss/n_valid:.4f}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
434
 
435
- avg = epoch_loss / max(n_valid, 1)
436
- loss_history.append(avg)
437
- print(f" Epoch {epoch+1} complete: avg_loss={avg:.4f}")
 
 
 
 
 
 
438
 
439
- torch.cuda.empty_cache()
440
- print(f"βœ“ Warm-up done. Loss: {' β†’ '.join(f'{l:.4f}' for l in loss_history)}\n")
441
- return loss_history
442
-
443
- # ══════════════════════════════════════════════════════════════════════════════
444
- # EVALUATION (produces rich EvalSnapshot for comparison plots)
445
- # ══════════════════════════════════════════════════════════════════════════════
446
- @torch.no_grad()
447
- def evaluate(env, model, tokenizer, label: str = "") -> EvalSnapshot:
448
- model.eval()
449
- per_task: Dict[str, List[float]] = defaultdict(list)
450
- action_counter: Counter = Counter()
451
- all_steps, all_success = [], []
452
- traces = []
453
-
454
- for ep in range(CFG["eval_episodes"]):
455
- task = TASK_LEVELS[ep % len(TASK_LEVELS)]
456
- traj, actions = collect_trajectory(
457
- env, model, tokenizer, CFG["max_steps"], 0.0, task
458
- )
459
- ep_r = sum(traj.rewards)
460
- per_task[task].append(ep_r)
461
- action_counter.update(actions)
462
- all_steps.append(len(traj.actions))
463
- # FIX 6 – meaningful success = agent explicitly called "done".
464
- # ep_r > 0 is misleading: even a single inspect returns +0.002.
465
- all_success.append(1 if "done" in actions else 0)
466
- traces.append({"task": task, "reward": round(ep_r, 4),
467
- "steps": len(traj.actions), "actions": actions})
468
-
469
- total_actions = max(sum(action_counter.values()), 1)
470
- snap = EvalSnapshot(
471
- avg_reward = float(np.mean([r for rs in per_task.values() for r in rs])),
472
- per_task = {t: float(np.mean(rs)) for t, rs in per_task.items()},
473
- action_dist = {a: c/total_actions for a, c in action_counter.most_common()},
474
- success_rate = float(np.mean(all_success)),
475
- avg_steps = float(np.mean(all_steps)),
476
- traces = traces,
477
- )
478
- if label:
479
- print(f"\n── {label} ──")
480
- print(f" avg_reward={snap.avg_reward:+.4f} "
481
- f"success={snap.success_rate:.0%} steps={snap.avg_steps:.1f}")
482
- print(f" per-task: " +
483
- " ".join(f"{t}={v:+.3f}" for t,v in snap.per_task.items()))
484
- print(f" top actions: " +
485
- " ".join(f"{a}={p:.0%}" for a,p in list(snap.action_dist.items())[:5]))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
486
  model.train()
487
- return snap
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
488
 
489
- # ══════════════════════════════════════════════════════════════════════════════
490
- # PPO UPDATE (BUG FIX 2 + 3: BPE-safe logprob + log-ratio clamping)
491
- # ══════════════════════════════════════════════════════════════════════════════
492
- def ppo_update(trajectories: List[Trajectory],
493
- model, tokenizer, optimizer) -> dict:
494
- model.train()
495
- losses, kls, entropies = [], [], []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
496
 
497
- # ── Compute discounted returns and a global mean baseline ────────────────
498
- all_returns = []
499
- traj_returns = []
500
- for traj in trajectories:
501
- ret, running = [], 0.0
502
- for r, done in zip(reversed(traj.rewards), reversed(traj.dones)):
503
- running = r + CFG["gamma"] * (0.0 if done else running)
504
- ret.insert(0, running)
505
- traj_returns.append(ret)
506
- all_returns.extend(ret)
507
-
508
- # FIX 5 – Normalise advantages to zero mean / unit std.
509
- # When all returns are identical (e.g. every episode returns 0.002),
510
- # baseline = mean = every return, so adv = 0 for all steps, the
511
- # policy loss is 0, and PPO never updates. Normalising creates real
512
- # signal: better-than-average trajectories get positive advantage,
513
- # worse-than-average get negative, even if the absolute spread is tiny.
514
- ret_arr = np.array(all_returns) if all_returns else np.array([0.0])
515
- ret_mean = float(ret_arr.mean())
516
- ret_std = float(ret_arr.std())
517
-
518
- if ret_std < 1e-6:
519
- # Truly zero variance – nothing to learn this iteration.
520
- print(" [PPO] Zero return variance – skipping gradient update.")
521
- return dict(loss=0.0, kl=0.0, entropy=0.0)
522
-
523
- # Build a lookup so we can retrieve the normalised advantage by
524
- # (trajectory index, step index) during the update loop below.
525
- norm_returns: List[List[float]] = [
526
- [(r - ret_mean) / (ret_std + 1e-8) for r in ret_list]
527
- for ret_list in traj_returns
528
- ]
529
-
530
- for traj_idx, (traj, returns) in enumerate(zip(trajectories, traj_returns)):
531
- for i in range(len(traj.states)):
532
- state = traj.states[i]
533
- action = traj.actions[i]
534
- old_lp = traj.logprobs[i]
535
- adv = norm_returns[traj_idx][i] # ← normalised advantage
536
-
537
- # ── Tokenise jointly (BPE FIX 2) ────────────────────────────────
538
- prompt_chat = tokenizer.apply_chat_template(
539
- [{"role": "user", "content": state}],
540
- tokenize=False, add_generation_prompt=True
541
- )
542
- full_text = prompt_chat + action
543
 
544
- full_ids = tokenizer(
545
- full_text, return_tensors="pt",
546
- max_length=CFG["max_seq_len"], truncation=True
547
- ).to(DEVICE)
 
548
 
549
- # Count prompt tokens IN THE JOINT SEQUENCE (not separately)
550
- prompt_ids = tokenizer(
551
- prompt_chat, return_tensors="pt",
552
- max_length=CFG["max_seq_len"] - 10, truncation=True
553
- )["input_ids"]
554
- prompt_len = min(prompt_ids.shape[1], full_ids["input_ids"].shape[1] - 1)
555
 
556
- logits = model(**full_ids).logits
 
 
557
 
558
- new_lp, avg_ent, n_tokens = _compute_action_logprob(
559
- logits, full_ids["input_ids"], prompt_len
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
560
  )
561
- if n_tokens == 0:
562
- continue
563
 
564
- # BUG FIX 3 ── clamp log-ratio before exp to prevent NaN
565
- old_lp_t = torch.tensor(old_lp, dtype=torch.float32, device=DEVICE)
566
- log_ratio = torch.clamp(new_lp - old_lp_t,
567
- -CFG["log_ratio_clamp"],
568
- CFG["log_ratio_clamp"])
569
- ratio = torch.exp(log_ratio)
570
 
571
- adv_t = torch.tensor(adv, dtype=torch.float32, device=DEVICE)
572
- s1 = ratio * adv_t
573
- s2 = torch.clamp(ratio,
574
- 1.0 - CFG["clip_eps"],
575
- 1.0 + CFG["clip_eps"]) * adv_t
576
 
577
- policy_loss = -torch.min(s1, s2)
578
- loss = policy_loss - CFG["entropy_coef"] * avg_ent
 
579
 
580
- if torch.isnan(loss) or torch.isinf(loss):
 
581
  continue
582
 
583
- optimizer.zero_grad()
 
 
584
  loss.backward()
585
  torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
586
- optimizer.step()
587
 
588
- losses.append(loss.item())
589
- kls.append((old_lp_t - new_lp).detach().cpu().item())
590
- entropies.append(avg_ent.item())
591
-
592
- torch.cuda.empty_cache()
593
- return dict(
594
- loss = float(np.mean(losses)) if losses else 0.0,
595
- kl = float(np.mean(kls)) if kls else 0.0,
596
- entropy = float(np.mean(entropies)) if entropies else 0.0,
597
- )
598
 
599
- # ══════════════════════════════════════════════════════════════════════════════
600
- # PLOTTING (rich evidence panel)
601
- # ══════════════════════════════════════════════════════════════════════════════
602
- def plot_all(warmup_losses, reward_hist, success_hist, kl_hist, entropy_hist,
603
- baseline_snap: EvalSnapshot,
604
- postwarmup_snap: EvalSnapshot,
605
- final_snap: EvalSnapshot):
606
-
607
- iters = list(range(1, len(reward_hist) + 1))
608
-
609
- # ── Figure 1: training curves (2Γ—3 grid) ─────────────────────────────────
610
- fig = plt.figure(figsize=(18, 10))
611
- gs = gridspec.GridSpec(2, 3, figure=fig, hspace=0.45, wspace=0.35)
612
-
613
- # (0,0) Warm-up loss
614
- ax = fig.add_subplot(gs[0, 0])
615
- ax.plot(range(1, len(warmup_losses)+1), warmup_losses,
616
- marker="o", color="mediumpurple", linewidth=2)
617
- ax.set_title("A. Warm-up CE Loss ↓", fontweight="bold")
618
- ax.set_xlabel("Epoch"); ax.set_ylabel("Loss"); ax.grid(alpha=0.3)
619
-
620
- # (0,1) PPO reward
621
- ax = fig.add_subplot(gs[0, 1])
622
- smooth = np.convolve(reward_hist, np.ones(3)/3, mode="same")
623
- ax.plot(iters, reward_hist, alpha=0.35, color="steelblue", linewidth=1)
624
- ax.plot(iters, smooth, color="steelblue", linewidth=2.5, label="reward (smoothed)")
625
- ax.axhline(baseline_snap.avg_reward, color="gray", linestyle=":",
626
- label=f"pre-warmup ({baseline_snap.avg_reward:+.3f})")
627
- ax.axhline(postwarmup_snap.avg_reward, color="mediumpurple", linestyle="--",
628
- label=f"post-warmup ({postwarmup_snap.avg_reward:+.3f})")
629
- ax.axhline(final_snap.avg_reward, color="forestgreen", linestyle="-.",
630
- label=f"final ({final_snap.avg_reward:+.3f})")
631
- ax.set_title("B. PPO Reward ↑", fontweight="bold")
632
- ax.set_xlabel("Iteration"); ax.set_ylabel("Avg Reward")
633
- ax.legend(fontsize=7); ax.grid(alpha=0.3)
634
-
635
- # (0,2) Success rate
636
- ax = fig.add_subplot(gs[0, 2])
637
- ax.plot(iters, success_hist, marker="s", color="seagreen", linewidth=2)
638
- ax.set_ylim(0, 1)
639
- ax.set_title("C. Episode Success Rate ↑", fontweight="bold")
640
- ax.set_xlabel("Iteration"); ax.set_ylabel("Fraction")
641
- ax.yaxis.set_major_formatter(plt.FuncFormatter(lambda y,_: f"{y:.0%}"))
642
- ax.grid(alpha=0.3)
643
 
644
- # (1,0) KL divergence
645
- ax = fig.add_subplot(gs[1, 0])
646
- ax.plot(iters, kl_hist, marker="^", color="tomato", linewidth=2)
647
- ax.axhline(0, color="gray", linewidth=0.8)
648
- ax.set_title("D. KL Divergence", fontweight="bold")
649
- ax.set_xlabel("Iteration"); ax.set_ylabel("KL"); ax.grid(alpha=0.3)
650
-
651
- # (1,1) Entropy
652
- ax = fig.add_subplot(gs[1, 1])
653
- ax.plot(iters, entropy_hist, marker="D", color="darkorange", linewidth=2)
654
- ax.set_title("E. Policy Entropy", fontweight="bold")
655
- ax.set_xlabel("Iteration"); ax.set_ylabel("Entropy"); ax.grid(alpha=0.3)
656
-
657
- # (1,2) Per-difficulty final reward
658
- ax = fig.add_subplot(gs[1, 2])
659
- tasks = TASK_LEVELS
660
- vals_base = [baseline_snap.per_task.get(t, 0) for t in tasks]
661
- vals_final = [final_snap.per_task.get(t, 0) for t in tasks]
662
- x = np.arange(len(tasks))
663
- ax.bar(x - 0.2, vals_base, 0.35, label="baseline",color="lightcoral", alpha=0.8)
664
- ax.bar(x + 0.2, vals_final, 0.35, label="final", color="steelblue", alpha=0.8)
665
- ax.set_xticks(x); ax.set_xticklabels(tasks, fontsize=8)
666
- ax.set_title("F. Per-Difficulty Reward", fontweight="bold")
667
- ax.set_ylabel("Avg Reward"); ax.legend(fontsize=8); ax.grid(alpha=0.3, axis="y")
668
- ax.axhline(0, color="gray", linewidth=0.8)
669
-
670
- fig.suptitle(f"Code-Review Agent – Full Training Evidence "
671
- f"(Qwen2.5-1.5B, PPO + QLoRA)",
672
- fontsize=13, fontweight="bold")
673
- fig.savefig("training_summary.png", dpi=150, bbox_inches="tight")
674
- plt.close(fig)
675
- print(" Saved: training_summary.png")
676
-
677
- # ── Figure 2: before / after action distribution ─────────────────────────
678
- fig, axes = plt.subplots(1, 3, figsize=(16, 4), sharey=False)
679
- for ax, snap, title in zip(
680
- axes,
681
- [baseline_snap, postwarmup_snap, final_snap],
682
- ["Before (baseline)", "After warm-up", "After PPO (final)"]
683
- ):
684
- if snap.action_dist:
685
- labels = list(snap.action_dist.keys())
686
- vals = [snap.action_dist[l]*100 for l in labels]
687
- bars = ax.barh(labels, vals,
688
- color=plt.cm.tab10(np.linspace(0, 0.8, len(labels))))
689
- ax.bar_label(bars, fmt="%.0f%%", padding=3, fontsize=8)
690
- ax.set_xlim(0, 105)
691
- ax.set_title(title, fontweight="bold")
692
- ax.set_xlabel("% of actions")
693
- ax.grid(alpha=0.3, axis="x")
694
-
695
- fig.suptitle("Action Distribution: Before vs After Training",
696
- fontsize=12, fontweight="bold")
697
- plt.tight_layout()
698
- fig.savefig("action_distribution.png", dpi=150, bbox_inches="tight")
699
- plt.close(fig)
700
- print(" Saved: action_distribution.png")
701
 
702
- # ══════════════════════════════════════════════════════════════════════════════
703
- # MAIN
704
- # ══════════════════════════════════════════════════════════════════════════════
705
- def train():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
706
  model, tokenizer = load_model()
 
 
707
  env = CodeReviewEnv()
 
708
 
709
- # ── PHASE 0: pre-warmup baseline ────────────────────────────────────────
 
 
710
  print("\n" + "="*60)
711
- print("PHASE 0 – BASELINE (untrained)")
712
  print("="*60)
713
- baseline_snap = evaluate(env, model, tokenizer, "Baseline")
714
-
715
- # ── PHASE 1: supervised warm-up ─────────────────────────────────────────
716
- warmup_losses = supervised_warmup(model, tokenizer)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
717
 
718
- postwarmup_snap = evaluate(env, model, tokenizer, "Post-Warmup")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
719
 
720
- # ── PHASE 2: PPO ────────────────────────────────────────────────────────
721
- optimizer = AdamW(model.parameters(), lr=CFG["ppo_lr"])
722
- reward_hist, success_hist, kl_hist, entropy_hist = [], [], [], []
723
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
724
  print("\n" + "="*60)
725
- print(f"PHASE 2 – PPO ({CFG['ppo_iters']} iterations Γ— "
726
- f"{CFG['trajs_per_iter']} trajectories)")
727
  print("="*60)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
728
 
729
- for it in range(CFG["ppo_iters"]):
730
- # Linearly anneal exploration temperature
731
- # FIX 7 – exponential decay with a floor (never below 0.35).
732
- # Linear annealing to 0.1 collapses exploration before we learn
733
- # anything; keeping >= 0.35 ensures trajectory diversity.
734
- t = max(CFG["temp_start"] * (0.93 ** it), 0.35)
735
-
736
- print(f"\n── Iteration {it+1}/{CFG['ppo_iters']} temp={t:.2f} ──")
737
- trajectories, action_counts = [], Counter()
738
- successes = 0
739
-
740
- for j in range(CFG["trajs_per_iter"]):
741
- task = TASK_LEVELS[j % len(TASK_LEVELS)]
742
- traj, actions = collect_trajectory(
743
- env, model, tokenizer, CFG["max_steps"], t, task
744
- )
745
- trajectories.append(traj)
746
- action_counts.update(actions)
747
- ep_r = sum(traj.rewards)
748
- # FIX 6b – consistent with evaluate(): only explicit done counts
749
- successes += int("done" in actions)
750
- print(f" traj {j+1}/{CFG['trajs_per_iter']} task={task}"
751
- f" steps={len(traj.actions)} reward={ep_r:+.3f}")
752
-
753
- avg_r = float(np.mean([sum(t.rewards) for t in trajectories]))
754
- success_r = successes / CFG["trajs_per_iter"]
755
-
756
- m = ppo_update(trajectories, model, tokenizer, optimizer)
757
-
758
- reward_hist.append(avg_r)
759
- success_hist.append(success_r)
760
- kl_hist.append(m["kl"])
761
- entropy_hist.append(m["entropy"])
762
-
763
- delta = avg_r - baseline_snap.avg_reward
764
- print(f" β†’ avg_reward={avg_r:+.4f} Ξ”baseline={delta:+.4f}"
765
- f" success={success_r:.0%}"
766
- f" loss={m['loss']:.4f} kl={m['kl']:.4f} ent={m['entropy']:.4f}")
767
- print(f" actions: {dict(action_counts.most_common(5))}")
768
-
769
- # ── PHASE 3: final evaluation ───────────────────────────────────────────
770
- print("\n" + "="*60)
771
- print("PHASE 3 – FINAL EVALUATION")
772
- print("="*60)
773
- final_snap = evaluate(env, model, tokenizer, "Final")
774
 
775
- # ── Summary table ───────────────────────────────────────────────────────
776
- print("\n" + "="*60)
777
- print("TRAINING SUMMARY")
778
- print("="*60)
779
- print(f" {'Stage':<20} {'Reward':>10} {'Success':>10} {'Ξ” baseline':>12}")
780
- print(f" {'-'*54}")
781
- for label, snap in [("Baseline", baseline_snap),
782
- ("Post-warmup", postwarmup_snap),
783
- ("Final (PPO)", final_snap)]:
784
- delta = snap.avg_reward - baseline_snap.avg_reward
785
- print(f" {label:<20} {snap.avg_reward:>+10.4f}"
786
- f" {snap.success_rate:>10.0%} {delta:>+11.4f}")
787
-
788
- improve = final_snap.avg_reward - baseline_snap.avg_reward
789
- verdict = "βœ“ LEARNED" if improve > 0 else "βœ— NO IMPROVEMENT"
790
- print(f"\n {verdict} (total Ξ” = {improve:+.4f})")
791
-
792
- print("\nBefore β†’ After traces (one per difficulty):")
793
- btask = {t["task"]: t for t in baseline_snap.traces}
794
- ftask = {t["task"]: t for t in final_snap.traces}
795
- for task in TASK_LEVELS:
796
- b = btask.get(task, {})
797
- f = ftask.get(task, {})
798
- print(f" {task:8s} baseline actions={b.get('actions',[])} "
799
- f"reward={b.get('reward',0):+.3f}"
800
- f" β”‚ final actions={f.get('actions',[])} "
801
- f"reward={f.get('reward',0):+.3f}")
802
-
803
- # ── Plots ───────────────────────────────────────────────────────────────
804
- plot_all(warmup_losses, reward_hist, success_hist, kl_hist, entropy_hist,
805
- baseline_snap, postwarmup_snap, final_snap)
806
-
807
- print("\nAll done. Saved: training_summary.png action_distribution.png")
 
 
 
 
 
 
 
 
 
 
808
 
 
 
 
809
 
810
  if __name__ == "__main__":
811
- train()
 
1
+ # training.py – Memory‑safe: Phi‑3‑mini + Expert Demos + Fast PPO (2 iterations)
2
+ import os
3
+ os.environ["TRITON_DISABLE"] = "1"
4
+ os.environ["TOKENIZERS_PARALLELISM"] = "false" # Issue #12: prevent OOM from parallel tokenization
5
+
6
+ import torch._dynamo
7
+ torch._dynamo.config.disable = True
8
+ import json
 
 
 
 
 
9
  import torch
10
  import torch.nn.functional as F
11
  from torch.optim import AdamW
12
+ from dataclasses import dataclass
13
+ from typing import List, Dict, Tuple, Optional
 
14
  import numpy as np
15
+ import re
16
+ import random
17
+ import matplotlib.pyplot as plt
18
 
 
19
  from unsloth import FastLanguageModel
20
+ from transformers import TrainingArguments
21
+ from trl import SFTTrainer
22
+ from datasets import Dataset
23
 
24
  from environment import CodeReviewEnv
25
  from redteam import BUG_DB
26
+ from models import (
27
+ RunTests, RunLinter, Inspect,
28
+ ProposeFix, WriteComment, AskQuestion,
29
+ Done, Skip, QueryDocs, map_to_env as model_map_to_env
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  )
31
 
32
+ # ======================================================================
 
 
 
 
 
33
  @dataclass
34
  class AgentAction:
35
  action_type: str
36
  content: Optional[str] = None
37
 
38
+ def parse_action(output: str) -> AgentAction:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
39
  try:
40
+ data = json.loads(output)
41
+ return AgentAction(
42
+ action_type=data.get("action_type", "").lower(),
43
+ content=data.get("content")
44
+ )
45
+ except:
46
  pass
47
+ json_match = re.search(r'```(?:json)?\s*(\{.*?\})\s*```', output, re.DOTALL)
48
+ if json_match:
49
+ try:
50
+ data = json.loads(json_match.group(1))
51
+ return AgentAction(
52
+ action_type=data.get("action_type", "").lower(),
53
+ content=data.get("content")
54
+ )
55
+ except:
56
+ pass
57
+ action_pattern = r'"action_type"\s*:\s*"(\w+)"'
58
+ match = re.search(action_pattern, output)
59
+ if match:
60
+ return AgentAction(action_type=match.group(1).lower())
61
+ output_lower = output.lower()
62
+ if "test" in output_lower:
63
+ return AgentAction("run_tests")
64
+ if "lint" in output_lower:
65
+ return AgentAction("run_linter")
66
+ if "inspect" in output_lower:
67
+ return AgentAction("inspect")
68
+ if "doc" in output_lower or "documentation" in output_lower:
69
+ return AgentAction("query_docs", "bug fix guidance")
70
+ return AgentAction("invalid", output)
71
 
72
  def map_to_env(action: AgentAction):
73
  return model_map_to_env(action.action_type, action.content)
74
 
75
+ # ======================================================================
 
 
76
  def load_model():
 
77
  model, tokenizer = FastLanguageModel.from_pretrained(
78
+ model_name="unsloth/Phi-3-mini-4k-instruct-bnb-4bit",
79
+ max_seq_length=480, # smaller window for memory
80
+ load_in_4bit=True,
81
  )
82
  model = FastLanguageModel.get_peft_model(
83
  model,
84
+ r=16,
85
+ target_modules=[
86
+ "q_proj", "k_proj", "v_proj", "o_proj",
87
+ "gate_proj", "up_proj", "down_proj"
88
+ ],
89
+ lora_alpha=32,
90
+ lora_dropout=0.0,
91
  )
 
 
 
92
  return model, tokenizer
93
 
94
+ def test_model_sanity(model, tokenizer) -> bool:
95
+ print("\n" + "="*60)
96
+ print("SANITY CHECK: Testing base model generation")
97
+ print("="*60)
98
+ test_prompt = "Hello, how are you?"
99
+ messages = [{"role": "user", "content": test_prompt}]
100
+ formatted = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
101
+ inputs = tokenizer(formatted, return_tensors="pt", max_length=256, truncation=True).to("cuda")
102
+ with torch.no_grad():
103
+ outputs = model.generate(
104
+ **inputs,
105
+ max_new_tokens=30,
106
+ do_sample=True,
107
+ temperature=0.7,
108
+ min_new_tokens=1,
109
+ eos_token_id=tokenizer.eos_token_id,
110
+ pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id,
111
+ )
112
+ generated_ids = outputs[0][inputs['input_ids'].shape[1]:]
113
+ response = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
114
+ print(f"Prompt: {test_prompt}")
115
+ print(f"Response: {repr(response)}")
116
+ if len(response) == 0:
117
+ print("❌ Model produces empty output – cannot train.")
118
+ return False
119
+ print("βœ“ Model sanity check PASSED\n")
120
+ return True
121
+
122
+ # ======================================================================
123
+ def _expert_fix_from_context(obs) -> str:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  """
125
+ Build a conservative fix template named `fix` (required by tests).
126
+ Uses bug hints + code snippet patterns to create realistic fixes.
 
 
127
  """
128
+ bug = (getattr(obs, "bug_description", "") or "").lower()
129
+ code = getattr(obs, "code_snippet", "") or ""
130
+
131
+ if "division" in bug or "average" in code.lower():
132
+ return (
133
+ "def fix(data):\n"
134
+ " if not data:\n"
135
+ " return 0\n"
136
+ " return sum(data) / len(data)"
137
+ )
138
 
139
+ if "operator" in bug or "sign" in bug:
140
+ return (
141
+ "def fix(a, b):\n"
142
+ " return a + b"
143
+ )
144
 
145
+ if "off_by_one" in bug or "loop" in bug:
146
+ return (
147
+ "def fix(items):\n"
148
+ " return len(items)"
149
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
 
151
+ if "null" in bug or "key" in bug or "dict" in code.lower():
152
+ return (
153
+ "def fix(payload):\n"
154
+ " users = payload.get('users', {})\n"
155
+ " user_id = payload.get('id')\n"
156
+ " return users.get(user_id)"
157
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
 
159
+ # Concurrency-heavy tasks (harder/hardest).
160
+ if "race" in bug or "missing_lock" in bug or "thread_safe" in bug or "global_nonatomic" in bug:
161
+ return (
162
+ "import threading\n"
163
+ "_lock = threading.Lock()\n"
164
+ "\n"
165
+ "def fix(counter):\n"
166
+ " with _lock:\n"
167
+ " if counter is None:\n"
168
+ " return 0\n"
169
+ " return counter + 1"
170
+ )
171
 
172
+ if "deadlock" in bug or "double_lock" in bug or "lock order" in bug or "nested_lock" in bug:
173
+ return (
174
+ "import threading\n"
175
+ "_lock_a = threading.Lock()\n"
176
+ "_lock_b = threading.Lock()\n"
177
+ "\n"
178
+ "def fix(work):\n"
179
+ " first, second = (_lock_a, _lock_b)\n"
180
+ " if id(first) > id(second):\n"
181
+ " first, second = second, first\n"
182
+ " with first:\n"
183
+ " with second:\n"
184
+ " return work() if callable(work) else work"
185
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
 
187
+ if "fork_join" in bug or "join" in bug:
188
+ return (
189
+ "import threading\n"
190
+ "\n"
191
+ "def fix(worker):\n"
192
+ " t = threading.Thread(target=worker)\n"
193
+ " t.start()\n"
194
+ " t.join()\n"
195
+ " return True"
196
+ )
197
 
198
+ # Generic safe fallback keeps the RL pipeline alive for unknown bugs.
199
+ return (
200
+ "def fix(data):\n"
201
+ " if data is None:\n"
202
+ " return None\n"
203
+ " return data"
204
+ )
205
 
 
 
 
 
206
 
207
+ def _expert_supervised_policy(obs) -> str:
208
+ """
209
+ Real workflow policy:
210
+ inspect -> tests/linter -> docs -> fix -> negotiate -> done.
211
+ """
212
+ author_msg = (getattr(obs, "author_response", "") or "").lower()
213
+ tool_output = (getattr(obs, "last_tool_output", "") or "").lower()
214
+
215
+ if not getattr(obs, "tests_run", False):
216
+ if "inspect" not in tool_output:
217
+ return '{"action_type": "inspect"}'
218
+ return '{"action_type": "run_tests"}'
219
+
220
+ if not getattr(obs, "linter_run", False):
221
+ return '{"action_type": "run_linter"}'
222
+
223
+ if not getattr(obs, "docs_queried", False):
224
+ return '{"action_type": "query_docs", "content": "python bug fixing best practices for edge cases and null safety"}'
225
+
226
+ # Use docs again on hard tasks when evidence is still weak.
227
+ if getattr(obs, "current_test_score", 0.0) < 0.6 and getattr(obs, "step", 0) >= 3:
228
+ bug_hint = (getattr(obs, "bug_description", "") or "concurrency bug").replace('"', "'")
229
+ return json.dumps(
230
+ {
231
+ "action_type": "query_docs",
232
+ "content": f"python {bug_hint} lock ordering race condition mitigation patterns",
233
+ }
234
+ )
235
 
236
+ # If test quality is poor, propose a concrete fix.
237
+ if getattr(obs, "current_test_score", 0.0) < 0.95:
238
+ fix_code = _expert_fix_from_context(obs)
239
+ return json.dumps({"action_type": "fix", "content": fix_code})
240
+
241
+ # If author is still unconvinced, provide causal explanation.
242
+ if author_msg and ("not convinced" in author_msg or "explain" in author_msg or "brief" in author_msg):
243
+ return (
244
+ '{"action_type": "comment", "content": "This fix works because it handles the failing edge case directly, '
245
+ 'keeps behavior deterministic, and aligns with the observed test and lint feedback. '
246
+ 'The change is intentionally small to reduce regression risk."}'
247
+ )
248
 
249
+ # If negotiation is strong enough and quality is good, terminate.
250
+ conf = float(getattr(obs, "author_confidence", 0.0))
251
+ threshold = float(getattr(obs, "author_threshold", 0.5))
252
+ score = float(getattr(obs, "current_test_score", 0.0))
253
+ if conf >= threshold and score >= 0.8:
254
+ return '{"action_type": "done"}'
255
 
256
+ # Nudge conversation forward when tests are okay but acceptance is pending.
257
+ return (
258
+ '{"action_type": "question", "content": "Would you like a quick walkthrough of a failing scenario, the root cause, and how the fix prevents regressions?"}'
259
+ )
260
 
261
+ # ======================================================================
262
+ def supervised_warmup(model, tokenizer, env, n_episodes=16, epochs=1, max_steps=8):
263
+ print("\n" + "="*60)
264
+ print("SUPERVISED WARM-UP: Real environment demonstrations")
265
+ print("="*60)
266
 
267
+ examples = []
268
+ tasks = ["easy", "medium", "hard", "harder", "hardest"]
269
+ for ep in range(n_episodes):
270
+ task = random.choice(tasks)
271
+ env.set_task(task)
272
+ obs = env.reset()
273
+ history = []
274
+ done = False
275
+
276
+ steps = 0
277
+ while not done and steps < max_steps:
278
+ prompt = build_prompt(obs, history)
279
+ action_text = _expert_supervised_policy(obs)
280
+ action = parse_action(action_text)
281
+ env_action = map_to_env(action)
282
+ next_obs, _, done, _ = env.step(env_action)
283
+
284
+ messages = [
285
+ {"role": "user", "content": prompt},
286
+ {"role": "assistant", "content": action_text},
287
+ ]
288
+ full_text = tokenizer.apply_chat_template(messages, tokenize=False)
289
+ examples.append({"text": full_text})
290
+
291
+ history.append(f"Agent: {action_text}")
292
+ history.append(f"Env: {next_obs.last_tool_output}")
293
+ history = history[-8:]
294
+ obs = next_obs
295
+ steps += 1
296
+
297
+ print(f"Supervised episode {ep+1}: task={task}, steps={steps}, done={done}")
298
+
299
+ if not examples:
300
+ print("No supervised examples generated; skipping warm-up.")
301
+ return
302
+
303
+ dataset = Dataset.from_list(examples)
304
+ trainer = SFTTrainer(
305
+ model=model,
306
+ tokenizer=tokenizer,
307
+ train_dataset=dataset,
308
+ dataset_text_field="text",
309
+ max_seq_length=480,
310
+ args=TrainingArguments(
311
+ output_dir="warmup_output",
312
+ num_train_epochs=epochs,
313
+ per_device_train_batch_size=2,
314
+ gradient_accumulation_steps=2,
315
+ learning_rate=2e-5,
316
+ logging_steps=50,
317
+ save_strategy="no",
318
+ bf16=True,
319
+ ),
320
+ )
321
+ print(f"Training on {len(examples)} real env examples for {epochs} epochs...")
322
+ trainer.train()
323
+ print("βœ“ Supervised warm-up (real env) complete\n")
324
+ torch.cuda.empty_cache()
325
 
326
+ # ======================================================================
327
+ def generate_action_with_logprob(prompt, model, tokenizer, temperature=0.0, max_retries=2):
328
+ messages = [{"role": "user", "content": prompt}]
329
+ formatted = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
330
+ inputs = tokenizer(formatted, return_tensors="pt", max_length=480, truncation=True).to("cuda")
331
+
332
+ for attempt in range(max_retries):
333
+ with torch.no_grad():
334
+ outputs = model.generate(
335
+ **inputs,
336
+ max_new_tokens=64,
337
+ do_sample=(temperature > 0),
338
+ temperature=max(temperature, 0.01) if temperature > 0 else 1.0,
339
+ min_new_tokens=1,
340
+ return_dict_in_generate=True,
341
+ output_scores=True,
342
+ )
343
+ generated_ids = outputs.sequences[0][inputs['input_ids'].shape[1]:]
344
+ action_text = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
345
+
346
+ logprobs = []
347
+ for idx, token_id in enumerate(generated_ids):
348
+ if idx < len(outputs.scores):
349
+ token_logits = outputs.scores[idx][0]
350
+ token_logprob = F.log_softmax(token_logits, dim=-1)[token_id].item()
351
+ logprobs.append(token_logprob)
352
+ total_logprob = sum(logprobs) if logprobs else -100.0
353
+
354
+ if not action_text:
355
+ fallback_actions = [
356
+ '{"action_type": "run_tests"}',
357
+ '{"action_type": "run_linter"}',
358
+ '{"action_type": "inspect"}',
359
+ '{"action_type": "skip"}',
360
+ ]
361
+ action_text = random.choice(fallback_actions)
362
+ total_logprob = -50.0
363
+ print(f"[WARN] Empty generation β†’ using fallback: {action_text}")
364
+ return action_text, total_logprob
365
+
366
+ try:
367
+ json.loads(action_text)
368
+ return action_text, total_logprob
369
+ except:
370
+ if attempt == max_retries - 1:
371
+ return '{"action_type":"skip"}', -100.0
372
+ continue
373
+ return '{"action_type":"skip"}', -100.0
374
 
375
+ # ======================================================================
376
+ def build_prompt(obs, history_lines: List[str]) -> str:
377
+ author_msg = getattr(obs, "author_response", "") or ""
378
+ tool_output = getattr(obs, "last_tool_output", "") or ""
379
+ author_personality = getattr(obs, "author_personality", "defensive")
380
+
381
+ prompt = f"""You are an AI code review agent. Your goal is to convince a simulated human developer to accept your proposed fix and name your proposed fix function fix.
382
+
383
+ The developer has a **{author_personality}** personality and will only accept if you provide solid evidence:
384
+ - Tests pass (high pass ratio)
385
+ - Lint is clean (zero errors)
386
+ - Documentation or references are provided
387
+ - Your reasoning is clear, uses words like "because" or "therefore", and is detailed (over 30 words if needed)
388
+
389
+ Workflow:
390
+ 1. Use `inspect` to understand the code.
391
+ 2. Use `run_tests` and `run_linter` to gather evidence.
392
+ 3. Use `query_docs` when you need references or language-specific guidance.
393
+ 4. Propose a fix (`fix`) and explain why it works (`comment` or `question`).
394
+ 5. If the developer pushes back, read their response carefully and address their specific concern.
395
+ 6. Once convinced, use `done` to finish.
396
+
397
+ Code:
398
+ {obs.code_snippet}
399
+
400
+ Author says:
401
+ {author_msg if author_msg else "(no response yet – start with inspection)"}
402
+
403
+ Last tool output:
404
+ {tool_output if tool_output else "(none)"}
405
+
406
+ Available actions:
407
+ run_tests, run_linter, inspect, query_docs, fix, comment, question, done
408
+
409
+ Respond ONLY in JSON:
410
+ {{"action_type": "...", "content": "..."}}"""
411
+
412
+ if history_lines:
413
+ history = "\n".join(history_lines[-6:])
414
+ prompt += f"\n\nPrevious steps:\n{history}"
415
+ return prompt
416
 
417
+ # ======================================================================
418
+ @dataclass
419
+ class Trajectory:
420
+ states: List[str]
421
+ actions: List[str]
422
+ rewards: List[float]
423
+ logprobs: List[float]
424
+ dones: List[bool]
425
+ def __len__(self): return len(self.states)
426
 
427
+ def collect_trajectory(env, model, tokenizer, max_steps=6, temperature=0.0):
428
+ obs = env.reset()
429
+ history_lines = []
430
+ states, actions, rewards, logprobs, dones = [], [], [], [], []
431
+ for step in range(max_steps):
432
+ prompt = build_prompt(obs, history_lines)
433
+ states.append(prompt)
434
+ action_text, logprob = generate_action_with_logprob(prompt, model, tokenizer, temperature)
435
+ actions.append(action_text)
436
+ logprobs.append(logprob)
437
+ action = parse_action(action_text)
438
+ env_action = map_to_env(action)
439
+ next_obs, reward, done, _ = env.step(env_action)
440
+ rewards.append(reward.value)
441
+ dones.append(done)
442
+ history_lines.append(f"Agent: {action_text}")
443
+ history_lines.append(f"Env: {next_obs.last_tool_output}")
444
+ obs = next_obs
445
+ if done: break
446
+ return Trajectory(states, actions, rewards, logprobs, dones)
447
+
448
+ def collect_trajectories(env, model, tokenizer, n_trajectories, max_steps=6,
449
+ task_levels=None, task_weights=None):
450
+ if task_levels is None:
451
+ task_levels = list(BUG_DB.keys())
452
+ if task_weights is not None and len(task_weights) != len(task_levels):
453
+ raise ValueError("task_weights must match task_levels length")
454
+ if task_weights is not None and sum(task_weights) <= 0:
455
+ raise ValueError("task_weights must have a positive total")
456
+ trajectories = []
457
+ for i in range(n_trajectories):
458
+ sampled_task = random.choices(task_levels, weights=task_weights, k=1)[0]
459
+ env.set_task(sampled_task)
460
+ traj = collect_trajectory(env, model, tokenizer, max_steps)
461
+ total_reward = sum(traj.rewards)
462
+ print(f"Trajectory {i+1}/{n_trajectories}: task={sampled_task}, steps={len(traj)}, reward={total_reward:.3f}")
463
+ trajectories.append(traj)
464
+ return trajectories
465
+
466
+ def compute_returns_and_advantages(rewards, dones, gamma=0.99, standardize=True):
467
+ """
468
+ Compute discounted returns and REINFORCE-style baseline advantages.
469
+ Advantages are centered and optionally standardised.
470
+ """
471
+ n = len(rewards)
472
+ returns = [0.0]*n
473
+ running = 0.0
474
+ for t in reversed(range(n)):
475
+ if dones[t]: running = 0.0
476
+ running = rewards[t] + gamma * running
477
+ returns[t] = running
478
+ if standardize:
479
+ advantages = np.array(returns) - np.mean(returns)
480
+ adv_std = np.std(advantages) + 1e-8
481
+ advantages = (advantages / adv_std).tolist()
482
+ else:
483
+ advantages = returns.copy()
484
+ return advantages, returns
485
+
486
+ def ppo_update(trajectories, model, tokenizer, optimizer, n_epochs=1, clip_epsilon=0.2,
487
+ entropy_coef=0.01, gamma=0.99):
488
  model.train()
489
+ all_states, all_actions, all_old_logprobs, all_advantages = [], [], [], []
490
+ for traj in trajectories:
491
+ advantages, _ = compute_returns_and_advantages(traj.rewards, traj.dones, gamma=gamma, standardize=True)
492
+ all_states.extend(traj.states)
493
+ all_actions.extend(traj.actions)
494
+ all_old_logprobs.extend(traj.logprobs)
495
+ all_advantages.extend(advantages)
496
+ n_samples = len(all_states)
497
+ total_loss, total_policy_loss, total_entropy, n_updates = 0.0, 0.0, 0.0, 0
498
+ for epoch in range(n_epochs):
499
+ indices = np.random.permutation(n_samples)
500
+ for i in indices:
501
+ state = all_states[i]
502
+ action = all_actions[i]
503
+ old_logprob = all_old_logprobs[i]
504
+ advantage = all_advantages[i]
505
+ messages = [{"role": "user", "content": state}]
506
+ formatted = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
507
+ full_text = formatted + action
508
+ inputs = tokenizer(full_text, return_tensors="pt", max_length=480, truncation=True).to("cuda")
509
+ outputs = model(**inputs)
510
+ logits = outputs.logits
511
+ action_ids = tokenizer.encode(action, add_special_tokens=False)
512
+ prefix_ids = tokenizer.encode(formatted, add_special_tokens=False)
513
+ action_start = len(prefix_ids)
514
+ logprobs = []
515
+ entropy = 0.0
516
+ for idx, token_id in enumerate(action_ids):
517
+ position = action_start + idx - 1
518
+ if 0 <= position < logits.shape[1]:
519
+ token_logits = logits[0, position]
520
+ log_probs = F.log_softmax(token_logits, dim=-1)
521
+ token_logprob = log_probs[token_id]
522
+ logprobs.append(token_logprob)
523
+ probs = F.softmax(token_logits, dim=-1)
524
+ entropy += -(probs * log_probs).sum()
525
+ if not logprobs: continue
526
+ new_logprob = sum(logprobs)
527
+ avg_entropy = entropy / len(logprobs) if logprobs else 0.0
528
+ ratio = torch.exp(new_logprob - old_logprob)
529
+ surr1 = ratio * advantage
530
+ surr2 = torch.clamp(ratio, 1 - clip_epsilon, 1 + clip_epsilon) * advantage
531
+ policy_loss = -torch.min(surr1, surr2)
532
+ loss = policy_loss - entropy_coef * avg_entropy
533
+ optimizer.zero_grad()
534
+ loss.backward()
535
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
536
+ optimizer.step()
537
+ total_loss += loss.item()
538
+ total_policy_loss += policy_loss.item()
539
+ total_entropy += avg_entropy.item()
540
+ n_updates += 1
541
+ torch.cuda.empty_cache()
542
+ return {"loss": total_loss / n_updates if n_updates else 0.0,
543
+ "policy_loss": total_policy_loss / n_updates if n_updates else 0.0,
544
+ "entropy": total_entropy / n_updates if n_updates else 0.0}
545
 
546
+ def evaluate_policy(env, model, tokenizer, n_episodes=3, max_steps=6,
547
+ task_levels=None, verbose=False):
548
+ """Evaluate the current policy across task levels. Returns metrics + optional traces."""
549
+ model.eval()
550
+ if task_levels is None:
551
+ task_levels = list(BUG_DB.keys())
552
+ total_rewards = []
553
+ traces = [] # human-readable behavior logs
554
+ for ep in range(n_episodes):
555
+ task = task_levels[ep % len(task_levels)]
556
+ env.set_task(task)
557
+ traj = collect_trajectory(env, model, tokenizer, max_steps, temperature=0.0)
558
+ ep_reward = sum(traj.rewards)
559
+ total_rewards.append(ep_reward)
560
+ if verbose:
561
+ actions_taken = []
562
+ for a in traj.actions:
563
+ try:
564
+ actions_taken.append(json.loads(a).get("action_type", "?"))
565
+ except Exception:
566
+ actions_taken.append("?")
567
+ traces.append({
568
+ "task": task,
569
+ "reward": round(ep_reward, 4),
570
+ "steps": len(traj),
571
+ "actions": actions_taken,
572
+ })
573
+ return {
574
+ "avg_reward": float(np.mean(total_rewards)),
575
+ "std_reward": float(np.std(total_rewards)),
576
+ "min_reward": float(np.min(total_rewards)),
577
+ "max_reward": float(np.max(total_rewards)),
578
+ "traces": traces,
579
+ }
580
+
581
+ # ======================================================================
582
+ # MANUAL WARM-UP (no SFTTrainer β†’ no multiprocessing OOM)
583
+ # ======================================================================
584
+ def json_warmup(model, tokenizer, json_path="training_data.json",
585
+ n_episodes=20, epochs=2, lr=2e-5):
586
+ """
587
+ Supervised warm-up from pre-generated expert demonstrations.
588
+ Uses raw cross-entropy on action tokens with manual gradient steps.
589
+ NO SFTTrainer, NO multiprocessing – runs safely on any GPU.
590
+ """
591
+ print("\n" + "="*60)
592
+ print("SUPERVISED WARM-UP: training_data.json (manual cross-entropy)")
593
+ print("="*60)
594
 
595
+ with open(json_path, encoding="utf-8") as f:
596
+ data = json.load(f)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
597
 
598
+ # Each episode = 7 steps. Select n_episodes worth.
599
+ steps_per_episode = 7
600
+ max_examples = n_episodes * steps_per_episode
601
+ if max_examples < len(data):
602
+ data = data[:max_examples]
603
 
604
+ print(f" {len(data)} examples ({len(data)//steps_per_episode} episodes), "
605
+ f"{epochs} epoch(s), lr={lr}")
 
 
 
 
606
 
607
+ model.train()
608
+ warmup_opt = AdamW(model.parameters(), lr=lr)
609
+ warmup_losses = [] # per-epoch avg loss
610
 
611
+ for epoch in range(epochs):
612
+ random.shuffle(data)
613
+ epoch_loss = 0.0
614
+ n_valid = 0
615
+
616
+ for i, example in enumerate(data):
617
+ prompt = example["prompt"]
618
+ action = example["action"]
619
+
620
+ # ---- tokenize full sequence (prompt + action) ----
621
+ messages = [
622
+ {"role": "user", "content": prompt},
623
+ {"role": "assistant", "content": action},
624
+ ]
625
+ full_text = tokenizer.apply_chat_template(messages, tokenize=False)
626
+ inputs = tokenizer(full_text, return_tensors="pt",
627
+ max_length=480, truncation=True).to("cuda")
628
+
629
+ # ---- find where the action tokens start ----
630
+ prompt_only = tokenizer.apply_chat_template(
631
+ [{"role": "user", "content": prompt}],
632
+ tokenize=False, add_generation_prompt=True
633
  )
634
+ prompt_ids = tokenizer.encode(prompt_only, add_special_tokens=False)
635
+ prompt_len = len(prompt_ids)
636
 
637
+ total_len = inputs.input_ids.shape[1]
638
+ if prompt_len >= total_len:
639
+ continue # prompt was truncated away, skip
 
 
 
640
 
641
+ # ---- cross-entropy on action tokens only ----
642
+ outputs = model(**inputs)
643
+ logits = outputs.logits
 
 
644
 
645
+ # next-token prediction: logits[t] predicts token[t+1]
646
+ shift_logits = logits[0, prompt_len - 1 : total_len - 1]
647
+ shift_labels = inputs.input_ids[0, prompt_len : total_len]
648
 
649
+ min_len = min(shift_logits.shape[0], shift_labels.shape[0])
650
+ if min_len == 0:
651
  continue
652
 
653
+ loss = F.cross_entropy(shift_logits[:min_len], shift_labels[:min_len])
654
+
655
+ warmup_opt.zero_grad()
656
  loss.backward()
657
  torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
658
+ warmup_opt.step()
659
 
660
+ epoch_loss += loss.item()
661
+ n_valid += 1
 
 
 
 
 
 
 
 
662
 
663
+ if (i + 1) % 25 == 0:
664
+ avg = epoch_loss / n_valid
665
+ print(f" epoch {epoch+1} step {i+1:3d}/{len(data)} "
666
+ f"running_loss={avg:.4f}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
667
 
668
+ avg_loss = epoch_loss / max(n_valid, 1)
669
+ warmup_losses.append(avg_loss)
670
+ print(f" Epoch {epoch+1} done: avg_loss={avg_loss:.4f} "
671
+ f"({n_valid} valid examples)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
672
 
673
+ torch.cuda.empty_cache()
674
+ print(f"βœ“ Warm-up complete. Loss: "
675
+ f"{' β†’ '.join(f'{l:.4f}' for l in warmup_losses)}\n")
676
+ return warmup_losses
677
+
678
+
679
+ # ======================================================================
680
+ # MAIN TRAINING PIPELINE
681
+ # ======================================================================
682
+ def train_ppo():
683
+ # --- Hyperparameters ---
684
+ n_iterations = 8 # enough for a clear upward trend
685
+ trajectories_per_iter = 4 # on-policy data per iteration
686
+ n_epochs = 1
687
+ max_steps = 6
688
+ learning_rate = 3e-5
689
+ clip_epsilon = 0.2
690
+ entropy_coef = 0.01
691
+ gamma = 0.99
692
+
693
+ # --- Pre-load embedder before LLM (Issue #13) ---
694
+ from rltool import ToolBox
695
+ print("Pre-loading sentence-transformer embedder...")
696
+ ToolBox._get_embedder()
697
+ print("βœ“ Embedder ready")
698
+
699
+ # --- Load model ---
700
+ print("Loading model...")
701
  model, tokenizer = load_model()
702
+ if not test_model_sanity(model, tokenizer):
703
+ return
704
  env = CodeReviewEnv()
705
+ task_levels = list(BUG_DB.keys())
706
 
707
+ # ==================================================================
708
+ # PHASE 0: BASELINE (untrained policy)
709
+ # ==================================================================
710
  print("\n" + "="*60)
711
+ print("PHASE 0 – BASELINE EVALUATION (untrained)")
712
  print("="*60)
713
+ baseline = evaluate_policy(env, model, tokenizer, n_episodes=5,
714
+ max_steps=max_steps, task_levels=task_levels,
715
+ verbose=True)
716
+ baseline_reward = baseline["avg_reward"]
717
+ print(f"Baseline avg reward: {baseline_reward:.4f} "
718
+ f"(min={baseline['min_reward']:.4f}, max={baseline['max_reward']:.4f})")
719
+ print("Baseline behavior:")
720
+ for t in baseline["traces"]:
721
+ print(f" task={t['task']:8s} reward={t['reward']:+.4f} "
722
+ f"steps={t['steps']} actions={t['actions']}")
723
+
724
+ # ==================================================================
725
+ # PHASE 1: SUPERVISED WARM-UP (expert demos, manual CE)
726
+ # ==================================================================
727
+ warmup_losses = json_warmup(
728
+ model, tokenizer,
729
+ json_path="training_data.json",
730
+ n_episodes=20, # 140 examples (20 Γ— 7 steps)
731
+ epochs=2,
732
+ lr=2e-5,
733
+ )
734
 
735
+ # Post-warmup evaluation
736
+ print("="*60)
737
+ print("POST WARM-UP EVALUATION")
738
+ print("="*60)
739
+ post_warmup = evaluate_policy(env, model, tokenizer, n_episodes=5,
740
+ max_steps=max_steps, task_levels=task_levels,
741
+ verbose=True)
742
+ warmup_reward = post_warmup["avg_reward"]
743
+ print(f"Post-warmup avg reward: {warmup_reward:.4f} "
744
+ f"(Ξ” vs baseline: {warmup_reward - baseline_reward:+.4f})")
745
+ print("Post-warmup behavior:")
746
+ for t in post_warmup["traces"]:
747
+ print(f" task={t['task']:8s} reward={t['reward']:+.4f} "
748
+ f"steps={t['steps']} actions={t['actions']}")
749
+
750
+ # ==================================================================
751
+ # PHASE 2: TRUE RL – PPO (on-policy, real environment interaction)
752
+ # ==================================================================
753
+ optimizer = AdamW(model.parameters(), lr=learning_rate)
754
+ print(f"\n{'='*60}")
755
+ print(f"PHASE 2 – PPO TRAINING: {n_iterations} iterations Γ— "
756
+ f"{trajectories_per_iter} trajectories (true RL)")
757
+ print(f"{'='*60}\n")
758
+
759
+ reward_history = []
760
+ eval_history = []
761
+ loss_history = []
762
+ policy_loss_history = []
763
+ entropy_history = []
764
 
765
+ for iteration in range(n_iterations):
766
+ print(f"\n--- PPO Iteration {iteration + 1}/{n_iterations} ---")
 
767
 
768
+ # Collect on-policy trajectories from REAL environment
769
+ trajectories = collect_trajectories(
770
+ env, model, tokenizer, trajectories_per_iter, max_steps,
771
+ task_levels=task_levels, task_weights=None
772
+ )
773
+ avg_reward = float(np.mean([sum(t.rewards) for t in trajectories]))
774
+ reward_history.append(avg_reward)
775
+ print(f" Collect avg reward: {avg_reward:+.4f}")
776
+
777
+ # PPO policy gradient update
778
+ metrics = ppo_update(
779
+ trajectories, model, tokenizer, optimizer,
780
+ n_epochs=n_epochs, clip_epsilon=clip_epsilon,
781
+ entropy_coef=entropy_coef, gamma=gamma
782
+ )
783
+ loss_history.append(float(metrics["loss"]))
784
+ policy_loss_history.append(float(metrics["policy_loss"]))
785
+ entropy_history.append(float(metrics["entropy"]))
786
+ print(f" Update loss={metrics['loss']:.4f} "
787
+ f"policy={metrics['policy_loss']:.4f} "
788
+ f"entropy={metrics['entropy']:.4f}")
789
+
790
+ # Evaluate greedy policy after update
791
+ eval_m = evaluate_policy(env, model, tokenizer, n_episodes=3,
792
+ max_steps=max_steps, task_levels=task_levels,
793
+ verbose=False)
794
+ eval_history.append(eval_m["avg_reward"])
795
+ delta = eval_m["avg_reward"] - baseline_reward
796
+ print(f" Eval avg reward: {eval_m['avg_reward']:+.4f} "
797
+ f"(Ξ” baseline: {delta:+.4f})")
798
+
799
+ # ==================================================================
800
+ # PHASE 3: FINAL EVALUATION (proof of learning)
801
+ # ==================================================================
802
  print("\n" + "="*60)
803
+ print("PHASE 3 – FINAL EVALUATION (after all training)")
 
804
  print("="*60)
805
+ final = evaluate_policy(env, model, tokenizer, n_episodes=5,
806
+ max_steps=max_steps, task_levels=task_levels,
807
+ verbose=True)
808
+ print(f"Final avg reward: {final['avg_reward']:.4f} "
809
+ f"(min={final['min_reward']:.4f}, max={final['max_reward']:.4f})")
810
+ print("Final behavior:")
811
+ for t in final["traces"]:
812
+ print(f" task={t['task']:8s} reward={t['reward']:+.4f} "
813
+ f"steps={t['steps']} actions={t['actions']}")
814
+
815
+ total_improvement = final["avg_reward"] - baseline_reward
816
+ ppo_improvement = final["avg_reward"] - warmup_reward
817
+ print(f"\n{'='*60}")
818
+ print("TRAINING SUMMARY")
819
+ print(f" Baseline reward: {baseline_reward:+.4f}")
820
+ print(f" Post-warmup reward: {warmup_reward:+.4f} "
821
+ f"(warmup Ξ”: {warmup_reward - baseline_reward:+.4f})")
822
+ print(f" Final reward: {final['avg_reward']:+.4f} "
823
+ f"(PPO Ξ”: {ppo_improvement:+.4f})")
824
+ print(f" Total improvement: {total_improvement:+.4f}")
825
+ print(f" Reward trend (PPO): {' β†’ '.join(f'{r:+.3f}' for r in reward_history)}")
826
+ print(f" Loss trend (PPO): {' β†’ '.join(f'{l:.4f}' for l in loss_history)}")
827
+ if total_improvement > 0:
828
+ print(f" βœ“ Agent IMPROVED by {total_improvement:+.4f}")
829
+ else:
830
+ print(f" βœ— No overall improvement detected")
831
+ print(f"{'='*60}")
832
+
833
+ # ==================================================================
834
+ # PLOTS
835
+ # ==================================================================
836
+ iters = list(range(1, n_iterations + 1))
837
+
838
+ # --- 1. Warm-up loss curve ---
839
+ if warmup_losses:
840
+ fig, ax = plt.subplots(figsize=(7, 4))
841
+ ax.plot(range(1, len(warmup_losses) + 1), warmup_losses,
842
+ marker="o", linewidth=2, color="tab:purple")
843
+ ax.set_title("Warm-up Loss (supervised, per epoch)",
844
+ fontsize=13, fontweight="bold")
845
+ ax.set_xlabel("Epoch")
846
+ ax.set_ylabel("Cross-Entropy Loss")
847
+ ax.grid(alpha=0.3)
848
+ fig.tight_layout()
849
+ fig.savefig("warmup_loss.png", dpi=150)
850
+ plt.close(fig)
851
+
852
+ # --- 2. PPO reward curve ---
853
+ fig, ax = plt.subplots(figsize=(9, 5))
854
+ ax.plot(iters, reward_history, marker="o", linewidth=2,
855
+ label="Collect reward", color="tab:blue")
856
+ ax.plot(iters, eval_history, marker="s", linewidth=2, linestyle="--",
857
+ label="Eval reward", color="tab:green")
858
+ ax.axhline(y=baseline_reward, color="tab:gray", linestyle=":",
859
+ linewidth=1.5, label=f"Baseline ({baseline_reward:+.3f})")
860
+ ax.axhline(y=warmup_reward, color="tab:purple", linestyle=":",
861
+ linewidth=1.5, label=f"Post-warmup ({warmup_reward:+.3f})")
862
+ ax.set_title("PPO Reward per Iteration", fontsize=14, fontweight="bold")
863
+ ax.set_xlabel("Iteration")
864
+ ax.set_ylabel("Average Reward")
865
+ ax.legend(loc="best", fontsize=8)
866
+ ax.grid(alpha=0.3)
867
+ fig.tight_layout()
868
+ fig.savefig("reward_curve.png", dpi=150)
869
+ plt.close(fig)
870
 
871
+ # --- 3. PPO loss curve ---
872
+ fig, ax = plt.subplots(figsize=(9, 5))
873
+ ax.plot(iters, loss_history, marker="o", linewidth=2,
874
+ label="Total loss", color="tab:red")
875
+ ax.plot(iters, policy_loss_history, marker="^", linewidth=2, linestyle="--",
876
+ label="Policy loss", color="tab:orange")
877
+ ax.set_title("PPO Loss per Iteration", fontsize=14, fontweight="bold")
878
+ ax.set_xlabel("Iteration")
879
+ ax.set_ylabel("Loss")
880
+ ax.legend(loc="best")
881
+ ax.grid(alpha=0.3)
882
+ fig.tight_layout()
883
+ fig.savefig("loss_curve.png", dpi=150)
884
+ plt.close(fig)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
885
 
886
+ # --- 4. Combined 3-panel summary ---
887
+ fig, axes = plt.subplots(1, 3, figsize=(18, 5))
888
+
889
+ # Panel A: warm-up loss
890
+ if warmup_losses:
891
+ axes[0].plot(range(1, len(warmup_losses) + 1), warmup_losses,
892
+ marker="o", linewidth=2, color="tab:purple")
893
+ axes[0].set_title("A. Warm-up Loss ↓")
894
+ axes[0].set_xlabel("Epoch")
895
+ axes[0].set_ylabel("CE Loss")
896
+ axes[0].grid(alpha=0.3)
897
+
898
+ # Panel B: PPO reward
899
+ axes[1].plot(iters, reward_history, marker="o", linewidth=2,
900
+ color="tab:blue", label="Collect")
901
+ axes[1].plot(iters, eval_history, marker="s", linewidth=2,
902
+ linestyle="--", color="tab:green", label="Eval")
903
+ axes[1].axhline(y=baseline_reward, color="tab:gray", linestyle=":",
904
+ linewidth=1.5, label="Baseline")
905
+ axes[1].axhline(y=warmup_reward, color="tab:purple", linestyle=":",
906
+ linewidth=1.5, label="Post-warmup")
907
+ axes[1].set_title("B. PPO Reward ↑")
908
+ axes[1].set_xlabel("Iteration")
909
+ axes[1].set_ylabel("Avg Reward")
910
+ axes[1].legend(fontsize=7)
911
+ axes[1].grid(alpha=0.3)
912
+
913
+ # Panel C: PPO loss
914
+ axes[2].plot(iters, loss_history, marker="o", linewidth=2,
915
+ color="tab:red", label="Total")
916
+ axes[2].plot(iters, policy_loss_history, marker="^", linewidth=2,
917
+ linestyle="--", color="tab:orange", label="Policy")
918
+ axes[2].set_title("C. PPO Loss ↓")
919
+ axes[2].set_xlabel("Iteration")
920
+ axes[2].set_ylabel("Loss")
921
+ axes[2].legend(fontsize=7)
922
+ axes[2].grid(alpha=0.3)
923
+
924
+ fig.suptitle("Code Review Agent – Full Training Evidence",
925
+ fontsize=14, fontweight="bold")
926
+ fig.tight_layout()
927
+ fig.savefig("training_summary.png", dpi=150)
928
+ plt.close(fig)
929
 
930
+ print("Plots saved: warmup_loss.png, reward_curve.png, "
931
+ "loss_curve.png, training_summary.png")
932
+ print("="*60)
933
 
934
  if __name__ == "__main__":
935
+ train_ppo()