| # Copyright 2023-2024 SGLang Team | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ============================================================================== | |
| from typing import Iterable, Optional, Tuple | |
| import torch | |
| from torch import nn | |
| from transformers import LlamaConfig | |
| from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType | |
| from sglang.srt.layers.quantization.base_config import QuantizationConfig | |
| from sglang.srt.model_executor.forward_batch_info import ForwardBatch | |
| from sglang.srt.model_loader.weight_utils import default_weight_loader | |
| from sglang.srt.models.llama import LlamaForCausalLM, LlamaModel | |
| from sglang.srt.utils import add_prefix | |
| class LlamaForClassification(nn.Module): | |
| def __init__( | |
| self, | |
| config: LlamaConfig, | |
| quant_config: Optional[QuantizationConfig] = None, | |
| prefix: str = "", | |
| ) -> None: | |
| super().__init__() | |
| self.config = config | |
| self.quant_config = quant_config | |
| self.model = LlamaModel( | |
| config, quant_config=quant_config, prefix=add_prefix("model", prefix) | |
| ) | |
| self.classification_head = nn.Linear( | |
| config.hidden_size, config.classification_out_size, bias=False | |
| ) | |
| self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=False) | |
| def forward( | |
| self, | |
| input_ids: torch.Tensor, | |
| positions: torch.Tensor, | |
| forward_batch: ForwardBatch, | |
| input_embeds: torch.Tensor = None, | |
| get_embedding: bool = True, | |
| ) -> EmbeddingPoolerOutput: | |
| assert ( | |
| get_embedding | |
| ), "LlamaForClassification is only used for embedding. Please add --is-embedding when you launch the server." | |
| hidden_states = self.model(input_ids, positions, forward_batch, input_embeds) | |
| last_token_hidden = self.pooler(hidden_states, forward_batch).embeddings | |
| scores = self.classification_head(last_token_hidden) | |
| return EmbeddingPoolerOutput(scores) | |
| def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): | |
| params_dict = dict(self.named_parameters()) | |
| for name, loaded_weight in weights: | |
| if "classification_head" in name: | |
| param = params_dict[name] | |
| weight_loader = getattr(param, "weight_loader", default_weight_loader) | |
| weight_loader(param, loaded_weight) | |
| elif "lm_head" in name: | |
| continue | |
| else: | |
| LlamaForCausalLM.load_weights(self, [(name, loaded_weight)]) | |
| EntryClass = LlamaForClassification | |
Xet Storage Details
- Size:
- 3.11 kB
- Xet hash:
- fc67e1a675f385e0506bc35d028c2cdae8404729749d645279fcbf1935528672
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.