| from typing import Literal |
|
|
| import torch |
| from torch import nn |
| from transformers import ( |
| AutoConfig, |
| AutoModel, |
| AutoModelForCausalLM, |
| AutoModelForMaskedLM, |
| DynamicCache, |
| ) |
| from transformers.modeling_outputs import ( |
| BaseModelOutputWithPast, |
| CausalLMOutputWithPast, |
| ) |
|
|
| from .backbone_custom_modeling_qwen3 import CustomQwen3ForCausalLM |
|
|
| try: |
| from torch.nn.attention.flex_attention import BlockMask |
| except ImportError: |
| BlockMask = None |
|
|
| AUTO_MODEL_CLS = { |
| "AutoModel": AutoModel, |
| "AutoModelForCausalLM": AutoModelForCausalLM, |
| "AutoModelForMaskedLM": AutoModelForMaskedLM, |
| } |
|
|
|
|
| class AutoModelFromPreTrained(nn.Module): |
| """Simple wrapper class that enables using AutoModel from pre-trained.""" |
|
|
| def __init__( |
| self, |
| automodel_cls: Literal[ |
| "AutoModel", |
| "AutoModelForCausalLM", |
| "AutoModelForMaskedLM", |
| ], |
| pretrained_model_name_or_path: str, |
| trust_remote_code: bool = True, |
| num_layers: int = -1, |
| keep_top_layers: bool = False, |
| reinit_model: bool = False, |
| use_causal_mask: bool = False, |
| **automodel_init_kwargs, |
| ): |
| super().__init__() |
| self.use_causal_mask = use_causal_mask |
| if reinit_model: |
| auto_config = AutoConfig.from_pretrained( |
| pretrained_model_name_or_path, |
| num_hidden_layers=num_layers, |
| trust_remote_code=trust_remote_code, |
| **automodel_init_kwargs, |
| ) |
| self.model = CustomQwen3ForCausalLM(auto_config) |
| |
| else: |
| self.model = AUTO_MODEL_CLS[automodel_cls].from_pretrained( |
| pretrained_model_name_or_path, |
| trust_remote_code=trust_remote_code, |
| **automodel_init_kwargs, |
| ) |
| num_layers = ( |
| len(self.model.model.layers) if num_layers == -1 else num_layers |
| ) |
| if keep_top_layers: |
| self.model.model.layers = self.model.model.layers[-num_layers:] |
| else: |
| self.model.model.layers = self.model.model.layers[:num_layers] |
|
|
| def forward( |
| self, |
| input_ids: torch.LongTensor, |
| attention_mask: torch.FloatTensor | BlockMask | None = None, |
| position_ids: torch.LongTensor | None = None, |
| cache_position: torch.LongTensor | None = None, |
| past_key_values: DynamicCache | None = None, |
| fix_cache_length: bool = False, |
| return_updated_cache=False, |
| **kwargs, |
| ) -> CausalLMOutputWithPast | BaseModelOutputWithPast: |
| prev_cache_len = None |
| if past_key_values is not None and fix_cache_length: |
| prev_cache_len = [ |
| past_key_values[i][0].shape[-2] |
| for i in range(len(past_key_values)) |
| ] |
| if self.use_causal_mask: |
| attention_mask = None |
| model_output = self.model( |
| input_ids, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| cache_position=cache_position, |
| past_key_values=past_key_values, |
| **kwargs, |
| ) |
| if return_updated_cache: |
| return BaseModelOutputWithPast(past_key_values=model_output.past_key_values) |
| if ( |
| prev_cache_len is not None |
| and model_output.get("past_key_values", None) is not None |
| ): |
| |
| |
| for i, cache_len in enumerate(prev_cache_len): |
| model_output.past_key_values.key_cache[i] = ( |
| model_output.past_key_values.key_cache[i][..., :cache_len, :] |
| ) |
| model_output.past_key_values.value_cache[i] = ( |
| model_output.past_key_values.value_cache[i][..., :cache_len, :] |
| ) |
| return model_output |
|
|