Spaces:
Sleeping
Sleeping
Update training.py
Browse files- training.py +44 -19
training.py
CHANGED
|
@@ -251,41 +251,58 @@ def supervised_warmup(model, tokenizer, data_path="training_data.json", epochs=3
|
|
| 251 |
# =========================================================
|
| 252 |
# PPO UPDATE (FIXED advantage = return – baseline)
|
| 253 |
# =========================================================
|
| 254 |
-
def ppo_update(trajectories, model, tokenizer, optimizer, clip=0.2):
|
| 255 |
model.train()
|
| 256 |
|
| 257 |
losses = []
|
| 258 |
kls = []
|
| 259 |
|
| 260 |
-
#
|
|
|
|
|
|
|
| 261 |
all_returns = []
|
| 262 |
-
for traj in trajectories:
|
| 263 |
-
returns = np.cumsum(traj.rewards[::-1])[::-1]
|
| 264 |
-
all_returns.extend(returns)
|
| 265 |
-
baseline = np.mean(all_returns) if all_returns else 0.0
|
| 266 |
|
|
|
|
| 267 |
for traj in trajectories:
|
| 268 |
-
returns =
|
| 269 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 270 |
|
| 271 |
for i in range(len(traj.states)):
|
| 272 |
state = traj.states[i]
|
| 273 |
action = traj.actions[i]
|
|
|
|
| 274 |
old_lp = torch.tensor(traj.logprobs[i], device=DEVICE)
|
| 275 |
|
| 276 |
-
#
|
| 277 |
-
adv = returns[i] - baseline
|
| 278 |
|
| 279 |
messages = [{"role": "user", "content": state}]
|
| 280 |
-
formatted = tokenizer.apply_chat_template(
|
| 281 |
-
|
|
|
|
| 282 |
|
|
|
|
| 283 |
inputs = tokenizer(full, return_tensors="pt", truncation=True).to(DEVICE)
|
|
|
|
| 284 |
logits = model(**inputs).logits
|
| 285 |
|
| 286 |
action_ids = tokenizer.encode(action, add_special_tokens=False)
|
| 287 |
-
|
| 288 |
-
prefix_len = len(prefix_ids)
|
| 289 |
|
| 290 |
logps = []
|
| 291 |
entropy = 0.0
|
|
@@ -295,8 +312,9 @@ def ppo_update(trajectories, model, tokenizer, optimizer, clip=0.2):
|
|
| 295 |
if pos == 0 or pos >= logits.shape[1]:
|
| 296 |
continue
|
| 297 |
|
| 298 |
-
token_logits = logits[0, pos-1]
|
| 299 |
log_probs = F.log_softmax(token_logits, dim=-1)
|
|
|
|
| 300 |
lp = log_probs[action_ids[idx]]
|
| 301 |
logps.append(lp)
|
| 302 |
|
|
@@ -307,10 +325,15 @@ def ppo_update(trajectories, model, tokenizer, optimizer, clip=0.2):
|
|
| 307 |
continue
|
| 308 |
|
| 309 |
new_lp = torch.stack(logps).sum()
|
|
|
|
|
|
|
| 310 |
ratio = torch.exp(new_lp - old_lp)
|
|
|
|
| 311 |
s1 = ratio * adv
|
| 312 |
-
s2 = torch.clamp(ratio, 1-clip, 1+clip) * adv
|
| 313 |
-
|
|
|
|
|
|
|
| 314 |
|
| 315 |
if torch.isnan(loss):
|
| 316 |
continue
|
|
@@ -324,8 +347,10 @@ def ppo_update(trajectories, model, tokenizer, optimizer, clip=0.2):
|
|
| 324 |
kls.append(kl)
|
| 325 |
losses.append(loss.item())
|
| 326 |
|
| 327 |
-
return
|
| 328 |
-
|
|
|
|
|
|
|
| 329 |
# =========================================================
|
| 330 |
# MAIN TRAINING LOOP
|
| 331 |
# =========================================================
|
|
|
|
| 251 |
# =========================================================
|
| 252 |
# PPO UPDATE (FIXED advantage = return – baseline)
|
| 253 |
# =========================================================
|
| 254 |
+
def ppo_update(trajectories, model, tokenizer, optimizer, clip=0.2, gamma=0.99):
|
| 255 |
model.train()
|
| 256 |
|
| 257 |
losses = []
|
| 258 |
kls = []
|
| 259 |
|
| 260 |
+
# =========================
|
| 261 |
+
# Compute returns + baseline
|
| 262 |
+
# =========================
|
| 263 |
all_returns = []
|
|
|
|
|
|
|
|
|
|
|
|
|
| 264 |
|
| 265 |
+
traj_returns = []
|
| 266 |
for traj in trajectories:
|
| 267 |
+
returns = []
|
| 268 |
+
running = 0.0
|
| 269 |
+
|
| 270 |
+
for r in reversed(traj.rewards):
|
| 271 |
+
running = r + gamma * running
|
| 272 |
+
returns.insert(0, running)
|
| 273 |
+
|
| 274 |
+
returns = torch.tensor(returns, dtype=torch.float32, device=DEVICE)
|
| 275 |
+
traj_returns.append(returns)
|
| 276 |
+
all_returns.extend(returns.tolist())
|
| 277 |
+
|
| 278 |
+
baseline = torch.tensor(np.mean(all_returns), device=DEVICE) if all_returns else torch.tensor(0.0, device=DEVICE)
|
| 279 |
+
|
| 280 |
+
# =========================
|
| 281 |
+
# PPO update
|
| 282 |
+
# =========================
|
| 283 |
+
for traj, returns in zip(trajectories, traj_returns):
|
| 284 |
|
| 285 |
for i in range(len(traj.states)):
|
| 286 |
state = traj.states[i]
|
| 287 |
action = traj.actions[i]
|
| 288 |
+
|
| 289 |
old_lp = torch.tensor(traj.logprobs[i], device=DEVICE)
|
| 290 |
|
| 291 |
+
# Advantage (detached)
|
| 292 |
+
adv = (returns[i] - baseline).detach()
|
| 293 |
|
| 294 |
messages = [{"role": "user", "content": state}]
|
| 295 |
+
formatted = tokenizer.apply_chat_template(
|
| 296 |
+
messages, tokenize=False, add_generation_prompt=True
|
| 297 |
+
)
|
| 298 |
|
| 299 |
+
full = formatted + action
|
| 300 |
inputs = tokenizer(full, return_tensors="pt", truncation=True).to(DEVICE)
|
| 301 |
+
|
| 302 |
logits = model(**inputs).logits
|
| 303 |
|
| 304 |
action_ids = tokenizer.encode(action, add_special_tokens=False)
|
| 305 |
+
prefix_len = len(tokenizer.encode(formatted, add_special_tokens=False))
|
|
|
|
| 306 |
|
| 307 |
logps = []
|
| 308 |
entropy = 0.0
|
|
|
|
| 312 |
if pos == 0 or pos >= logits.shape[1]:
|
| 313 |
continue
|
| 314 |
|
| 315 |
+
token_logits = logits[0, pos - 1]
|
| 316 |
log_probs = F.log_softmax(token_logits, dim=-1)
|
| 317 |
+
|
| 318 |
lp = log_probs[action_ids[idx]]
|
| 319 |
logps.append(lp)
|
| 320 |
|
|
|
|
| 325 |
continue
|
| 326 |
|
| 327 |
new_lp = torch.stack(logps).sum()
|
| 328 |
+
|
| 329 |
+
# PPO ratio
|
| 330 |
ratio = torch.exp(new_lp - old_lp)
|
| 331 |
+
|
| 332 |
s1 = ratio * adv
|
| 333 |
+
s2 = torch.clamp(ratio, 1 - clip, 1 + clip) * adv
|
| 334 |
+
|
| 335 |
+
policy_loss = -torch.min(s1, s2)
|
| 336 |
+
loss = policy_loss - 0.01 * (entropy / len(logps))
|
| 337 |
|
| 338 |
if torch.isnan(loss):
|
| 339 |
continue
|
|
|
|
| 347 |
kls.append(kl)
|
| 348 |
losses.append(loss.item())
|
| 349 |
|
| 350 |
+
return (
|
| 351 |
+
float(np.mean(losses)) if losses else 0.0,
|
| 352 |
+
float(np.mean(kls)) if kls else 0.0,
|
| 353 |
+
)
|
| 354 |
# =========================================================
|
| 355 |
# MAIN TRAINING LOOP
|
| 356 |
# =========================================================
|