100XZX001 commited on
Commit
a604258
·
verified ·
1 Parent(s): 6d77d18

Update training.py

Browse files
Files changed (1) hide show
  1. training.py +229 -674
training.py CHANGED
@@ -1,4 +1,6 @@
1
- # training.py – Vanilla bitsandbytes QLoRA + custom PPO (no unsloth, no Triton)
 
 
2
  import os
3
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
4
 
@@ -7,761 +9,314 @@ import torch
7
  import torch.nn.functional as F
8
  from torch.optim import AdamW
9
  from dataclasses import dataclass
10
- from typing import List, Dict, Tuple, Optional
11
  import numpy as np
12
- import re
13
  import random
14
  import matplotlib
15
- matplotlib.use('Agg')
16
  import matplotlib.pyplot as plt
 
17
 
18
- from transformers import (
19
- AutoModelForCausalLM,
20
- AutoTokenizer,
21
- AutoConfig,
22
- BitsAndBytesConfig,
23
- TrainingArguments
24
- )
25
  from peft import LoraConfig, get_peft_model, TaskType
26
 
27
  from environment import CodeReviewEnv
28
  from redteam import BUG_DB
29
- from models import (
30
- RunTests, RunLinter, Inspect,
31
- ProposeFix, WriteComment, AskQuestion,
32
- Done, Skip, QueryDocs, map_to_env as model_map_to_env
33
- )
 
34
 
35
- # ======================================================================
 
 
36
  @dataclass
37
  class AgentAction:
38
  action_type: str
39
  content: Optional[str] = None
40
 
 
 
 
 
 
 
 
 
 
 
 
41
  def parse_action(output: str) -> AgentAction:
42
  try:
43
  data = json.loads(output)
44
- return AgentAction(
45
- action_type=data.get("action_type", "").lower(),
46
- content=data.get("content")
47
- )
48
  except:
49
- pass
50
- json_match = re.search(r'```(?:json)?\s*(\{.*?\})\s*```', output, re.DOTALL)
51
- if json_match:
52
- try:
53
- data = json.loads(json_match.group(1))
54
- return AgentAction(
55
- action_type=data.get("action_type", "").lower(),
56
- content=data.get("content")
57
- )
58
- except:
59
- pass
60
- action_pattern = r'"action_type"\s*:\s*"(\w+)"'
61
- match = re.search(action_pattern, output)
62
- if match:
63
- return AgentAction(action_type=match.group(1).lower())
64
- output_lower = output.lower()
65
- if "test" in output_lower:
66
- return AgentAction("run_tests")
67
- if "lint" in output_lower:
68
- return AgentAction("run_linter")
69
- if "inspect" in output_lower:
70
- return AgentAction("inspect")
71
- if "doc" in output_lower or "documentation" in output_lower:
72
- return AgentAction("query_docs", "bug fix guidance")
73
- return AgentAction("invalid", output)
74
 
75
  def map_to_env(action: AgentAction):
76
  return model_map_to_env(action.action_type, action.content)
77
 
78
- # ======================================================================
79
- # Model loading – no unsloth, no Triton kernels
80
- # ======================================================================
81
  def load_model():
82
  model_name = "microsoft/Phi-3-mini-4k-instruct"
83
 
84
- bnb_config = BitsAndBytesConfig(
85
  load_in_4bit=True,
86
  bnb_4bit_compute_dtype=torch.bfloat16,
87
- bnb_4bit_use_double_quant=True,
88
- bnb_4bit_quant_type="nf4",
89
  )
90
 
91
  model = AutoModelForCausalLM.from_pretrained(
92
  model_name,
93
- quantization_config=bnb_config,
94
  device_map="auto",
95
- attn_implementation="eager", # avoid flash‑attn
96
- torch_dtype=torch.bfloat16,
97
  )
98
 
99
  tokenizer = AutoTokenizer.from_pretrained(model_name)
100
  tokenizer.pad_token = tokenizer.eos_token
101
 
102
- lora_config = LoraConfig(
103
  r=16,
104
  lora_alpha=32,
105
- target_modules=[
106
- "q_proj", "k_proj", "v_proj", "o_proj",
107
- "gate_proj", "up_proj", "down_proj"
108
- ],
109
- lora_dropout=0.0,
110
- bias="none",
111
- task_type=TaskType.CAUSAL_LM,
112
  )
113
 
114
- model = get_peft_model(model, lora_config)
 
 
115
  return model, tokenizer
116
 
117
- # ======================================================================
118
- def test_model_sanity(model, tokenizer) -> bool:
119
- print("\n" + "="*60)
120
- print("SANITY CHECK: Testing base model generation")
121
- print("="*60)
122
- test_prompt = "Hello, how are you?"
123
- messages = [{"role": "user", "content": test_prompt}]
124
  formatted = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
125
- inputs = tokenizer(formatted, return_tensors="pt", max_length=256, truncation=True).to("cuda")
126
- with torch.no_grad():
127
- outputs = model.generate(
128
- **inputs,
129
- max_new_tokens=30,
130
- do_sample=True,
131
- temperature=0.7,
132
- min_new_tokens=1,
133
- eos_token_id=tokenizer.eos_token_id,
134
- pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id,
135
- )
136
- generated_ids = outputs[0][inputs['input_ids'].shape[1]:]
137
- response = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
138
- print(f"Prompt: {test_prompt}")
139
- print(f"Response: {repr(response)}")
140
- if len(response) == 0:
141
- print("❌ Model produces empty output – cannot train.")
142
- return False
143
- print("✓ Model sanity check PASSED\n")
144
- return True
145
-
146
- # ======================================================================
147
- def _expert_fix_from_context(obs) -> str:
148
- """Build a conservative fix template based on bug hints."""
149
- bug = (getattr(obs, "bug_description", "") or "").lower()
150
- code = getattr(obs, "code_snippet", "") or ""
151
-
152
- if "division" in bug or "average" in code.lower():
153
- return (
154
- "def fix(data):\n"
155
- " if not data:\n"
156
- " return 0\n"
157
- " return sum(data) / len(data)"
158
- )
159
-
160
- if "operator" in bug or "sign" in bug:
161
- return (
162
- "def fix(a, b):\n"
163
- " return a + b"
164
- )
165
-
166
- if "off_by_one" in bug or "loop" in bug:
167
- return (
168
- "def fix(items):\n"
169
- " return len(items)"
170
- )
171
-
172
- if "null" in bug or "key" in bug or "dict" in code.lower():
173
- return (
174
- "def fix(payload):\n"
175
- " users = payload.get('users', {})\n"
176
- " user_id = payload.get('id')\n"
177
- " return users.get(user_id)"
178
- )
179
-
180
- if "race" in bug or "missing_lock" in bug or "thread_safe" in bug or "global_nonatomic" in bug:
181
- return (
182
- "import threading\n"
183
- "_lock = threading.Lock()\n"
184
- "\n"
185
- "def fix(counter):\n"
186
- " with _lock:\n"
187
- " if counter is None:\n"
188
- " return 0\n"
189
- " return counter + 1"
190
- )
191
-
192
- if "deadlock" in bug or "double_lock" in bug or "lock order" in bug or "nested_lock" in bug:
193
- return (
194
- "import threading\n"
195
- "_lock_a = threading.Lock()\n"
196
- "_lock_b = threading.Lock()\n"
197
- "\n"
198
- "def fix(work):\n"
199
- " first, second = (_lock_a, _lock_b)\n"
200
- " if id(first) > id(second):\n"
201
- " first, second = second, first\n"
202
- " with first:\n"
203
- " with second:\n"
204
- " return work() if callable(work) else work"
205
- )
206
-
207
- if "fork_join" in bug or "join" in bug:
208
- return (
209
- "import threading\n"
210
- "\n"
211
- "def fix(worker):\n"
212
- " t = threading.Thread(target=worker)\n"
213
- " t.start()\n"
214
- " t.join()\n"
215
- " return True"
216
- )
217
-
218
- return (
219
- "def fix(data):\n"
220
- " if data is None:\n"
221
- " return None\n"
222
- " return data"
223
- )
224
 
 
225
 
226
- def _expert_supervised_policy(obs) -> str:
227
- """Real workflow policy: inspect -> tests/linter -> docs -> fix -> negotiate -> done."""
228
- author_msg = (getattr(obs, "author_response", "") or "").lower()
229
- tool_output = (getattr(obs, "last_tool_output", "") or "").lower()
230
-
231
- if not getattr(obs, "tests_run", False):
232
- if "inspect" not in tool_output:
233
- return '{"action_type": "inspect"}'
234
- return '{"action_type": "run_tests"}'
235
-
236
- if not getattr(obs, "linter_run", False):
237
- return '{"action_type": "run_linter"}'
238
-
239
- if not getattr(obs, "docs_queried", False):
240
- return '{"action_type": "query_docs", "content": "python bug fixing best practices for edge cases and null safety"}'
241
-
242
- if getattr(obs, "current_test_score", 0.0) < 0.6 and getattr(obs, "step", 0) >= 3:
243
- bug_hint = (getattr(obs, "bug_description", "") or "concurrency bug").replace('"', "'")
244
- return json.dumps(
245
- {
246
- "action_type": "query_docs",
247
- "content": f"python {bug_hint} lock ordering race condition mitigation patterns",
248
- }
249
- )
250
-
251
- if getattr(obs, "current_test_score", 0.0) < 0.95:
252
- fix_code = _expert_fix_from_context(obs)
253
- return json.dumps({"action_type": "fix", "content": fix_code})
254
-
255
- if author_msg and ("not convinced" in author_msg or "explain" in author_msg or "brief" in author_msg):
256
- return (
257
- '{"action_type": "comment", "content": "This fix works because it handles the failing edge case directly, '
258
- 'keeps behavior deterministic, and aligns with the observed test and lint feedback. '
259
- 'The change is intentionally small to reduce regression risk."}'
260
- )
261
-
262
- conf = float(getattr(obs, "author_confidence", 0.0))
263
- threshold = float(getattr(obs, "author_threshold", 0.5))
264
- score = float(getattr(obs, "current_test_score", 0.0))
265
- if conf >= threshold and score >= 0.8:
266
- return '{"action_type": "done"}'
267
-
268
- return (
269
- '{"action_type": "question", "content": "Would you like a quick walkthrough of a failing scenario, the root cause, and how the fix prevents regressions?"}'
270
  )
271
 
272
- # ======================================================================
273
- def build_prompt(obs, history_lines: List[str]) -> str:
274
- author_msg = getattr(obs, "author_response", "") or ""
275
- tool_output = getattr(obs, "last_tool_output", "") or ""
276
- author_personality = getattr(obs, "author_personality", "defensive")
277
 
278
- prompt = f"""You are an AI code review agent. Your goal is to convince a simulated human developer to accept your proposed fix and name your proposed fix function fix.
 
 
 
 
 
279
 
280
- The developer has a **{author_personality}** personality and will only accept if you provide solid evidence:
281
- - Tests pass (high pass ratio)
282
- - Lint is clean (zero errors)
283
- - Documentation or references are provided
284
- - Your reasoning is clear, uses words like "because" or "therefore", and is detailed (over 30 words if needed)
285
 
286
- Workflow:
287
- 1. Use `inspect` to understand the code.
288
- 2. Use `run_tests` and `run_linter` to gather evidence.
289
- 3. Use `query_docs` when you need references or language-specific guidance.
290
- 4. Propose a fix (`fix`) and explain why it works (`comment` or `question`).
291
- 5. If the developer pushes back, read their response carefully and address their specific concern.
292
- 6. Once convinced, use `done` to finish.
293
 
294
- Code:
295
- {obs.code_snippet}
 
 
 
296
 
297
- Author says:
298
- {author_msg if author_msg else "(no response yet – start with inspection)"}
299
 
300
- Last tool output:
301
- {tool_output if tool_output else "(none)"}
302
 
303
- Available actions:
304
- run_tests, run_linter, inspect, query_docs, fix, comment, question, done
305
 
306
- Respond ONLY in JSON:
307
- {{"action_type": "...", "content": "..."}}"""
308
 
309
- if history_lines:
310
- history = "\n".join(history_lines[-6:])
311
- prompt += f"\n\nPrevious steps:\n{history}"
312
- return prompt
313
 
314
- # ======================================================================
315
- @dataclass
316
- class Trajectory:
317
- states: List[str]
318
- actions: List[str]
319
- rewards: List[float]
320
- logprobs: List[float]
321
- dones: List[bool]
322
- def __len__(self): return len(self.states)
323
-
324
- def generate_action_with_logprob(prompt, model, tokenizer, temperature=0.0, max_retries=2):
325
- messages = [{"role": "user", "content": prompt}]
326
- formatted = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
327
- # 1024 max length, no unsloth
328
- inputs = tokenizer(formatted, return_tensors="pt", max_length=1024, truncation=True).to("cuda")
329
-
330
- for attempt in range(max_retries):
331
- with torch.no_grad():
332
- outputs = model.generate(
333
- **inputs,
334
- max_new_tokens=128,
335
- do_sample=(temperature > 0),
336
- temperature=max(temperature, 0.01) if temperature > 0 else 1.0,
337
- min_new_tokens=1,
338
- return_dict_in_generate=True,
339
- output_scores=True,
340
- )
341
- generated_ids = outputs.sequences[0][inputs['input_ids'].shape[1]:]
342
- action_text = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
343
-
344
- logprobs = []
345
- for idx, token_id in enumerate(generated_ids):
346
- if idx < len(outputs.scores):
347
- token_logits = outputs.scores[idx][0]
348
- token_logprob = F.log_softmax(token_logits, dim=-1)[token_id].item()
349
- logprobs.append(token_logprob)
350
- total_logprob = sum(logprobs) if logprobs else -100.0
351
-
352
- if not action_text:
353
- fallback_actions = [
354
- '{"action_type": "run_tests"}',
355
- '{"action_type": "run_linter"}',
356
- '{"action_type": "inspect"}',
357
- '{"action_type": "skip"}',
358
- ]
359
- action_text = random.choice(fallback_actions)
360
- total_logprob = -50.0
361
- print(f"[WARN] Empty generation → using fallback: {action_text}")
362
- return action_text, total_logprob
363
-
364
- try:
365
- json.loads(action_text)
366
- return action_text, total_logprob
367
- except:
368
- if attempt == max_retries - 1:
369
- return '{"action_type":"skip"}', -100.0
370
- continue
371
- return '{"action_type":"skip"}', -100.0
372
-
373
- def collect_trajectory(env, model, tokenizer, max_steps=6, temperature=0.0):
374
- obs = env.reset()
375
- history_lines = []
376
- states, actions, rewards, logprobs, dones = [], [], [], [], []
377
- for step in range(max_steps):
378
- prompt = build_prompt(obs, history_lines)
379
- states.append(prompt)
380
- action_text, logprob = generate_action_with_logprob(prompt, model, tokenizer, temperature)
381
  actions.append(action_text)
382
- logprobs.append(logprob)
 
383
  action = parse_action(action_text)
384
  env_action = map_to_env(action)
385
- next_obs, reward, done, _ = env.step(env_action)
386
- rewards.append(reward.value)
 
 
387
  dones.append(done)
388
- history_lines.append(f"Agent: {action_text}")
389
- history_lines.append(f"Env: {next_obs.last_tool_output}")
390
- obs = next_obs
391
- if done: break
392
- return Trajectory(states, actions, rewards, logprobs, dones)
393
-
394
- def collect_trajectories(env, model, tokenizer, n_trajectories, max_steps=6,
395
- task_levels=None, task_weights=None):
396
- if task_levels is None:
397
- task_levels = list(BUG_DB.keys())
398
- if task_weights is not None and len(task_weights) != len(task_levels):
399
- raise ValueError("task_weights must match task_levels length")
400
- if task_weights is not None and sum(task_weights) <= 0:
401
- raise ValueError("task_weights must have a positive total")
402
- trajectories = []
403
- for i in range(n_trajectories):
404
- sampled_task = random.choices(task_levels, weights=task_weights, k=1)[0]
405
- env.set_task(sampled_task)
406
- traj = collect_trajectory(env, model, tokenizer, max_steps)
407
- total_reward = sum(traj.rewards)
408
- print(f"Trajectory {i+1}/{n_trajectories}: task={sampled_task}, steps={len(traj)}, reward={total_reward:.3f}")
409
- trajectories.append(traj)
410
- return trajectories
411
-
412
- def compute_returns_and_advantages(rewards, dones, gamma=0.99, standardize=True):
413
- n = len(rewards)
414
- returns = [0.0]*n
415
- running = 0.0
416
- for t in reversed(range(n)):
417
- if dones[t]: running = 0.0
418
- running = rewards[t] + gamma * running
419
- returns[t] = running
420
- if standardize:
421
- advantages = np.array(returns) - np.mean(returns)
422
- adv_std = np.std(advantages) + 1e-8
423
- advantages = (advantages / adv_std).tolist()
424
- else:
425
- advantages = returns.copy()
426
- return advantages, returns
427
-
428
- def ppo_update(trajectories, model, tokenizer, optimizer, n_epochs=1, clip_epsilon=0.2,
429
- entropy_coef=0.01, gamma=0.99):
430
  model.train()
431
- all_states, all_actions, all_old_logprobs, all_advantages = [], [], [], []
 
 
 
432
  for traj in trajectories:
433
- advantages, _ = compute_returns_and_advantages(traj.rewards, traj.dones, gamma=gamma, standardize=True)
434
- all_states.extend(traj.states)
435
- all_actions.extend(traj.actions)
436
- all_old_logprobs.extend(traj.logprobs)
437
- all_advantages.extend(advantages)
438
- n_samples = len(all_states)
439
- total_loss, total_policy_loss, total_entropy, n_updates = 0.0, 0.0, 0.0, 0
440
- for epoch in range(n_epochs):
441
- indices = np.random.permutation(n_samples)
442
- for i in indices:
443
- state = all_states[i]
444
- action = all_actions[i]
445
- old_logprob = all_old_logprobs[i]
446
- advantage = all_advantages[i]
447
  messages = [{"role": "user", "content": state}]
448
  formatted = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
449
- full_text = formatted + action
450
- inputs = tokenizer(full_text, return_tensors="pt", max_length=1024, truncation=True).to("cuda")
451
- outputs = model(**inputs)
452
- logits = outputs.logits
 
453
  action_ids = tokenizer.encode(action, add_special_tokens=False)
454
- prefix_ids = tokenizer.encode(formatted, add_special_tokens=False)
455
- action_start = len(prefix_ids)
456
- logprobs = []
457
  entropy = 0.0
458
- for idx, token_id in enumerate(action_ids):
459
- position = action_start + idx - 1
460
- if 0 <= position < logits.shape[1]:
461
- token_logits = logits[0, position]
462
- log_probs = F.log_softmax(token_logits, dim=-1)
463
- token_logprob = log_probs[token_id]
464
- logprobs.append(token_logprob)
465
- probs = F.softmax(token_logits, dim=-1)
466
- entropy += -(probs * log_probs).sum()
467
- if not logprobs: continue
468
- new_logprob = sum(logprobs)
469
- avg_entropy = entropy / len(logprobs) if logprobs else 0.0
470
- ratio = torch.exp(new_logprob - old_logprob)
471
- surr1 = ratio * advantage
472
- surr2 = torch.clamp(ratio, 1 - clip_epsilon, 1 + clip_epsilon) * advantage
473
- policy_loss = -torch.min(surr1, surr2)
474
- loss = policy_loss - entropy_coef * avg_entropy
475
- optimizer.zero_grad()
476
- loss.backward()
477
- torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
478
- optimizer.step()
479
- total_loss += loss.item()
480
- total_policy_loss += policy_loss.item()
481
- total_entropy += avg_entropy.item()
482
- n_updates += 1
483
- torch.cuda.empty_cache()
484
- return {"loss": total_loss / n_updates if n_updates else 0.0,
485
- "policy_loss": total_policy_loss / n_updates if n_updates else 0.0,
486
- "entropy": total_entropy / n_updates if n_updates else 0.0}
487
-
488
- def evaluate_policy(env, model, tokenizer, n_episodes=3, max_steps=6,
489
- task_levels=None, verbose=False):
490
- model.eval()
491
- if task_levels is None:
492
- task_levels = list(BUG_DB.keys())
493
- total_rewards = []
494
- traces = []
495
- for ep in range(n_episodes):
496
- task = task_levels[ep % len(task_levels)]
497
- env.set_task(task)
498
- traj = collect_trajectory(env, model, tokenizer, max_steps, temperature=0.0)
499
- ep_reward = sum(traj.rewards)
500
- total_rewards.append(ep_reward)
501
- if verbose:
502
- actions_taken = []
503
- for a in traj.actions:
504
- try: actions_taken.append(json.loads(a).get("action_type", "?"))
505
- except: actions_taken.append("?")
506
- traces.append({
507
- "task": task,
508
- "reward": round(ep_reward, 4),
509
- "steps": len(traj),
510
- "actions": actions_taken,
511
- })
512
- return {
513
- "avg_reward": float(np.mean(total_rewards)),
514
- "std_reward": float(np.std(total_rewards)),
515
- "min_reward": float(np.min(total_rewards)),
516
- "max_reward": float(np.max(total_rewards)),
517
- "traces": traces,
518
- }
519
-
520
- # ======================================================================
521
- # Manual warm-up from JSON (no SFTTrainer, no Unsloth)
522
- def json_warmup(model, tokenizer, json_path="training_data.json",
523
- n_episodes=25, epochs=3, lr=2e-5):
524
- print("\n" + "="*60)
525
- print("SUPERVISED WARM-UP: training_data.json (manual cross-entropy)")
526
- print("="*60)
527
-
528
- with open(json_path, encoding="utf-8") as f:
529
- data = json.load(f)
530
-
531
- steps_per_episode = 7
532
- max_examples = n_episodes * steps_per_episode
533
- if max_examples < len(data):
534
- data = data[:max_examples]
535
-
536
- print(f" {len(data)} examples ({len(data)//steps_per_episode} episodes), "
537
- f"{epochs} epoch(s), lr={lr}")
538
 
539
- model.train()
540
- warmup_opt = AdamW(model.parameters(), lr=lr)
541
- warmup_losses = []
542
-
543
- for epoch in range(epochs):
544
- random.shuffle(data)
545
- epoch_loss = 0.0
546
- n_valid = 0
547
-
548
- for i, example in enumerate(data):
549
- prompt = example["prompt"]
550
- action = example["action"]
551
-
552
- messages = [
553
- {"role": "user", "content": prompt},
554
- {"role": "assistant", "content": action},
555
- ]
556
- full_text = tokenizer.apply_chat_template(messages, tokenize=False)
557
- inputs = tokenizer(full_text, return_tensors="pt", max_length=1024, truncation=True).to("cuda")
558
-
559
- prompt_only = tokenizer.apply_chat_template(
560
- [{"role": "user", "content": prompt}],
561
- tokenize=False, add_generation_prompt=True
562
- )
563
- prompt_ids = tokenizer.encode(prompt_only, add_special_tokens=False)
564
- prompt_len = len(prompt_ids)
565
-
566
- total_len = inputs.input_ids.shape[1]
567
- if prompt_len >= total_len:
568
- continue
569
 
570
- outputs = model(**inputs)
571
- logits = outputs.logits
572
 
573
- shift_logits = logits[0, prompt_len - 1 : total_len - 1]
574
- shift_labels = inputs.input_ids[0, prompt_len : total_len]
575
 
576
- min_len = min(shift_logits.shape[0], shift_labels.shape[0])
577
- if min_len == 0:
 
 
578
  continue
579
 
580
- loss = F.cross_entropy(shift_logits[:min_len], shift_labels[:min_len])
 
 
 
 
581
 
582
- warmup_opt.zero_grad()
 
 
 
 
 
583
  loss.backward()
584
  torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
585
- warmup_opt.step()
586
-
587
- epoch_loss += loss.item()
588
- n_valid += 1
589
-
590
- if (i + 1) % 25 == 0:
591
- avg = epoch_loss / n_valid
592
- print(f" epoch {epoch+1} step {i+1:3d}/{len(data)} "
593
- f"running_loss={avg:.4f}")
594
-
595
- avg_loss = epoch_loss / max(n_valid, 1)
596
- warmup_losses.append(avg_loss)
597
- print(f" Epoch {epoch+1} done: avg_loss={avg_loss:.4f} "
598
- f"({n_valid} valid examples)")
599
-
600
- torch.cuda.empty_cache()
601
- print(f"✓ Warm-up complete. Loss: "
602
- f"{' → '.join(f'{l:.4f}' for l in warmup_losses)}\n")
603
- return warmup_losses
604
-
605
-
606
- # ======================================================================
607
- # MAIN TRAINING PIPELINE
608
- def train_ppo():
609
- n_iterations = 15
610
- trajectories_per_iter = 6
611
- n_epochs = 2
612
- max_steps = 8
613
- learning_rate = 3e-5
614
- clip_epsilon = 0.2
615
- entropy_coef = 0.01
616
- gamma = 0.99
617
-
618
- # Pre-load embedder (unchanged)
619
- from rltool import ToolBox
620
- print("Pre-loading sentence-transformer embedder...")
621
- ToolBox._get_embedder()
622
- print("✓ Embedder ready")
623
 
 
 
 
 
 
 
 
 
 
 
624
  model, tokenizer = load_model()
625
- if not test_model_sanity(model, tokenizer):
626
- return
627
  env = CodeReviewEnv()
 
 
 
 
 
628
  task_levels = list(BUG_DB.keys())
629
 
630
- # Phase 0: baseline
631
- print("\n" + "="*60)
632
- print("PHASE 0 – BASELINE EVALUATION (untrained)")
633
- print("="*60)
634
- baseline = evaluate_policy(env, model, tokenizer, n_episodes=5,
635
- max_steps=max_steps, task_levels=task_levels,
636
- verbose=True)
637
- baseline_reward = baseline["avg_reward"]
638
- print(f"Baseline avg reward: {baseline_reward:.4f} "
639
- f"(min={baseline['min_reward']:.4f}, max={baseline['max_reward']:.4f})")
640
- print("Baseline behavior:")
641
- for t in baseline["traces"]:
642
- print(f" task={t['task']:8s} reward={t['reward']:+.4f} "
643
- f"steps={t['steps']} actions={t['actions']}")
644
-
645
- # Phase 1: supervised warm-up
646
- warmup_losses = json_warmup(model, tokenizer, json_path="training_data.json",
647
- n_episodes=25, epochs=3, lr=2e-5)
648
-
649
- print("="*60)
650
- print("POST WARM-UP EVALUATION")
651
- print("="*60)
652
- post_warmup = evaluate_policy(env, model, tokenizer, n_episodes=5,
653
- max_steps=max_steps, task_levels=task_levels,
654
- verbose=True)
655
- warmup_reward = post_warmup["avg_reward"]
656
- print(f"Post-warmup avg reward: {warmup_reward:.4f} "
657
- f"(Δ vs baseline: {warmup_reward - baseline_reward:+.4f})")
658
- print("Post-warmup behavior:")
659
- for t in post_warmup["traces"]:
660
- print(f" task={t['task']:8s} reward={t['reward']:+.4f} "
661
- f"steps={t['steps']} actions={t['actions']}")
662
-
663
- optimizer = AdamW(model.parameters(), lr=learning_rate)
664
- print(f"\n{'='*60}")
665
- print(f"PHASE 2 – PPO TRAINING: {n_iterations} iterations × "
666
- f"{trajectories_per_iter} trajectories (true RL)")
667
- print(f"{'='*60}\n")
668
-
669
- reward_history, eval_history, loss_history = [], [], []
670
- for iteration in range(n_iterations):
671
- print(f"\n--- PPO Iteration {iteration + 1}/{n_iterations} ---")
672
- trajectories = collect_trajectories(
673
- env, model, tokenizer, trajectories_per_iter, max_steps,
674
- task_levels=task_levels, task_weights=None
675
- )
676
- avg_reward = float(np.mean([sum(t.rewards) for t in trajectories]))
677
- reward_history.append(avg_reward)
678
- print(f" Collect avg reward: {avg_reward:+.4f}")
679
-
680
- metrics = ppo_update(
681
- trajectories, model, tokenizer, optimizer,
682
- n_epochs=n_epochs, clip_epsilon=clip_epsilon,
683
- entropy_coef=entropy_coef, gamma=gamma
684
- )
685
- loss_history.append(float(metrics["loss"]))
686
- print(f" Update loss={metrics['loss']:.4f} "
687
- f"policy={metrics['policy_loss']:.4f} "
688
- f"entropy={metrics['entropy']:.4f}")
689
-
690
- eval_m = evaluate_policy(env, model, tokenizer, n_episodes=3,
691
- max_steps=max_steps, task_levels=task_levels,
692
- verbose=False)
693
- eval_history.append(eval_m["avg_reward"])
694
- delta = eval_m["avg_reward"] - baseline_reward
695
- print(f" Eval avg reward: {eval_m['avg_reward']:+.4f} "
696
- f"(Δ baseline: {delta:+.4f})")
697
-
698
- print("\n" + "="*60)
699
- print("PHASE 3 – FINAL EVALUATION (after all training)")
700
- print("="*60)
701
- final = evaluate_policy(env, model, tokenizer, n_episodes=5,
702
- max_steps=max_steps, task_levels=task_levels,
703
- verbose=True)
704
- print(f"Final avg reward: {final['avg_reward']:.4f} "
705
- f"(min={final['min_reward']:.4f}, max={final['max_reward']:.4f})")
706
- print("Final behavior:")
707
- for t in final["traces"]:
708
- print(f" task={t['task']:8s} reward={t['reward']:+.4f} "
709
- f"steps={t['steps']} actions={t['actions']}")
710
-
711
- total_improvement = final["avg_reward"] - baseline_reward
712
- ppo_improvement = final["avg_reward"] - warmup_reward
713
- print(f"\n{'='*60}")
714
- print("TRAINING SUMMARY")
715
- print(f" Baseline reward: {baseline_reward:+.4f}")
716
- print(f" Post-warmup reward: {warmup_reward:+.4f} "
717
- f"(warmup Δ: {warmup_reward - baseline_reward:+.4f})")
718
- print(f" Final reward: {final['avg_reward']:+.4f} "
719
- f"(PPO Δ: {ppo_improvement:+.4f})")
720
- print(f" Total improvement: {total_improvement:+.4f}")
721
- print(f" Reward trend (PPO): {' → '.join(f'{r:+.3f}' for r in reward_history)}")
722
- print(f" Loss trend (PPO): {' → '.join(f'{l:.4f}' for l in loss_history)}")
723
- if total_improvement > 0:
724
- print(f" ✓ Agent IMPROVED by {total_improvement:+.4f}")
725
- else:
726
- print(f" ✗ No overall improvement detected")
727
- print(f"{'='*60}")
728
-
729
- # Plots
730
- iters = list(range(1, n_iterations + 1))
731
-
732
- if warmup_losses:
733
- fig, ax = plt.subplots(figsize=(7, 4))
734
- ax.plot(range(1, len(warmup_losses)+1), warmup_losses,
735
- marker="o", linewidth=2, color="tab:purple")
736
- ax.set_title("Warm-up Loss (supervised, per epoch)", fontsize=13, fontweight="bold")
737
- ax.set_xlabel("Epoch"); ax.set_ylabel("Cross-Entropy Loss")
738
- ax.grid(alpha=0.3); fig.tight_layout()
739
- fig.savefig("warmup_loss.png", dpi=150); plt.close(fig)
740
-
741
- fig, ax = plt.subplots(figsize=(9,5))
742
- ax.plot(iters, reward_history, marker="o", linewidth=2,
743
- label="Collect reward", color="tab:blue")
744
- ax.plot(iters, eval_history, marker="s", linewidth=2, linestyle="--",
745
- label="Eval reward", color="tab:green")
746
- ax.axhline(y=baseline_reward, color="tab:gray", linestyle=":",
747
- linewidth=1.5, label=f"Baseline ({baseline_reward:+.3f})")
748
- ax.axhline(y=warmup_reward, color="tab:purple", linestyle=":",
749
- linewidth=1.5, label=f"Post-warmup ({warmup_reward:+.3f})")
750
- ax.set_title("PPO Reward per Iteration", fontsize=14, fontweight="bold")
751
- ax.set_xlabel("Iteration"); ax.set_ylabel("Average Reward")
752
- ax.legend(loc="best", fontsize=8); ax.grid(alpha=0.3)
753
- fig.tight_layout(); fig.savefig("reward_curve.png", dpi=150); plt.close(fig)
754
-
755
- fig, ax = plt.subplots(figsize=(9,5))
756
- ax.plot(iters, loss_history, marker="o", linewidth=2,
757
- label="Total loss", color="tab:red")
758
- ax.set_title("PPO Loss per Iteration", fontsize=14, fontweight="bold")
759
- ax.set_xlabel("Iteration"); ax.set_ylabel("Loss")
760
- ax.legend(loc="best"); ax.grid(alpha=0.3)
761
- fig.tight_layout(); fig.savefig("loss_curve.png", dpi=150); plt.close(fig)
762
-
763
- print("Plots saved: warmup_loss.png, reward_curve.png, loss_curve.png")
764
- print("="*60)
765
 
 
766
  if __name__ == "__main__":
767
- train_ppo()
 
 
1
+ ```python
2
+ # training.py – Clean PPO + QLoRA Code Review Agent (evidence-driven)
3
+
4
  import os
5
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
6
 
 
9
  import torch.nn.functional as F
10
  from torch.optim import AdamW
11
  from dataclasses import dataclass
12
+ from typing import List, Optional
13
  import numpy as np
 
14
  import random
15
  import matplotlib
16
+ matplotlib.use("Agg")
17
  import matplotlib.pyplot as plt
18
+ from collections import Counter
19
 
20
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
 
 
 
 
 
 
21
  from peft import LoraConfig, get_peft_model, TaskType
22
 
23
  from environment import CodeReviewEnv
24
  from redteam import BUG_DB
25
+ from models import map_to_env as model_map_to_env
26
+
27
+ # =========================================================
28
+ # DEVICE
29
+ # =========================================================
30
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
31
 
32
+ # =========================================================
33
+ # DATA STRUCTURES
34
+ # =========================================================
35
  @dataclass
36
  class AgentAction:
37
  action_type: str
38
  content: Optional[str] = None
39
 
40
+ @dataclass
41
+ class Trajectory:
42
+ states: List[str]
43
+ actions: List[str]
44
+ rewards: List[float]
45
+ logprobs: List[float]
46
+ dones: List[bool]
47
+
48
+ # =========================================================
49
+ # ACTION PARSER
50
+ # =========================================================
51
  def parse_action(output: str) -> AgentAction:
52
  try:
53
  data = json.loads(output)
54
+ return AgentAction(data.get("action_type", ""), data.get("content"))
 
 
 
55
  except:
56
+ return AgentAction("skip", None)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
  def map_to_env(action: AgentAction):
59
  return model_map_to_env(action.action_type, action.content)
60
 
61
+ # =========================================================
62
+ # MODEL
63
+ # =========================================================
64
  def load_model():
65
  model_name = "microsoft/Phi-3-mini-4k-instruct"
66
 
67
+ bnb = BitsAndBytesConfig(
68
  load_in_4bit=True,
69
  bnb_4bit_compute_dtype=torch.bfloat16,
70
+ bnb_4bit_quant_type="nf4"
 
71
  )
72
 
73
  model = AutoModelForCausalLM.from_pretrained(
74
  model_name,
75
+ quantization_config=bnb,
76
  device_map="auto",
77
+ torch_dtype=torch.bfloat16
 
78
  )
79
 
80
  tokenizer = AutoTokenizer.from_pretrained(model_name)
81
  tokenizer.pad_token = tokenizer.eos_token
82
 
83
+ lora = LoraConfig(
84
  r=16,
85
  lora_alpha=32,
86
+ target_modules=["q_proj","k_proj","v_proj","o_proj"],
87
+ task_type=TaskType.CAUSAL_LM
 
 
 
 
 
88
  )
89
 
90
+ model = get_peft_model(model, lora)
91
+ model.gradient_checkpointing_enable()
92
+
93
  return model, tokenizer
94
 
95
+ # =========================================================
96
+ # GENERATION
97
+ # =========================================================
98
+ def generate_action(prompt, model, tokenizer, temperature):
99
+ messages = [{"role": "user", "content": prompt}]
 
 
100
  formatted = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
102
+ inputs = tokenizer(formatted, return_tensors="pt", truncation=True).to(DEVICE)
103
 
104
+ outputs = model.generate(
105
+ **inputs,
106
+ max_new_tokens=128,
107
+ do_sample=temperature > 0,
108
+ temperature=temperature if temperature > 0 else None,
109
+ return_dict_in_generate=True,
110
+ output_scores=True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  )
112
 
113
+ gen_ids = outputs.sequences[0][inputs["input_ids"].shape[1]:]
114
+ text = tokenizer.decode(gen_ids, skip_special_tokens=True)
 
 
 
115
 
116
+ logprobs = []
117
+ for i, token_id in enumerate(gen_ids):
118
+ if i < len(outputs.scores):
119
+ logits = outputs.scores[i][0]
120
+ lp = F.log_softmax(logits, dim=-1)[token_id]
121
+ logprobs.append(lp)
122
 
123
+ if not logprobs:
124
+ return '{"action_type":"skip"}', -100.0
 
 
 
125
 
126
+ return text, torch.stack(logprobs).sum().item()
 
 
 
 
 
 
127
 
128
+ # =========================================================
129
+ # TRAJECTORY COLLECTION
130
+ # =========================================================
131
+ def collect_trajectory(env, model, tokenizer, max_steps, temperature):
132
+ obs = env.reset()
133
 
134
+ states, actions, rewards, logprobs, dones = [], [], [], [], []
 
135
 
136
+ metrics = {"test_score": [], "actions": []}
 
137
 
138
+ for _ in range(max_steps):
139
+ prompt = f"Code:\n{obs.code_snippet}\nRespond JSON action."
140
 
141
+ states.append(prompt)
 
142
 
143
+ action_text, lp = generate_action(prompt, model, tokenizer, temperature)
 
 
 
144
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
  actions.append(action_text)
146
+ logprobs.append(lp)
147
+
148
  action = parse_action(action_text)
149
  env_action = map_to_env(action)
150
+
151
+ obs, reward, done, _ = env.step(env_action)
152
+
153
+ rewards.append(float(np.clip(reward.value, -1, 1)))
154
  dones.append(done)
155
+
156
+ metrics["test_score"].append(getattr(obs, "current_test_score", 0.0))
157
+ metrics["actions"].append(action.action_type)
158
+
159
+ if done:
160
+ break
161
+
162
+ return Trajectory(states, actions, rewards, logprobs, dones), metrics
163
+
164
+ # =========================================================
165
+ # PPO UPDATE (FIXED)
166
+ # =========================================================
167
+ def ppo_update(trajectories, model, tokenizer, optimizer, clip=0.2):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  model.train()
169
+
170
+ losses = []
171
+ kls = []
172
+
173
  for traj in trajectories:
174
+ returns = np.cumsum(traj.rewards[::-1])[::-1]
175
+ returns = torch.tensor(returns, device=DEVICE)
176
+
177
+ for i in range(len(traj.states)):
178
+ state = traj.states[i]
179
+ action = traj.actions[i]
180
+
181
+ old_lp = torch.tensor(traj.logprobs[i], device=DEVICE)
182
+ adv = returns[i]
183
+
 
 
 
 
184
  messages = [{"role": "user", "content": state}]
185
  formatted = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
186
+ full = formatted + action
187
+
188
+ inputs = tokenizer(full, return_tensors="pt", truncation=True).to(DEVICE)
189
+ logits = model(**inputs).logits
190
+
191
  action_ids = tokenizer.encode(action, add_special_tokens=False)
192
+ prefix_len = len(tokenizer.encode(formatted, add_special_tokens=False))
193
+
194
+ logps = []
195
  entropy = 0.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
 
197
+ for idx in range(len(action_ids)):
198
+ pos = prefix_len + idx
199
+ if pos == 0 or pos >= logits.shape[1]:
200
+ continue
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
 
202
+ token_logits = logits[0, pos-1]
203
+ log_probs = F.log_softmax(token_logits, dim=-1)
204
 
205
+ lp = log_probs[action_ids[idx]]
206
+ logps.append(lp)
207
 
208
+ probs = torch.exp(log_probs)
209
+ entropy += (-(probs * log_probs).sum()).detach()
210
+
211
+ if not logps:
212
  continue
213
 
214
+ new_lp = torch.stack(logps).sum()
215
+
216
+ ratio = torch.exp(new_lp - old_lp)
217
+ s1 = ratio * adv
218
+ s2 = torch.clamp(ratio, 1-clip, 1+clip) * adv
219
 
220
+ loss = -torch.min(s1, s2) - 0.01 * entropy
221
+
222
+ if torch.isnan(loss):
223
+ continue
224
+
225
+ optimizer.zero_grad()
226
  loss.backward()
227
  torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
228
+ optimizer.step()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
 
230
+ kl = (old_lp - new_lp).detach().cpu().item()
231
+ kls.append(kl)
232
+ losses.append(loss.item())
233
+
234
+ return np.mean(losses), np.mean(kls)
235
+
236
+ # =========================================================
237
+ # TRAIN LOOP
238
+ # =========================================================
239
+ def train():
240
  model, tokenizer = load_model()
 
 
241
  env = CodeReviewEnv()
242
+
243
+ optimizer = AdamW(model.parameters(), lr=3e-5)
244
+
245
+ reward_hist, success_hist, kl_hist = [], [], []
246
+
247
  task_levels = list(BUG_DB.keys())
248
 
249
+ print("Baseline evaluation...")
250
+ baseline = []
251
+
252
+ for _ in range(5):
253
+ traj, _ = collect_trajectory(env, model, tokenizer, 6, 0.0)
254
+ baseline.append(sum(traj.rewards))
255
+
256
+ baseline_reward = np.mean(baseline)
257
+ print("Baseline:", baseline_reward)
258
+
259
+ for it in range(15):
260
+ print(f"\nIteration {it+1}")
261
+
262
+ temperature = max(0.7 * (1 - it/15), 0.1)
263
+
264
+ trajectories = []
265
+ successes = 0
266
+
267
+ action_counter = Counter()
268
+
269
+ for _ in range(6):
270
+ env.set_task(random.choice(task_levels))
271
+
272
+ traj, metrics = collect_trajectory(env, model, tokenizer, 6, temperature)
273
+
274
+ trajectories.append(traj)
275
+
276
+ for a in metrics["actions"]:
277
+ action_counter[a] += 1
278
+
279
+ if sum(traj.rewards) > 0:
280
+ successes += 1
281
+
282
+ avg_reward = np.mean([sum(t.rewards) for t in trajectories])
283
+ success_rate = successes / len(trajectories)
284
+
285
+ loss, kl = ppo_update(trajectories, model, tokenizer, optimizer)
286
+
287
+ reward_hist.append(avg_reward)
288
+ success_hist.append(success_rate)
289
+ kl_hist.append(kl)
290
+
291
+ print("Reward:", avg_reward)
292
+ print("Success:", success_rate)
293
+ print("KL:", kl)
294
+ print("Actions:", dict(action_counter))
295
+
296
+ # =====================================================
297
+ # PLOTS
298
+ # =====================================================
299
+ iters = list(range(1, len(reward_hist)+1))
300
+
301
+ plt.figure()
302
+ plt.plot(iters, reward_hist)
303
+ plt.axhline(y=baseline_reward)
304
+ plt.title("Reward Curve")
305
+ plt.savefig("reward.png")
306
+
307
+ plt.figure()
308
+ plt.plot(iters, success_hist)
309
+ plt.title("Success Rate")
310
+ plt.savefig("success.png")
311
+
312
+ plt.figure()
313
+ plt.plot(iters, kl_hist)
314
+ plt.title("KL Divergence")
315
+ plt.savefig("kl.png")
316
+
317
+ print("Training complete. Plots saved.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
318
 
319
+ # =========================================================
320
  if __name__ == "__main__":
321
+ train()
322
+ ```