| from typing import Iterable, Tuple | |
| import torch | |
| from torch import nn | |
| from transformers import LlamaConfig | |
| from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType | |
| from sglang.srt.model_executor.model_runner import ForwardBatch | |
| from sglang.srt.model_loader.weight_utils import default_weight_loader | |
| from sglang.srt.models.llama import LlamaModel | |
| from sglang.srt.utils import add_prefix | |
| class LlamaEmbeddingModel(nn.Module): | |
| def __init__( | |
| self, | |
| config: LlamaConfig, | |
| quant_config=None, | |
| prefix: str = "", | |
| ) -> None: | |
| super().__init__() | |
| self.model = LlamaModel( | |
| config, quant_config=quant_config, prefix=add_prefix("model", prefix) | |
| ) | |
| self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) | |
| def forward( | |
| self, | |
| input_ids: torch.Tensor, | |
| positions: torch.Tensor, | |
| forward_batch: ForwardBatch, | |
| input_embeds: torch.Tensor = None, | |
| get_embedding: bool = True, | |
| ) -> EmbeddingPoolerOutput: | |
| assert ( | |
| get_embedding | |
| ), "LlamaEmbeddingModel / MistralModel is only used for embedding" | |
| hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) | |
| return self.pooler(hidden_states, forward_batch) | |
| def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): | |
| stacked_params_mapping = [ | |
| # (param_name, shard_name, shard_id) | |
| ("qkv_proj", "q_proj", "q"), | |
| ("qkv_proj", "k_proj", "k"), | |
| ("qkv_proj", "v_proj", "v"), | |
| ("gate_up_proj", "gate_proj", 0), | |
| ("gate_up_proj", "up_proj", 1), | |
| ] | |
| params_dict = dict(self.model.named_parameters()) | |
| for name, loaded_weight in weights: | |
| if "rotary_emb.inv_freq" in name or "projector" in name: | |
| return | |
| if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: | |
| # Models trained using ColossalAI may include these tensors in | |
| # the checkpoint. Skip them. | |
| return | |
| if name.startswith("model.vision_tower") and name not in params_dict: | |
| return | |
| for param_name, weight_name, shard_id in stacked_params_mapping: | |
| if weight_name not in name: | |
| continue | |
| name = name.replace(weight_name, param_name) | |
| # Skip loading extra bias for GPTQ models. | |
| if name.endswith(".bias") and name not in params_dict: | |
| continue | |
| param = params_dict[name] | |
| weight_loader = param.weight_loader | |
| weight_loader(param, loaded_weight, shard_id) | |
| break | |
| else: | |
| # Skip loading extra bias for GPTQ models. | |
| if name.endswith(".bias") and name not in params_dict: | |
| return | |
| param = params_dict[name] | |
| weight_loader = getattr(param, "weight_loader", default_weight_loader) | |
| weight_loader(param, loaded_weight) | |
| class MistralModel(LlamaEmbeddingModel): | |
| pass | |
| EntryClass = [LlamaEmbeddingModel, MistralModel] | |
Xet Storage Details
- Size:
- 3.25 kB
- Xet hash:
- 7feaf80497410967fbd7fd195a6c14ff7050dd848afc401e20a9ef77e60c3b50
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.