test-api / vllm_modeling_embedder.py
JalalKhal's picture
test api
d86cecb verified
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import cast
import torch
import torch.nn as nn
import torch.nn.functional as F
from vllm.compilation.decorators import support_torch_compile # ty: ignore[unresolved-import]
from vllm.config import VllmConfig # ty: ignore[unresolved-import]
from vllm.model_executor.models.bert_with_rope import NomicBertModel # ty: ignore[unresolved-import]
from vllm.model_executor.models.interfaces_base import default_pooling_type # ty: ignore[unresolved-import]
from vllm.model_executor.models.utils import WeightsMapper # ty: ignore[unresolved-import]
class EncoderBlock(nn.Module):
def __init__(self, dim: int, hidden_dim: int, dropout: float):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.ReLU(),
)
self.norm = nn.LayerNorm(dim)
self.dropout = nn.Dropout(dropout)
self.proj = nn.Linear(hidden_dim, dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
residual = x
x = self.net(x)
x = self.dropout(x)
x = self.proj(x)
return cast(torch.Tensor, self.norm(x + residual))
class Head(nn.Module):
def __init__(self, dim: int, num_blocks: int = 1, dropout: float = 0):
super().__init__()
self.blocks = nn.Sequential(
*[EncoderBlock(dim=dim, hidden_dim=dim, dropout=dropout) for _ in range(num_blocks)]
)
self.proj = nn.Linear(dim, dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.blocks(x)
x = self.proj(x)
return x
@support_torch_compile
@default_pooling_type("CLS")
class EmbedderModel(nn.Module):
"""
vLLM wrapper for HF-trained EmbedderModel
(encoder + custom graph head)
"""
# HF state_dict keys start with "model."
hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})
def __init__(
self,
*,
vllm_config: VllmConfig,
prefix: str = "",
):
super().__init__()
self.hf_config = vllm_config.model_config.hf_config
# --------------------------------------------------
# Base encoder (identical to training)
# --------------------------------------------------
self.encoder = NomicBertModel(
vllm_config=vllm_config,
prefix=f"{prefix}.encoder",
add_pooling_layer=False,
)
# --------------------------------------------------
# Custom head (must match HF exactly)
# --------------------------------------------------
self.head = Head(
dim=self.hf_config.hidden_size,
num_blocks=self.hf_config.num_blocks,
dropout=self.hf_config.dropout,
)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.encoder.embed_input_ids(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: torch.Tensor | None = None,
inputs_embeds: torch.Tensor | None = None,
token_type_ids: torch.Tensor | None = None,
) -> torch.Tensor:
# vLLM manages attention & KV internally
hidden_states = self.encoder(
input_ids=input_ids,
positions=positions,
inputs_embeds=inputs_embeds,
token_type_ids=token_type_ids,
)
emb = hidden_states
if not self.hf_config.encoder_only:
# Head + normalize (same as HF)
emb = self.head(hidden_states)
emb = F.normalize(emb, dim=-1)
return emb