| # Adapted from qwen2.py | |
| from typing import Iterable, Optional, Tuple | |
| import torch | |
| from torch import nn | |
| from sglang.srt.layers.logits_processor import LogitsProcessor | |
| from sglang.srt.layers.pooler import Pooler, PoolingType | |
| from sglang.srt.layers.quantization.base_config import QuantizationConfig | |
| from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead | |
| from sglang.srt.model_executor.forward_batch_info import ForwardBatch | |
| from sglang.srt.model_loader.weight_utils import default_weight_loader | |
| from sglang.srt.models.qwen2 import Qwen2DecoderLayer, Qwen2Model | |
| from sglang.srt.utils import add_prefix | |
| MiMoConfig = None | |
| class MiMoModel(Qwen2Model): | |
| def __init__( | |
| self, | |
| config: MiMoConfig, | |
| quant_config: Optional[QuantizationConfig] = None, | |
| prefix: str = "", | |
| ) -> None: | |
| super().__init__( | |
| config=config, | |
| quant_config=quant_config, | |
| prefix=prefix, | |
| decoder_layer_type=Qwen2DecoderLayer, | |
| ) | |
| class MiMoForCausalLM(nn.Module): | |
| # BitandBytes specific attributes | |
| default_bitsandbytes_target_modules = [ | |
| ".gate_proj.", | |
| ".down_proj.", | |
| ".up_proj.", | |
| ".q_proj.", | |
| ".k_proj.", | |
| ".v_proj.", | |
| ".o_proj.", | |
| ] | |
| bitsandbytes_stacked_params_mapping = { | |
| # shard_name, weight_name, index | |
| "q_proj": ("qkv_proj", 0), | |
| "k_proj": ("qkv_proj", 1), | |
| "v_proj": ("qkv_proj", 2), | |
| "gate_proj": ("gate_up_proj", 0), | |
| "up_proj": ("gate_up_proj", 1), | |
| } | |
| def __init__( | |
| self, | |
| config: MiMoConfig, | |
| quant_config: Optional[QuantizationConfig] = None, | |
| prefix: str = "", | |
| ) -> None: | |
| super().__init__() | |
| self.config = config | |
| self.quant_config = quant_config | |
| self.model = MiMoModel( | |
| config, quant_config=quant_config, prefix=add_prefix("model", prefix) | |
| ) | |
| if config.tie_word_embeddings: | |
| self.lm_head = self.model.embed_tokens | |
| else: | |
| self.lm_head = ParallelLMHead( | |
| config.vocab_size, | |
| config.hidden_size, | |
| quant_config=quant_config, | |
| prefix=add_prefix("lm_head", prefix), | |
| ) | |
| self.logits_processor = LogitsProcessor(config) | |
| self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) | |
| def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: | |
| return self.model.get_input_embeddings(input_ids) | |
| def forward( | |
| self, | |
| input_ids: torch.Tensor, | |
| positions: torch.Tensor, | |
| forward_batch: ForwardBatch, | |
| input_embeds: torch.Tensor = None, | |
| get_embedding: bool = False, | |
| ) -> torch.Tensor: | |
| hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) | |
| if not get_embedding: | |
| return self.logits_processor( | |
| input_ids, hidden_states, self.lm_head, forward_batch | |
| ) | |
| else: | |
| return self.pooler(hidden_states, forward_batch) | |
| def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): | |
| stacked_params_mapping = [ | |
| # (param_name, shard_name, shard_id) | |
| ("qkv_proj", "q_proj", "q"), | |
| ("qkv_proj", "k_proj", "k"), | |
| ("qkv_proj", "v_proj", "v"), | |
| ("gate_up_proj", "gate_proj", 0), | |
| ("gate_up_proj", "up_proj", 1), | |
| ] | |
| params_dict = dict(self.named_parameters()) | |
| for name, loaded_weight in weights: | |
| if ( | |
| "rotary_emb.inv_freq" in name | |
| or "projector" in name | |
| or "mtp_layers" in name | |
| ): | |
| continue | |
| if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: | |
| # Models trained using ColossalAI may include these tensors in | |
| # the checkpoint. Skip them. | |
| continue | |
| if self.config.tie_word_embeddings and "lm_head.weight" in name: | |
| continue | |
| if name.startswith("model.vision_tower") and name not in params_dict: | |
| continue | |
| for param_name, weight_name, shard_id in stacked_params_mapping: | |
| if weight_name not in name: | |
| continue | |
| name = name.replace(weight_name, param_name) | |
| # Skip loading extra bias for GPTQ models. | |
| if name.endswith(".bias") and name not in params_dict: | |
| continue | |
| param = params_dict[name] | |
| weight_loader = param.weight_loader | |
| weight_loader(param, loaded_weight, shard_id) | |
| break | |
| else: | |
| # Skip loading extra bias for GPTQ models. | |
| if name.endswith(".bias") and name not in params_dict: | |
| continue | |
| param = params_dict[name] | |
| weight_loader = getattr(param, "weight_loader", default_weight_loader) | |
| weight_loader(param, loaded_weight) | |
| def get_embed_and_head(self): | |
| return self.model.embed_tokens.weight, self.lm_head.weight | |
| def set_embed_and_head(self, embed, head): | |
| del self.model.embed_tokens.weight | |
| del self.lm_head.weight | |
| self.model.embed_tokens.weight = embed | |
| self.lm_head.weight = head | |
| torch.cuda.empty_cache() | |
| torch.cuda.synchronize() | |
| def load_kv_cache_scales(self, quantization_param_path: str) -> None: | |
| self.model.load_kv_cache_scales(quantization_param_path) | |
| EntryClass = MiMoForCausalLM | |
Xet Storage Details
- Size:
- 5.66 kB
- Xet hash:
- 79271a46a174e415921accf62b07ec4528d52f531f5b5ca6bc9845c5561d1db2
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.