| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | """Inference-only BaiChuan model compatible with HuggingFace weights.""" |
| | import math |
| | from collections.abc import Iterable |
| | from typing import Optional, Union |
| |
|
| | import torch |
| | from torch import nn |
| | from transformers import PretrainedConfig |
| |
|
| | from vllm.attention import Attention |
| | from vllm.compilation.decorators import support_torch_compile |
| | from vllm.config import CacheConfig, VllmConfig |
| | from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, |
| | get_tensor_model_parallel_world_size) |
| | from vllm.model_executor.layers.activation import SiluAndMul |
| | from vllm.model_executor.layers.layernorm import RMSNorm |
| | from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, |
| | QKVParallelLinear, |
| | RowParallelLinear) |
| | from vllm.model_executor.layers.logits_processor import LogitsProcessor |
| | from vllm.model_executor.layers.quantization import QuantizationConfig |
| | from vllm.model_executor.layers.rotary_embedding import get_rope |
| | from vllm.model_executor.layers.vocab_parallel_embedding import ( |
| | ParallelLMHead, VocabParallelEmbedding) |
| | from vllm.model_executor.model_loader.weight_utils import ( |
| | default_weight_loader, row_parallel_weight_loader) |
| | from vllm.model_executor.sampling_metadata import SamplingMetadata |
| | from vllm.sequence import IntermediateTensors |
| |
|
| | from .interfaces import SupportsLoRA, SupportsPP, SupportsQuant |
| | from .utils import (AutoWeightsLoader, is_pp_missing_parameter, |
| | make_empty_intermediate_tensors_factory, make_layers) |
| |
|
| |
|
| | def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor: |
| | closest_power_of_2 = 2**math.floor(math.log2(total_num_heads)) |
| | base = torch.tensor( |
| | 2**(-(2**-(math.log2(closest_power_of_2) - 3))), |
| | dtype=torch.float32, |
| | ) |
| | powers = torch.arange(1, 1 + closest_power_of_2, dtype=torch.int32) |
| | slopes = torch.pow(base, powers) |
| |
|
| | if closest_power_of_2 != total_num_heads: |
| | extra_base = torch.tensor( |
| | 2**(-(2**-(math.log2(2 * closest_power_of_2) - 3))), |
| | dtype=torch.float32, |
| | ) |
| | num_remaining_heads = min(closest_power_of_2, |
| | total_num_heads - closest_power_of_2) |
| | extra_powers = torch.arange(start=1, |
| | end=1 + 2 * num_remaining_heads, |
| | step=2, |
| | dtype=torch.int32) |
| | slopes = torch.cat( |
| | [slopes, torch.pow(extra_base, extra_powers)], dim=0) |
| | return slopes |
| |
|
| |
|
| | class BaiChuanMLP(nn.Module): |
| |
|
| | def __init__( |
| | self, |
| | hidden_size: int, |
| | intermediate_size: int, |
| | hidden_act: str, |
| | quant_config: Optional[QuantizationConfig] = None, |
| | ): |
| | super().__init__() |
| | self.gate_up_proj = MergedColumnParallelLinear( |
| | hidden_size, [intermediate_size] * 2, |
| | bias=False, |
| | quant_config=quant_config) |
| | self.down_proj = RowParallelLinear(intermediate_size, |
| | hidden_size, |
| | bias=False, |
| | quant_config=quant_config) |
| | if hidden_act != "silu": |
| | raise ValueError(f"Unsupported activation: {hidden_act}. " |
| | "Only silu is supported for now.") |
| | self.act_fn = SiluAndMul() |
| |
|
| | def forward(self, x): |
| | gate_up, _ = self.gate_up_proj(x) |
| | x = self.act_fn(gate_up) |
| | x, _ = self.down_proj(x) |
| | return x |
| |
|
| |
|
| | class BaiChuanAttention(nn.Module): |
| | """Multi-headed attention from 'Attention Is All You Need' paper""" |
| |
|
| | def __init__( |
| | self, |
| | hidden_size: int, |
| | num_heads: int, |
| | position_embedding: str, |
| | rope_theta: float = 10000, |
| | max_position_embeddings: int = 8192, |
| | cache_config: Optional[CacheConfig] = None, |
| | quant_config: Optional[QuantizationConfig] = None, |
| | prefix: str = "", |
| | ): |
| | super().__init__() |
| | self.hidden_size = hidden_size |
| | tensor_model_parallel_world_size = get_tensor_model_parallel_world_size( |
| | ) |
| | self.total_num_heads = num_heads |
| | assert self.total_num_heads % tensor_model_parallel_world_size == 0 |
| | self.num_heads = (self.total_num_heads // |
| | tensor_model_parallel_world_size) |
| | self.head_dim = hidden_size // self.total_num_heads |
| | self.postion_embedding = position_embedding |
| | self.rope_theta = rope_theta |
| | self.max_position_embeddings = max_position_embeddings |
| |
|
| | |
| | self.W_pack = QKVParallelLinear( |
| | hidden_size, |
| | self.head_dim, |
| | self.total_num_heads, |
| | self.total_num_heads, |
| | bias=False, |
| | quant_config=quant_config, |
| | ) |
| | self.o_proj = RowParallelLinear( |
| | self.total_num_heads * self.head_dim, |
| | hidden_size, |
| | bias=False, |
| | quant_config=quant_config, |
| | ) |
| | |
| | if self.postion_embedding == "ALIBI": |
| | tp_rank = get_tensor_model_parallel_rank() |
| | head_start = tp_rank * self.num_heads |
| | head_end = (tp_rank + 1) * self.num_heads |
| | alibi_slopes = _get_alibi_slopes(self.total_num_heads) |
| | alibi_slopes = alibi_slopes[head_start:head_end].tolist() |
| |
|
| | scaling = self.head_dim**-0.5 |
| | self.attn = Attention(self.num_heads, |
| | self.head_dim, |
| | scaling, |
| | alibi_slopes=alibi_slopes, |
| | quant_config=quant_config, |
| | prefix=f"{prefix}.attn") |
| | else: |
| | self.rotary_emb = get_rope( |
| | self.head_dim, |
| | rotary_dim=self.head_dim, |
| | max_position=self.max_position_embeddings, |
| | base=self.rope_theta, |
| | ) |
| | self.scaling = self.head_dim**-0.5 |
| | self.attn = Attention(self.num_heads, |
| | self.head_dim, |
| | self.scaling, |
| | cache_config=cache_config, |
| | quant_config=quant_config, |
| | prefix=f"{prefix}.attn") |
| |
|
| | def forward( |
| | self, |
| | positions: torch.Tensor, |
| | hidden_states: torch.Tensor, |
| | ) -> torch.Tensor: |
| | qkv, _ = self.W_pack(hidden_states) |
| | q, k, v = qkv.chunk(chunks=3, dim=-1) |
| | if self.postion_embedding != "ALIBI": |
| | q, k = self.rotary_emb(positions, q, k) |
| | attn_output = self.attn(q, k, v) |
| | output, _ = self.o_proj(attn_output) |
| | return output |
| |
|
| |
|
| | class BaiChuanDecoderLayer(nn.Module): |
| |
|
| | def __init__(self, |
| | config: PretrainedConfig, |
| | position_embedding: str, |
| | cache_config: Optional[CacheConfig] = None, |
| | quant_config: Optional[QuantizationConfig] = None, |
| | prefix: str = ""): |
| | super().__init__() |
| | self.hidden_size = config.hidden_size |
| | rope_theta = getattr(config, "rope_theta", 10000) |
| | max_position_embeddings = getattr(config, "max_position_embeddings", |
| | 8192) |
| | self.self_attn = BaiChuanAttention( |
| | hidden_size=self.hidden_size, |
| | num_heads=config.num_attention_heads, |
| | position_embedding=position_embedding, |
| | rope_theta=rope_theta, |
| | max_position_embeddings=max_position_embeddings, |
| | cache_config=cache_config, |
| | quant_config=quant_config, |
| | prefix=f"{prefix}.self_attn", |
| | ) |
| | self.mlp = BaiChuanMLP( |
| | hidden_size=self.hidden_size, |
| | intermediate_size=config.intermediate_size, |
| | hidden_act=config.hidden_act, |
| | quant_config=quant_config, |
| | ) |
| | self.input_layernorm = RMSNorm(config.hidden_size, |
| | eps=config.rms_norm_eps) |
| | self.post_attention_layernorm = RMSNorm(config.hidden_size, |
| | eps=config.rms_norm_eps) |
| |
|
| | def forward( |
| | self, |
| | positions: torch.Tensor, |
| | hidden_states: torch.Tensor, |
| | residual: Optional[torch.Tensor], |
| | ) -> tuple[torch.Tensor, torch.Tensor]: |
| | |
| | if residual is None: |
| | residual = hidden_states |
| | hidden_states = self.input_layernorm(hidden_states) |
| | else: |
| | hidden_states, residual = self.input_layernorm( |
| | hidden_states, residual) |
| | hidden_states = self.self_attn( |
| | positions=positions, |
| | hidden_states=hidden_states, |
| | ) |
| |
|
| | |
| | hidden_states, residual = self.post_attention_layernorm( |
| | hidden_states, residual) |
| | hidden_states = self.mlp(hidden_states) |
| | return hidden_states, residual |
| |
|
| |
|
| | @support_torch_compile |
| | class BaiChuanModel(nn.Module): |
| |
|
| | def __init__( |
| | self, |
| | vllm_config: VllmConfig, |
| | prefix: str = "", |
| | position_embedding: str = "ROPE", |
| | ) -> None: |
| | super().__init__() |
| |
|
| | config = vllm_config.model_config.hf_config |
| | cache_config = vllm_config.cache_config |
| | quant_config = vllm_config.quant_config |
| |
|
| | self.config = config |
| | self.vocab_size = config.vocab_size |
| |
|
| | self.embed_tokens = VocabParallelEmbedding( |
| | config.vocab_size, |
| | config.hidden_size, |
| | ) |
| | self.start_layer, self.end_layer, self.layers = make_layers( |
| | config.num_hidden_layers, |
| | lambda prefix: BaiChuanDecoderLayer(config, |
| | position_embedding, |
| | cache_config, |
| | quant_config, |
| | prefix=prefix), |
| | prefix=f"{prefix}.layers", |
| | ) |
| | self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
| | self.make_empty_intermediate_tensors = ( |
| | make_empty_intermediate_tensors_factory( |
| | ["hidden_states", "residual"], config.hidden_size)) |
| |
|
| | def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: |
| | return self.embed_tokens(input_ids) |
| |
|
| | def forward( |
| | self, |
| | input_ids: torch.Tensor, |
| | positions: torch.Tensor, |
| | intermediate_tensors: Optional[IntermediateTensors], |
| | inputs_embeds: Optional[torch.Tensor] = None, |
| | ) -> Union[torch.Tensor, IntermediateTensors]: |
| | if get_pp_group().is_first_rank: |
| | if inputs_embeds is not None: |
| | hidden_states = inputs_embeds |
| | else: |
| | hidden_states = self.get_input_embeddings(input_ids) |
| | residual = None |
| | else: |
| | assert intermediate_tensors is not None |
| | hidden_states = intermediate_tensors["hidden_states"] |
| | residual = intermediate_tensors["residual"] |
| | for layer in self.layers[self.start_layer:self.end_layer]: |
| | hidden_states, residual = layer( |
| | positions, |
| | hidden_states, |
| | residual, |
| | ) |
| | if not get_pp_group().is_last_rank: |
| | return IntermediateTensors({ |
| | "hidden_states": hidden_states, |
| | "residual": residual, |
| | }) |
| | hidden_states, _ = self.norm(hidden_states, residual) |
| | return hidden_states |
| |
|
| | def load_weights(self, weights: Iterable[tuple[str, |
| | torch.Tensor]]) -> set[str]: |
| | stacked_params_mapping = [ |
| | |
| | ("gate_up_proj", "gate_proj", 0), |
| | ("gate_up_proj", "up_proj", 1), |
| | ] |
| | params_dict = dict(self.named_parameters()) |
| | loaded_params: set[str] = set() |
| | for name, loaded_weight in weights: |
| | if "rotary_emb.inv_freq" in name: |
| | continue |
| |
|
| | for (param_name, weight_name, shard_id) in stacked_params_mapping: |
| | if weight_name not in name: |
| | continue |
| | name = name.replace(weight_name, param_name) |
| | |
| | if name.endswith(".bias") and name not in params_dict: |
| | continue |
| | if is_pp_missing_parameter(name, self): |
| | continue |
| | param = params_dict[name] |
| | weight_loader = param.weight_loader |
| | weight_loader(param, loaded_weight, shard_id) |
| | break |
| | else: |
| | |
| | if name.endswith(".bias") and name not in params_dict: |
| | continue |
| | if is_pp_missing_parameter(name, self): |
| | continue |
| | param = params_dict[name] |
| | weight_loader = getattr(param, "weight_loader", |
| | default_weight_loader) |
| | weight_loader(param, loaded_weight) |
| | loaded_params.add(name) |
| | return loaded_params |
| |
|
| |
|
| | class BaiChuanBaseForCausalLM(nn.Module, SupportsLoRA, SupportsPP, |
| | SupportsQuant): |
| | packed_modules_mapping = { |
| | "W_pack": ["W_pack"], |
| | "gate_up_proj": [ |
| | "gate_proj", |
| | "up_proj", |
| | ], |
| | } |
| |
|
| | def __init__( |
| | self, |
| | *, |
| | vllm_config: VllmConfig, |
| | prefix: str = "", |
| | position_embedding: str = "ROPE", |
| | ): |
| | super().__init__() |
| | config = vllm_config.model_config.hf_config |
| | quant_config = vllm_config.quant_config |
| | lora_config = vllm_config.lora_config |
| | self.config = config |
| | self.lora_config = lora_config |
| | self.tp_size = get_tensor_model_parallel_world_size() |
| | self.quant_config = quant_config |
| | self.model = BaiChuanModel(vllm_config=vllm_config, |
| | prefix=prefix, |
| | position_embedding=position_embedding) |
| | self.lm_head = ParallelLMHead(config.vocab_size, |
| | config.hidden_size, |
| | quant_config=quant_config) |
| | self.lm_head.weight.weight_loader = self.lm_head_weight_loader |
| | if self.config.tie_word_embeddings: |
| | self.lm_head.weight = self.model.embed_tokens.weight |
| | self.logits_processor = LogitsProcessor(config.vocab_size) |
| | self.make_empty_intermediate_tensors = ( |
| | self.model.make_empty_intermediate_tensors) |
| |
|
| | def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: |
| | return self.model.get_input_embeddings(input_ids) |
| |
|
| | def forward( |
| | self, |
| | input_ids: torch.Tensor, |
| | positions: torch.Tensor, |
| | intermediate_tensors: Optional[IntermediateTensors] = None, |
| | inputs_embeds: Optional[torch.Tensor] = None, |
| | ) -> Union[torch.Tensor, IntermediateTensors]: |
| | hidden_states = self.model(input_ids, positions, intermediate_tensors, |
| | inputs_embeds) |
| | return hidden_states |
| |
|
| | def compute_logits( |
| | self, |
| | hidden_states: torch.Tensor, |
| | sampling_metadata: SamplingMetadata, |
| | ) -> Optional[torch.Tensor]: |
| | logits = self.logits_processor(self.lm_head, hidden_states, |
| | sampling_metadata) |
| | return logits |
| |
|
| | def load_weights(self, weights: Iterable[tuple[str, |
| | torch.Tensor]]) -> set[str]: |
| | loader = AutoWeightsLoader(self) |
| | return loader.load_weights(weights) |
| |
|
| | def lm_head_weight_loader(self, param: nn.Parameter, |
| | loaded_weight: torch.Tensor): |
| | |
| | |
| | |
| | |
| | |
| | |
| | is_baichuan2 = self.config.vocab_size == 125696 |
| | if is_baichuan2: |
| | loaded_weight = torch.nn.functional.normalize(loaded_weight) |
| | if self.tp_size > 1: |
| | row_parallel_weight_loader(param, loaded_weight) |
| | else: |
| | default_weight_loader(param, loaded_weight) |
| |
|
| |
|
| | class BaichuanForCausalLM(BaiChuanBaseForCausalLM): |
| | """Baichuan 13B and Baichuan2 7B/13B. |
| | NOTE: the class name has a lower case 'c'. |
| | """ |
| |
|
| | def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): |
| | config = vllm_config.model_config.hf_config |
| | if config.hidden_size == 4096: |
| | super().__init__(vllm_config=vllm_config, |
| | prefix=prefix, |
| | position_embedding="ROPE") |
| | else: |
| | super().__init__(vllm_config=vllm_config, |
| | prefix=prefix, |
| | position_embedding="ALIBI") |
| |
|
| |
|
| | class BaiChuanForCausalLM(BaiChuanBaseForCausalLM): |
| | """Baichuan 7B. |
| | NOTE: the class name has an upper case 'C'. |
| | """ |
| |
|
| | def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): |
| | super().__init__(vllm_config=vllm_config, |
| | prefix=prefix, |
| | position_embedding="ROPE") |
| |
|