| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from typing import Optional |
|
|
| import torch |
|
|
|
|
| def _fused_linear_for_ppo_fwd( |
| hidden_states: torch.FloatTensor, |
| vocab_weights: torch.FloatTensor, |
| input_ids: torch.LongTensor, |
| temperature: float = 1.0, |
| ) -> tuple[torch.FloatTensor, torch.FloatTensor]: |
| logits = (hidden_states @ vocab_weights.t()) / temperature |
| orig_dtype = logits.dtype |
| logits = logits.to(torch.float32) |
|
|
| |
| probs = logits.softmax(dim=-1) |
| log_probs = logits.log_softmax(dim=-1) |
|
|
| token_log_probs = log_probs.gather(-1, input_ids.unsqueeze(-1)).squeeze(-1) |
| entropy = torch.logsumexp(logits, dim=-1) - torch.sum(probs * logits, dim=-1) |
|
|
| return token_log_probs.to(orig_dtype), entropy.to(orig_dtype) |
|
|
|
|
| def _fused_linear_for_ppo_bwd( |
| dlog_probs: Optional[torch.FloatTensor], |
| dentropy: Optional[torch.FloatTensor], |
| hidden_states: torch.FloatTensor, |
| vocab_weights: torch.FloatTensor, |
| input_ids: torch.LongTensor, |
| temperature: float = 1.0, |
| ) -> tuple[torch.FloatTensor, torch.FloatTensor]: |
| logits = (hidden_states @ vocab_weights.t()) / temperature |
| orig_dtype = logits.dtype |
| logits = logits.to(torch.float32) |
|
|
| probs = logits.softmax(dim=-1) |
|
|
| dlogits = 0 |
|
|
| |
| if dlog_probs is not None: |
| one_hot_input = torch.zeros_like(logits).scatter_(-1, input_ids.unsqueeze(-1), 1) |
| dlogits += dlog_probs.to(torch.float32).unsqueeze(-1) * (one_hot_input - probs) |
|
|
| |
| if dentropy is not None: |
| log_probs = logits.log_softmax(dim=-1) |
| entropy = torch.logsumexp(logits, dim=-1) - torch.sum(probs * logits, dim=-1) |
| dlogits += probs * (log_probs + entropy.unsqueeze(-1)) * (-dentropy.unsqueeze(-1)) |
|
|
| dlogits = dlogits.to(orig_dtype) / temperature |
|
|
| dhidden_states = dlogits @ vocab_weights |
| dvocab_weights = dlogits.t() @ hidden_states |
|
|
| return dhidden_states, dvocab_weights |
|
|
|
|
| class FusedLinearForPPOFunction(torch.autograd.Function): |
| @staticmethod |
| def forward( |
| ctx, |
| hidden_states: torch.FloatTensor, |
| vocab_weights: torch.FloatTensor, |
| input_ids: torch.LongTensor, |
| temperature: float = 1.0, |
| chunk_size: int = 512, |
| ) -> tuple[torch.FloatTensor, torch.FloatTensor]: |
| ctx.set_materialize_grads(False) |
|
|
| |
| orig_ndim = hidden_states.ndim |
| assert orig_ndim in (2, 3), f"Invalid hidden_states shape, received {hidden_states.shape}" |
|
|
| orig_batch_size = -1 |
| if orig_ndim == 3: |
| assert input_ids.ndim == 2, f"input_ids shape doesn't match, {hidden_states.shape} {input_ids.shape}" |
| orig_batch_size = hidden_states.shape[0] |
| hidden_states = hidden_states.flatten(0, 1) |
| input_ids = input_ids.flatten(0, 1) |
|
|
| T = hidden_states.shape[0] |
|
|
| |
| output_requires_grad = hidden_states.requires_grad or vocab_weights.requires_grad |
| log_probs = hidden_states.new_zeros(T, requires_grad=output_requires_grad) |
| entropy = hidden_states.new_zeros(T, requires_grad=output_requires_grad) |
|
|
| |
| for chunk_start in range(0, T, chunk_size): |
| chunk_end = min(chunk_start + chunk_size, T) |
|
|
| chunk_log_probs, chunk_entropy = _fused_linear_for_ppo_fwd( |
| hidden_states=hidden_states[chunk_start:chunk_end], |
| vocab_weights=vocab_weights, |
| input_ids=input_ids[chunk_start:chunk_end], |
| temperature=temperature, |
| ) |
| log_probs[chunk_start:chunk_end] = chunk_log_probs |
| entropy[chunk_start:chunk_end] = chunk_entropy |
|
|
| |
| if orig_ndim == 3: |
| log_probs = log_probs.view(orig_batch_size, -1) |
| entropy = entropy.view(orig_batch_size, -1) |
|
|
| ctx.save_for_backward(hidden_states, vocab_weights, input_ids) |
| ctx.orig_batch_size = orig_batch_size |
| ctx.orig_ndim = orig_ndim |
| ctx.temperature = temperature |
| ctx.chunk_size = chunk_size |
|
|
| return log_probs, entropy |
|
|
| @staticmethod |
| def backward(ctx, dlog_probs: Optional[torch.FloatTensor], dentropy: Optional[torch.FloatTensor]): |
| assert dlog_probs is not None or dentropy is not None |
|
|
| hidden_states, vocab_weights, input_ids = ctx.saved_tensors |
| orig_batch_size = ctx.orig_batch_size |
| orig_ndim = ctx.orig_ndim |
| temperature = ctx.temperature |
| chunk_size = ctx.chunk_size |
|
|
| |
| if orig_ndim == 3: |
| if dlog_probs is not None: |
| dlog_probs = dlog_probs.flatten() |
| if dentropy is not None: |
| dentropy = dentropy.flatten() |
|
|
| T = hidden_states.shape[0] |
|
|
| |
| dhidden_states = None |
| if hidden_states.requires_grad: |
| dhidden_states = torch.zeros_like(hidden_states) |
| dvocab_weights = None |
| if vocab_weights.requires_grad: |
| dvocab_weights = torch.zeros_like(vocab_weights) |
|
|
| |
| for chunk_start in range(0, T, chunk_size): |
| chunk_end = min(chunk_start + chunk_size, T) |
| chunk_dlog_probs = None |
| if dlog_probs is not None: |
| chunk_dlog_probs = dlog_probs[chunk_start:chunk_end] |
| chunk_dentropy = None |
| if dentropy is not None: |
| chunk_dentropy = dentropy[chunk_start:chunk_end] |
|
|
| h, v = _fused_linear_for_ppo_bwd( |
| dlog_probs=chunk_dlog_probs, |
| dentropy=chunk_dentropy, |
| hidden_states=hidden_states[chunk_start:chunk_end], |
| vocab_weights=vocab_weights, |
| input_ids=input_ids[chunk_start:chunk_end], |
| temperature=temperature, |
| ) |
|
|
| if hidden_states.requires_grad: |
| dhidden_states[chunk_start:chunk_end] += h |
| if vocab_weights.requires_grad: |
| dvocab_weights += v |
|
|
| |
| if orig_ndim == 3 and hidden_states.requires_grad: |
| hidden_size = hidden_states.shape[-1] |
| dhidden_states = dhidden_states.view(orig_batch_size, -1, hidden_size) |
|
|
| return ( |
| dhidden_states, |
| dvocab_weights, |
| None, |
| None, |
| None, |
| ) |
|
|
|
|
| class FusedLinearForPPO(torch.nn.Module): |
| def __init__(self, chunk_size: int = 512): |
| super().__init__() |
|
|
| self.chunk_size = chunk_size |
|
|
| def forward( |
| self, |
| hidden_states: torch.FloatTensor, |
| vocab_weights: torch.FloatTensor, |
| input_ids: torch.LongTensor, |
| temperature: float = 1.0, |
| ) -> tuple[torch.FloatTensor, torch.FloatTensor]: |
| input_ids = input_ids.to(torch.int64) |
| return FusedLinearForPPOFunction.apply( |
| hidden_states, |
| vocab_weights, |
| input_ids, |
| temperature, |
| self.chunk_size, |
| ) |
|
|