# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from typing import cast import torch import torch.nn as nn import torch.nn.functional as F from vllm.compilation.decorators import support_torch_compile # ty: ignore[unresolved-import] from vllm.config import VllmConfig # ty: ignore[unresolved-import] from vllm.model_executor.models.bert_with_rope import NomicBertModel # ty: ignore[unresolved-import] from vllm.model_executor.models.interfaces_base import default_pooling_type # ty: ignore[unresolved-import] from vllm.model_executor.models.utils import WeightsMapper # ty: ignore[unresolved-import] class EncoderBlock(nn.Module): def __init__(self, dim: int, hidden_dim: int, dropout: float): super().__init__() self.net = nn.Sequential( nn.Linear(dim, hidden_dim), nn.ReLU(), ) self.norm = nn.LayerNorm(dim) self.dropout = nn.Dropout(dropout) self.proj = nn.Linear(hidden_dim, dim) def forward(self, x: torch.Tensor) -> torch.Tensor: residual = x x = self.net(x) x = self.dropout(x) x = self.proj(x) return cast(torch.Tensor, self.norm(x + residual)) class Head(nn.Module): def __init__(self, dim: int, num_blocks: int = 1, dropout: float = 0): super().__init__() self.blocks = nn.Sequential( *[EncoderBlock(dim=dim, hidden_dim=dim, dropout=dropout) for _ in range(num_blocks)] ) self.proj = nn.Linear(dim, dim) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.blocks(x) x = self.proj(x) return x @support_torch_compile @default_pooling_type("CLS") class EmbedderModel(nn.Module): """ vLLM wrapper for HF-trained EmbedderModel (encoder + custom graph head) """ # HF state_dict keys start with "model." hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""}) def __init__( self, *, vllm_config: VllmConfig, prefix: str = "", ): super().__init__() self.hf_config = vllm_config.model_config.hf_config # -------------------------------------------------- # Base encoder (identical to training) # -------------------------------------------------- self.encoder = NomicBertModel( vllm_config=vllm_config, prefix=f"{prefix}.encoder", add_pooling_layer=False, ) # -------------------------------------------------- # Custom head (must match HF exactly) # -------------------------------------------------- self.head = Head( dim=self.hf_config.hidden_size, num_blocks=self.hf_config.num_blocks, dropout=self.hf_config.dropout, ) def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.encoder.embed_input_ids(input_ids) def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, intermediate_tensors: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None, token_type_ids: torch.Tensor | None = None, ) -> torch.Tensor: # vLLM manages attention & KV internally hidden_states = self.encoder( input_ids=input_ids, positions=positions, inputs_embeds=inputs_embeds, token_type_ids=token_type_ids, ) emb = hidden_states if not self.hf_config.encoder_only: # Head + normalize (same as HF) emb = self.head(hidden_states) emb = F.normalize(emb, dim=-1) return emb