test / modeling_embedder.py
JalalKhal's picture
hf model and vllm integration files
d12d1be verified
from typing import Any, Optional, cast
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import nn
from transformers import AutoModel, PreTrainedModel
from transformers.modeling_outputs import BaseModelOutput
from .configuration_embedder import EmbedderConfig
class EncoderBlock(nn.Module):
def __init__(self, dim: int, hidden_dim: int, dropout: float):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, hidden_dim),
nn.ReLU(),
)
self.norm = nn.LayerNorm(dim)
self.dropout = nn.Dropout(dropout)
self.proj = nn.Linear(hidden_dim, dim)
self.relu = nn.ReLU()
def forward(self, x: torch.Tensor) -> torch.Tensor:
residual = x
x = self.net(x)
x = self.dropout(x)
x = self.relu(self.proj(x))
return cast(torch.Tensor, self.norm(x + residual))
class Head(nn.Module):
def __init__(self, dim: int, num_blocks: int = 1, dropout: float = 0):
super().__init__()
self.blocks = nn.Sequential(
*[EncoderBlock(dim=dim, hidden_dim=dim, dropout=dropout) for _ in range(num_blocks)]
)
self.proj = nn.Linear(dim, dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.blocks(x)
x = self.proj(x)
return x
class EmbedderModel(PreTrainedModel):
config_class = EmbedderConfig # type: ignore[assignment]
base_model_prefix = "model"
_supports_attention_backend = True
def __init__(self, config: EmbedderConfig):
super().__init__(config)
self.encoder = AutoModel.from_config(
config.encoder_config,
trust_remote_code=True,
)
self._init_requires_grad(self.encoder)
self.head = Head(
dim=self.encoder.embeddings.word_embeddings.embedding_dim,
num_blocks=config.num_blocks,
dropout=config.dropout,
)
def _init_requires_grad(self, module: nn.Module) -> None:
for p in module.parameters():
p.requires_grad = False
def forward(
self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs: Any
) -> BaseModelOutput:
hidden_states = self.encoder(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
if not self.config.encoder_only:
emb = self.head(hidden_states) # B, T, D
return BaseModelOutput(last_hidden_state=emb)
EmbedderModel.register_for_auto_class()