|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Inference-only Deepseek model.""" |
|
|
from collections.abc import Iterable |
|
|
from typing import Any, Optional, Union |
|
|
|
|
|
import torch |
|
|
from torch import nn |
|
|
from transformers import PretrainedConfig |
|
|
|
|
|
from vllm.attention import Attention |
|
|
from vllm.config import CacheConfig, VllmConfig |
|
|
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank, |
|
|
get_tensor_model_parallel_world_size, |
|
|
tensor_model_parallel_all_reduce) |
|
|
from vllm.model_executor.layers.activation import SiluAndMul |
|
|
from vllm.model_executor.layers.fused_moe import fused_moe |
|
|
from vllm.model_executor.layers.layernorm import RMSNorm |
|
|
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, |
|
|
QKVParallelLinear, |
|
|
ReplicatedLinear, |
|
|
RowParallelLinear) |
|
|
from vllm.model_executor.layers.logits_processor import LogitsProcessor |
|
|
from vllm.model_executor.layers.quantization import QuantizationConfig |
|
|
from vllm.model_executor.layers.rotary_embedding import get_rope |
|
|
from vllm.model_executor.layers.vocab_parallel_embedding import ( |
|
|
ParallelLMHead, VocabParallelEmbedding) |
|
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader |
|
|
from vllm.model_executor.sampling_metadata import SamplingMetadata |
|
|
from vllm.sequence import IntermediateTensors |
|
|
|
|
|
from .interfaces import SupportsPP |
|
|
from .utils import (AutoWeightsLoader, extract_layer_index, |
|
|
is_pp_missing_parameter, |
|
|
make_empty_intermediate_tensors_factory, make_layers, |
|
|
maybe_prefix) |
|
|
|
|
|
|
|
|
class DeepseekMLP(nn.Module): |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
hidden_size: int, |
|
|
intermediate_size: int, |
|
|
hidden_act: str, |
|
|
quant_config: Optional[QuantizationConfig] = None, |
|
|
reduce_results: bool = True, |
|
|
prefix: str = "", |
|
|
) -> None: |
|
|
super().__init__() |
|
|
self.gate_up_proj = MergedColumnParallelLinear( |
|
|
hidden_size, [intermediate_size] * 2, |
|
|
bias=False, |
|
|
quant_config=quant_config) |
|
|
self.down_proj = RowParallelLinear(intermediate_size, |
|
|
hidden_size, |
|
|
bias=False, |
|
|
quant_config=quant_config, |
|
|
reduce_results=reduce_results) |
|
|
if hidden_act != "silu": |
|
|
raise ValueError(f"Unsupported activation: {hidden_act}. " |
|
|
"Only silu is supported for now.") |
|
|
self.act_fn = SiluAndMul() |
|
|
|
|
|
def forward(self, x): |
|
|
gate_up, _ = self.gate_up_proj(x) |
|
|
x = self.act_fn(gate_up) |
|
|
x, _ = self.down_proj(x) |
|
|
return x |
|
|
|
|
|
|
|
|
class DeepseekMoE(nn.Module): |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
config: PretrainedConfig, |
|
|
quant_config: Optional[QuantizationConfig] = None, |
|
|
prefix: str = "", |
|
|
): |
|
|
super().__init__() |
|
|
self.config = config |
|
|
self.rank = get_tensor_model_parallel_rank() |
|
|
self.tp_size = get_tensor_model_parallel_world_size() |
|
|
self.n_routed_experts = config.n_routed_experts |
|
|
self.top_k = config.num_experts_per_tok |
|
|
if self.tp_size > self.n_routed_experts: |
|
|
raise ValueError( |
|
|
f"Tensor parallel size {self.tp_size} is greater than " |
|
|
f"the number of experts {self.n_routed_experts}.") |
|
|
|
|
|
self.experts = nn.ModuleList([ |
|
|
DeepseekMLP(hidden_size=config.hidden_size, |
|
|
intermediate_size=config.moe_intermediate_size, |
|
|
hidden_act=config.hidden_act, |
|
|
quant_config=quant_config, |
|
|
reduce_results=False) |
|
|
for idx in range(self.n_routed_experts) |
|
|
]) |
|
|
self.pack_params() |
|
|
|
|
|
self.gate = ReplicatedLinear(config.hidden_size, |
|
|
self.n_routed_experts, |
|
|
bias=False, |
|
|
quant_config=None) |
|
|
|
|
|
if config.n_shared_experts is not None: |
|
|
intermediate_size = (config.moe_intermediate_size * |
|
|
config.n_shared_experts) |
|
|
self.shared_experts = DeepseekMLP( |
|
|
hidden_size=config.hidden_size, |
|
|
intermediate_size=intermediate_size, |
|
|
hidden_act=config.hidden_act, |
|
|
quant_config=quant_config, |
|
|
reduce_results=False, |
|
|
) |
|
|
|
|
|
def pack_params(self): |
|
|
w1 = [] |
|
|
w2 = [] |
|
|
for expert in self.experts: |
|
|
w1.append(expert.gate_up_proj.weight) |
|
|
w2.append(expert.down_proj.weight) |
|
|
self.w1 = torch._utils._flatten_dense_tensors(w1) |
|
|
w1s = torch._utils._unflatten_dense_tensors(self.w1, w1) |
|
|
for data, param in zip(w1s, w1): |
|
|
param.data = data |
|
|
self.w1 = self.w1.view(len(w1), *w1s[0].shape) |
|
|
|
|
|
self.w2 = torch._utils._flatten_dense_tensors(w2) |
|
|
w2s = torch._utils._unflatten_dense_tensors(self.w2, w2) |
|
|
for data, param in zip(w2s, w2): |
|
|
param.data = data |
|
|
|
|
|
self.w2 = self.w2.view(len(w2), *w2s[0].shape) |
|
|
|
|
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
|
|
num_tokens, hidden_dim = hidden_states.shape |
|
|
hidden_states = hidden_states.view(-1, hidden_dim) |
|
|
if self.config.n_shared_experts is not None: |
|
|
shared_output = self.shared_experts(hidden_states) |
|
|
|
|
|
router_logits, _ = self.gate(hidden_states) |
|
|
final_hidden_states = fused_moe(hidden_states, |
|
|
self.w1, |
|
|
self.w2, |
|
|
router_logits, |
|
|
self.top_k, |
|
|
renormalize=self.config.norm_topk_prob, |
|
|
inplace=True) |
|
|
|
|
|
if self.config.n_shared_experts is not None: |
|
|
final_hidden_states = final_hidden_states + shared_output |
|
|
final_hidden_states = tensor_model_parallel_all_reduce( |
|
|
final_hidden_states) |
|
|
|
|
|
return final_hidden_states.view(num_tokens, hidden_dim) |
|
|
|
|
|
|
|
|
class DeepseekAttention(nn.Module): |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
hidden_size: int, |
|
|
num_heads: int, |
|
|
num_kv_heads: int, |
|
|
rope_theta: float = 10000, |
|
|
rope_scaling: Optional[dict[str, Any]] = None, |
|
|
max_position_embeddings: int = 8192, |
|
|
cache_config: Optional[CacheConfig] = None, |
|
|
quant_config: Optional[QuantizationConfig] = None, |
|
|
prefix: str = "", |
|
|
) -> None: |
|
|
super().__init__() |
|
|
self.hidden_size = hidden_size |
|
|
tp_size = get_tensor_model_parallel_world_size() |
|
|
self.total_num_heads = num_heads |
|
|
assert self.total_num_heads % tp_size == 0 |
|
|
self.num_heads = self.total_num_heads // tp_size |
|
|
self.total_num_kv_heads = num_kv_heads |
|
|
if self.total_num_kv_heads >= tp_size: |
|
|
|
|
|
|
|
|
assert self.total_num_kv_heads % tp_size == 0 |
|
|
else: |
|
|
|
|
|
|
|
|
assert tp_size % self.total_num_kv_heads == 0 |
|
|
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) |
|
|
self.head_dim = hidden_size // self.total_num_heads |
|
|
self.q_size = self.num_heads * self.head_dim |
|
|
self.kv_size = self.num_kv_heads * self.head_dim |
|
|
self.scaling = self.head_dim**-0.5 |
|
|
self.rope_theta = rope_theta |
|
|
self.max_position_embeddings = max_position_embeddings |
|
|
|
|
|
self.qkv_proj = QKVParallelLinear( |
|
|
hidden_size, |
|
|
self.head_dim, |
|
|
self.total_num_heads, |
|
|
self.total_num_kv_heads, |
|
|
bias=False, |
|
|
quant_config=quant_config, |
|
|
) |
|
|
|
|
|
self.o_proj = RowParallelLinear( |
|
|
self.total_num_heads * self.head_dim, |
|
|
hidden_size, |
|
|
bias=False, |
|
|
quant_config=quant_config, |
|
|
) |
|
|
|
|
|
self.rotary_emb = get_rope( |
|
|
self.head_dim, |
|
|
rotary_dim=self.head_dim, |
|
|
max_position=max_position_embeddings, |
|
|
base=rope_theta, |
|
|
rope_scaling=rope_scaling, |
|
|
) |
|
|
self.attn = Attention(self.num_heads, |
|
|
self.head_dim, |
|
|
self.scaling, |
|
|
num_kv_heads=self.num_kv_heads, |
|
|
cache_config=cache_config, |
|
|
quant_config=quant_config, |
|
|
prefix=f"{prefix}.attn") |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
positions: torch.Tensor, |
|
|
hidden_states: torch.Tensor, |
|
|
) -> torch.Tensor: |
|
|
qkv, _ = self.qkv_proj(hidden_states) |
|
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) |
|
|
q, k = self.rotary_emb(positions, q, k) |
|
|
attn_output = self.attn(q, k, v) |
|
|
output, _ = self.o_proj(attn_output) |
|
|
return output |
|
|
|
|
|
|
|
|
class DeepseekDecoderLayer(nn.Module): |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
config: PretrainedConfig, |
|
|
cache_config: Optional[CacheConfig] = None, |
|
|
quant_config: Optional[QuantizationConfig] = None, |
|
|
prefix: str = "", |
|
|
) -> None: |
|
|
super().__init__() |
|
|
layer_idx = extract_layer_index(prefix) |
|
|
self.hidden_size = config.hidden_size |
|
|
rope_theta = getattr(config, "rope_theta", 10000) |
|
|
rope_scaling = getattr(config, "rope_scaling", None) |
|
|
max_position_embeddings = getattr(config, "max_position_embeddings", |
|
|
8192) |
|
|
self.self_attn = DeepseekAttention( |
|
|
hidden_size=self.hidden_size, |
|
|
num_heads=config.num_attention_heads, |
|
|
num_kv_heads=config.num_key_value_heads, |
|
|
rope_theta=rope_theta, |
|
|
rope_scaling=rope_scaling, |
|
|
max_position_embeddings=max_position_embeddings, |
|
|
cache_config=cache_config, |
|
|
quant_config=quant_config, |
|
|
prefix=f"{prefix}.self_attn", |
|
|
) |
|
|
if (config.n_routed_experts is not None |
|
|
and layer_idx >= config.first_k_dense_replace |
|
|
and layer_idx % config.moe_layer_freq == 0): |
|
|
self.mlp = DeepseekMoE(config=config, |
|
|
quant_config=quant_config, |
|
|
prefix=f"{prefix}.mlp") |
|
|
else: |
|
|
self.mlp = DeepseekMLP( |
|
|
hidden_size=config.hidden_size, |
|
|
intermediate_size=config.intermediate_size, |
|
|
hidden_act=config.hidden_act, |
|
|
quant_config=quant_config, |
|
|
prefix=f"{prefix}.mlp", |
|
|
) |
|
|
self.input_layernorm = RMSNorm(config.hidden_size, |
|
|
eps=config.rms_norm_eps) |
|
|
self.post_attention_layernorm = RMSNorm(config.hidden_size, |
|
|
eps=config.rms_norm_eps) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
positions: torch.Tensor, |
|
|
hidden_states: torch.Tensor, |
|
|
residual: Optional[torch.Tensor], |
|
|
) -> torch.Tensor: |
|
|
|
|
|
if residual is None: |
|
|
residual = hidden_states |
|
|
hidden_states = self.input_layernorm(hidden_states) |
|
|
else: |
|
|
hidden_states, residual = self.input_layernorm( |
|
|
hidden_states, residual) |
|
|
hidden_states = self.self_attn( |
|
|
positions=positions, |
|
|
hidden_states=hidden_states, |
|
|
) |
|
|
|
|
|
|
|
|
hidden_states, residual = self.post_attention_layernorm( |
|
|
hidden_states, residual) |
|
|
hidden_states = self.mlp(hidden_states) |
|
|
return hidden_states, residual |
|
|
|
|
|
|
|
|
class DeepseekModel(nn.Module): |
|
|
|
|
|
fall_back_to_pt_during_load = False |
|
|
|
|
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): |
|
|
super().__init__() |
|
|
|
|
|
config = vllm_config.model_config.hf_config |
|
|
cache_config = vllm_config.cache_config |
|
|
quant_config = vllm_config.quant_config |
|
|
|
|
|
self.vocab_size = config.vocab_size |
|
|
|
|
|
self.embed_tokens = VocabParallelEmbedding( |
|
|
config.vocab_size, |
|
|
config.hidden_size, |
|
|
) |
|
|
self.start_layer, self.end_layer, self.layers = make_layers( |
|
|
config.num_hidden_layers, |
|
|
lambda prefix: DeepseekDecoderLayer( |
|
|
config, cache_config, quant_config=quant_config, prefix=prefix |
|
|
), |
|
|
prefix=f"{prefix}.layers") |
|
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) |
|
|
self.make_empty_intermediate_tensors = ( |
|
|
make_empty_intermediate_tensors_factory( |
|
|
["hidden_states", "residual"], config.hidden_size)) |
|
|
|
|
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: |
|
|
return self.embed_tokens(input_ids) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: torch.Tensor, |
|
|
positions: torch.Tensor, |
|
|
intermediate_tensors: Optional[IntermediateTensors], |
|
|
inputs_embeds: Optional[torch.Tensor] = None, |
|
|
) -> Union[torch.Tensor, IntermediateTensors]: |
|
|
if get_pp_group().is_first_rank: |
|
|
if inputs_embeds is not None: |
|
|
hidden_states = inputs_embeds |
|
|
else: |
|
|
hidden_states = self.get_input_embeddings(input_ids) |
|
|
residual = None |
|
|
else: |
|
|
hidden_states = intermediate_tensors["hidden_states"] |
|
|
residual = intermediate_tensors["residual"] |
|
|
for layer in self.layers[self.start_layer:self.end_layer]: |
|
|
hidden_states, residual = layer(positions, hidden_states, residual) |
|
|
if not get_pp_group().is_last_rank: |
|
|
return IntermediateTensors({ |
|
|
"hidden_states": hidden_states, |
|
|
"residual": residual |
|
|
}) |
|
|
hidden_states, _ = self.norm(hidden_states, residual) |
|
|
return hidden_states |
|
|
|
|
|
def load_weights(self, weights: Iterable[tuple[str, |
|
|
torch.Tensor]]) -> set[str]: |
|
|
stacked_params_mapping = [ |
|
|
|
|
|
("qkv_proj", "q_proj", "q"), |
|
|
("qkv_proj", "k_proj", "k"), |
|
|
("qkv_proj", "v_proj", "v"), |
|
|
("gate_up_proj", "gate_proj", 0), |
|
|
("gate_up_proj", "up_proj", 1), |
|
|
] |
|
|
|
|
|
params_dict = dict(self.named_parameters()) |
|
|
loaded_params: set[str] = set() |
|
|
for name, loaded_weight in weights: |
|
|
if "rotary_emb.inv_freq" in name: |
|
|
continue |
|
|
for (param_name, weight_name, shard_id) in stacked_params_mapping: |
|
|
if weight_name not in name: |
|
|
continue |
|
|
name = name.replace(weight_name, param_name) |
|
|
|
|
|
if name.endswith(".bias") and name not in params_dict: |
|
|
continue |
|
|
|
|
|
if (("mlp.experts." in name or "mlp.shared_experts." in name) |
|
|
and name not in params_dict): |
|
|
continue |
|
|
if is_pp_missing_parameter(name, self): |
|
|
continue |
|
|
param = params_dict[name] |
|
|
weight_loader = param.weight_loader |
|
|
weight_loader(param, loaded_weight, shard_id) |
|
|
break |
|
|
else: |
|
|
|
|
|
if name.endswith(".bias") and name not in params_dict: |
|
|
continue |
|
|
|
|
|
if (("mlp.experts." in name or "mlp.shared_experts." in name) |
|
|
and name not in params_dict): |
|
|
continue |
|
|
if is_pp_missing_parameter(name, self): |
|
|
continue |
|
|
param = params_dict[name] |
|
|
weight_loader = getattr(param, "weight_loader", |
|
|
default_weight_loader) |
|
|
weight_loader(param, loaded_weight) |
|
|
loaded_params.add(name) |
|
|
return loaded_params |
|
|
|
|
|
|
|
|
class DeepseekForCausalLM(nn.Module, SupportsPP): |
|
|
|
|
|
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): |
|
|
super().__init__() |
|
|
config = vllm_config.model_config.hf_config |
|
|
quant_config = vllm_config.quant_config |
|
|
self.config = config |
|
|
self.quant_config = quant_config |
|
|
self.model = DeepseekModel(vllm_config=vllm_config, |
|
|
prefix=maybe_prefix(prefix, "model")) |
|
|
self.lm_head = ParallelLMHead(config.vocab_size, |
|
|
config.hidden_size, |
|
|
quant_config=quant_config) |
|
|
if self.config.tie_word_embeddings: |
|
|
self.lm_head.weight = self.model.embed_tokens.weight |
|
|
self.logits_processor = LogitsProcessor(config.vocab_size) |
|
|
self.make_empty_intermediate_tensors = ( |
|
|
self.model.make_empty_intermediate_tensors) |
|
|
|
|
|
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: |
|
|
return self.model.get_input_embeddings(input_ids) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: torch.Tensor, |
|
|
positions: torch.Tensor, |
|
|
intermediate_tensors: Optional[IntermediateTensors] = None, |
|
|
inputs_embeds: Optional[torch.Tensor] = None, |
|
|
) -> Union[torch.Tensor, IntermediateTensors]: |
|
|
hidden_states = self.model(input_ids, positions, intermediate_tensors, |
|
|
inputs_embeds) |
|
|
return hidden_states |
|
|
|
|
|
def compute_logits( |
|
|
self, |
|
|
hidden_states: torch.Tensor, |
|
|
sampling_metadata: SamplingMetadata, |
|
|
) -> Optional[torch.Tensor]: |
|
|
logits = self.logits_processor(self.lm_head, hidden_states, |
|
|
sampling_metadata) |
|
|
return logits |
|
|
|
|
|
def load_weights(self, weights: Iterable[tuple[str, |
|
|
torch.Tensor]]) -> set[str]: |
|
|
loader = AutoWeightsLoader(self) |
|
|
return loader.load_weights(weights) |