| # SPDX-License-Identifier: Apache-2.0 | |
| from typing import Iterable, Optional, Set, Tuple | |
| import torch | |
| from torch import nn | |
| from sglang.srt.distributed import get_tensor_model_parallel_world_size | |
| from sglang.srt.layers.activation import get_act_fn | |
| from sglang.srt.layers.linear import ( | |
| ColumnParallelLinear, | |
| QKVParallelLinear, | |
| RowParallelLinear, | |
| ) | |
| from sglang.srt.layers.pooler import CrossEncodingPooler, Pooler, PoolingType | |
| from sglang.srt.layers.quantization.base_config import QuantizationConfig | |
| from sglang.srt.layers.radix_attention import AttentionType, RadixAttention | |
| from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding | |
| from sglang.srt.model_executor.forward_batch_info import ForwardBatch | |
| from sglang.srt.model_loader.weight_utils import default_weight_loader | |
| from sglang.srt.utils import add_prefix | |
| BertConfig = None | |
| class BertEmbedding(nn.Module): | |
| def __init__(self, config: BertConfig): | |
| super().__init__() | |
| self.size = config.hidden_size | |
| self.word_embeddings = VocabParallelEmbedding( | |
| config.vocab_size, config.hidden_size | |
| ) | |
| self.position_embeddings = VocabParallelEmbedding( | |
| config.max_position_embeddings, config.hidden_size | |
| ) | |
| self.token_type_embeddings = VocabParallelEmbedding( | |
| config.type_vocab_size, config.hidden_size | |
| ) | |
| self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) | |
| self.position_ids = nn.Parameter( | |
| torch.empty((1, config.max_position_embeddings)), | |
| ) | |
| self.position_embedding_type = config.position_embedding_type | |
| if self.position_embedding_type != "absolute": | |
| raise ValueError( | |
| "Only 'absolute' position_embedding_type" + " is supported" | |
| ) | |
| def forward( | |
| self, | |
| input_ids: torch.Tensor, | |
| positions: torch.Tensor, | |
| forward_batch: ForwardBatch, | |
| ) -> torch.Tensor: | |
| input_shape = input_ids.size() | |
| # Input embeddings. | |
| inputs_embeds = self.word_embeddings(input_ids) | |
| # Position embeddings. | |
| position_embeddings = self.position_embeddings(positions) | |
| token_type_ids = forward_batch.token_type_ids | |
| if token_type_ids is None: | |
| token_type_ids = torch.zeros( | |
| input_shape, dtype=torch.long, device=inputs_embeds.device | |
| ) | |
| token_type_embeddings = self.token_type_embeddings(token_type_ids) | |
| embeddings = inputs_embeds + token_type_embeddings + position_embeddings | |
| embeddings = self.LayerNorm(embeddings) | |
| return embeddings | |
| class BertPooler(nn.Module): | |
| def __init__(self, config: BertConfig): | |
| super().__init__() | |
| self.dense = nn.Linear(config.hidden_size, config.hidden_size) | |
| self.activation = nn.Tanh() | |
| def forward( | |
| self, hidden_states: torch.Tensor, forward_batch: ForwardBatch | |
| ) -> torch.Tensor: | |
| # simply taking the hidden state corresponding | |
| first_token_tensor = hidden_states[0, :] | |
| pooled_output = self.dense(first_token_tensor) | |
| pooled_output = self.activation(pooled_output) | |
| return pooled_output | |
| class BertEncoder(nn.Module): | |
| def __init__( | |
| self, | |
| config: BertConfig, | |
| quant_config: Optional[QuantizationConfig] = None, | |
| prefix: str = "", | |
| ): | |
| super().__init__() | |
| self.config = config | |
| self.quant_config = quant_config | |
| self.layer = nn.ModuleList( | |
| [ | |
| BertLayer( | |
| config=config, | |
| layer_id=layer_idx, | |
| quant_config=quant_config, | |
| prefix=f"{prefix}.layer.{layer_idx}", | |
| ) | |
| for layer_idx in range(config.num_hidden_layers) | |
| ] | |
| ) | |
| def forward( | |
| self, hidden_states: torch.Tensor, forward_batch: ForwardBatch | |
| ) -> torch.Tensor: | |
| for layer in self.layer: | |
| hidden_states = layer(hidden_states, forward_batch) | |
| return hidden_states | |
| class BertLayer(nn.Module): | |
| def __init__( | |
| self, | |
| config: BertConfig, | |
| layer_id: int = 0, | |
| quant_config: Optional[QuantizationConfig] = None, | |
| prefix: str = "", | |
| ): | |
| super().__init__() | |
| self.layer_id = layer_id | |
| self.attention = BertAttention( | |
| hidden_size=config.hidden_size, | |
| num_attention_heads=config.num_attention_heads, | |
| layer_id=layer_id, | |
| layer_norm_eps=config.layer_norm_eps, | |
| quant_config=quant_config, | |
| prefix=f"{prefix}.attention", | |
| ) | |
| self.intermediate = BertIntermediate( | |
| hidden_size=config.hidden_size, | |
| intermediate_size=config.intermediate_size, | |
| hidden_act=config.hidden_act, | |
| quant_config=quant_config, | |
| prefix=f"{prefix}.intermediate", | |
| ) | |
| self.output = BertOutput( | |
| hidden_size=config.hidden_size, | |
| intermediate_size=config.intermediate_size, | |
| layer_norm_eps=config.layer_norm_eps, | |
| quant_config=quant_config, | |
| prefix=f"{prefix}.output", | |
| ) | |
| def forward(self, hidden_states: torch.Tensor, forward_batch: ForwardBatch): | |
| attn_output = self.attention(hidden_states, forward_batch) | |
| intermediate_output = self.intermediate(attn_output) | |
| output = self.output(intermediate_output, attn_output) | |
| return output | |
| class BertAttention(nn.Module): | |
| def __init__( | |
| self, | |
| hidden_size: int, | |
| num_attention_heads: int, | |
| layer_norm_eps: float, | |
| layer_id: int = 0, | |
| quant_config: Optional[QuantizationConfig] = None, | |
| prefix: str = "", | |
| ): | |
| super().__init__() | |
| self.self_attn = BertSelfAttention( | |
| hidden_size=hidden_size, | |
| num_attention_heads=num_attention_heads, | |
| layer_id=layer_id, | |
| quant_config=quant_config, | |
| prefix=f"{prefix}.output", | |
| ) | |
| self.output = BertSelfOutput( | |
| hidden_size=hidden_size, | |
| layer_norm_eps=layer_norm_eps, | |
| quant_config=quant_config, | |
| prefix=f"{prefix}.output", | |
| ) | |
| def forward( | |
| self, hidden_states: torch.Tensor, forward_batch: ForwardBatch | |
| ) -> torch.Tensor: | |
| self_output = self.self_attn(hidden_states, forward_batch) | |
| return self.output(self_output, hidden_states) | |
| class BertSelfAttention(nn.Module): | |
| def __init__( | |
| self, | |
| hidden_size: int, | |
| num_attention_heads: int, | |
| layer_id: int = 0, | |
| quant_config: Optional[QuantizationConfig] = None, | |
| prefix: str = "", | |
| ): | |
| super().__init__() | |
| self.hidden_size = hidden_size | |
| tp_size = get_tensor_model_parallel_world_size() | |
| self.total_num_heads = num_attention_heads | |
| assert self.total_num_heads % tp_size == 0 | |
| self.num_heads = self.total_num_heads // tp_size | |
| self.total_num_kv_heads = self.total_num_heads | |
| self.head_dim = self.hidden_size // self.total_num_heads | |
| assert self.head_dim * self.total_num_heads == self.hidden_size | |
| self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) | |
| self.q_size = self.num_heads * self.head_dim | |
| self.kv_size = self.num_kv_heads * self.head_dim | |
| self.scaling = self.head_dim**-0.5 | |
| self.qkv_proj = QKVParallelLinear( | |
| hidden_size=self.hidden_size, | |
| head_size=self.head_dim, | |
| total_num_heads=self.total_num_heads, | |
| total_num_kv_heads=self.total_num_kv_heads, | |
| bias=True, | |
| quant_config=quant_config, | |
| prefix=f"{prefix}.qkv_proj", | |
| ) | |
| self.attn = RadixAttention( | |
| num_heads=self.num_heads, | |
| head_dim=self.head_dim, | |
| scaling=self.scaling, | |
| num_kv_heads=self.num_kv_heads, | |
| layer_id=layer_id, | |
| prefix=f"{prefix}.attn", | |
| attn_type=AttentionType.ENCODER_ONLY, | |
| ) | |
| def forward( | |
| self, hidden_states: torch.Tensor, forward_batch: ForwardBatch | |
| ) -> torch.Tensor: | |
| qkv, _ = self.qkv_proj(hidden_states) | |
| q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) | |
| output = self.attn(q, k, v, forward_batch) | |
| return output | |
| class BertSelfOutput(nn.Module): | |
| def __init__( | |
| self, | |
| hidden_size: int, | |
| layer_norm_eps: float, | |
| quant_config: Optional[QuantizationConfig] = None, | |
| prefix: str = "", | |
| ): | |
| super().__init__() | |
| self.dense = RowParallelLinear( | |
| input_size=hidden_size, | |
| output_size=hidden_size, | |
| bias=True, | |
| quant_config=quant_config, | |
| prefix=f"{prefix}.dense", | |
| ) | |
| self.LayerNorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps) | |
| def forward( | |
| self, hidden_states: torch.Tensor, input_tensor: torch.Tensor | |
| ) -> torch.Tensor: | |
| hidden_states, _ = self.dense(hidden_states) | |
| hidden_states = self.LayerNorm(hidden_states + input_tensor) | |
| return hidden_states | |
| class BertIntermediate(nn.Module): | |
| def __init__( | |
| self, | |
| hidden_size: int, | |
| intermediate_size: int, | |
| hidden_act: str, | |
| quant_config: Optional[QuantizationConfig] = None, | |
| prefix: str = "", | |
| ): | |
| super().__init__() | |
| self.dense = ColumnParallelLinear( | |
| input_size=hidden_size, | |
| output_size=intermediate_size, | |
| bias=True, | |
| quant_config=quant_config, | |
| prefix=f"{prefix}.dense", | |
| ) | |
| self.intermediate_act_fn = get_act_fn(hidden_act) | |
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | |
| hidden_states, _ = self.dense(hidden_states) | |
| hidden_states = self.intermediate_act_fn(hidden_states) | |
| return hidden_states | |
| class BertOutput(nn.Module): | |
| def __init__( | |
| self, | |
| hidden_size: int, | |
| intermediate_size: int, | |
| layer_norm_eps: float, | |
| quant_config: Optional[QuantizationConfig] = None, | |
| prefix: str = "", | |
| ): | |
| super().__init__() | |
| self.dense = RowParallelLinear( | |
| input_size=intermediate_size, | |
| output_size=hidden_size, | |
| bias=True, | |
| quant_config=quant_config, | |
| prefix=f"{prefix}.dense", | |
| ) | |
| self.LayerNorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps) | |
| def forward( | |
| self, hidden_states: torch.Tensor, input_tensor: torch.Tensor | |
| ) -> torch.Tensor: | |
| hidden_states, _ = self.dense(hidden_states) | |
| hidden_states = self.LayerNorm(hidden_states + input_tensor) | |
| return hidden_states | |
| class BertModel(nn.Module): | |
| def __init__( | |
| self, | |
| *, | |
| config: BertConfig, | |
| quant_config: Optional[QuantizationConfig] = None, | |
| use_bert_pooler: bool = False, | |
| prefix: str = "", | |
| ): | |
| super().__init__() | |
| self.use_bert_pooler = use_bert_pooler | |
| self.config = config | |
| self.embeddings = BertEmbedding(config) | |
| self.encoder = BertEncoder( | |
| config=config, | |
| quant_config=quant_config, | |
| prefix=add_prefix("encoder", prefix), | |
| ) | |
| self.pooler = ( | |
| BertPooler(config) | |
| if self.use_bert_pooler | |
| else Pooler(pooling_type=PoolingType.LAST, normalize=True) | |
| ) | |
| def forward( | |
| self, | |
| input_ids: torch.Tensor, | |
| positions: torch.Tensor, | |
| forward_batch: ForwardBatch, | |
| input_embeds: torch.Tensor = None, | |
| get_embedding: bool = False, | |
| ) -> torch.Tensor: | |
| assert get_embedding == True | |
| # Your tokenized IDs | |
| hidden_states = self.embeddings( | |
| input_ids=input_ids, | |
| positions=positions, | |
| forward_batch=forward_batch, | |
| ) | |
| hidden_states = self.encoder(hidden_states, forward_batch=forward_batch) | |
| if not self.use_bert_pooler: | |
| hidden_states = self.pooler(hidden_states, forward_batch) | |
| return hidden_states | |
| def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: | |
| stacked_params_mapping = [ | |
| # (param_name, shard_name, shard_id) | |
| ("qkv_proj", "query", "q"), | |
| ("qkv_proj", "key", "k"), | |
| ("qkv_proj", "value", "v"), | |
| ] | |
| params_dict = dict(self.named_parameters()) | |
| for name, loaded_weight in weights: | |
| name = name.replace("self", "self_attn") | |
| if not self.use_bert_pooler and "pooler" in name: | |
| continue | |
| for param_name, weight_name, shard_id in stacked_params_mapping: | |
| if weight_name not in name: | |
| continue | |
| name = name.replace(weight_name, param_name) | |
| # Skip loading extra bias for GPTQ models. | |
| if name.endswith(".bias") and name not in params_dict: | |
| continue | |
| param = params_dict[name] | |
| weight_loader = param.weight_loader | |
| weight_loader(param, loaded_weight, shard_id) | |
| break | |
| else: | |
| # Skip loading extra bias for GPTQ models. | |
| if name.endswith(".bias") and name not in params_dict: | |
| continue | |
| param = params_dict[name] | |
| weight_loader = getattr(param, "weight_loader", default_weight_loader) | |
| weight_loader(param, loaded_weight) | |
| class Contriever(BertModel): | |
| pass | |
| class BertForSequenceClassification(nn.Module): | |
| def __init__( | |
| self, | |
| *, | |
| config: BertConfig, | |
| quant_config: Optional[QuantizationConfig] = None, | |
| prefix: str = "", | |
| ): | |
| super().__init__() | |
| self.num_labels = config.num_labels | |
| self.bert = BertModel( | |
| config=config, | |
| quant_config=quant_config, | |
| use_bert_pooler=True, | |
| prefix=add_prefix("bert", prefix), | |
| ) | |
| self.classifier = nn.Linear(config.hidden_size, config.num_labels) | |
| self.pooler = CrossEncodingPooler(config, self.classifier, self.bert.pooler) | |
| def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): | |
| self_weights = [] | |
| def weight_filter(): | |
| for name, weight in weights: | |
| if name.startswith("bert."): | |
| yield (name[len("bert.") :], weight) | |
| else: | |
| self_weights.append((name, weight)) | |
| self.bert.load_weights(weight_filter()) | |
| params_dict = dict(self.named_parameters()) | |
| for name, loaded_weight in self_weights: | |
| if name.startswith("classifier"): | |
| param = params_dict[name] | |
| weight_loader = getattr(param, "weight_loader", default_weight_loader) | |
| weight_loader(param, loaded_weight) | |
| def forward( | |
| self, | |
| input_ids: torch.Tensor, | |
| positions: torch.Tensor, | |
| forward_batch: ForwardBatch, | |
| input_embeds: torch.Tensor = None, | |
| get_embedding: bool = False, | |
| ) -> torch.Tensor: | |
| assert get_embedding == True | |
| hidden_states = self.bert( | |
| input_ids=input_ids, | |
| positions=positions, | |
| forward_batch=forward_batch, | |
| input_embeds=input_embeds, | |
| get_embedding=get_embedding, | |
| ) | |
| return self.pooler(hidden_states, forward_batch) | |
| EntryClass = [BertModel, Contriever, BertForSequenceClassification] | |
Xet Storage Details
- Size:
- 15.8 kB
- Xet hash:
- bbc4bfe3a60f197319bc0a9cf50885fcec73e9a585de2103b5f8d66a7fad29ce
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.