File size: 3,724 Bytes
d86cecb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import cast

import torch
import torch.nn as nn
import torch.nn.functional as F
from vllm.compilation.decorators import support_torch_compile  # ty: ignore[unresolved-import]
from vllm.config import VllmConfig  # ty: ignore[unresolved-import]
from vllm.model_executor.models.bert_with_rope import NomicBertModel  # ty: ignore[unresolved-import]
from vllm.model_executor.models.interfaces_base import default_pooling_type  # ty: ignore[unresolved-import]
from vllm.model_executor.models.utils import WeightsMapper  # ty: ignore[unresolved-import]


class EncoderBlock(nn.Module):
    def __init__(self, dim: int, hidden_dim: int, dropout: float):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.ReLU(),
        )
        self.norm = nn.LayerNorm(dim)
        self.dropout = nn.Dropout(dropout)
        self.proj = nn.Linear(hidden_dim, dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        residual = x
        x = self.net(x)
        x = self.dropout(x)
        x = self.proj(x)
        return cast(torch.Tensor, self.norm(x + residual))


class Head(nn.Module):
    def __init__(self, dim: int, num_blocks: int = 1, dropout: float = 0):
        super().__init__()
        self.blocks = nn.Sequential(
            *[EncoderBlock(dim=dim, hidden_dim=dim, dropout=dropout) for _ in range(num_blocks)]
        )
        self.proj = nn.Linear(dim, dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.blocks(x)
        x = self.proj(x)
        return x


@support_torch_compile
@default_pooling_type("CLS")
class EmbedderModel(nn.Module):
    """
    vLLM wrapper for HF-trained EmbedderModel
    (encoder + custom graph head)
    """

    # HF state_dict keys start with "model."
    hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""})

    def __init__(
        self,
        *,
        vllm_config: VllmConfig,
        prefix: str = "",
    ):
        super().__init__()
        self.hf_config = vllm_config.model_config.hf_config
        # --------------------------------------------------
        # Base encoder (identical to training)
        # --------------------------------------------------
        self.encoder = NomicBertModel(
            vllm_config=vllm_config,
            prefix=f"{prefix}.encoder",
            add_pooling_layer=False,
        )
        # --------------------------------------------------
        # Custom head (must match HF exactly)
        # --------------------------------------------------
        self.head = Head(
            dim=self.hf_config.hidden_size,
            num_blocks=self.hf_config.num_blocks,
            dropout=self.hf_config.dropout,
        )

    def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
        return self.encoder.embed_input_ids(input_ids)

    def forward(
        self,
        input_ids: torch.Tensor,
        positions: torch.Tensor,
        intermediate_tensors: torch.Tensor | None = None,
        inputs_embeds: torch.Tensor | None = None,
        token_type_ids: torch.Tensor | None = None,
    ) -> torch.Tensor:
        # vLLM manages attention & KV internally
        hidden_states = self.encoder(
            input_ids=input_ids,
            positions=positions,
            inputs_embeds=inputs_embeds,
            token_type_ids=token_type_ids,
        )
        emb = hidden_states
        if not self.hf_config.encoder_only:
            # Head + normalize (same as HF)
            emb = self.head(hidden_states)
        emb = F.normalize(emb, dim=-1)
        return emb