| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from dataclasses import dataclass |
| from typing import Optional, Union |
|
|
| import torch |
| from transformers.cache_utils import Cache |
| from transformers.modeling_outputs import CausalLMOutputWithPast |
|
|
|
|
| @dataclass |
| class CausalLMOutputForPPO(CausalLMOutputWithPast): |
| log_probs: Optional[torch.FloatTensor] = None |
| entropy: Optional[torch.FloatTensor] = None |
|
|
|
|
| def forward_base_model( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[Cache] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| ) -> CausalLMOutputWithPast: |
| r""" |
| Copy paste LLaMa's forward |
| https://github.com/linkedin/Liger-Kernel/blob/main/src/liger_kernel/transformers/model/llama.py |
| |
| This function should be generic enough for all pure text models. |
| ```""" |
|
|
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
| output_hidden_states = ( |
| output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| ) |
|
|
| |
| outputs = self.model( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| cache_position=cache_position, |
| ) |
|
|
| return outputs |
|
|
|
|
| def forward_with_torch_backend( |
| self, |
| input_ids: torch.LongTensor = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[Union["Cache", list[torch.FloatTensor]]] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| logits_to_keep: int | torch.Tensor = 0, |
| temperature: float = 1.0, |
| **loss_kwargs, |
| ) -> tuple | CausalLMOutputForPPO: |
| from verl.utils.experimental.torch_functional import FusedLinearForPPO |
|
|
| outputs = forward_base_model( |
| self, |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| cache_position=cache_position, |
| ) |
|
|
| hidden_states = outputs[0] |
|
|
| if not return_dict: |
| raise NotImplementedError("forward_with_torch_backend has to return_dict") |
|
|
| |
| if labels is not None: |
| rolled_labels = torch.roll(labels, shifts=-1, dims=-1) |
| elif input_ids is not None: |
| rolled_labels = torch.roll(input_ids, shifts=-1, dims=-1) |
| else: |
| raise RuntimeError("To use forward_with_torch_backend, either labels or input_ids must be provided.") |
|
|
| fused_linear_for_ppo = FusedLinearForPPO() |
| log_probs, entropy = fused_linear_for_ppo.forward( |
| hidden_states=hidden_states, |
| vocab_weights=self.lm_head.weight, |
| input_ids=rolled_labels, |
| temperature=temperature, |
| ) |
|
|
| return CausalLMOutputForPPO( |
| log_probs=log_probs, |
| entropy=entropy, |
| past_key_values=outputs.past_key_values, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| ) |
|
|
|
|
| def forward_with_triton_backend( |
| self, |
| input_ids: torch.LongTensor = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| past_key_values: Optional[Union["Cache", list[torch.FloatTensor]]] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| labels: Optional[torch.LongTensor] = None, |
| use_cache: Optional[bool] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| logits_to_keep: int | torch.Tensor = 0, |
| temperature: float = 1.0, |
| **loss_kwargs, |
| ) -> tuple | CausalLMOutputForPPO: |
| from verl.utils.kernel.linear_cross_entropy import linear_cross_entropy |
|
|
| outputs = forward_base_model( |
| self, |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| inputs_embeds=inputs_embeds, |
| use_cache=use_cache, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| cache_position=cache_position, |
| ) |
|
|
| hidden_states = outputs[0] |
|
|
| if not return_dict: |
| raise NotImplementedError("forward_with_triton_backend has to return_dict") |
|
|
| |
| if labels is not None: |
| rolled_labels = torch.roll(labels, shifts=-1, dims=-1) |
| elif input_ids is not None: |
| rolled_labels = torch.roll(input_ids, shifts=-1, dims=-1) |
| else: |
| raise RuntimeError("To use forward_with_triton_backend, either labels or input_ids must be provided.") |
|
|
| log_probs, entropy = linear_cross_entropy( |
| hidden_states, |
| self.lm_head.weight, |
| rolled_labels, |
| temperature, |
| "none", |
| ) |
|
|
| return CausalLMOutputForPPO( |
| log_probs=log_probs, |
| entropy=entropy, |
| past_key_values=outputs.past_key_values, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| ) |
|
|