from typing import Optional import torch import torch.distributed as dist import torch.nn as nn from peft import LoraConfig, TaskType, get_peft_model from peft.tuners.lora import LoraLayer from transformers import AutoModelForCausalLM, BitsAndBytesConfig from transformers.integrations.deepspeed import HfDeepSpeedConfig from .ring_attn_utils import gather_and_pad_tensor, unpad_and_slice_tensor from .utils import compute_entropy, log_probs_from_logits class Actor(nn.Module): """ Base class for Actor models in reinforcement learning. This class serves as a foundation for implementing various actor models, which are responsible for selecting actions based on the policy learned from the environment. Args: pretrain_or_model (nn.Module): A pretrained model or a new model instance to be used as the actor. use_flash_attention_2 (bool, optional): Whether to utilize Flash Attention 2.0 for improved performance. Defaults to False. bf16 (bool, optional): Enable bfloat16 precision for model computations. Defaults to True. load_in_4bit (bool, optional): Load the model in 4-bit precision. Defaults to False. lora_rank (int, optional): Rank for LoRA adaptation. Defaults to 0. lora_alpha (int, optional): Alpha parameter for LoRA. Defaults to 16. lora_dropout (float, optional): Dropout rate for LoRA layers. Defaults to 0. target_modules (list, optional): List of target modules for applying LoRA. Defaults to None. ds_config (dict, optional): Configuration for DeepSpeed, enabling model partitioning across multiple GPUs. Defaults to None. device_map (dict, optional): Device mapping for loading the model onto specific devices. Defaults to None. packing_samples (bool, optional): Whether to pack samples during training. Defaults to False. temperature (float, optional): Temperature for action selection. Defaults to 1.0. use_liger_kernel (bool, optional): Whether to use Liger Kernel for the model. Defaults to False. """ def __init__( self, pretrain_or_model, use_flash_attention_2=False, bf16=True, load_in_4bit=False, lora_rank=0, lora_alpha=16, lora_dropout=0, target_modules=None, ds_config=None, device_map=None, packing_samples=False, temperature=1.0, use_liger_kernel=False, **kwargs, ) -> None: super().__init__() self.temperature = temperature if isinstance(pretrain_or_model, str): attn_implementation = "flash_attention_2" if use_flash_attention_2 else "eager" # Note: dschf is defined in function scope to avoid global effects # https://huggingface.co/docs/transformers/deepspeed#non-trainer-deepspeed-integration if ds_config is not None and ds_config["zero_optimization"]["stage"] == 3: dschf = HfDeepSpeedConfig(ds_config) else: dschf = None if load_in_4bit: assert bf16, "we only support bnb_4bit_compute_dtype = bf16" nf4_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16, ) else: nf4_config = None if use_liger_kernel: from liger_kernel.transformers import AutoLigerKernelForCausalLM model_class = AutoLigerKernelForCausalLM else: model_class = AutoModelForCausalLM self.model = model_class.from_pretrained( pretrain_or_model, trust_remote_code=True, attn_implementation=attn_implementation, quantization_config=nf4_config, torch_dtype=torch.bfloat16 if bf16 else "auto", device_map=device_map, ) # LoRA if lora_rank > 0: # https://github.com/huggingface/peft/issues/137 self.model.enable_input_require_grads() lora_config = LoraConfig( task_type=TaskType.CAUSAL_LM, r=lora_rank, lora_alpha=lora_alpha, target_modules=target_modules, lora_dropout=lora_dropout, bias="none", ) self.model = get_peft_model(self.model, lora_config) if load_in_4bit: for name, module in self.model.named_modules(): if isinstance(module, LoraLayer): module = module.to(torch.bfloat16) if "norm" in name: module = module.to(torch.float32) if "lm_head" in name or "embed_tokens" in name: if hasattr(module, "weight"): module = module.to(torch.bfloat16) # MoE - balancing loss model_config = self.model.config.to_dict() if "output_router_logits" in model_config: print("[MoE] set output_router_logits as True") self.model.config.output_router_logits = True # https://github.com/huggingface/transformers/issues/26877 # Use `model.generate(use_cache=True)` instead.` self.model.config.use_cache = False # packing samples using Flash Attention 2 self.packing_samples = packing_samples else: self.model = pretrain_or_model def forward( self, sequences: torch.LongTensor, action_mask: Optional[torch.Tensor] = None, attention_mask: Optional[torch.Tensor] = None, return_output=False, allgather_logits=False, return_logprobs=False, ring_attn_group: Optional[dist.ProcessGroup] = None, packed_seq_lens: Optional[list[int]] = None, return_entropy=False, ) -> torch.Tensor: """Returns action log probs""" batch, seqlen = sequences.size() foward_attention_mask = attention_mask if self.packing_samples: sequences, position_ids, rolled_sequences, ring_attn_pad_len, indices = unpad_and_slice_tensor( sequences, attention_mask, ring_attn_group ) foward_attention_mask = None else: # https://github.com/OpenRLHF/OpenRLHF/issues/217 rolled_sequences = torch.roll(sequences, shifts=-1, dims=1) position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) output = self.model(sequences, attention_mask=foward_attention_mask, position_ids=position_ids) # https://github.com/OpenRLHF/OpenRLHF/pull/634 output["logits"] = output["logits"].to(torch.float32) if return_entropy: assert return_output entropy = compute_entropy(output["logits"]) if self.packing_samples: entropy = gather_and_pad_tensor(entropy, ring_attn_group, ring_attn_pad_len, indices, batch, seqlen) setattr(output, "entropy", entropy[:, :-1]) return_action_log_probs = action_mask is not None if not return_action_log_probs and not return_logprobs: assert return_output if allgather_logits and self.packing_samples: output["logits"] = gather_and_pad_tensor( output["logits"], ring_attn_group, ring_attn_pad_len, indices, batch, seqlen ) return output log_probs = log_probs_from_logits(output["logits"], rolled_sequences, temperature=self.temperature) if self.packing_samples: log_probs = gather_and_pad_tensor(log_probs, ring_attn_group, ring_attn_pad_len, indices, batch, seqlen) log_probs = log_probs[:, :-1] if not return_action_log_probs and return_logprobs: return (log_probs, output) if return_output else log_probs action_log_probs = log_probs[:, -action_mask.shape[1] :] * action_mask.float() return (action_log_probs, output) if return_output else action_log_probs def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs={"use_reentrant": False}): self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs) def gradient_checkpointing_disable(self): self.model.gradient_checkpointing_disable() def print_trainable_parameters(self): self.model.print_trainable_parameters()