100XZX001 commited on
Commit
e31002c
·
verified ·
1 Parent(s): 80d6603

Update training.py

Browse files
Files changed (1) hide show
  1. training.py +44 -19
training.py CHANGED
@@ -251,41 +251,58 @@ def supervised_warmup(model, tokenizer, data_path="training_data.json", epochs=3
251
  # =========================================================
252
  # PPO UPDATE (FIXED advantage = return – baseline)
253
  # =========================================================
254
- def ppo_update(trajectories, model, tokenizer, optimizer, clip=0.2):
255
  model.train()
256
 
257
  losses = []
258
  kls = []
259
 
260
- # Gather all returns and compute a global baseline (simple REINFORCE)
 
 
261
  all_returns = []
262
- for traj in trajectories:
263
- returns = np.cumsum(traj.rewards[::-1])[::-1]
264
- all_returns.extend(returns)
265
- baseline = np.mean(all_returns) if all_returns else 0.0
266
 
 
267
  for traj in trajectories:
268
- returns = np.cumsum(traj.rewards[::-1])[::-1]
269
- returns = torch.tensor(returns, device=DEVICE)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
270
 
271
  for i in range(len(traj.states)):
272
  state = traj.states[i]
273
  action = traj.actions[i]
 
274
  old_lp = torch.tensor(traj.logprobs[i], device=DEVICE)
275
 
276
- # Proper advantage: return – baseline
277
- adv = returns[i] - baseline
278
 
279
  messages = [{"role": "user", "content": state}]
280
- formatted = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
281
- full = formatted + action
 
282
 
 
283
  inputs = tokenizer(full, return_tensors="pt", truncation=True).to(DEVICE)
 
284
  logits = model(**inputs).logits
285
 
286
  action_ids = tokenizer.encode(action, add_special_tokens=False)
287
- prefix_ids = tokenizer.encode(formatted, add_special_tokens=False)
288
- prefix_len = len(prefix_ids)
289
 
290
  logps = []
291
  entropy = 0.0
@@ -295,8 +312,9 @@ def ppo_update(trajectories, model, tokenizer, optimizer, clip=0.2):
295
  if pos == 0 or pos >= logits.shape[1]:
296
  continue
297
 
298
- token_logits = logits[0, pos-1]
299
  log_probs = F.log_softmax(token_logits, dim=-1)
 
300
  lp = log_probs[action_ids[idx]]
301
  logps.append(lp)
302
 
@@ -307,10 +325,15 @@ def ppo_update(trajectories, model, tokenizer, optimizer, clip=0.2):
307
  continue
308
 
309
  new_lp = torch.stack(logps).sum()
 
 
310
  ratio = torch.exp(new_lp - old_lp)
 
311
  s1 = ratio * adv
312
- s2 = torch.clamp(ratio, 1-clip, 1+clip) * adv
313
- loss = -torch.min(s1, s2) - 0.01 * entropy
 
 
314
 
315
  if torch.isnan(loss):
316
  continue
@@ -324,8 +347,10 @@ def ppo_update(trajectories, model, tokenizer, optimizer, clip=0.2):
324
  kls.append(kl)
325
  losses.append(loss.item())
326
 
327
- return np.mean(losses) if losses else 0.0, np.mean(kls) if kls else 0.0
328
-
 
 
329
  # =========================================================
330
  # MAIN TRAINING LOOP
331
  # =========================================================
 
251
  # =========================================================
252
  # PPO UPDATE (FIXED advantage = return – baseline)
253
  # =========================================================
254
+ def ppo_update(trajectories, model, tokenizer, optimizer, clip=0.2, gamma=0.99):
255
  model.train()
256
 
257
  losses = []
258
  kls = []
259
 
260
+ # =========================
261
+ # Compute returns + baseline
262
+ # =========================
263
  all_returns = []
 
 
 
 
264
 
265
+ traj_returns = []
266
  for traj in trajectories:
267
+ returns = []
268
+ running = 0.0
269
+
270
+ for r in reversed(traj.rewards):
271
+ running = r + gamma * running
272
+ returns.insert(0, running)
273
+
274
+ returns = torch.tensor(returns, dtype=torch.float32, device=DEVICE)
275
+ traj_returns.append(returns)
276
+ all_returns.extend(returns.tolist())
277
+
278
+ baseline = torch.tensor(np.mean(all_returns), device=DEVICE) if all_returns else torch.tensor(0.0, device=DEVICE)
279
+
280
+ # =========================
281
+ # PPO update
282
+ # =========================
283
+ for traj, returns in zip(trajectories, traj_returns):
284
 
285
  for i in range(len(traj.states)):
286
  state = traj.states[i]
287
  action = traj.actions[i]
288
+
289
  old_lp = torch.tensor(traj.logprobs[i], device=DEVICE)
290
 
291
+ # Advantage (detached)
292
+ adv = (returns[i] - baseline).detach()
293
 
294
  messages = [{"role": "user", "content": state}]
295
+ formatted = tokenizer.apply_chat_template(
296
+ messages, tokenize=False, add_generation_prompt=True
297
+ )
298
 
299
+ full = formatted + action
300
  inputs = tokenizer(full, return_tensors="pt", truncation=True).to(DEVICE)
301
+
302
  logits = model(**inputs).logits
303
 
304
  action_ids = tokenizer.encode(action, add_special_tokens=False)
305
+ prefix_len = len(tokenizer.encode(formatted, add_special_tokens=False))
 
306
 
307
  logps = []
308
  entropy = 0.0
 
312
  if pos == 0 or pos >= logits.shape[1]:
313
  continue
314
 
315
+ token_logits = logits[0, pos - 1]
316
  log_probs = F.log_softmax(token_logits, dim=-1)
317
+
318
  lp = log_probs[action_ids[idx]]
319
  logps.append(lp)
320
 
 
325
  continue
326
 
327
  new_lp = torch.stack(logps).sum()
328
+
329
+ # PPO ratio
330
  ratio = torch.exp(new_lp - old_lp)
331
+
332
  s1 = ratio * adv
333
+ s2 = torch.clamp(ratio, 1 - clip, 1 + clip) * adv
334
+
335
+ policy_loss = -torch.min(s1, s2)
336
+ loss = policy_loss - 0.01 * (entropy / len(logps))
337
 
338
  if torch.isnan(loss):
339
  continue
 
347
  kls.append(kl)
348
  losses.append(loss.item())
349
 
350
+ return (
351
+ float(np.mean(losses)) if losses else 0.0,
352
+ float(np.mean(kls)) if kls else 0.0,
353
+ )
354
  # =========================================================
355
  # MAIN TRAINING LOOP
356
  # =========================================================