| | from typing import Optional |
| |
|
| | import torch |
| | import torch.distributed as dist |
| | import torch.nn as nn |
| | from peft import LoraConfig, TaskType, get_peft_model |
| | from peft.tuners.lora import LoraLayer |
| | from transformers import AutoModelForCausalLM, BitsAndBytesConfig |
| | from transformers.integrations.deepspeed import HfDeepSpeedConfig |
| |
|
| | from .ring_attn_utils import gather_and_pad_tensor, unpad_and_slice_tensor |
| | from .utils import compute_entropy, log_probs_from_logits |
| |
|
| |
|
| | class Actor(nn.Module): |
| | """ |
| | Base class for Actor models in reinforcement learning. |
| | |
| | This class serves as a foundation for implementing various actor models, which are responsible for selecting actions based on the policy learned from the environment. |
| | |
| | Args: |
| | pretrain_or_model (nn.Module): A pretrained model or a new model instance to be used as the actor. |
| | use_flash_attention_2 (bool, optional): Whether to utilize Flash Attention 2.0 for improved performance. Defaults to False. |
| | bf16 (bool, optional): Enable bfloat16 precision for model computations. Defaults to True. |
| | load_in_4bit (bool, optional): Load the model in 4-bit precision. Defaults to False. |
| | lora_rank (int, optional): Rank for LoRA adaptation. Defaults to 0. |
| | lora_alpha (int, optional): Alpha parameter for LoRA. Defaults to 16. |
| | lora_dropout (float, optional): Dropout rate for LoRA layers. Defaults to 0. |
| | target_modules (list, optional): List of target modules for applying LoRA. Defaults to None. |
| | ds_config (dict, optional): Configuration for DeepSpeed, enabling model partitioning across multiple GPUs. Defaults to None. |
| | device_map (dict, optional): Device mapping for loading the model onto specific devices. Defaults to None. |
| | packing_samples (bool, optional): Whether to pack samples during training. Defaults to False. |
| | temperature (float, optional): Temperature for action selection. Defaults to 1.0. |
| | use_liger_kernel (bool, optional): Whether to use Liger Kernel for the model. Defaults to False. |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | pretrain_or_model, |
| | use_flash_attention_2=False, |
| | bf16=True, |
| | load_in_4bit=False, |
| | lora_rank=0, |
| | lora_alpha=16, |
| | lora_dropout=0, |
| | target_modules=None, |
| | ds_config=None, |
| | device_map=None, |
| | packing_samples=False, |
| | temperature=1.0, |
| | use_liger_kernel=False, |
| | **kwargs, |
| | ) -> None: |
| | super().__init__() |
| | self.temperature = temperature |
| |
|
| | if isinstance(pretrain_or_model, str): |
| | attn_implementation = "flash_attention_2" if use_flash_attention_2 else "eager" |
| |
|
| | |
| | |
| | if ds_config is not None and ds_config["zero_optimization"]["stage"] == 3: |
| | dschf = HfDeepSpeedConfig(ds_config) |
| | else: |
| | dschf = None |
| |
|
| | if load_in_4bit: |
| | assert bf16, "we only support bnb_4bit_compute_dtype = bf16" |
| | nf4_config = BitsAndBytesConfig( |
| | load_in_4bit=True, |
| | bnb_4bit_quant_type="nf4", |
| | bnb_4bit_use_double_quant=True, |
| | bnb_4bit_compute_dtype=torch.bfloat16, |
| | ) |
| | else: |
| | nf4_config = None |
| |
|
| | if use_liger_kernel: |
| | from liger_kernel.transformers import AutoLigerKernelForCausalLM |
| |
|
| | model_class = AutoLigerKernelForCausalLM |
| | else: |
| | model_class = AutoModelForCausalLM |
| |
|
| | self.model = model_class.from_pretrained( |
| | pretrain_or_model, |
| | trust_remote_code=True, |
| | attn_implementation=attn_implementation, |
| | quantization_config=nf4_config, |
| | torch_dtype=torch.bfloat16 if bf16 else "auto", |
| | device_map=device_map, |
| | ) |
| |
|
| | |
| | if lora_rank > 0: |
| | |
| | self.model.enable_input_require_grads() |
| | lora_config = LoraConfig( |
| | task_type=TaskType.CAUSAL_LM, |
| | r=lora_rank, |
| | lora_alpha=lora_alpha, |
| | target_modules=target_modules, |
| | lora_dropout=lora_dropout, |
| | bias="none", |
| | ) |
| | self.model = get_peft_model(self.model, lora_config) |
| |
|
| | if load_in_4bit: |
| | for name, module in self.model.named_modules(): |
| | if isinstance(module, LoraLayer): |
| | module = module.to(torch.bfloat16) |
| | if "norm" in name: |
| | module = module.to(torch.float32) |
| | if "lm_head" in name or "embed_tokens" in name: |
| | if hasattr(module, "weight"): |
| | module = module.to(torch.bfloat16) |
| |
|
| | |
| | model_config = self.model.config.to_dict() |
| | if "output_router_logits" in model_config: |
| | print("[MoE] set output_router_logits as True") |
| | self.model.config.output_router_logits = True |
| |
|
| | |
| | |
| | self.model.config.use_cache = False |
| |
|
| | |
| | self.packing_samples = packing_samples |
| | else: |
| | self.model = pretrain_or_model |
| |
|
| | def forward( |
| | self, |
| | sequences: torch.LongTensor, |
| | action_mask: Optional[torch.Tensor] = None, |
| | attention_mask: Optional[torch.Tensor] = None, |
| | return_output=False, |
| | allgather_logits=False, |
| | return_logprobs=False, |
| | ring_attn_group: Optional[dist.ProcessGroup] = None, |
| | packed_seq_lens: Optional[list[int]] = None, |
| | return_entropy=False, |
| | ) -> torch.Tensor: |
| | """Returns action log probs""" |
| | batch, seqlen = sequences.size() |
| | foward_attention_mask = attention_mask |
| | if self.packing_samples: |
| | sequences, position_ids, rolled_sequences, ring_attn_pad_len, indices = unpad_and_slice_tensor( |
| | sequences, attention_mask, ring_attn_group |
| | ) |
| | foward_attention_mask = None |
| | else: |
| | |
| | rolled_sequences = torch.roll(sequences, shifts=-1, dims=1) |
| | position_ids = attention_mask.long().cumsum(-1) - 1 |
| | position_ids.masked_fill_(attention_mask == 0, 1) |
| |
|
| | output = self.model(sequences, attention_mask=foward_attention_mask, position_ids=position_ids) |
| | |
| | output["logits"] = output["logits"].to(torch.float32) |
| |
|
| | if return_entropy: |
| | assert return_output |
| | entropy = compute_entropy(output["logits"]) |
| | if self.packing_samples: |
| | entropy = gather_and_pad_tensor(entropy, ring_attn_group, ring_attn_pad_len, indices, batch, seqlen) |
| | setattr(output, "entropy", entropy[:, :-1]) |
| |
|
| | return_action_log_probs = action_mask is not None |
| | if not return_action_log_probs and not return_logprobs: |
| | assert return_output |
| | if allgather_logits and self.packing_samples: |
| | output["logits"] = gather_and_pad_tensor( |
| | output["logits"], ring_attn_group, ring_attn_pad_len, indices, batch, seqlen |
| | ) |
| | return output |
| |
|
| | log_probs = log_probs_from_logits(output["logits"], rolled_sequences, temperature=self.temperature) |
| |
|
| | if self.packing_samples: |
| | log_probs = gather_and_pad_tensor(log_probs, ring_attn_group, ring_attn_pad_len, indices, batch, seqlen) |
| |
|
| | log_probs = log_probs[:, :-1] |
| | if not return_action_log_probs and return_logprobs: |
| | return (log_probs, output) if return_output else log_probs |
| |
|
| | action_log_probs = log_probs[:, -action_mask.shape[1] :] * action_mask.float() |
| |
|
| | return (action_log_probs, output) if return_output else action_log_probs |
| |
|
| | def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs={"use_reentrant": False}): |
| | self.model.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs) |
| |
|
| | def gradient_checkpointing_disable(self): |
| | self.model.gradient_checkpointing_disable() |
| |
|
| | def print_trainable_parameters(self): |
| | self.model.print_trainable_parameters() |
| |
|