from __future__ import annotations from typing import Iterable, Optional import torch import torch.nn as nn import torch.nn.functional as F from vllm.config import VllmConfig from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( ColumnParallelLinear, ReplicatedLinear, RowParallelLinear, ) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.utils import AutoWeightsLoader, WeightsMapper from vllm.distributed import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, ) def _build_rope_cos_sin( positions: torch.Tensor, d_head: int, device: torch.device, ) -> tuple[torch.Tensor, torch.Tensor]: js = torch.arange(d_head // 2, device=device, dtype=torch.float32) theta = 1.0 / (1024.0 ** (2.0 * js / d_head)) phi = positions.float().unsqueeze(-1) * theta.unsqueeze(0) cos = torch.cos(phi).repeat_interleave(2, dim=-1) sin = torch.sin(phi).repeat_interleave(2, dim=-1) return cos, sin def _apply_rope( x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, ) -> torch.Tensor: x_rot = torch.empty_like(x) x_rot[..., 0::2] = -x[..., 1::2] x_rot[..., 1::2] = x[..., 0::2] return (x * cos + x_rot * sin).to(x.dtype) class CloverLMAttention(nn.Module): def __init__( self, d: int, num_heads: int, num_kv_heads: int, head_dim: int, cache_config=None, quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() tp_size = get_tensor_model_parallel_world_size() tp_rank = get_tensor_model_parallel_rank() self.num_heads = num_heads // tp_size self.head_dim = head_dim self.q_size = self.num_heads * head_dim total_q_size = num_heads * head_dim total_kv_size = num_kv_heads * head_dim if num_kv_heads % tp_size == 0: self.num_kv_heads = num_kv_heads // tp_size kv_linear_cls = ColumnParallelLinear else: self.num_kv_heads = num_kv_heads kv_linear_cls = ReplicatedLinear self.kv_size = self.num_kv_heads * head_dim self.lq = ColumnParallelLinear( d, total_q_size, bias=False, quant_config=quant_config, prefix=f"{prefix}.lq", ) self.lk = kv_linear_cls( d, total_kv_size, bias=False, quant_config=quant_config, prefix=f"{prefix}.lk", ) self.lv = kv_linear_cls( d, total_kv_size, bias=False, quant_config=quant_config, prefix=f"{prefix}.lv", ) self.lo = RowParallelLinear( total_q_size, d, bias=False, quant_config=quant_config, prefix=f"{prefix}.lo", ) self.scale = nn.Parameter( torch.empty(1, self.num_heads, 1, 1), requires_grad=False, ) heads_per_tp = self.num_heads def _scale_weight_loader(param, loaded_weight): start = tp_rank * heads_per_tp end = start + heads_per_tp param.data.copy_(loaded_weight[:, start:end, :, :]) self.scale.weight_loader = _scale_weight_loader self.attn = Attention( num_heads=self.num_heads, head_size=head_dim, scale=1.0, num_kv_heads=self.num_kv_heads, cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.attn", ) def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, ) -> torch.Tensor: q, _ = self.lq(hidden_states) k, _ = self.lk(hidden_states) v, _ = self.lv(hidden_states) cos, sin = _build_rope_cos_sin( positions, self.head_dim, hidden_states.device, ) q = q.view(-1, self.num_heads, self.head_dim) k = k.view(-1, self.num_kv_heads, self.head_dim) q = _apply_rope(q, cos.unsqueeze(1), sin.unsqueeze(1)) k = _apply_rope(k, cos.unsqueeze(1), sin.unsqueeze(1)) q = F.normalize(q, dim=-1) k = F.normalize(k, dim=-1) # scale: (1, heads, 1, 1) → broadcast over (tokens, heads, head_dim) q = q * self.scale.squeeze(-1) q = q.reshape(-1, self.q_size) k = k.reshape(-1, self.kv_size) attn_output = self.attn(q, k, v) output, _ = self.lo(attn_output) return output class CloverLMMLP(nn.Module): def __init__( self, d: int, quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() d_hidden = 4 * d self.l1 = ColumnParallelLinear( d, d_hidden, bias=False, quant_config=quant_config, prefix=f"{prefix}.l1.0", ) self.l2 = RowParallelLinear( d_hidden, d, bias=False, quant_config=quant_config, prefix=f"{prefix}.l2", ) def forward(self, x: torch.Tensor) -> torch.Tensor: x, _ = self.l1(x) x = F.relu(x) ** 2 x, _ = self.l2(x) return x class CloverLMBlock(nn.Module): def __init__( self, d: int, num_heads: int, num_kv_heads: int, head_dim: int, cache_config=None, quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() self.mhsa = CloverLMAttention( d, num_heads, num_kv_heads, head_dim, cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.mhsa", ) self.out_att_norm = RMSNorm(d) self.mlp = CloverLMMLP( d, quant_config=quant_config, prefix=f"{prefix}.mlp", ) self.out_mlp_norm = RMSNorm(d) def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, ) -> torch.Tensor: # Post-norm attention residual attn_out = self.mhsa(positions, hidden_states) attn_out = self.out_att_norm(attn_out) hidden_states = hidden_states + attn_out # Post-norm MLP residual mlp_out = self.mlp(hidden_states) mlp_out = self.out_mlp_norm(mlp_out) hidden_states = hidden_states + mlp_out return hidden_states class CloverLMModel(nn.Module): def __init__( self, config, cache_config=None, quant_config: QuantizationConfig | None = None, prefix: str = "", ): super().__init__() self.config = config d = config.heads * config.d_head self.emb = VocabParallelEmbedding( config.vocab_size, d, quant_config=quant_config, prefix=f"{prefix}.emb", ) self.blocks = nn.ModuleList([ CloverLMBlock( d, config.heads, config.heads // config.ratio, config.d_head, cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.blocks.{i}", ) for i in range(config.num_blocks) ]) self.out_norm = RMSNorm(d) def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.emb(input_ids) def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, intermediate_tensors=None, inputs_embeds: torch.Tensor | None = None, ) -> torch.Tensor: if inputs_embeds is not None: hidden_states = inputs_embeds else: hidden_states = self.emb(input_ids) for block in self.blocks: hidden_states = block(positions, hidden_states) hidden_states = self.out_norm(hidden_states) return hidden_states _HF_TO_VLLM = WeightsMapper( orig_to_new_prefix={"transformer.": "model."}, ) class CloverLMForCausalLM_vLLM(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config d = config.heads * config.d_head self.config = config self.model = CloverLMModel( config, cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}model", ) self.lm_head = ParallelLMHead( config.vocab_size, d, bias=False, quant_config=quant_config, prefix=f"{prefix}lm_head", ) self.logits_processor = LogitsProcessor(config.vocab_size) if getattr(config, "weight_tying", True): self.lm_head.weight = self.model.emb.weight def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.embed_input_ids(input_ids) def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, intermediate_tensors=None, inputs_embeds: torch.Tensor | None = None, ) -> torch.Tensor: hidden_states = self.model( input_ids, positions, intermediate_tensors, inputs_embeds, ) return hidden_states def compute_logits( self, hidden_states: torch.Tensor, ) -> torch.Tensor | None: return self.logits_processor(self.lm_head, hidden_states) def load_weights( self, weights: Iterable[tuple[str, torch.Tensor]], ) -> set[str]: params_dict = dict(self.named_parameters(remove_duplicate=False)) loaded: set[str] = set() skip_prefixes = set() if getattr(self.config, "weight_tying", True): skip_prefixes.add("transformer.linear.weight") skipped = [] unmapped = [] for hf_name, loaded_weight in weights: if hf_name in skip_prefixes: skipped.append(hf_name) continue # Map HuggingFace names → vLLM names vllm_name = hf_name.replace("transformer.", "model.", 1) # In HuggingFace model, MLP l1 is Sequential(Linear, ReLU²), # so the linear weight is at "mlp.l1.0.weight". In our vLLM # model l1 is a flat ColumnParallelLinear → "mlp.l1.weight". vllm_name = vllm_name.replace(".mlp.l1.0.", ".mlp.l1.") if vllm_name not in params_dict: unmapped.append(f"{hf_name} -> {vllm_name}") continue param = params_dict[vllm_name] weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) loaded.add(vllm_name) not_loaded = set(params_dict.keys()) - loaded import logging logger = logging.getLogger(__name__) logger.info("Loaded %d/%d params, skipped %d, unmapped %d, " "not_loaded %d", len(loaded), len(params_dict), len(skipped), len(unmapped), len(not_loaded)) if unmapped: logger.warning("Unmapped HF keys: %s", unmapped) if not_loaded: logger.warning("Params not loaded: %s", sorted(not_loaded)) return loaded