from typing import Any, Optional, cast import torch import torch.nn as nn import torch.nn.functional as F from torch import nn from transformers import AutoModel, PreTrainedModel from transformers.modeling_outputs import BaseModelOutput from .configuration_embedder import EmbedderConfig class EncoderBlock(nn.Module): def __init__(self, dim: int, hidden_dim: int, dropout: float): super().__init__() self.net = nn.Sequential( nn.Linear(dim, hidden_dim), nn.ReLU(), ) self.norm = nn.LayerNorm(dim) self.dropout = nn.Dropout(dropout) self.proj = nn.Linear(hidden_dim, dim) self.relu = nn.ReLU() def forward(self, x: torch.Tensor) -> torch.Tensor: residual = x x = self.net(x) x = self.dropout(x) x = self.relu(self.proj(x)) return cast(torch.Tensor, self.norm(x + residual)) class Head(nn.Module): def __init__(self, dim: int, num_blocks: int = 1, dropout: float = 0): super().__init__() self.blocks = nn.Sequential( *[EncoderBlock(dim=dim, hidden_dim=dim, dropout=dropout) for _ in range(num_blocks)] ) self.proj = nn.Linear(dim, dim) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.blocks(x) x = self.proj(x) return x class EmbedderModel(PreTrainedModel): config_class = EmbedderConfig # type: ignore[assignment] base_model_prefix = "model" _supports_attention_backend = True def __init__(self, config: EmbedderConfig): super().__init__(config) self.encoder = AutoModel.from_config( config.encoder_config, trust_remote_code=True, ) self._init_requires_grad(self.encoder) self.head = Head( dim=self.encoder.embeddings.word_embeddings.embedding_dim, num_blocks=config.num_blocks, dropout=config.dropout, ) def _init_requires_grad(self, module: nn.Module) -> None: for p in module.parameters(): p.requires_grad = False def forward( self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs: Any ) -> BaseModelOutput: hidden_states = self.encoder(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state if not self.config.encoder_only: emb = self.head(hidden_states) # B, T, D return BaseModelOutput(last_hidden_state=emb) EmbedderModel.register_for_auto_class()