100XZX001 commited on
Commit
76f5801
·
verified ·
1 Parent(s): 0f1f590

Update training.py

Browse files
Files changed (1) hide show
  1. training.py +92 -66
training.py CHANGED
@@ -1,4 +1,4 @@
1
- # training.py – Memorysafe: Phi‑3‑mini + Expert Demos + Fast PPO (2 iterations)
2
  import os
3
  os.environ["TRITON_DISABLE"] = "1"
4
 
@@ -14,6 +14,7 @@ import numpy as np
14
  import re
15
  import random
16
  import matplotlib.pyplot as plt
 
17
 
18
  from unsloth import FastLanguageModel
19
  from transformers import TrainingArguments
@@ -25,9 +26,15 @@ from redteam import BUG_DB
25
  from models import (
26
  RunTests, RunLinter, Inspect,
27
  ProposeFix, WriteComment, AskQuestion,
28
- Done, Skip, QueryDocs, map_to_env as model_map_to_env
29
  )
30
 
 
 
 
 
 
 
31
  # ======================================================================
32
  @dataclass
33
  class AgentAction:
@@ -69,13 +76,30 @@ def parse_action(output: str) -> AgentAction:
69
  return AgentAction("invalid", output)
70
 
71
  def map_to_env(action: AgentAction):
72
- return model_map_to_env(action.action_type, action.content)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
 
74
  # ======================================================================
75
  def load_model():
76
  model, tokenizer = FastLanguageModel.from_pretrained(
77
  model_name="unsloth/Phi-3-mini-4k-instruct-bnb-4bit",
78
- max_seq_length=480, # smaller window for memory
79
  load_in_4bit=True,
80
  )
81
  model = FastLanguageModel.get_peft_model(
@@ -120,10 +144,7 @@ def test_model_sanity(model, tokenizer) -> bool:
120
 
121
  # ======================================================================
122
  def _expert_fix_from_context(obs) -> str:
123
- """
124
- Build a conservative fix template named `fix` (required by tests).
125
- Uses bug hints + code snippet patterns to create realistic fixes.
126
- """
127
  bug = (getattr(obs, "bug_description", "") or "").lower()
128
  code = getattr(obs, "code_snippet", "") or ""
129
 
@@ -134,19 +155,16 @@ def _expert_fix_from_context(obs) -> str:
134
  " return 0\n"
135
  " return sum(data) / len(data)"
136
  )
137
-
138
  if "operator" in bug or "sign" in bug:
139
  return (
140
  "def fix(a, b):\n"
141
  " return a + b"
142
  )
143
-
144
  if "off_by_one" in bug or "loop" in bug:
145
  return (
146
  "def fix(items):\n"
147
  " return len(items)"
148
  )
149
-
150
  if "null" in bug or "key" in bug or "dict" in code.lower():
151
  return (
152
  "def fix(payload):\n"
@@ -154,8 +172,6 @@ def _expert_fix_from_context(obs) -> str:
154
  " user_id = payload.get('id')\n"
155
  " return users.get(user_id)"
156
  )
157
-
158
- # Concurrency-heavy tasks (harder/hardest).
159
  if "race" in bug or "missing_lock" in bug or "thread_safe" in bug or "global_nonatomic" in bug:
160
  return (
161
  "import threading\n"
@@ -167,7 +183,6 @@ def _expert_fix_from_context(obs) -> str:
167
  " return 0\n"
168
  " return counter + 1"
169
  )
170
-
171
  if "deadlock" in bug or "double_lock" in bug or "lock order" in bug or "nested_lock" in bug:
172
  return (
173
  "import threading\n"
@@ -182,7 +197,6 @@ def _expert_fix_from_context(obs) -> str:
182
  " with second:\n"
183
  " return work() if callable(work) else work"
184
  )
185
-
186
  if "fork_join" in bug or "join" in bug:
187
  return (
188
  "import threading\n"
@@ -193,8 +207,6 @@ def _expert_fix_from_context(obs) -> str:
193
  " t.join()\n"
194
  " return True"
195
  )
196
-
197
- # Generic safe fallback keeps the RL pipeline alive for unknown bugs.
198
  return (
199
  "def fix(data):\n"
200
  " if data is None:\n"
@@ -202,12 +214,8 @@ def _expert_fix_from_context(obs) -> str:
202
  " return data"
203
  )
204
 
205
-
206
  def _expert_supervised_policy(obs) -> str:
207
- """
208
- Real workflow policy:
209
- inspect -> tests/linter -> docs -> fix -> negotiate -> done.
210
- """
211
  author_msg = (getattr(obs, "author_response", "") or "").lower()
212
  tool_output = (getattr(obs, "last_tool_output", "") or "").lower()
213
 
@@ -222,22 +230,17 @@ def _expert_supervised_policy(obs) -> str:
222
  if not getattr(obs, "docs_queried", False):
223
  return '{"action_type": "query_docs", "content": "python bug fixing best practices for edge cases and null safety"}'
224
 
225
- # Use docs again on hard tasks when evidence is still weak.
226
  if getattr(obs, "current_test_score", 0.0) < 0.6 and getattr(obs, "step", 0) >= 3:
227
  bug_hint = (getattr(obs, "bug_description", "") or "concurrency bug").replace('"', "'")
228
- return json.dumps(
229
- {
230
- "action_type": "query_docs",
231
- "content": f"python {bug_hint} lock ordering race condition mitigation patterns",
232
- }
233
- )
234
 
235
- # If test quality is poor, propose a concrete fix.
236
  if getattr(obs, "current_test_score", 0.0) < 0.95:
237
  fix_code = _expert_fix_from_context(obs)
238
  return json.dumps({"action_type": "fix", "content": fix_code})
239
 
240
- # If author is still unconvinced, provide causal explanation.
241
  if author_msg and ("not convinced" in author_msg or "explain" in author_msg or "brief" in author_msg):
242
  return (
243
  '{"action_type": "comment", "content": "This fix works because it handles the failing edge case directly, '
@@ -245,55 +248,72 @@ def _expert_supervised_policy(obs) -> str:
245
  'The change is intentionally small to reduce regression risk."}'
246
  )
247
 
248
- # If negotiation is strong enough and quality is good, terminate.
249
  conf = float(getattr(obs, "author_confidence", 0.0))
250
  threshold = float(getattr(obs, "author_threshold", 0.5))
251
  score = float(getattr(obs, "current_test_score", 0.0))
252
  if conf >= threshold and score >= 0.8:
253
  return '{"action_type": "done"}'
254
 
255
- # Nudge conversation forward when tests are okay but acceptance is pending.
256
  return (
257
- '{"action_type": "question", "content": "Would you like a quick walkthrough of a failing scenario, the root cause, and how the fix prevents regressions?"}'
 
258
  )
259
 
260
  # ======================================================================
261
- def supervised_warmup(model, tokenizer, env, n_episodes=16, epochs=1, max_steps=8):
 
 
 
 
 
262
  print("\n" + "="*60)
263
  print("SUPERVISED WARM-UP: Real environment demonstrations")
264
  print("="*60)
265
 
266
  examples = []
267
- tasks = ["easy", "medium", "hard", "harder", "hardest"]
268
- for ep in range(n_episodes):
269
- task = random.choice(tasks)
270
- env.set_task(task)
271
- obs = env.reset()
272
- history = []
273
- done = False
274
-
275
- steps = 0
276
- while not done and steps < max_steps:
277
- prompt = build_prompt(obs, history)
278
- action_text = _expert_supervised_policy(obs)
279
- action = parse_action(action_text)
280
- env_action = map_to_env(action)
281
- next_obs, _, done, _ = env.step(env_action)
282
 
 
 
 
 
 
 
 
283
  messages = [
284
  {"role": "user", "content": prompt},
285
- {"role": "assistant", "content": action_text},
286
  ]
287
  full_text = tokenizer.apply_chat_template(messages, tokenize=False)
288
  examples.append({"text": full_text})
289
-
290
- history.append(f"Agent: {action_text}")
291
- history.append(f"Env: {next_obs.last_tool_output}")
292
- history = history[-8:]
293
- obs = next_obs
294
- steps += 1
295
-
296
- print(f"Supervised episode {ep+1}: task={task}, steps={steps}, done={done}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
297
 
298
  if not examples:
299
  print("No supervised examples generated; skipping warm-up.")
@@ -317,9 +337,9 @@ def supervised_warmup(model, tokenizer, env, n_episodes=16, epochs=1, max_steps=
317
  bf16=True,
318
  ),
319
  )
320
- print(f"Training on {len(examples)} real env examples for {epochs} epochs...")
321
  trainer.train()
322
- print("✓ Supervised warm-up (real env) complete\n")
323
  torch.cuda.empty_cache()
324
 
325
  # ======================================================================
@@ -551,7 +571,7 @@ def evaluate_policy(env, model, tokenizer, n_episodes=2, max_steps=6):
551
  return {"avg_reward": np.mean(total_rewards), "std_reward": np.std(total_rewards)}
552
 
553
  # ======================================================================
554
- def train_ppo():
555
  n_iterations = 2
556
  trajectories_per_iter = 2
557
  n_epochs = 1
@@ -568,8 +588,11 @@ def train_ppo():
568
  return
569
  env = CodeReviewEnv()
570
 
571
- # Warm-up (real env demonstrations with expert policy)
572
- supervised_warmup(model, tokenizer, env, n_episodes=16, epochs=1, max_steps=8)
 
 
 
573
 
574
  optimizer = AdamW(model.parameters(), lr=learning_rate)
575
  task_levels = list(BUG_DB.keys())
@@ -607,6 +630,9 @@ def train_ppo():
607
  plt.grid(alpha=0.3); plt.tight_layout(); plt.savefig("loss_curve.png", dpi=150); plt.close()
608
  print("Plots saved as reward_curve.png and loss_curve.png.")
609
  print("="*60)
610
-
 
611
  if __name__ == "__main__":
612
- train_ppo()
 
 
 
1
+ # training.py – Preloaded embedder, dual supervised warm‑up, 2 PPO iterations
2
  import os
3
  os.environ["TRITON_DISABLE"] = "1"
4
 
 
14
  import re
15
  import random
16
  import matplotlib.pyplot as plt
17
+ from pathlib import Path
18
 
19
  from unsloth import FastLanguageModel
20
  from transformers import TrainingArguments
 
26
  from models import (
27
  RunTests, RunLinter, Inspect,
28
  ProposeFix, WriteComment, AskQuestion,
29
+ Done, Skip, QueryDocs
30
  )
31
 
32
+ # Pre‑load the sentence‑transformer model to avoid OOM during warm‑up
33
+ from rltool import ToolBox
34
+ print("Pre‑loading documentation retriever …")
35
+ ToolBox._get_embedder()
36
+ print("Done.")
37
+
38
  # ======================================================================
39
  @dataclass
40
  class AgentAction:
 
76
  return AgentAction("invalid", output)
77
 
78
  def map_to_env(action: AgentAction):
79
+ if action.action_type == "run_tests":
80
+ return RunTests()
81
+ elif action.action_type == "run_linter":
82
+ return RunLinter()
83
+ elif action.action_type == "inspect":
84
+ return Inspect()
85
+ elif action.action_type == "fix":
86
+ return ProposeFix(fix_code=action.content or "")
87
+ elif action.action_type == "comment":
88
+ return WriteComment(comment_text=action.content or "")
89
+ elif action.action_type == "question":
90
+ return AskQuestion(question=action.content or "")
91
+ elif action.action_type == "query_docs":
92
+ return QueryDocs(query_topic=action.content or "")
93
+ elif action.action_type == "done":
94
+ return Done()
95
+ else:
96
+ return Skip()
97
 
98
  # ======================================================================
99
  def load_model():
100
  model, tokenizer = FastLanguageModel.from_pretrained(
101
  model_name="unsloth/Phi-3-mini-4k-instruct-bnb-4bit",
102
+ max_seq_length=480,
103
  load_in_4bit=True,
104
  )
105
  model = FastLanguageModel.get_peft_model(
 
144
 
145
  # ======================================================================
146
  def _expert_fix_from_context(obs) -> str:
147
+ """Build a conservative fix template based on bug hints."""
 
 
 
148
  bug = (getattr(obs, "bug_description", "") or "").lower()
149
  code = getattr(obs, "code_snippet", "") or ""
150
 
 
155
  " return 0\n"
156
  " return sum(data) / len(data)"
157
  )
 
158
  if "operator" in bug or "sign" in bug:
159
  return (
160
  "def fix(a, b):\n"
161
  " return a + b"
162
  )
 
163
  if "off_by_one" in bug or "loop" in bug:
164
  return (
165
  "def fix(items):\n"
166
  " return len(items)"
167
  )
 
168
  if "null" in bug or "key" in bug or "dict" in code.lower():
169
  return (
170
  "def fix(payload):\n"
 
172
  " user_id = payload.get('id')\n"
173
  " return users.get(user_id)"
174
  )
 
 
175
  if "race" in bug or "missing_lock" in bug or "thread_safe" in bug or "global_nonatomic" in bug:
176
  return (
177
  "import threading\n"
 
183
  " return 0\n"
184
  " return counter + 1"
185
  )
 
186
  if "deadlock" in bug or "double_lock" in bug or "lock order" in bug or "nested_lock" in bug:
187
  return (
188
  "import threading\n"
 
197
  " with second:\n"
198
  " return work() if callable(work) else work"
199
  )
 
200
  if "fork_join" in bug or "join" in bug:
201
  return (
202
  "import threading\n"
 
207
  " t.join()\n"
208
  " return True"
209
  )
 
 
210
  return (
211
  "def fix(data):\n"
212
  " if data is None:\n"
 
214
  " return data"
215
  )
216
 
 
217
  def _expert_supervised_policy(obs) -> str:
218
+ """Expert policy used during supervised warm‑up."""
 
 
 
219
  author_msg = (getattr(obs, "author_response", "") or "").lower()
220
  tool_output = (getattr(obs, "last_tool_output", "") or "").lower()
221
 
 
230
  if not getattr(obs, "docs_queried", False):
231
  return '{"action_type": "query_docs", "content": "python bug fixing best practices for edge cases and null safety"}'
232
 
 
233
  if getattr(obs, "current_test_score", 0.0) < 0.6 and getattr(obs, "step", 0) >= 3:
234
  bug_hint = (getattr(obs, "bug_description", "") or "concurrency bug").replace('"', "'")
235
+ return json.dumps({
236
+ "action_type": "query_docs",
237
+ "content": f"python {bug_hint} lock ordering race condition mitigation patterns"
238
+ })
 
 
239
 
 
240
  if getattr(obs, "current_test_score", 0.0) < 0.95:
241
  fix_code = _expert_fix_from_context(obs)
242
  return json.dumps({"action_type": "fix", "content": fix_code})
243
 
 
244
  if author_msg and ("not convinced" in author_msg or "explain" in author_msg or "brief" in author_msg):
245
  return (
246
  '{"action_type": "comment", "content": "This fix works because it handles the failing edge case directly, '
 
248
  'The change is intentionally small to reduce regression risk."}'
249
  )
250
 
 
251
  conf = float(getattr(obs, "author_confidence", 0.0))
252
  threshold = float(getattr(obs, "author_threshold", 0.5))
253
  score = float(getattr(obs, "current_test_score", 0.0))
254
  if conf >= threshold and score >= 0.8:
255
  return '{"action_type": "done"}'
256
 
 
257
  return (
258
+ '{"action_type": "question", "content": "Would you like a quick walkthrough of a failing scenario, '
259
+ 'the root cause, and how the fix prevents regressions?"}'
260
  )
261
 
262
  # ======================================================================
263
+ def supervised_warmup(model, tokenizer, env, n_episodes=16, epochs=1, max_steps=8,
264
+ json_path: Optional[str] = None):
265
+ """
266
+ Supervised warm‑up using either a JSON file of (prompt, action) pairs,
267
+ or a rule‑based expert playing in the real environment.
268
+ """
269
  print("\n" + "="*60)
270
  print("SUPERVISED WARM-UP: Real environment demonstrations")
271
  print("="*60)
272
 
273
  examples = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
274
 
275
+ if json_path and Path(json_path).exists():
276
+ print(f"Loading training examples from {json_path} ...")
277
+ with open(json_path, 'r', encoding='utf-8') as f:
278
+ raw_pairs = json.load(f)
279
+ for pair in raw_pairs:
280
+ prompt = pair["prompt"]
281
+ action = pair["action"]
282
  messages = [
283
  {"role": "user", "content": prompt},
284
+ {"role": "assistant", "content": action}
285
  ]
286
  full_text = tokenizer.apply_chat_template(messages, tokenize=False)
287
  examples.append({"text": full_text})
288
+ print(f"Loaded {len(examples)} examples from JSON.")
289
+ else:
290
+ # Fallback to real environment rollouts with expert policy
291
+ tasks = ["easy", "medium", "hard", "harder", "hardest"]
292
+ for ep in range(n_episodes):
293
+ task = random.choice(tasks)
294
+ env.set_task(task)
295
+ obs = env.reset()
296
+ history = []
297
+ done = False
298
+ steps = 0
299
+ while not done and steps < max_steps:
300
+ prompt = build_prompt(obs, history)
301
+ action_text = _expert_supervised_policy(obs)
302
+ action = parse_action(action_text)
303
+ env_action = map_to_env(action)
304
+ next_obs, _, done, _ = env.step(env_action)
305
+ messages = [
306
+ {"role": "user", "content": prompt},
307
+ {"role": "assistant", "content": action_text},
308
+ ]
309
+ full_text = tokenizer.apply_chat_template(messages, tokenize=False)
310
+ examples.append({"text": full_text})
311
+ history.append(f"Agent: {action_text}")
312
+ history.append(f"Env: {next_obs.last_tool_output}")
313
+ history = history[-8:]
314
+ obs = next_obs
315
+ steps += 1
316
+ print(f"Supervised episode {ep+1}: task={task}, steps={steps}, done={done}")
317
 
318
  if not examples:
319
  print("No supervised examples generated; skipping warm-up.")
 
337
  bf16=True,
338
  ),
339
  )
340
+ print(f"Training on {len(examples)} examples for {epochs} epochs...")
341
  trainer.train()
342
+ print("✓ Supervised warm-up complete\n")
343
  torch.cuda.empty_cache()
344
 
345
  # ======================================================================
 
571
  return {"avg_reward": np.mean(total_rewards), "std_reward": np.std(total_rewards)}
572
 
573
  # ======================================================================
574
+ def train_ppo(json_dataset_path: Optional[str] = None):
575
  n_iterations = 2
576
  trajectories_per_iter = 2
577
  n_epochs = 1
 
588
  return
589
  env = CodeReviewEnv()
590
 
591
+ # Run supervised warm‑up twice (if JSON provided, it will be used each time)
592
+ supervised_warmup(model, tokenizer, env, n_episodes=12, epochs=1, max_steps=8,
593
+ json_path=json_dataset_path)
594
+ supervised_warmup(model, tokenizer, env, n_episodes=12, epochs=1, max_steps=8,
595
+ json_path=json_dataset_path)
596
 
597
  optimizer = AdamW(model.parameters(), lr=learning_rate)
598
  task_levels = list(BUG_DB.keys())
 
630
  plt.grid(alpha=0.3); plt.tight_layout(); plt.savefig("loss_curve.png", dpi=150); plt.close()
631
  print("Plots saved as reward_curve.png and loss_curve.png.")
632
  print("="*60)
633
+
634
+ # ======================================================================
635
  if __name__ == "__main__":
636
+ # Optionally provide a path to a JSON file of training pairs.
637
+ # Example: {"prompt": "You are a code review agent...", "action": "{\"action_type\": \"inspect\"}"}
638
+ train_ppo(json_dataset_path=None) # set to your JSON file path if you have one