100XZX001 commited on
Commit
3d70980
·
verified ·
1 Parent(s): 374db2f

Update training.py

Browse files
Files changed (1) hide show
  1. training.py +126 -297
training.py CHANGED
@@ -1,10 +1,7 @@
1
- # training.py – Memory‑safe: Phi‑3‑mini + Expert Demos + Fast PPO (2 iterations)
2
  import os
3
- os.environ["TRITON_DISABLE"] = "1"
4
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
5
 
6
- import torch._dynamo
7
- torch._dynamo.config.disable = True
8
  import json
9
  import torch
10
  import torch.nn.functional as F
@@ -15,13 +12,16 @@ import numpy as np
15
  import re
16
  import random
17
  import matplotlib
18
- matplotlib.use('Agg') # ← add this line
19
  import matplotlib.pyplot as plt
20
 
21
- from unsloth import FastLanguageModel
22
- from transformers import TrainingArguments
23
- from trl import SFTTrainer
24
- from datasets import Dataset
 
 
 
25
 
26
  from environment import CodeReviewEnv
27
  from redteam import BUG_DB
@@ -74,25 +74,46 @@ def parse_action(output: str) -> AgentAction:
74
  def map_to_env(action: AgentAction):
75
  return model_map_to_env(action.action_type, action.content)
76
 
 
 
77
  # ======================================================================
78
  def load_model():
79
- model, tokenizer = FastLanguageModel.from_pretrained(
80
- model_name="unsloth/Phi-3-mini-4k-instruct-bnb-4bit",
81
- max_seq_length=480,
82
- load_in_4bit=True,
 
 
 
 
 
 
 
 
 
 
 
83
  )
84
- model = FastLanguageModel.get_peft_model(
85
- model,
 
 
 
86
  r=16,
 
87
  target_modules=[
88
  "q_proj", "k_proj", "v_proj", "o_proj",
89
  "gate_proj", "up_proj", "down_proj"
90
  ],
91
- lora_alpha=32,
92
  lora_dropout=0.0,
 
 
93
  )
 
 
94
  return model, tokenizer
95
 
 
96
  def test_model_sanity(model, tokenizer) -> bool:
97
  print("\n" + "="*60)
98
  print("SANITY CHECK: Testing base model generation")
@@ -100,7 +121,7 @@ def test_model_sanity(model, tokenizer) -> bool:
100
  test_prompt = "Hello, how are you?"
101
  messages = [{"role": "user", "content": test_prompt}]
102
  formatted = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
103
- inputs = tokenizer(formatted, return_tensors="pt", max_length=2048, truncation=True).to("cuda")
104
  with torch.no_grad():
105
  outputs = model.generate(
106
  **inputs,
@@ -123,10 +144,7 @@ def test_model_sanity(model, tokenizer) -> bool:
123
 
124
  # ======================================================================
125
  def _expert_fix_from_context(obs) -> str:
126
- """
127
- Build a conservative fix template named `fix` (required by tests).
128
- Uses bug hints + code snippet patterns to create realistic fixes.
129
- """
130
  bug = (getattr(obs, "bug_description", "") or "").lower()
131
  code = getattr(obs, "code_snippet", "") or ""
132
 
@@ -158,7 +176,6 @@ def _expert_fix_from_context(obs) -> str:
158
  " return users.get(user_id)"
159
  )
160
 
161
- # Concurrency-heavy tasks (harder/hardest).
162
  if "race" in bug or "missing_lock" in bug or "thread_safe" in bug or "global_nonatomic" in bug:
163
  return (
164
  "import threading\n"
@@ -197,7 +214,6 @@ def _expert_fix_from_context(obs) -> str:
197
  " return True"
198
  )
199
 
200
- # Generic safe fallback keeps the RL pipeline alive for unknown bugs.
201
  return (
202
  "def fix(data):\n"
203
  " if data is None:\n"
@@ -207,10 +223,7 @@ def _expert_fix_from_context(obs) -> str:
207
 
208
 
209
  def _expert_supervised_policy(obs) -> str:
210
- """
211
- Real workflow policy:
212
- inspect -> tests/linter -> docs -> fix -> negotiate -> done.
213
- """
214
  author_msg = (getattr(obs, "author_response", "") or "").lower()
215
  tool_output = (getattr(obs, "last_tool_output", "") or "").lower()
216
 
@@ -225,7 +238,6 @@ def _expert_supervised_policy(obs) -> str:
225
  if not getattr(obs, "docs_queried", False):
226
  return '{"action_type": "query_docs", "content": "python bug fixing best practices for edge cases and null safety"}'
227
 
228
- # Use docs again on hard tasks when evidence is still weak.
229
  if getattr(obs, "current_test_score", 0.0) < 0.6 and getattr(obs, "step", 0) >= 3:
230
  bug_hint = (getattr(obs, "bug_description", "") or "concurrency bug").replace('"', "'")
231
  return json.dumps(
@@ -235,12 +247,10 @@ def _expert_supervised_policy(obs) -> str:
235
  }
236
  )
237
 
238
- # If test quality is poor, propose a concrete fix.
239
  if getattr(obs, "current_test_score", 0.0) < 0.95:
240
  fix_code = _expert_fix_from_context(obs)
241
  return json.dumps({"action_type": "fix", "content": fix_code})
242
 
243
- # If author is still unconvinced, provide causal explanation.
244
  if author_msg and ("not convinced" in author_msg or "explain" in author_msg or "brief" in author_msg):
245
  return (
246
  '{"action_type": "comment", "content": "This fix works because it handles the failing edge case directly, '
@@ -248,94 +258,79 @@ def _expert_supervised_policy(obs) -> str:
248
  'The change is intentionally small to reduce regression risk."}'
249
  )
250
 
251
- # If negotiation is strong enough and quality is good, terminate.
252
  conf = float(getattr(obs, "author_confidence", 0.0))
253
  threshold = float(getattr(obs, "author_threshold", 0.5))
254
  score = float(getattr(obs, "current_test_score", 0.0))
255
  if conf >= threshold and score >= 0.8:
256
  return '{"action_type": "done"}'
257
 
258
- # Nudge conversation forward when tests are okay but acceptance is pending.
259
  return (
260
  '{"action_type": "question", "content": "Would you like a quick walkthrough of a failing scenario, the root cause, and how the fix prevents regressions?"}'
261
  )
262
 
263
  # ======================================================================
264
- def supervised_warmup(model, tokenizer, env, n_episodes=16, epochs=1, max_steps=8):
265
- print("\n" + "="*60)
266
- print("SUPERVISED WARM-UP: Real environment demonstrations")
267
- print("="*60)
268
 
269
- examples = []
270
- tasks = ["easy", "medium", "hard", "harder", "hardest"]
271
- for ep in range(n_episodes):
272
- task = random.choice(tasks)
273
- env.set_task(task)
274
- obs = env.reset()
275
- history = []
276
- done = False
277
-
278
- steps = 0
279
- while not done and steps < max_steps:
280
- prompt = build_prompt(obs, history)
281
- action_text = _expert_supervised_policy(obs)
282
- action = parse_action(action_text)
283
- env_action = map_to_env(action)
284
- next_obs, _, done, _ = env.step(env_action)
285
 
286
- messages = [
287
- {"role": "user", "content": prompt},
288
- {"role": "assistant", "content": action_text},
289
- ]
290
- full_text = tokenizer.apply_chat_template(messages, tokenize=False)
291
- examples.append({"text": full_text})
292
 
293
- history.append(f"Agent: {action_text}")
294
- history.append(f"Env: {next_obs.last_tool_output}")
295
- history = history[-8:]
296
- obs = next_obs
297
- steps += 1
 
 
298
 
299
- print(f"Supervised episode {ep+1}: task={task}, steps={steps}, done={done}")
 
300
 
301
- if not examples:
302
- print("No supervised examples generated; skipping warm-up.")
303
- return
304
 
305
- dataset = Dataset.from_list(examples)
306
- trainer = SFTTrainer(
307
- model=model,
308
- tokenizer=tokenizer,
309
- train_dataset=dataset,
310
- dataset_text_field="text",
311
- max_seq_length=2048,
312
- args=TrainingArguments(
313
- output_dir="warmup_output",
314
- num_train_epochs=epochs,
315
- per_device_train_batch_size=2,
316
- gradient_accumulation_steps=2,
317
- learning_rate=2e-5,
318
- logging_steps=50,
319
- save_strategy="no",
320
- bf16=True,
321
- ),
322
- )
323
- print(f"Training on {len(examples)} real env examples for {epochs} epochs...")
324
- trainer.train()
325
- print("✓ Supervised warm-up (real env) complete\n")
326
- torch.cuda.empty_cache()
327
 
328
  # ======================================================================
 
 
 
 
 
 
 
 
 
329
  def generate_action_with_logprob(prompt, model, tokenizer, temperature=0.0, max_retries=2):
330
  messages = [{"role": "user", "content": prompt}]
331
  formatted = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
332
- inputs = tokenizer(formatted, return_tensors="pt", max_length=2048, truncation=True).to("cuda")
333
-
 
334
  for attempt in range(max_retries):
335
  with torch.no_grad():
336
  outputs = model.generate(
337
  **inputs,
338
- max_new_tokens=64,
339
  do_sample=(temperature > 0),
340
  temperature=max(temperature, 0.01) if temperature > 0 else 1.0,
341
  min_new_tokens=1,
@@ -344,7 +339,7 @@ def generate_action_with_logprob(prompt, model, tokenizer, temperature=0.0, max_
344
  )
345
  generated_ids = outputs.sequences[0][inputs['input_ids'].shape[1]:]
346
  action_text = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
347
-
348
  logprobs = []
349
  for idx, token_id in enumerate(generated_ids):
350
  if idx < len(outputs.scores):
@@ -352,7 +347,7 @@ def generate_action_with_logprob(prompt, model, tokenizer, temperature=0.0, max_
352
  token_logprob = F.log_softmax(token_logits, dim=-1)[token_id].item()
353
  logprobs.append(token_logprob)
354
  total_logprob = sum(logprobs) if logprobs else -100.0
355
-
356
  if not action_text:
357
  fallback_actions = [
358
  '{"action_type": "run_tests"}',
@@ -364,7 +359,7 @@ def generate_action_with_logprob(prompt, model, tokenizer, temperature=0.0, max_
364
  total_logprob = -50.0
365
  print(f"[WARN] Empty generation → using fallback: {action_text}")
366
  return action_text, total_logprob
367
-
368
  try:
369
  json.loads(action_text)
370
  return action_text, total_logprob
@@ -374,58 +369,6 @@ def generate_action_with_logprob(prompt, model, tokenizer, temperature=0.0, max_
374
  continue
375
  return '{"action_type":"skip"}', -100.0
376
 
377
- # ======================================================================
378
- def build_prompt(obs, history_lines: List[str]) -> str:
379
- author_msg = getattr(obs, "author_response", "") or ""
380
- tool_output = getattr(obs, "last_tool_output", "") or ""
381
- author_personality = getattr(obs, "author_personality", "defensive")
382
-
383
- prompt = f"""You are an AI code review agent. Your goal is to convince a simulated human developer to accept your proposed fix and name your proposed fix function fix.
384
-
385
- The developer has a **{author_personality}** personality and will only accept if you provide solid evidence:
386
- - Tests pass (high pass ratio)
387
- - Lint is clean (zero errors)
388
- - Documentation or references are provided
389
- - Your reasoning is clear, uses words like "because" or "therefore", and is detailed (over 30 words if needed)
390
-
391
- Workflow:
392
- 1. Use `inspect` to understand the code.
393
- 2. Use `run_tests` and `run_linter` to gather evidence.
394
- 3. Use `query_docs` when you need references or language-specific guidance.
395
- 4. Propose a fix (`fix`) and explain why it works (`comment` or `question`).
396
- 5. If the developer pushes back, read their response carefully and address their specific concern.
397
- 6. Once convinced, use `done` to finish.
398
-
399
- Code:
400
- {obs.code_snippet}
401
-
402
- Author says:
403
- {author_msg if author_msg else "(no response yet – start with inspection)"}
404
-
405
- Last tool output:
406
- {tool_output if tool_output else "(none)"}
407
-
408
- Available actions:
409
- run_tests, run_linter, inspect, query_docs, fix, comment, question, done
410
-
411
- Respond ONLY in JSON:
412
- {{"action_type": "...", "content": "..."}}"""
413
-
414
- if history_lines:
415
- history = "\n".join(history_lines[-6:])
416
- prompt += f"\n\nPrevious steps:\n{history}"
417
- return prompt
418
-
419
- # ======================================================================
420
- @dataclass
421
- class Trajectory:
422
- states: List[str]
423
- actions: List[str]
424
- rewards: List[float]
425
- logprobs: List[float]
426
- dones: List[bool]
427
- def __len__(self): return len(self.states)
428
-
429
  def collect_trajectory(env, model, tokenizer, max_steps=6, temperature=0.0):
430
  obs = env.reset()
431
  history_lines = []
@@ -466,10 +409,6 @@ def collect_trajectories(env, model, tokenizer, n_trajectories, max_steps=6,
466
  return trajectories
467
 
468
  def compute_returns_and_advantages(rewards, dones, gamma=0.99, standardize=True):
469
- """
470
- Compute discounted returns and REINFORCE-style baseline advantages.
471
- Advantages are centered and optionally standardised.
472
- """
473
  n = len(rewards)
474
  returns = [0.0]*n
475
  running = 0.0
@@ -507,7 +446,7 @@ def ppo_update(trajectories, model, tokenizer, optimizer, n_epochs=1, clip_epsil
507
  messages = [{"role": "user", "content": state}]
508
  formatted = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
509
  full_text = formatted + action
510
- inputs = tokenizer(full_text, return_tensors="pt", max_length=2048, truncation=True).to("cuda")
511
  outputs = model(**inputs)
512
  logits = outputs.logits
513
  action_ids = tokenizer.encode(action, add_special_tokens=False)
@@ -527,8 +466,7 @@ def ppo_update(trajectories, model, tokenizer, optimizer, n_epochs=1, clip_epsil
527
  if not logprobs: continue
528
  new_logprob = sum(logprobs)
529
  avg_entropy = entropy / len(logprobs) if logprobs else 0.0
530
- log_ratio = torch.clamp(new_logprob - old_logprob, min=-10.0, max=10.0) # ← guard
531
- ratio = torch.exp(log_ratio)
532
  surr1 = ratio * advantage
533
  surr2 = torch.clamp(ratio, 1 - clip_epsilon, 1 + clip_epsilon) * advantage
534
  policy_loss = -torch.min(surr1, surr2)
@@ -548,12 +486,11 @@ def ppo_update(trajectories, model, tokenizer, optimizer, n_epochs=1, clip_epsil
548
 
549
  def evaluate_policy(env, model, tokenizer, n_episodes=3, max_steps=6,
550
  task_levels=None, verbose=False):
551
- """Evaluate the current policy across task levels. Returns metrics + optional traces."""
552
  model.eval()
553
  if task_levels is None:
554
  task_levels = list(BUG_DB.keys())
555
  total_rewards = []
556
- traces = [] # human-readable behavior logs
557
  for ep in range(n_episodes):
558
  task = task_levels[ep % len(task_levels)]
559
  env.set_task(task)
@@ -563,10 +500,8 @@ def evaluate_policy(env, model, tokenizer, n_episodes=3, max_steps=6,
563
  if verbose:
564
  actions_taken = []
565
  for a in traj.actions:
566
- try:
567
- actions_taken.append(json.loads(a).get("action_type", "?"))
568
- except Exception:
569
- actions_taken.append("?")
570
  traces.append({
571
  "task": task,
572
  "reward": round(ep_reward, 4),
@@ -582,15 +517,9 @@ def evaluate_policy(env, model, tokenizer, n_episodes=3, max_steps=6,
582
  }
583
 
584
  # ======================================================================
585
- # MANUAL WARM-UP (no SFTTrainer no multiprocessing OOM)
586
- # ======================================================================
587
  def json_warmup(model, tokenizer, json_path="training_data.json",
588
- n_episodes=20, epochs=2, lr=2e-5):
589
- """
590
- Supervised warm-up from pre-generated expert demonstrations.
591
- Uses raw cross-entropy on action tokens with manual gradient steps.
592
- NO SFTTrainer, NO multiprocessing – runs safely on any GPU.
593
- """
594
  print("\n" + "="*60)
595
  print("SUPERVISED WARM-UP: training_data.json (manual cross-entropy)")
596
  print("="*60)
@@ -598,7 +527,6 @@ def json_warmup(model, tokenizer, json_path="training_data.json",
598
  with open(json_path, encoding="utf-8") as f:
599
  data = json.load(f)
600
 
601
- # Each episode = 7 steps. Select n_episodes worth.
602
  steps_per_episode = 7
603
  max_examples = n_episodes * steps_per_episode
604
  if max_examples < len(data):
@@ -609,7 +537,7 @@ def json_warmup(model, tokenizer, json_path="training_data.json",
609
 
610
  model.train()
611
  warmup_opt = AdamW(model.parameters(), lr=lr)
612
- warmup_losses = [] # per-epoch avg loss
613
 
614
  for epoch in range(epochs):
615
  random.shuffle(data)
@@ -620,15 +548,13 @@ def json_warmup(model, tokenizer, json_path="training_data.json",
620
  prompt = example["prompt"]
621
  action = example["action"]
622
 
623
- # ---- tokenize full sequence (prompt + action) ----
624
  messages = [
625
  {"role": "user", "content": prompt},
626
  {"role": "assistant", "content": action},
627
  ]
628
  full_text = tokenizer.apply_chat_template(messages, tokenize=False)
629
- inputs = tokenizer(full_text, return_tensors="pt", max_length=2048, truncation=True).to("cuda")
630
 
631
- # ---- find where the action tokens start ----
632
  prompt_only = tokenizer.apply_chat_template(
633
  [{"role": "user", "content": prompt}],
634
  tokenize=False, add_generation_prompt=True
@@ -638,13 +564,11 @@ def json_warmup(model, tokenizer, json_path="training_data.json",
638
 
639
  total_len = inputs.input_ids.shape[1]
640
  if prompt_len >= total_len:
641
- continue # prompt was truncated away, skip
642
 
643
- # ---- cross-entropy on action tokens only ----
644
  outputs = model(**inputs)
645
  logits = outputs.logits
646
 
647
- # next-token prediction: logits[t] predicts token[t+1]
648
  shift_logits = logits[0, prompt_len - 1 : total_len - 1]
649
  shift_labels = inputs.input_ids[0, prompt_len : total_len]
650
 
@@ -680,11 +604,9 @@ def json_warmup(model, tokenizer, json_path="training_data.json",
680
 
681
  # ======================================================================
682
  # MAIN TRAINING PIPELINE
683
- # ======================================================================
684
  def train_ppo():
685
- # --- Hyperparameters ---
686
- n_iterations = 15 # enough for a clear upward trend
687
- trajectories_per_iter = 6 # on-policy data per iteration
688
  n_epochs = 2
689
  max_steps = 8
690
  learning_rate = 3e-5
@@ -692,23 +614,19 @@ def train_ppo():
692
  entropy_coef = 0.01
693
  gamma = 0.99
694
 
695
- # --- Pre-load embedder before LLM (Issue #13) ---
696
  from rltool import ToolBox
697
  print("Pre-loading sentence-transformer embedder...")
698
  ToolBox._get_embedder()
699
  print("✓ Embedder ready")
700
 
701
- # --- Load model ---
702
- print("Loading model...")
703
  model, tokenizer = load_model()
704
  if not test_model_sanity(model, tokenizer):
705
  return
706
  env = CodeReviewEnv()
707
  task_levels = list(BUG_DB.keys())
708
 
709
- # ==================================================================
710
- # PHASE 0: BASELINE (untrained policy)
711
- # ==================================================================
712
  print("\n" + "="*60)
713
  print("PHASE 0 – BASELINE EVALUATION (untrained)")
714
  print("="*60)
@@ -723,18 +641,10 @@ def train_ppo():
723
  print(f" task={t['task']:8s} reward={t['reward']:+.4f} "
724
  f"steps={t['steps']} actions={t['actions']}")
725
 
726
- # ==================================================================
727
- # PHASE 1: SUPERVISED WARM-UP (expert demos, manual CE)
728
- # ==================================================================
729
- warmup_losses = json_warmup(
730
- model, tokenizer,
731
- json_path="training_data.json",
732
- n_episodes=30, # 140 examples (20 × 7 steps)
733
- epochs=3,
734
- lr=2e-5,
735
- )
736
 
737
- # Post-warmup evaluation
738
  print("="*60)
739
  print("POST WARM-UP EVALUATION")
740
  print("="*60)
@@ -749,25 +659,15 @@ def train_ppo():
749
  print(f" task={t['task']:8s} reward={t['reward']:+.4f} "
750
  f"steps={t['steps']} actions={t['actions']}")
751
 
752
- # ==================================================================
753
- # PHASE 2: TRUE RL – PPO (on-policy, real environment interaction)
754
- # ==================================================================
755
  optimizer = AdamW(model.parameters(), lr=learning_rate)
756
  print(f"\n{'='*60}")
757
  print(f"PHASE 2 – PPO TRAINING: {n_iterations} iterations × "
758
  f"{trajectories_per_iter} trajectories (true RL)")
759
  print(f"{'='*60}\n")
760
 
761
- reward_history = []
762
- eval_history = []
763
- loss_history = []
764
- policy_loss_history = []
765
- entropy_history = []
766
-
767
  for iteration in range(n_iterations):
768
  print(f"\n--- PPO Iteration {iteration + 1}/{n_iterations} ---")
769
-
770
- # Collect on-policy trajectories from REAL environment
771
  trajectories = collect_trajectories(
772
  env, model, tokenizer, trajectories_per_iter, max_steps,
773
  task_levels=task_levels, task_weights=None
@@ -776,20 +676,16 @@ def train_ppo():
776
  reward_history.append(avg_reward)
777
  print(f" Collect avg reward: {avg_reward:+.4f}")
778
 
779
- # PPO policy gradient update
780
  metrics = ppo_update(
781
  trajectories, model, tokenizer, optimizer,
782
  n_epochs=n_epochs, clip_epsilon=clip_epsilon,
783
  entropy_coef=entropy_coef, gamma=gamma
784
  )
785
  loss_history.append(float(metrics["loss"]))
786
- policy_loss_history.append(float(metrics["policy_loss"]))
787
- entropy_history.append(float(metrics["entropy"]))
788
  print(f" Update loss={metrics['loss']:.4f} "
789
  f"policy={metrics['policy_loss']:.4f} "
790
  f"entropy={metrics['entropy']:.4f}")
791
 
792
- # Evaluate greedy policy after update
793
  eval_m = evaluate_policy(env, model, tokenizer, n_episodes=3,
794
  max_steps=max_steps, task_levels=task_levels,
795
  verbose=False)
@@ -798,9 +694,6 @@ def train_ppo():
798
  print(f" Eval avg reward: {eval_m['avg_reward']:+.4f} "
799
  f"(Δ baseline: {delta:+.4f})")
800
 
801
- # ==================================================================
802
- # PHASE 3: FINAL EVALUATION (proof of learning)
803
- # ==================================================================
804
  print("\n" + "="*60)
805
  print("PHASE 3 – FINAL EVALUATION (after all training)")
806
  print("="*60)
@@ -832,27 +725,19 @@ def train_ppo():
832
  print(f" ✗ No overall improvement detected")
833
  print(f"{'='*60}")
834
 
835
- # ==================================================================
836
- # PLOTS
837
- # ==================================================================
838
  iters = list(range(1, n_iterations + 1))
839
 
840
- # --- 1. Warm-up loss curve ---
841
  if warmup_losses:
842
  fig, ax = plt.subplots(figsize=(7, 4))
843
- ax.plot(range(1, len(warmup_losses) + 1), warmup_losses,
844
  marker="o", linewidth=2, color="tab:purple")
845
- ax.set_title("Warm-up Loss (supervised, per epoch)",
846
- fontsize=13, fontweight="bold")
847
- ax.set_xlabel("Epoch")
848
- ax.set_ylabel("Cross-Entropy Loss")
849
- ax.grid(alpha=0.3)
850
- fig.tight_layout()
851
- fig.savefig("warmup_loss.png", dpi=150)
852
- plt.close(fig)
853
-
854
- # --- 2. PPO reward curve ---
855
- fig, ax = plt.subplots(figsize=(9, 5))
856
  ax.plot(iters, reward_history, marker="o", linewidth=2,
857
  label="Collect reward", color="tab:blue")
858
  ax.plot(iters, eval_history, marker="s", linewidth=2, linestyle="--",
@@ -862,75 +747,19 @@ def train_ppo():
862
  ax.axhline(y=warmup_reward, color="tab:purple", linestyle=":",
863
  linewidth=1.5, label=f"Post-warmup ({warmup_reward:+.3f})")
864
  ax.set_title("PPO Reward per Iteration", fontsize=14, fontweight="bold")
865
- ax.set_xlabel("Iteration")
866
- ax.set_ylabel("Average Reward")
867
- ax.legend(loc="best", fontsize=8)
868
- ax.grid(alpha=0.3)
869
- fig.tight_layout()
870
- fig.savefig("reward_curve.png", dpi=150)
871
- plt.close(fig)
872
-
873
- # --- 3. PPO loss curve ---
874
- fig, ax = plt.subplots(figsize=(9, 5))
875
  ax.plot(iters, loss_history, marker="o", linewidth=2,
876
  label="Total loss", color="tab:red")
877
- ax.plot(iters, policy_loss_history, marker="^", linewidth=2, linestyle="--",
878
- label="Policy loss", color="tab:orange")
879
  ax.set_title("PPO Loss per Iteration", fontsize=14, fontweight="bold")
880
- ax.set_xlabel("Iteration")
881
- ax.set_ylabel("Loss")
882
- ax.legend(loc="best")
883
- ax.grid(alpha=0.3)
884
- fig.tight_layout()
885
- fig.savefig("loss_curve.png", dpi=150)
886
- plt.close(fig)
887
-
888
- # --- 4. Combined 3-panel summary ---
889
- fig, axes = plt.subplots(1, 3, figsize=(18, 5))
890
-
891
- # Panel A: warm-up loss
892
- if warmup_losses:
893
- axes[0].plot(range(1, len(warmup_losses) + 1), warmup_losses,
894
- marker="o", linewidth=2, color="tab:purple")
895
- axes[0].set_title("A. Warm-up Loss ↓")
896
- axes[0].set_xlabel("Epoch")
897
- axes[0].set_ylabel("CE Loss")
898
- axes[0].grid(alpha=0.3)
899
-
900
- # Panel B: PPO reward
901
- axes[1].plot(iters, reward_history, marker="o", linewidth=2,
902
- color="tab:blue", label="Collect")
903
- axes[1].plot(iters, eval_history, marker="s", linewidth=2,
904
- linestyle="--", color="tab:green", label="Eval")
905
- axes[1].axhline(y=baseline_reward, color="tab:gray", linestyle=":",
906
- linewidth=1.5, label="Baseline")
907
- axes[1].axhline(y=warmup_reward, color="tab:purple", linestyle=":",
908
- linewidth=1.5, label="Post-warmup")
909
- axes[1].set_title("B. PPO Reward ↑")
910
- axes[1].set_xlabel("Iteration")
911
- axes[1].set_ylabel("Avg Reward")
912
- axes[1].legend(fontsize=7)
913
- axes[1].grid(alpha=0.3)
914
-
915
- # Panel C: PPO loss
916
- axes[2].plot(iters, loss_history, marker="o", linewidth=2,
917
- color="tab:red", label="Total")
918
- axes[2].plot(iters, policy_loss_history, marker="^", linewidth=2,
919
- linestyle="--", color="tab:orange", label="Policy")
920
- axes[2].set_title("C. PPO Loss ↓")
921
- axes[2].set_xlabel("Iteration")
922
- axes[2].set_ylabel("Loss")
923
- axes[2].legend(fontsize=7)
924
- axes[2].grid(alpha=0.3)
925
-
926
- fig.suptitle("Code Review Agent – Full Training Evidence",
927
- fontsize=14, fontweight="bold")
928
- fig.tight_layout()
929
- fig.savefig("training_summary.png", dpi=150)
930
- plt.close(fig)
931
-
932
- print("Plots saved: warmup_loss.png, reward_curve.png, "
933
- "loss_curve.png, training_summary.png")
934
  print("="*60)
935
 
936
  if __name__ == "__main__":
 
1
+ # training.py – Vanilla bitsandbytes QLoRA + custom PPO (no unsloth, no Triton)
2
  import os
 
3
  os.environ["TOKENIZERS_PARALLELISM"] = "false"
4
 
 
 
5
  import json
6
  import torch
7
  import torch.nn.functional as F
 
12
  import re
13
  import random
14
  import matplotlib
15
+ matplotlib.use('Agg')
16
  import matplotlib.pyplot as plt
17
 
18
+ from transformers import (
19
+ AutoModelForCausalLM,
20
+ AutoTokenizer,
21
+ BitsAndBytesConfig,
22
+ TrainingArguments
23
+ )
24
+ from peft import LoraConfig, get_peft_model, TaskType
25
 
26
  from environment import CodeReviewEnv
27
  from redteam import BUG_DB
 
74
  def map_to_env(action: AgentAction):
75
  return model_map_to_env(action.action_type, action.content)
76
 
77
+ # ======================================================================
78
+ # Model loading – no unsloth, no Triton kernels
79
  # ======================================================================
80
  def load_model():
81
+ model_name = "microsoft/Phi-3-mini-4k-instruct"
82
+
83
+ bnb_config = BitsAndBytesConfig(
84
+ load_in_4bit=True,
85
+ bnb_4bit_compute_dtype=torch.bfloat16,
86
+ bnb_4bit_use_double_quant=True,
87
+ bnb_4bit_quant_type="nf4",
88
+ )
89
+
90
+ model = AutoModelForCausalLM.from_pretrained(
91
+ model_name,
92
+ quantization_config=bnb_config,
93
+ device_map="auto",
94
+ trust_remote_code=True,
95
+ torch_dtype=torch.bfloat16,
96
  )
97
+
98
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
99
+ tokenizer.pad_token = tokenizer.eos_token
100
+
101
+ lora_config = LoraConfig(
102
  r=16,
103
+ lora_alpha=32,
104
  target_modules=[
105
  "q_proj", "k_proj", "v_proj", "o_proj",
106
  "gate_proj", "up_proj", "down_proj"
107
  ],
 
108
  lora_dropout=0.0,
109
+ bias="none",
110
+ task_type=TaskType.CAUSAL_LM,
111
  )
112
+
113
+ model = get_peft_model(model, lora_config)
114
  return model, tokenizer
115
 
116
+ # ======================================================================
117
  def test_model_sanity(model, tokenizer) -> bool:
118
  print("\n" + "="*60)
119
  print("SANITY CHECK: Testing base model generation")
 
121
  test_prompt = "Hello, how are you?"
122
  messages = [{"role": "user", "content": test_prompt}]
123
  formatted = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
124
+ inputs = tokenizer(formatted, return_tensors="pt", max_length=256, truncation=True).to("cuda")
125
  with torch.no_grad():
126
  outputs = model.generate(
127
  **inputs,
 
144
 
145
  # ======================================================================
146
  def _expert_fix_from_context(obs) -> str:
147
+ """Build a conservative fix template based on bug hints."""
 
 
 
148
  bug = (getattr(obs, "bug_description", "") or "").lower()
149
  code = getattr(obs, "code_snippet", "") or ""
150
 
 
176
  " return users.get(user_id)"
177
  )
178
 
 
179
  if "race" in bug or "missing_lock" in bug or "thread_safe" in bug or "global_nonatomic" in bug:
180
  return (
181
  "import threading\n"
 
214
  " return True"
215
  )
216
 
 
217
  return (
218
  "def fix(data):\n"
219
  " if data is None:\n"
 
223
 
224
 
225
  def _expert_supervised_policy(obs) -> str:
226
+ """Real workflow policy: inspect -> tests/linter -> docs -> fix -> negotiate -> done."""
 
 
 
227
  author_msg = (getattr(obs, "author_response", "") or "").lower()
228
  tool_output = (getattr(obs, "last_tool_output", "") or "").lower()
229
 
 
238
  if not getattr(obs, "docs_queried", False):
239
  return '{"action_type": "query_docs", "content": "python bug fixing best practices for edge cases and null safety"}'
240
 
 
241
  if getattr(obs, "current_test_score", 0.0) < 0.6 and getattr(obs, "step", 0) >= 3:
242
  bug_hint = (getattr(obs, "bug_description", "") or "concurrency bug").replace('"', "'")
243
  return json.dumps(
 
247
  }
248
  )
249
 
 
250
  if getattr(obs, "current_test_score", 0.0) < 0.95:
251
  fix_code = _expert_fix_from_context(obs)
252
  return json.dumps({"action_type": "fix", "content": fix_code})
253
 
 
254
  if author_msg and ("not convinced" in author_msg or "explain" in author_msg or "brief" in author_msg):
255
  return (
256
  '{"action_type": "comment", "content": "This fix works because it handles the failing edge case directly, '
 
258
  'The change is intentionally small to reduce regression risk."}'
259
  )
260
 
 
261
  conf = float(getattr(obs, "author_confidence", 0.0))
262
  threshold = float(getattr(obs, "author_threshold", 0.5))
263
  score = float(getattr(obs, "current_test_score", 0.0))
264
  if conf >= threshold and score >= 0.8:
265
  return '{"action_type": "done"}'
266
 
 
267
  return (
268
  '{"action_type": "question", "content": "Would you like a quick walkthrough of a failing scenario, the root cause, and how the fix prevents regressions?"}'
269
  )
270
 
271
  # ======================================================================
272
+ def build_prompt(obs, history_lines: List[str]) -> str:
273
+ author_msg = getattr(obs, "author_response", "") or ""
274
+ tool_output = getattr(obs, "last_tool_output", "") or ""
275
+ author_personality = getattr(obs, "author_personality", "defensive")
276
 
277
+ prompt = f"""You are an AI code review agent. Your goal is to convince a simulated human developer to accept your proposed fix and name your proposed fix function fix.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
278
 
279
+ The developer has a **{author_personality}** personality and will only accept if you provide solid evidence:
280
+ - Tests pass (high pass ratio)
281
+ - Lint is clean (zero errors)
282
+ - Documentation or references are provided
283
+ - Your reasoning is clear, uses words like "because" or "therefore", and is detailed (over 30 words if needed)
 
284
 
285
+ Workflow:
286
+ 1. Use `inspect` to understand the code.
287
+ 2. Use `run_tests` and `run_linter` to gather evidence.
288
+ 3. Use `query_docs` when you need references or language-specific guidance.
289
+ 4. Propose a fix (`fix`) and explain why it works (`comment` or `question`).
290
+ 5. If the developer pushes back, read their response carefully and address their specific concern.
291
+ 6. Once convinced, use `done` to finish.
292
 
293
+ Code:
294
+ {obs.code_snippet}
295
 
296
+ Author says:
297
+ {author_msg if author_msg else "(no response yet start with inspection)"}
 
298
 
299
+ Last tool output:
300
+ {tool_output if tool_output else "(none)"}
301
+
302
+ Available actions:
303
+ run_tests, run_linter, inspect, query_docs, fix, comment, question, done
304
+
305
+ Respond ONLY in JSON:
306
+ {{"action_type": "...", "content": "..."}}"""
307
+
308
+ if history_lines:
309
+ history = "\n".join(history_lines[-6:])
310
+ prompt += f"\n\nPrevious steps:\n{history}"
311
+ return prompt
 
 
 
 
 
 
 
 
 
312
 
313
  # ======================================================================
314
+ @dataclass
315
+ class Trajectory:
316
+ states: List[str]
317
+ actions: List[str]
318
+ rewards: List[float]
319
+ logprobs: List[float]
320
+ dones: List[bool]
321
+ def __len__(self): return len(self.states)
322
+
323
  def generate_action_with_logprob(prompt, model, tokenizer, temperature=0.0, max_retries=2):
324
  messages = [{"role": "user", "content": prompt}]
325
  formatted = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
326
+ # 1024 max length, no unsloth
327
+ inputs = tokenizer(formatted, return_tensors="pt", max_length=1024, truncation=True).to("cuda")
328
+
329
  for attempt in range(max_retries):
330
  with torch.no_grad():
331
  outputs = model.generate(
332
  **inputs,
333
+ max_new_tokens=128,
334
  do_sample=(temperature > 0),
335
  temperature=max(temperature, 0.01) if temperature > 0 else 1.0,
336
  min_new_tokens=1,
 
339
  )
340
  generated_ids = outputs.sequences[0][inputs['input_ids'].shape[1]:]
341
  action_text = tokenizer.decode(generated_ids, skip_special_tokens=True).strip()
342
+
343
  logprobs = []
344
  for idx, token_id in enumerate(generated_ids):
345
  if idx < len(outputs.scores):
 
347
  token_logprob = F.log_softmax(token_logits, dim=-1)[token_id].item()
348
  logprobs.append(token_logprob)
349
  total_logprob = sum(logprobs) if logprobs else -100.0
350
+
351
  if not action_text:
352
  fallback_actions = [
353
  '{"action_type": "run_tests"}',
 
359
  total_logprob = -50.0
360
  print(f"[WARN] Empty generation → using fallback: {action_text}")
361
  return action_text, total_logprob
362
+
363
  try:
364
  json.loads(action_text)
365
  return action_text, total_logprob
 
369
  continue
370
  return '{"action_type":"skip"}', -100.0
371
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
372
  def collect_trajectory(env, model, tokenizer, max_steps=6, temperature=0.0):
373
  obs = env.reset()
374
  history_lines = []
 
409
  return trajectories
410
 
411
  def compute_returns_and_advantages(rewards, dones, gamma=0.99, standardize=True):
 
 
 
 
412
  n = len(rewards)
413
  returns = [0.0]*n
414
  running = 0.0
 
446
  messages = [{"role": "user", "content": state}]
447
  formatted = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
448
  full_text = formatted + action
449
+ inputs = tokenizer(full_text, return_tensors="pt", max_length=1024, truncation=True).to("cuda")
450
  outputs = model(**inputs)
451
  logits = outputs.logits
452
  action_ids = tokenizer.encode(action, add_special_tokens=False)
 
466
  if not logprobs: continue
467
  new_logprob = sum(logprobs)
468
  avg_entropy = entropy / len(logprobs) if logprobs else 0.0
469
+ ratio = torch.exp(new_logprob - old_logprob)
 
470
  surr1 = ratio * advantage
471
  surr2 = torch.clamp(ratio, 1 - clip_epsilon, 1 + clip_epsilon) * advantage
472
  policy_loss = -torch.min(surr1, surr2)
 
486
 
487
  def evaluate_policy(env, model, tokenizer, n_episodes=3, max_steps=6,
488
  task_levels=None, verbose=False):
 
489
  model.eval()
490
  if task_levels is None:
491
  task_levels = list(BUG_DB.keys())
492
  total_rewards = []
493
+ traces = []
494
  for ep in range(n_episodes):
495
  task = task_levels[ep % len(task_levels)]
496
  env.set_task(task)
 
500
  if verbose:
501
  actions_taken = []
502
  for a in traj.actions:
503
+ try: actions_taken.append(json.loads(a).get("action_type", "?"))
504
+ except: actions_taken.append("?")
 
 
505
  traces.append({
506
  "task": task,
507
  "reward": round(ep_reward, 4),
 
517
  }
518
 
519
  # ======================================================================
520
+ # Manual warm-up from JSON (no SFTTrainer, no Unsloth)
 
521
  def json_warmup(model, tokenizer, json_path="training_data.json",
522
+ n_episodes=25, epochs=3, lr=2e-5):
 
 
 
 
 
523
  print("\n" + "="*60)
524
  print("SUPERVISED WARM-UP: training_data.json (manual cross-entropy)")
525
  print("="*60)
 
527
  with open(json_path, encoding="utf-8") as f:
528
  data = json.load(f)
529
 
 
530
  steps_per_episode = 7
531
  max_examples = n_episodes * steps_per_episode
532
  if max_examples < len(data):
 
537
 
538
  model.train()
539
  warmup_opt = AdamW(model.parameters(), lr=lr)
540
+ warmup_losses = []
541
 
542
  for epoch in range(epochs):
543
  random.shuffle(data)
 
548
  prompt = example["prompt"]
549
  action = example["action"]
550
 
 
551
  messages = [
552
  {"role": "user", "content": prompt},
553
  {"role": "assistant", "content": action},
554
  ]
555
  full_text = tokenizer.apply_chat_template(messages, tokenize=False)
556
+ inputs = tokenizer(full_text, return_tensors="pt", max_length=1024, truncation=True).to("cuda")
557
 
 
558
  prompt_only = tokenizer.apply_chat_template(
559
  [{"role": "user", "content": prompt}],
560
  tokenize=False, add_generation_prompt=True
 
564
 
565
  total_len = inputs.input_ids.shape[1]
566
  if prompt_len >= total_len:
567
+ continue
568
 
 
569
  outputs = model(**inputs)
570
  logits = outputs.logits
571
 
 
572
  shift_logits = logits[0, prompt_len - 1 : total_len - 1]
573
  shift_labels = inputs.input_ids[0, prompt_len : total_len]
574
 
 
604
 
605
  # ======================================================================
606
  # MAIN TRAINING PIPELINE
 
607
  def train_ppo():
608
+ n_iterations = 15
609
+ trajectories_per_iter = 6
 
610
  n_epochs = 2
611
  max_steps = 8
612
  learning_rate = 3e-5
 
614
  entropy_coef = 0.01
615
  gamma = 0.99
616
 
617
+ # Pre-load embedder (unchanged)
618
  from rltool import ToolBox
619
  print("Pre-loading sentence-transformer embedder...")
620
  ToolBox._get_embedder()
621
  print("✓ Embedder ready")
622
 
 
 
623
  model, tokenizer = load_model()
624
  if not test_model_sanity(model, tokenizer):
625
  return
626
  env = CodeReviewEnv()
627
  task_levels = list(BUG_DB.keys())
628
 
629
+ # Phase 0: baseline
 
 
630
  print("\n" + "="*60)
631
  print("PHASE 0 – BASELINE EVALUATION (untrained)")
632
  print("="*60)
 
641
  print(f" task={t['task']:8s} reward={t['reward']:+.4f} "
642
  f"steps={t['steps']} actions={t['actions']}")
643
 
644
+ # Phase 1: supervised warm-up
645
+ warmup_losses = json_warmup(model, tokenizer, json_path="training_data.json",
646
+ n_episodes=25, epochs=3, lr=2e-5)
 
 
 
 
 
 
 
647
 
 
648
  print("="*60)
649
  print("POST WARM-UP EVALUATION")
650
  print("="*60)
 
659
  print(f" task={t['task']:8s} reward={t['reward']:+.4f} "
660
  f"steps={t['steps']} actions={t['actions']}")
661
 
 
 
 
662
  optimizer = AdamW(model.parameters(), lr=learning_rate)
663
  print(f"\n{'='*60}")
664
  print(f"PHASE 2 – PPO TRAINING: {n_iterations} iterations × "
665
  f"{trajectories_per_iter} trajectories (true RL)")
666
  print(f"{'='*60}\n")
667
 
668
+ reward_history, eval_history, loss_history = [], [], []
 
 
 
 
 
669
  for iteration in range(n_iterations):
670
  print(f"\n--- PPO Iteration {iteration + 1}/{n_iterations} ---")
 
 
671
  trajectories = collect_trajectories(
672
  env, model, tokenizer, trajectories_per_iter, max_steps,
673
  task_levels=task_levels, task_weights=None
 
676
  reward_history.append(avg_reward)
677
  print(f" Collect avg reward: {avg_reward:+.4f}")
678
 
 
679
  metrics = ppo_update(
680
  trajectories, model, tokenizer, optimizer,
681
  n_epochs=n_epochs, clip_epsilon=clip_epsilon,
682
  entropy_coef=entropy_coef, gamma=gamma
683
  )
684
  loss_history.append(float(metrics["loss"]))
 
 
685
  print(f" Update loss={metrics['loss']:.4f} "
686
  f"policy={metrics['policy_loss']:.4f} "
687
  f"entropy={metrics['entropy']:.4f}")
688
 
 
689
  eval_m = evaluate_policy(env, model, tokenizer, n_episodes=3,
690
  max_steps=max_steps, task_levels=task_levels,
691
  verbose=False)
 
694
  print(f" Eval avg reward: {eval_m['avg_reward']:+.4f} "
695
  f"(Δ baseline: {delta:+.4f})")
696
 
 
 
 
697
  print("\n" + "="*60)
698
  print("PHASE 3 – FINAL EVALUATION (after all training)")
699
  print("="*60)
 
725
  print(f" ✗ No overall improvement detected")
726
  print(f"{'='*60}")
727
 
728
+ # Plots
 
 
729
  iters = list(range(1, n_iterations + 1))
730
 
 
731
  if warmup_losses:
732
  fig, ax = plt.subplots(figsize=(7, 4))
733
+ ax.plot(range(1, len(warmup_losses)+1), warmup_losses,
734
  marker="o", linewidth=2, color="tab:purple")
735
+ ax.set_title("Warm-up Loss (supervised, per epoch)", fontsize=13, fontweight="bold")
736
+ ax.set_xlabel("Epoch"); ax.set_ylabel("Cross-Entropy Loss")
737
+ ax.grid(alpha=0.3); fig.tight_layout()
738
+ fig.savefig("warmup_loss.png", dpi=150); plt.close(fig)
739
+
740
+ fig, ax = plt.subplots(figsize=(9,5))
 
 
 
 
 
741
  ax.plot(iters, reward_history, marker="o", linewidth=2,
742
  label="Collect reward", color="tab:blue")
743
  ax.plot(iters, eval_history, marker="s", linewidth=2, linestyle="--",
 
747
  ax.axhline(y=warmup_reward, color="tab:purple", linestyle=":",
748
  linewidth=1.5, label=f"Post-warmup ({warmup_reward:+.3f})")
749
  ax.set_title("PPO Reward per Iteration", fontsize=14, fontweight="bold")
750
+ ax.set_xlabel("Iteration"); ax.set_ylabel("Average Reward")
751
+ ax.legend(loc="best", fontsize=8); ax.grid(alpha=0.3)
752
+ fig.tight_layout(); fig.savefig("reward_curve.png", dpi=150); plt.close(fig)
753
+
754
+ fig, ax = plt.subplots(figsize=(9,5))
 
 
 
 
 
755
  ax.plot(iters, loss_history, marker="o", linewidth=2,
756
  label="Total loss", color="tab:red")
 
 
757
  ax.set_title("PPO Loss per Iteration", fontsize=14, fontweight="bold")
758
+ ax.set_xlabel("Iteration"); ax.set_ylabel("Loss")
759
+ ax.legend(loc="best"); ax.grid(alpha=0.3)
760
+ fig.tight_layout(); fig.savefig("loss_curve.png", dpi=150); plt.close(fig)
761
+
762
+ print("Plots saved: warmup_loss.png, reward_curve.png, loss_curve.png")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
763
  print("="*60)
764
 
765
  if __name__ == "__main__":