| # Copyright 2023-2024 SGLang Team | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # ============================================================================== | |
| """Inference-only OPT model compatible with HuggingFace weights.""" | |
| import logging | |
| from collections.abc import Iterable | |
| from typing import Optional, Union | |
| import torch | |
| from torch import nn | |
| from transformers import OPTConfig | |
| from sglang.srt.distributed import ( | |
| get_pp_group, | |
| get_tensor_model_parallel_rank, | |
| get_tensor_model_parallel_world_size, | |
| ) | |
| from sglang.srt.layers.linear import ( | |
| ColumnParallelLinear, | |
| QKVParallelLinear, | |
| ReplicatedLinear, | |
| RowParallelLinear, | |
| ) | |
| from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput | |
| from sglang.srt.layers.pooler import Pooler, PoolingType | |
| from sglang.srt.layers.quantization.base_config import QuantizationConfig | |
| from sglang.srt.layers.radix_attention import RadixAttention | |
| from sglang.srt.layers.utils import get_layer_id | |
| from sglang.srt.layers.vocab_parallel_embedding import ( | |
| ParallelLMHead, | |
| VocabParallelEmbedding, | |
| ) | |
| from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors | |
| from sglang.srt.model_loader.weight_utils import ( | |
| default_weight_loader, | |
| kv_cache_scales_loader, | |
| ) | |
| from sglang.srt.utils import add_prefix, make_layers | |
| from sglang.utils import get_exception_traceback | |
| logger = logging.getLogger(__name__) | |
| def get_activation(name="relu"): | |
| """Select an activation function by name | |
| Args: | |
| name: str | |
| activation function name, | |
| one of ["relu", "gelu", "swish", "sigmoid"], | |
| default "relu". | |
| """ | |
| name = name.lower() | |
| if name == "relu": | |
| return nn.ReLU() | |
| if name == "gelu": | |
| return nn.GELU() | |
| if name == "sigmoid": | |
| return torch.nn.Sigmoid() | |
| return nn.Identity() | |
| class OPTLearnedPositionalEmbedding(nn.Embedding): | |
| def __init__(self, num_embeddings: int, embedding_dim: int): | |
| # OPT is set up so that if padding_idx is specified then offset the | |
| # embedding ids by 2 and adjust num_embeddings appropriately. Other | |
| # models don't have this hack | |
| self.offset = 2 | |
| super().__init__(num_embeddings + self.offset, embedding_dim) | |
| def forward(self, positions: torch.Tensor): | |
| return super().forward(positions + self.offset) | |
| class OPTAttention(nn.Module): | |
| def __init__( | |
| self, | |
| embed_dim: int, | |
| num_heads: int, | |
| layer_id: int = 0, | |
| bias: bool = True, | |
| quant_config: Optional[QuantizationConfig] = None, | |
| prefix: str = "", | |
| ) -> None: | |
| super().__init__() | |
| self.embed_dim = embed_dim | |
| tensor_model_parallel_world_size = get_tensor_model_parallel_world_size() | |
| total_num_heads = num_heads | |
| assert num_heads % tensor_model_parallel_world_size == 0 | |
| self.num_heads = total_num_heads // tensor_model_parallel_world_size | |
| self.head_dim = embed_dim // total_num_heads | |
| self.scaling = self.head_dim**-0.5 | |
| self.qkv_proj = QKVParallelLinear( | |
| embed_dim, | |
| self.head_dim, | |
| total_num_heads, | |
| bias=bias, | |
| quant_config=quant_config, | |
| prefix=add_prefix("qkv_proj", prefix), | |
| ) | |
| self.out_proj = RowParallelLinear( | |
| embed_dim, | |
| embed_dim, | |
| bias=bias, | |
| quant_config=quant_config, | |
| prefix=add_prefix("o_proj", prefix), | |
| ) | |
| self.attn = RadixAttention( | |
| self.num_heads, | |
| self.head_dim, | |
| self.scaling, | |
| num_kv_heads=self.num_heads, | |
| layer_id=layer_id, | |
| quant_config=quant_config, | |
| prefix=add_prefix("attn", prefix), | |
| ) | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| forward_batch: ForwardBatch, | |
| ) -> torch.Tensor: | |
| qkv, _ = self.qkv_proj(hidden_states) | |
| q, k, v = qkv.chunk(chunks=3, dim=-1) | |
| attn_output = self.attn(q, k, v, forward_batch) | |
| output, _ = self.out_proj(attn_output) | |
| return output | |
| class OPTDecoderLayer(nn.Module): | |
| def __init__( | |
| self, | |
| config: OPTConfig, | |
| layer_id: int = 0, | |
| quant_config: Optional[QuantizationConfig] = None, | |
| prefix: str = "", | |
| ): | |
| super().__init__() | |
| self.config = config | |
| self.embed_dim = config.hidden_size | |
| self.self_attn = OPTAttention( | |
| embed_dim=self.embed_dim, | |
| num_heads=config.num_attention_heads, | |
| layer_id=layer_id, | |
| bias=config.enable_bias, | |
| quant_config=quant_config, | |
| prefix=add_prefix("self_attn", prefix), | |
| ) | |
| self.do_layer_norm_before = config.do_layer_norm_before | |
| self.self_attn_layer_norm = nn.LayerNorm( | |
| self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine | |
| ) | |
| self.fc1 = ColumnParallelLinear( | |
| self.embed_dim, | |
| config.ffn_dim, | |
| bias=config.enable_bias, | |
| quant_config=quant_config, | |
| prefix=add_prefix("fc1", prefix), | |
| ) | |
| self.activation_fn = get_activation(config.activation_function) | |
| self.fc2 = RowParallelLinear( | |
| config.ffn_dim, | |
| self.embed_dim, | |
| bias=config.enable_bias, | |
| quant_config=quant_config, | |
| prefix=add_prefix("fc2", prefix), | |
| ) | |
| self.final_layer_norm = nn.LayerNorm( | |
| self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine | |
| ) | |
| def forward( | |
| self, | |
| hidden_states: torch.Tensor, | |
| forward_batch: ForwardBatch, | |
| ) -> torch.Tensor: | |
| # Self Attention | |
| residual = hidden_states | |
| # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention | |
| if self.do_layer_norm_before: | |
| hidden_states = self.self_attn_layer_norm(hidden_states) | |
| hidden_states = self.self_attn( | |
| hidden_states=hidden_states, forward_batch=forward_batch | |
| ) | |
| hidden_states = residual + hidden_states | |
| # 350m applies layer norm AFTER attention | |
| if not self.do_layer_norm_before: | |
| hidden_states = self.self_attn_layer_norm(hidden_states) | |
| # Fully Connected | |
| residual = hidden_states | |
| # 125m, 1.7B, ..., 175B applies layer norm BEFORE attention | |
| if self.do_layer_norm_before: | |
| hidden_states = self.final_layer_norm(hidden_states) | |
| hidden_states, _ = self.fc1(hidden_states) | |
| hidden_states = self.activation_fn(hidden_states) | |
| hidden_states, _ = self.fc2(hidden_states) | |
| hidden_states = residual + hidden_states | |
| # 350m applies layer norm AFTER attention | |
| if not self.do_layer_norm_before: | |
| hidden_states = self.final_layer_norm(hidden_states) | |
| return hidden_states | |
| class OPTDecoder(nn.Module): | |
| def __init__( | |
| self, | |
| config: OPTConfig, | |
| layer_id: int = 0, | |
| quant_config: Optional[QuantizationConfig] = None, | |
| prefix: str = "", | |
| ): | |
| super().__init__() | |
| self.config = config | |
| self.max_target_positions = config.max_position_embeddings | |
| self.vocab_size = config.vocab_size | |
| self.pp_group = get_pp_group() | |
| self.embed_tokens = VocabParallelEmbedding( | |
| config.vocab_size, | |
| config.word_embed_proj_dim, | |
| prefix=add_prefix("embed_tokens", prefix), | |
| ) | |
| # Positional embeddings are replicated (not sharded). | |
| self.embed_positions = OPTLearnedPositionalEmbedding( | |
| config.max_position_embeddings, config.hidden_size | |
| ) | |
| # Project out & in will be replicated if they exist. | |
| if config.word_embed_proj_dim != config.hidden_size: | |
| self.project_out = ReplicatedLinear( | |
| config.hidden_size, | |
| config.word_embed_proj_dim, | |
| bias=False, | |
| quant_config=quant_config, | |
| prefix=add_prefix("project_out", prefix), | |
| ) | |
| else: | |
| self.project_out = None | |
| if config.word_embed_proj_dim != config.hidden_size: | |
| self.project_in = ReplicatedLinear( | |
| config.word_embed_proj_dim, | |
| config.hidden_size, | |
| bias=False, | |
| quant_config=quant_config, | |
| prefix=add_prefix("project_in", prefix), | |
| ) | |
| else: | |
| self.project_in = None | |
| # Note that the only purpose of `config._remove_final_layer_norm` is to | |
| # keep backward compatibility with checkpoints that have been fine-tuned | |
| # before transformers v4.20.1 | |
| # see https://github.com/facebookresearch/metaseq/pull/164 | |
| if config.do_layer_norm_before and not config._remove_final_layer_norm: | |
| self.final_layer_norm = nn.LayerNorm( | |
| config.hidden_size, | |
| elementwise_affine=config.layer_norm_elementwise_affine, | |
| ) | |
| else: | |
| self.final_layer_norm = None | |
| self.layers, self.start_layer, self.end_layer = make_layers( | |
| config.num_hidden_layers, | |
| lambda idx, prefix: OPTDecoderLayer( | |
| config=config, layer_id=idx, quant_config=quant_config, prefix=prefix | |
| ), | |
| pp_rank=self.pp_group.rank_in_group, | |
| pp_size=self.pp_group.world_size, | |
| prefix="model.layers", | |
| ) | |
| def forward( | |
| self, | |
| input_ids: torch.Tensor, | |
| positions: torch.Tensor, | |
| forward_batch: ForwardBatch, | |
| pp_proxy_tensors: Optional[PPProxyTensors] = None, | |
| input_embeds: Optional[torch.Tensor] = None, | |
| ) -> Union[torch.Tensor, PPProxyTensors]: | |
| if self.pp_group.is_first_rank: | |
| if input_embeds is None: | |
| input_embeds = self.embed_tokens(input_ids) | |
| pos_embeds = self.embed_positions(positions) | |
| if self.project_in is not None: | |
| input_embeds, _ = self.project_in(input_embeds) | |
| hidden_states = input_embeds + pos_embeds | |
| else: | |
| assert pp_proxy_tensors is not None | |
| hidden_states = pp_proxy_tensors["hidden_states"] | |
| for layer in self.layers[self.start_layer : self.end_layer]: | |
| hidden_states = layer( | |
| hidden_states=hidden_states, forward_batch=forward_batch | |
| ) | |
| if not self.pp_group.is_last_rank: | |
| return PPProxyTensors({"hidden_states": hidden_states}) | |
| if self.final_layer_norm is not None: | |
| hidden_states = self.final_layer_norm(hidden_states) | |
| # 没有经过这里 | |
| if self.project_out is not None: | |
| hidden_states, _ = self.project_out(hidden_states) | |
| return hidden_states | |
| class OPTModel(nn.Module): | |
| def __init__( | |
| self, | |
| config: OPTConfig, | |
| quant_config: Optional[QuantizationConfig] = None, | |
| prefix: str = "", | |
| ) -> None: | |
| super().__init__() | |
| # config = vllm_config.model_config.hf_config | |
| # quant_config = vllm_config.quant_config | |
| self.config = config | |
| self.padding_idx = config.pad_token_id | |
| self.vocab_size = config.vocab_size | |
| self.pp_group = get_pp_group() | |
| self.decoder = OPTDecoder( | |
| config=config, | |
| quant_config=quant_config, | |
| prefix=add_prefix("decoder", prefix), | |
| ) | |
| def forward( | |
| self, | |
| input_ids: torch.Tensor, | |
| positions: torch.Tensor, | |
| forward_batch: ForwardBatch, | |
| pp_proxy_tensors: Optional[PPProxyTensors], | |
| input_embeds: Optional[torch.Tensor] = None, | |
| ) -> Union[torch.Tensor, PPProxyTensors]: | |
| return self.decoder( | |
| input_ids, | |
| positions, | |
| pp_proxy_tensors=pp_proxy_tensors, | |
| input_embeds=input_embeds, | |
| forward_batch=forward_batch, | |
| ) | |
| def load_kv_cache_scales(self, quantization_param_path: str) -> None: | |
| tp_size = get_tensor_model_parallel_world_size() | |
| tp_rank = get_tensor_model_parallel_rank() | |
| for layer_idx, scaling_factor in kv_cache_scales_loader( | |
| quantization_param_path, | |
| tp_rank, | |
| tp_size, | |
| self.config.num_hidden_layers, | |
| self.config.__class__.model_type, | |
| ): | |
| if not isinstance(self.decoder.layers[layer_idx], nn.Identity): | |
| layer_self_attn = self.decoder.layers[layer_idx].self_attn | |
| if hasattr(layer_self_attn.attn, "k_scale"): | |
| layer_self_attn.attn.k_scale = scaling_factor | |
| layer_self_attn.attn.v_scale = scaling_factor | |
| else: | |
| raise RuntimeError( | |
| "Self attention has no KV cache scaling " "factor attribute!" | |
| ) | |
| class OPTForCausalLM(nn.Module): | |
| # BitandBytes specific attributes | |
| # in TP, these weights are partitioned along the column dimension (dim=-1) | |
| column_parallel_weights_modules = [".down_proj.", ".o_proj."] | |
| def __init__( | |
| self, | |
| config: OPTConfig, | |
| quant_config: Optional[QuantizationConfig] = None, | |
| prefix: str = "", | |
| ): | |
| super().__init__() | |
| self.config = config | |
| self.quant_config = quant_config | |
| self.model = OPTModel( | |
| config=config, quant_config=quant_config, prefix=add_prefix("model", prefix) | |
| ) | |
| if self.config.tie_word_embeddings: | |
| self.lm_head = self.model.decoder.embed_tokens | |
| else: | |
| self.lm_head = ParallelLMHead( | |
| config.vocab_size, | |
| config.word_embed_proj_dim, | |
| prefix=add_prefix("lm_head", prefix), | |
| ) | |
| self.logits_processor = LogitsProcessor(config) | |
| self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True) | |
| self.capture_aux_hidden_states = False | |
| self.pp_group = get_pp_group() | |
| self.stacked_params_mapping = [ | |
| # (param_name, shard_name, shard_id) | |
| (".qkv_proj", ".q_proj", "q"), | |
| (".qkv_proj", ".k_proj", "k"), | |
| (".qkv_proj", ".v_proj", "v"), | |
| ] | |
| def forward( | |
| self, | |
| input_ids: torch.Tensor, | |
| positions: torch.Tensor, | |
| forward_batch: ForwardBatch, | |
| pp_proxy_tensors: Optional[PPProxyTensors] = None, | |
| input_embeds: Optional[torch.Tensor] = None, | |
| get_embedding: bool = False, | |
| ) -> LogitsProcessorOutput: | |
| hidden_states = self.model( | |
| input_ids=input_ids, | |
| positions=positions, | |
| forward_batch=forward_batch, | |
| input_embeds=input_embeds, | |
| pp_proxy_tensors=pp_proxy_tensors, | |
| ) | |
| aux_hidden_states = None | |
| if self.capture_aux_hidden_states: | |
| hidden_states, aux_hidden_states = hidden_states | |
| if self.pp_group.is_last_rank: | |
| if not get_embedding: | |
| return self.logits_processor( | |
| input_ids, | |
| hidden_states, | |
| self.lm_head, | |
| forward_batch, | |
| aux_hidden_states=aux_hidden_states, | |
| ) | |
| else: | |
| return self.pooler(hidden_states, forward_batch) | |
| else: | |
| return hidden_states | |
| def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> None: | |
| stacked_params_mapping = [ | |
| # (param_name, shard_name, shard_id) | |
| ("qkv_proj", "q_proj", "q"), | |
| ("qkv_proj", "k_proj", "k"), | |
| ("qkv_proj", "v_proj", "v"), | |
| ] | |
| params_dict = dict(self.named_parameters(remove_duplicate=False)) | |
| for name, loaded_weight in weights: | |
| if name.startswith("decoder"): | |
| name = name.replace("decoder.", "model.decoder.") | |
| layer_id = get_layer_id(name) | |
| if ( | |
| layer_id is not None | |
| and hasattr(self.model, "start_layer") | |
| and ( | |
| layer_id < self.model.start_layer | |
| or layer_id >= self.model.end_layer | |
| ) | |
| ): | |
| continue | |
| for param_name, weight_name, shard_id in stacked_params_mapping: | |
| if weight_name not in name: | |
| continue | |
| name = name.replace(weight_name, param_name) | |
| # Skip loading extra bias for GPTQ models. | |
| if name.endswith(".bias") and name not in params_dict: | |
| continue | |
| # if is_pp_missing_parameter(name, self): | |
| # continue | |
| param = params_dict[name] | |
| weight_loader = param.weight_loader | |
| weight_loader(param, loaded_weight, shard_id) | |
| break | |
| else: | |
| # Skip loading extra bias for GPTQ models. | |
| if name.endswith(".bias") and name not in params_dict: | |
| continue | |
| # if is_pp_missing_parameter(name, self): | |
| # continue | |
| if name not in params_dict: | |
| continue | |
| if name in params_dict.keys(): | |
| param = params_dict[name] | |
| weight_loader = getattr( | |
| param, "weight_loader", default_weight_loader | |
| ) | |
| weight_loader(param, loaded_weight) | |
| else: | |
| logger.warning(f"Parameter {name} not found in params_dict") | |
| def start_layer(self): | |
| return self.model.start_layer | |
| def end_layer(self): | |
| return self.model.end_layer | |
| def get_input_embeddings(self) -> nn.Embedding: | |
| return self.model.embed_tokens | |
| def get_module_name_from_weight_name(self, name): | |
| for param_name, weight_name, shard_id, num_shard in self.stacked_params_mapping: | |
| if weight_name in name: | |
| return ( | |
| name.replace(weight_name, param_name)[: -len(".weight")], | |
| num_shard, | |
| ) | |
| return name[: -len(".weight")], 1 | |
| def get_num_params(self): | |
| params_dict = dict(self.named_parameters()) | |
| return len(params_dict) | |
| def get_weights_by_name( | |
| self, name: str, truncate_size: int = 100, tp_size: int = 1 | |
| ) -> Optional[torch.Tensor]: | |
| """Get the weights of the parameter by its name. Similar to `get_parameter` in Hugging Face. | |
| Only used for unit test with an unoptimized performance. | |
| For optimized performance, please use torch.save and torch.load. | |
| """ | |
| try: | |
| if name == "lm_head.weight" and self.config.tie_word_embeddings: | |
| logger.info( | |
| "word embedding is tied for this model, return embed_tokens.weight as lm_head.weight." | |
| ) | |
| return ( | |
| self.model.embed_tokens.weight.cpu() | |
| .to(torch.float32) | |
| .numpy() | |
| .tolist()[:truncate_size] | |
| ) | |
| mapped_name = name | |
| mapped_shard_id = None | |
| for param_name, weight_name, shard_id in self.stacked_params_mapping: | |
| if weight_name in name: | |
| mapped_name = name.replace(weight_name, param_name) | |
| mapped_shard_id = shard_id | |
| break | |
| params_dict = dict(self.named_parameters()) | |
| param = params_dict[mapped_name] | |
| if mapped_shard_id is not None: | |
| if mapped_shard_id in ["q", "k", "v"]: | |
| num_heads = self.config.num_attention_heads // tp_size | |
| num_kv_heads = self.config.num_attention_heads // tp_size | |
| head_dim = ( | |
| self.config.hidden_size // self.config.num_attention_heads | |
| ) | |
| if mapped_shard_id == "q": | |
| offset = 0 | |
| size = num_heads * head_dim | |
| elif mapped_shard_id == "k": | |
| offset = num_heads * head_dim | |
| size = num_kv_heads * head_dim | |
| elif mapped_shard_id == "v": | |
| offset = (num_heads + num_kv_heads) * head_dim | |
| size = num_kv_heads * head_dim | |
| weight = param.data.narrow(0, offset, size) | |
| elif mapped_shard_id in [0, 1]: | |
| intermediate_size = self.config.ffn_dim | |
| slice_size = intermediate_size // tp_size | |
| if mapped_shard_id == 0: # gate_proj | |
| offset = 0 | |
| size = slice_size | |
| elif mapped_shard_id == 1: # up_proj | |
| offset = slice_size | |
| size = slice_size | |
| weight = param.data.narrow(0, offset, size) | |
| else: | |
| weight = param.data | |
| else: | |
| weight = param.data | |
| if tp_size > 1 and ("o_proj" in name or "down_proj" in name): | |
| gathered_weights = [torch.zeros_like(weight) for _ in range(tp_size)] | |
| torch.distributed.all_gather(gathered_weights, weight) | |
| weight = torch.cat(gathered_weights, dim=1) | |
| return weight.cpu().to(torch.float32).numpy().tolist()[:truncate_size] | |
| except Exception: | |
| logger.error( | |
| f"Error getting weights by name {name} in OPTForCausalLM: {get_exception_traceback()}" | |
| ) | |
| return None | |
| def get_embed_and_head(self): | |
| return self.model.embed_tokens.weight, self.lm_head.weight | |
| def set_embed_and_head(self, embed, head): | |
| del self.model.embed_tokens.weight | |
| del self.lm_head.weight | |
| self.model.embed_tokens.weight = embed | |
| self.lm_head.weight = head | |
| torch.cuda.empty_cache() | |
| torch.cuda.synchronize() | |
| def get_embed(self): | |
| return self.model.embed_tokens.weight | |
| def set_embed(self, embed): | |
| # NOTE: If draft hidden size != target hidden size, the embed weight cannot be shared for EAGLE3 | |
| if ( | |
| hasattr(self.config, "target_hidden_size") | |
| and self.config.target_hidden_size != self.config.hidden_size | |
| ): | |
| return | |
| del self.model.embed_tokens.weight | |
| self.model.embed_tokens.weight = embed | |
| torch.cuda.empty_cache() | |
| torch.cuda.synchronize() | |
| def load_kv_cache_scales(self, quantization_param_path: str) -> None: | |
| self.model.load_kv_cache_scales(quantization_param_path) | |
| EntryClass = [OPTForCausalLM] | |
Xet Storage Details
- Size:
- 23.4 kB
- Xet hash:
- 066f1a4dce8e4f1d000e36f1493f8fa88c633cc933d66453b96be7bef249117d
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.