Buckets:
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from pytorch_optimizer import SOAP | |
| import os | |
| import copy | |
| from src.tokenizer import CharTokenizer | |
| from src.model import TinyReasonerModel | |
| from src.sampler import Sampler | |
| from src.rewards import get_total_reward | |
| from src.prompts import get_random_prompt | |
| def compute_grpo_loss(model, ref_model, tokens, old_log_probs, mask, advantages, clip_eps=0.2, beta=0.0001): | |
| # tokens: (batch, seq_len) | |
| # old_log_probs: list of tensors (different lengths) | |
| # mask: list of tensors | |
| # advantages: (batch,) | |
| total_loss = 0 | |
| for i in range(len(tokens)): | |
| t = torch.tensor([tokens[i]]).long().to(model.embedding.weight.device) | |
| m = mask[i].clone().detach().to(model.embedding.weight.device) | |
| adv = advantages[i] | |
| old_lp = old_log_probs[i] | |
| logits, _ = model(t) | |
| log_probs_full = F.log_softmax(logits[0, :-1, :], dim=-1) | |
| target_tokens = t[0, 1:] | |
| current_lp_all = log_probs_full[torch.arange(len(target_tokens)), target_tokens] | |
| # Ensure mask and current_lp_all have same length | |
| if len(m) > len(current_lp_all): | |
| m = m[:len(current_lp_all)] | |
| elif len(m) < len(current_lp_all): | |
| # This shouldn't happen based on sampler logic, but let's be safe | |
| current_lp_all = current_lp_all[:len(m)] | |
| current_lp = current_lp_all[m == 1] | |
| if len(current_lp) != len(old_lp): | |
| # print(f"Mismatch: current_lp {len(current_lp)}, old_lp {len(old_lp)}") | |
| if len(current_lp) > len(old_lp): | |
| current_lp = current_lp[:len(old_lp)] | |
| else: | |
| # Should not happen if sampler and mask are correct | |
| continue | |
| ratio = torch.exp(current_lp - old_lp) | |
| surr1 = ratio * adv | |
| surr2 = torch.clamp(ratio, 1.0 - clip_eps, 1.0 + clip_eps) * adv | |
| policy_loss = -torch.min(surr1, surr2).mean() | |
| # KL Penalty (optional, against ref_model) | |
| with torch.no_grad(): | |
| ref_logits, _ = ref_model(t) | |
| ref_lp_full = F.log_softmax(ref_logits[0, :-1, :], dim=-1) | |
| ref_lp = ref_lp_full[torch.arange(len(target_tokens)), target_tokens][m == 1] | |
| kl = (torch.exp(ref_lp - current_lp) - (ref_lp - current_lp) - 1).mean() | |
| total_loss += policy_loss + beta * kl | |
| return total_loss / len(tokens) | |
| def train_grpo(num_iterations=500, group_size=32, load_model=None): | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Using device: {device}") | |
| tokenizer = CharTokenizer() | |
| model = TinyReasonerModel(tokenizer.vocab_size).to(device) | |
| start_iteration = 0 | |
| if load_model and os.path.exists(load_model): | |
| model.load_state_dict(torch.load(load_model, map_location=device)) | |
| print(f"Loaded model from {load_model}.") | |
| elif os.path.exists("models/rl_model.pt"): | |
| model.load_state_dict(torch.load("models/rl_model.pt", map_location=device)) | |
| print("Loaded existing RL model.") | |
| # Re-starting from level 0 to ensure grounding. | |
| start_iteration = 0 | |
| elif os.path.exists("models/sft_model.pt"): | |
| model.load_state_dict(torch.load("models/sft_model.pt", map_location=device)) | |
| print("Loaded SFT model.") | |
| else: | |
| print("Warning: No model found. Starting from scratch or pretrained.") | |
| if os.path.exists("models/pretrained.pt"): | |
| model.load_state_dict(torch.load("models/pretrained.pt", map_location=device)) | |
| print("Loaded pretrained model.") | |
| ref_model = copy.deepcopy(model) | |
| ref_model.eval() | |
| for p in ref_model.parameters(): | |
| p.requires_grad = False | |
| embedding_params = list(model.embedding.parameters()) | |
| other_params = [p for n, p in model.named_parameters() if "embedding" not in n] | |
| param_groups = [ | |
| {"params": other_params}, | |
| {"params": embedding_params, "max_precond_dim": 1} | |
| ] | |
| optimizer = SOAP(param_groups, lr=1e-5) # Smaller LR for RL | |
| sampler = Sampler(model, tokenizer, device=device) | |
| for i in range(start_iteration, start_iteration + num_iterations): | |
| # Curriculum: Level 0 for first 300 iters, Level 1 for next 300, Level 2 after | |
| if i < 300: | |
| level = 0 | |
| elif i < 600: | |
| level = 1 | |
| else: | |
| level = 2 | |
| prompt_text, ref_answer, task_type = get_random_prompt(level=level) | |
| prompt = f"[BOS]{prompt_text}" | |
| # 1. Rollout with exploration noise | |
| with torch.no_grad(): | |
| # Alternate between noise and slightly higher temperature for variety | |
| use_noise = (i % 2 == 0) | |
| completions, log_probs, masks = sampler.grpo_rollout( | |
| prompt, | |
| num_rollouts=group_size, | |
| temperature=1.0 if use_noise else 1.1, | |
| noise_std=0.03 if use_noise else 0.0 | |
| ) | |
| # 2. Rewards | |
| rewards = [] | |
| for completion in completions: | |
| r = get_total_reward(prompt_text, completion, ref_answer, task_type) | |
| rewards.append(r) | |
| rewards = torch.tensor(rewards).to(device) | |
| unique_completions = len(set(completions)) | |
| print(f"Iter {i} (Level {level}), Prompt: {prompt}, Mean Reward: {rewards.mean().item():.4f}, Unique: {unique_completions}/{group_size}", flush=True) | |
| if i % 1 == 0: | |
| print(f"Sample Completion: {completions[0]}", flush=True) | |
| # 3. Advantages | |
| if len(rewards) > 1: | |
| adv = (rewards - rewards.mean()) / (rewards.std() + 1e-8) | |
| else: | |
| adv = rewards - rewards.mean() | |
| # 4. Update | |
| model.train() | |
| optimizer.zero_grad() | |
| # Re-encode completions to tokens | |
| all_tokens = [tokenizer.encode(c) for c in completions] | |
| loss = compute_grpo_loss(model, ref_model, all_tokens, log_probs, masks, adv) | |
| loss.backward() | |
| optimizer.step() | |
| if (i+1) % 10 == 0: | |
| torch.save(model.state_dict(), "models/rl_model.pt") | |
| print(f"Saved checkpoint at iter {i+1}") | |
| torch.save(model.state_dict(), "models/rl_model.pt") | |
| print("RL training complete. Model saved to models/rl_model.pt") | |
| if __name__ == "__main__": | |
| import argparse | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--iterations", type=int, default=500) | |
| parser.add_argument("--group_size", type=int, default=32) | |
| parser.add_argument("--load_model", type=str, default=None) | |
| args = parser.parse_args() | |
| train_grpo(num_iterations=args.iterations, group_size=args.group_size, load_model=args.load_model) | |
Xet Storage Details
- Size:
- 6.67 kB
- Xet hash:
- 23e030ea9457aa26e512a33c29dd992f41de0b549a1a763ac8f768f8017e1234
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.