| from typing import Optional, Union |
|
|
| import deepspeed |
| import torch |
| import torch.nn as nn |
| from flash_attn.utils.distributed import all_gather |
| from peft import LoraConfig, get_peft_model |
| from peft.tuners.lora import LoraLayer |
| from transformers import AutoConfig, AutoModel, BitsAndBytesConfig |
| from transformers.integrations.deepspeed import HfDeepSpeedConfig |
|
|
| from openrlhf.utils.logging_utils import init_logger |
|
|
| from .ring_attn_utils import convert_ring_attn_params |
| from .utils import reset_position_ids |
|
|
| logger = init_logger(__name__) |
|
|
|
|
| |
| |
| def get_llm_for_sequence_regression( |
| model_name_or_path: str, |
| model_type: str, |
| *, |
| bf16=True, |
| load_in_4bit=False, |
| lora_rank=0, |
| lora_alpha=16, |
| target_modules=None, |
| lora_dropout=0, |
| normalize_reward=False, |
| use_flash_attention_2=False, |
| ds_config: dict = None, |
| init_value_head: bool = False, |
| value_head_prefix="score", |
| device_map=None, |
| packing_samples=False, |
| **kwargs, |
| ) -> nn.Module: |
| """Retrieve a transformer model with a sequence regression head on top. |
| |
| This function loads a pretrained transformer model and attaches a linear layer for sequence regression. |
| |
| Args: |
| model_name_or_path (str): Path to the pretrained model. |
| model_type (str): Type of the model, either "reward" or "critic". |
| bf16 (bool, optional): Enable bfloat16 precision. Defaults to True. |
| load_in_4bit (bool, optional): Load the model in 4-bit precision. Defaults to False. |
| lora_rank (int, optional): Rank for LoRA adaptation. Defaults to 0. |
| lora_alpha (int, optional): Alpha parameter for LoRA. Defaults to 16. |
| target_modules (list, optional): List of target modules for LoRA. Defaults to None. |
| lora_dropout (float, optional): Dropout rate for LoRA layers. Defaults to 0. |
| normalize_reward (bool, optional): Normalize reward values. Defaults to False. |
| use_flash_attention_2 (bool, optional): Use Flash Attention 2.0. Defaults to False. |
| ds_config (dict, optional): Deepspeed configuration for model partitioning across multiple GPUs when ZeRO-3 is enabled. Defaults to None. |
| init_value_head (bool, optional): Initialize the value head. Defaults to False. |
| value_head_prefix (str, optional): Prefix for the value head. Defaults to "score". |
| device_map (dict, optional): Map of devices for model loading. Defaults to None. |
| packing_samples (bool, optional): Whether to pack samples during training. Defaults to False. |
| |
| Returns: |
| nn.Module: A pretrained transformer model with a sequence regression head. |
| """ |
| assert ( |
| model_type == "critic" or model_type == "reward" |
| ), f"invalid model_type: {model_type}, should be critic or reward." |
|
|
| config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True) |
| config.normalize_reward = normalize_reward |
| config._attn_implementation = "flash_attention_2" if use_flash_attention_2 else "eager" |
|
|
| |
| value_head_prefix = getattr(config, "value_head_prefix", value_head_prefix) |
| logger.info(f"set value_head_prefix to `{value_head_prefix}`") |
|
|
| base_class = AutoModel._model_mapping[type(config)] |
| base_pretrained_class = base_class.__base__ |
| if model_type == "reward": |
| cls_class = _get_reward_model(base_pretrained_class, base_class, value_head_prefix, packing_samples) |
| else: |
| cls_class = _get_critic_model(base_pretrained_class, base_class, value_head_prefix, packing_samples) |
|
|
| |
| |
| if ds_config is not None and ds_config["zero_optimization"]["stage"] == 3: |
| dschf = HfDeepSpeedConfig(ds_config) |
| else: |
| dschf = None |
|
|
| if load_in_4bit: |
| assert bf16, "we only support bnb_4bit_compute_dtype = bf16" |
| nf4_config = BitsAndBytesConfig( |
| load_in_4bit=True, |
| bnb_4bit_quant_type="nf4", |
| bnb_4bit_use_double_quant=True, |
| bnb_4bit_compute_dtype=torch.bfloat16, |
| ) |
| else: |
| nf4_config = None |
|
|
| model = cls_class.from_pretrained( |
| model_name_or_path, |
| config=config, |
| trust_remote_code=True, |
| torch_dtype=torch.bfloat16 if bf16 else "auto", |
| quantization_config=nf4_config, |
| device_map=device_map, |
| **kwargs, |
| ) |
|
|
| |
| if lora_rank > 0: |
| model.enable_input_require_grads() |
| lora_config = LoraConfig( |
| r=lora_rank, |
| lora_alpha=lora_alpha, |
| target_modules=target_modules, |
| lora_dropout=lora_dropout, |
| bias="none", |
| ) |
| model = get_peft_model(model, lora_config) |
|
|
| if load_in_4bit: |
| for name, module in model.named_modules(): |
| if isinstance(module, LoraLayer): |
| module = module.to(torch.bfloat16) |
| if "norm" in name: |
| module = module.to(torch.float32) |
| if value_head_prefix in name or "embed_tokens" in name: |
| if hasattr(module, "weight"): |
| module = module.to(torch.bfloat16) |
|
|
| |
| model_config = model.config.to_dict() |
| if "output_router_logits" in model_config: |
| print("[MoE] set output_router_logits as True") |
| model.config.output_router_logits = True |
|
|
| |
| model.config.use_cache = False |
|
|
| |
| |
| |
| if init_value_head: |
| value_head = getattr(model, value_head_prefix) |
| if dschf is not None: |
| logger.info("initialize value_head for ZeRO-3 reward model training.") |
| with deepspeed.zero.GatheredParameters([value_head.weight], modifier_rank=0): |
| if torch.distributed.get_rank() == 0: |
| value_head.weight.data.normal_(mean=0.0, std=1 / (config.hidden_size + 1)) |
| else: |
| value_head.weight.data.normal_(mean=0.0, std=1 / (config.hidden_size + 1)) |
|
|
| return model |
|
|
|
|
| def _get_reward_model(base_pretrained_model, base_llm_model, value_head_prefix="score", packing_samples=False): |
| class RewardModel(base_pretrained_model): |
| supports_gradient_checkpointing = True |
|
|
| def __init__(self, config: AutoConfig): |
| super().__init__(config) |
| setattr(self, self.base_model_prefix, base_llm_model(config)) |
|
|
| self.value_head_prefix = value_head_prefix |
| setattr(self, value_head_prefix, nn.Linear(config.hidden_size, 1, bias=False)) |
|
|
| self.packing_samples = packing_samples |
|
|
| |
| self.normalize_reward = config.normalize_reward |
| self.register_buffer("mean", torch.zeros(1), persistent=False) |
| self.register_buffer("std", torch.ones(1), persistent=False) |
|
|
| |
| if hasattr(config, "mean"): |
| self.mean[0] = config.mean |
| self.std[0] = config.std |
|
|
| def forward( |
| self, |
| input_ids: torch.LongTensor = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| return_output=False, |
| ring_attn_group=None, |
| packed_seq_lens=None, |
| ) -> torch.Tensor: |
| if not self.packing_samples: |
| |
| position_ids = attention_mask.long().cumsum(-1) - 1 |
| position_ids.masked_fill_(attention_mask == 0, 1) |
| else: |
| |
| if ring_attn_group is not None: |
| input_ids, attention_mask, position_ids = convert_ring_attn_params( |
| input_ids, attention_mask, packed_seq_lens, ring_attn_group |
| ) |
| else: |
| position_ids = reset_position_ids(attention_mask) |
| |
| attention_mask = None |
|
|
| outputs = getattr(self, self.base_model_prefix)( |
| input_ids, attention_mask=attention_mask, position_ids=position_ids |
| ) |
| last_hidden_states = outputs["last_hidden_state"] |
| values = getattr(self, self.value_head_prefix)(last_hidden_states).squeeze(-1) |
|
|
| if self.packing_samples: |
| if ring_attn_group is not None: |
| reward = all_gather(values, ring_attn_group).reshape(1, -1) |
| else: |
| reward = values |
| |
| packed_seq_lens = torch.tensor(packed_seq_lens, device=values.device) |
| eos_indices = packed_seq_lens.cumsum(dim=0) - 1 |
| reward = reward.squeeze(0).gather(dim=0, index=eos_indices) |
| else: |
| eos_indices = attention_mask.size(1) - 1 - attention_mask.long().fliplr().argmax(dim=1, keepdim=True) |
| reward = values.gather(dim=1, index=eos_indices).squeeze(1) |
|
|
| if not self.training and self.normalize_reward: |
| reward = (reward - self.mean) / self.std |
|
|
| return (reward, outputs) if return_output else reward |
|
|
| return RewardModel |
|
|
|
|
| def _get_critic_model(base_pretrained_model, base_llm_model, value_head_prefix="score", packing_samples=False): |
| class CriticModel(base_pretrained_model): |
| supports_gradient_checkpointing = True |
|
|
| def __init__(self, config: AutoConfig): |
| super().__init__(config) |
| setattr(self, self.base_model_prefix, base_llm_model(config)) |
|
|
| self.value_head_prefix = value_head_prefix |
| setattr(self, value_head_prefix, nn.Linear(config.hidden_size, 1, bias=False)) |
|
|
| self.packing_samples = packing_samples |
|
|
| |
| self.normalize_reward = config.normalize_reward |
| self.register_buffer("mean", torch.zeros(1), persistent=False) |
| self.register_buffer("std", torch.ones(1), persistent=False) |
|
|
| |
| if hasattr(config, "mean"): |
| self.mean[0] = config.mean |
| self.std[0] = config.std |
|
|
| def forward( |
| self, |
| input_ids: torch.LongTensor = None, |
| num_actions: Optional[Union[int, list[int]]] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| return_output=False, |
| packed_seq_lens=None, |
| ) -> torch.Tensor: |
| if not self.packing_samples: |
| |
| position_ids = attention_mask.long().cumsum(-1) - 1 |
| position_ids.masked_fill_(attention_mask == 0, 1) |
| else: |
| |
| position_ids = reset_position_ids(attention_mask) |
| |
| attention_mask = None |
|
|
| outputs = getattr(self, self.base_model_prefix)( |
| input_ids, attention_mask=attention_mask, position_ids=position_ids |
| ) |
| last_hidden_states = outputs["last_hidden_state"] |
| values = getattr(self, self.value_head_prefix)(last_hidden_states).squeeze(-1)[:, :-1] |
|
|
| |
| if self.normalize_reward: |
| values = (values - self.mean) / self.std |
|
|
| if num_actions is None: |
| assert return_output |
| return outputs |
|
|
| if not self.packing_samples: |
| action_values = values[:, -num_actions:] |
| else: |
| assert isinstance(num_actions, list) and len(num_actions) == len(packed_seq_lens) |
| action_values = [] |
| offset = 0 |
| for num_action, seq_len in zip(num_actions, packed_seq_lens): |
| start, end = max(0, offset + seq_len - num_action - 1), offset + seq_len - 1 |
| action_values.append(values[:, start:end]) |
| offset += seq_len |
| action_values = torch.cat(action_values, dim=1) |
|
|
| if return_output: |
| return (action_values, outputs) |
| else: |
| return action_values |
|
|
| return CriticModel |
|
|