parth-1 commited on
Commit
7de0535
Β·
verified Β·
1 Parent(s): c16c504

Update grpo_train.py

Browse files
Files changed (1) hide show
  1. grpo_train.py +73 -18
grpo_train.py CHANGED
@@ -13,6 +13,16 @@ from trl import GRPOTrainer, GRPOConfig
13
 
14
  PatchFastRL("GRPO", FastLanguageModel)
15
 
 
 
 
 
 
 
 
 
 
 
16
  # =========================
17
  # CONFIG
18
  # =========================
@@ -37,7 +47,10 @@ ALLOWED_ACTIONS = [
37
  # =========================
38
 
39
  def ensure_env_ready():
40
- for _ in range(20):
 
 
 
41
  try:
42
  r = requests.post(
43
  f"{ENV_URL}/reset",
@@ -45,11 +58,20 @@ def ensure_env_ready():
45
  timeout=5
46
  )
47
  if r.status_code == 200:
 
 
 
48
  print("βœ… Environment ready")
49
  return
50
- except:
 
 
 
51
  pass
52
  time.sleep(1)
 
 
 
53
  raise RuntimeError("❌ ENV not reachable")
54
 
55
  # =========================
@@ -240,21 +262,39 @@ def build_dataset():
240
  # REWARD FUNCTION (FIXED)
241
  # =========================
242
 
 
 
243
  def reward_environment(prompts, completions, task_id=None, setup_actions=None, **kwargs):
244
- """Shaped reward for GRPO.
245
-
246
- Pure env reward is too sparse (mostly -0.05) to give clear gradients.
247
- We add explicit shaping:
248
- - invalid JSON / invalid action_type -> -1.0 (strong negative signal)
249
- - valid action env REJECTS (wrong phase / API failure) -> -0.5
250
- - valid action env ACCEPTS (advances state) -> +0.5 + env_reward
251
- - terminal correct decision -> env_reward already contains +1.0 bonus
252
- """
 
 
 
 
 
 
 
253
  client = EnvClient(ENV_URL)
254
  rewards = []
255
 
256
- for completion, t_id, setup in zip(completions, task_id, setup_actions):
 
 
 
 
 
 
257
  parsed = extract_json(completion)
 
 
 
258
  if not parsed:
259
  rewards.append(-1.0)
260
  continue
@@ -301,14 +341,22 @@ def reward_environment(prompts, completions, task_id=None, setup_actions=None, *
301
  # =========================
302
 
303
  if torch.cuda.is_available():
304
- _vram = torch.cuda.get_device_properties(0).total_memory
305
- _name = torch.cuda.get_device_name(0)
306
- print(f"GPU: {_name} VRAM: {_vram / 1024**3:.1f} GB")
 
 
307
  else:
308
  _vram = 0
309
  _name = "CPU"
 
 
 
 
310
 
311
- USE_4BIT = _vram < 40 * 1024**3 # True for T4 (15 GB) and L4 (24 GB); False for A100 (80 GB)
 
 
312
 
313
  model, tokenizer = FastLanguageModel.from_pretrained(
314
  model_name="unsloth/Llama-3.1-8B-Instruct",
@@ -337,6 +385,10 @@ model = FastLanguageModel.get_peft_model(
337
 
338
  dataset = build_dataset()
339
 
 
 
 
 
340
  trainer = GRPOTrainer(
341
  model=model,
342
  reward_funcs=[reward_environment],
@@ -351,8 +403,8 @@ trainer = GRPOTrainer(
351
  max_completion_length=128,
352
  logging_steps=3 if USE_4BIT else 5,
353
  warmup_steps=5 if USE_4BIT else 10,
354
- bf16=not USE_4BIT,
355
- fp16=USE_4BIT,
356
  report_to="none",
357
  ),
358
  train_dataset=dataset,
@@ -366,6 +418,9 @@ trainer = GRPOTrainer(
366
  if __name__ == "__main__":
367
  ensure_env_ready()
368
 
 
 
 
369
  print("Starting GRPO training...")
370
  trainer.train()
371
 
 
13
 
14
  PatchFastRL("GRPO", FastLanguageModel)
15
 
16
+ # #region agent log
17
+ import pathlib as _pl
18
+ _DLOG = _pl.Path("debug-851b5f.log")
19
+ def _dlog(hyp, loc, msg, data=None):
20
+ import time as _t
21
+ entry = json.dumps({"sessionId":"851b5f","hypothesisId":hyp,"location":loc,"message":msg,"data":data or {},"timestamp":int(_t.time()*1000)})
22
+ with open(_DLOG, "a") as f: f.write(entry + "\n")
23
+ print(f"[DBG:{hyp}] {msg} {data or ''}", flush=True)
24
+ # #endregion
25
+
26
  # =========================
27
  # CONFIG
28
  # =========================
 
47
  # =========================
48
 
49
  def ensure_env_ready():
50
+ # #region agent log
51
+ _dlog("B", "grpo_train.py:ensure_env_ready", "Checking env", {"ENV_URL": ENV_URL})
52
+ # #endregion
53
+ for i in range(20):
54
  try:
55
  r = requests.post(
56
  f"{ENV_URL}/reset",
 
58
  timeout=5
59
  )
60
  if r.status_code == 200:
61
+ # #region agent log
62
+ _dlog("B", "grpo_train.py:ensure_env_ready", "Env ready", {"attempt": i+1, "status": r.status_code})
63
+ # #endregion
64
  print("βœ… Environment ready")
65
  return
66
+ except Exception as e:
67
+ # #region agent log
68
+ if i == 0: _dlog("B", "grpo_train.py:ensure_env_ready", "Env connection failed", {"attempt": i+1, "error": str(e)[:200]})
69
+ # #endregion
70
  pass
71
  time.sleep(1)
72
+ # #region agent log
73
+ _dlog("B", "grpo_train.py:ensure_env_ready", "ENV UNREACHABLE after 20 attempts", {})
74
+ # #endregion
75
  raise RuntimeError("❌ ENV not reachable")
76
 
77
  # =========================
 
262
  # REWARD FUNCTION (FIXED)
263
  # =========================
264
 
265
+ _reward_call_count = [0]
266
+
267
  def reward_environment(prompts, completions, task_id=None, setup_actions=None, **kwargs):
268
+ """Shaped reward for GRPO."""
269
+ _reward_call_count[0] += 1
270
+ _call = _reward_call_count[0]
271
+ # #region agent log
272
+ _dlog("C", "grpo_train.py:reward_env", f"reward call #{_call}", {
273
+ "n_prompts": len(prompts) if prompts else 0,
274
+ "n_completions": len(completions) if completions else 0,
275
+ "completions_type": type(completions).__name__,
276
+ "first_completion_type": type(completions[0]).__name__ if completions else "N/A",
277
+ "first_completion_preview": str(completions[0])[:150] if completions else "N/A",
278
+ "task_id_is_none": task_id is None,
279
+ "setup_actions_is_none": setup_actions is None,
280
+ "kwargs_keys": list(kwargs.keys()),
281
+ })
282
+ # #endregion
283
+
284
  client = EnvClient(ENV_URL)
285
  rewards = []
286
 
287
+ if task_id is None or setup_actions is None:
288
+ # #region agent log
289
+ _dlog("D", "grpo_train.py:reward_env", "task_id or setup_actions is None β€” returning -1 for all", {"call": _call})
290
+ # #endregion
291
+ return [-1.0] * len(completions)
292
+
293
+ for idx, (completion, t_id, setup) in enumerate(zip(completions, task_id, setup_actions)):
294
  parsed = extract_json(completion)
295
+ # #region agent log
296
+ if _call <= 3: _dlog("D", "grpo_train.py:reward_loop", f"call#{_call} item#{idx}", {"parsed_ok": parsed is not None, "action": parsed.get("action_type") if parsed else None, "raw_preview": str(completion)[:120], "task_id": t_id})
297
+ # #endregion
298
  if not parsed:
299
  rewards.append(-1.0)
300
  continue
 
341
  # =========================
342
 
343
  if torch.cuda.is_available():
344
+ _props = torch.cuda.get_device_properties(0)
345
+ _vram = _props.total_memory
346
+ _name = _props.name
347
+ _cc = (_props.major, _props.minor) # compute capability
348
+ print(f"GPU: {_name} VRAM: {_vram / 1024**3:.1f} GB Compute: {_cc[0]}.{_cc[1]}")
349
  else:
350
  _vram = 0
351
  _name = "CPU"
352
+ _cc = (0, 0)
353
+
354
+ USE_4BIT = _vram < 40 * 1024**3 # T4 (15 GB), L4 (24 GB) β†’ 4-bit; A100 (80 GB) β†’ bf16
355
+ USE_BF16 = _cc >= (8, 0) # Ampere+ (A100, L4) support bf16; Turing (T4) does not
356
 
357
+ # #region agent log
358
+ _dlog("A", "grpo_train.py:gpu_detect", "GPU config resolved", {"name":_name,"vram_gb":round(_vram/1024**3,1),"cc":list(_cc),"USE_4BIT":USE_4BIT,"USE_BF16":USE_BF16})
359
+ # #endregion
360
 
361
  model, tokenizer = FastLanguageModel.from_pretrained(
362
  model_name="unsloth/Llama-3.1-8B-Instruct",
 
385
 
386
  dataset = build_dataset()
387
 
388
+ # #region agent log
389
+ _dlog("A", "grpo_train.py:trainer_init", "Creating GRPOTrainer", {"USE_4BIT":USE_4BIT,"USE_BF16":USE_BF16,"epochs":1 if USE_4BIT else 3,"batch":1 if USE_4BIT else 2,"gens":2 if USE_4BIT else 4,"dataset_len":len(dataset)})
390
+ # #endregion
391
+
392
  trainer = GRPOTrainer(
393
  model=model,
394
  reward_funcs=[reward_environment],
 
403
  max_completion_length=128,
404
  logging_steps=3 if USE_4BIT else 5,
405
  warmup_steps=5 if USE_4BIT else 10,
406
+ bf16=USE_BF16,
407
+ fp16=not USE_BF16,
408
  report_to="none",
409
  ),
410
  train_dataset=dataset,
 
418
  if __name__ == "__main__":
419
  ensure_env_ready()
420
 
421
+ # #region agent log
422
+ _dlog("E", "grpo_train.py:train_start", "About to call trainer.train()", {"gpu_mem_allocated_gb": round(torch.cuda.memory_allocated()/1024**3, 2) if torch.cuda.is_available() else 0})
423
+ # #endregion
424
  print("Starting GRPO training...")
425
  trainer.train()
426