|
|
from typing import Any, Optional, cast |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from torch import nn |
|
|
from transformers import AutoModel, PreTrainedModel |
|
|
from transformers.modeling_outputs import BaseModelOutput |
|
|
|
|
|
from .configuration_embedder import EmbedderConfig |
|
|
|
|
|
|
|
|
class EncoderBlock(nn.Module): |
|
|
def __init__(self, dim: int, hidden_dim: int, dropout: float): |
|
|
super().__init__() |
|
|
self.net = nn.Sequential( |
|
|
nn.Linear(dim, hidden_dim), |
|
|
nn.ReLU(), |
|
|
) |
|
|
self.norm = nn.LayerNorm(dim) |
|
|
self.dropout = nn.Dropout(dropout) |
|
|
self.proj = nn.Linear(hidden_dim, dim) |
|
|
self.relu = nn.ReLU() |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
residual = x |
|
|
x = self.net(x) |
|
|
x = self.dropout(x) |
|
|
x = self.relu(self.proj(x)) |
|
|
return cast(torch.Tensor, self.norm(x + residual)) |
|
|
|
|
|
|
|
|
class Head(nn.Module): |
|
|
def __init__(self, dim: int, num_blocks: int = 1, dropout: float = 0): |
|
|
super().__init__() |
|
|
self.blocks = nn.Sequential( |
|
|
*[EncoderBlock(dim=dim, hidden_dim=dim, dropout=dropout) for _ in range(num_blocks)] |
|
|
) |
|
|
self.proj = nn.Linear(dim, dim) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
x = self.blocks(x) |
|
|
x = self.proj(x) |
|
|
return x |
|
|
|
|
|
|
|
|
class EmbedderModel(PreTrainedModel): |
|
|
config_class = EmbedderConfig |
|
|
base_model_prefix = "model" |
|
|
_supports_attention_backend = True |
|
|
|
|
|
def __init__(self, config: EmbedderConfig): |
|
|
super().__init__(config) |
|
|
self.encoder = AutoModel.from_config( |
|
|
config.encoder_config, |
|
|
trust_remote_code=True, |
|
|
) |
|
|
self._init_requires_grad(self.encoder) |
|
|
self.head = Head( |
|
|
dim=self.encoder.embeddings.word_embeddings.embedding_dim, |
|
|
num_blocks=config.num_blocks, |
|
|
dropout=config.dropout, |
|
|
) |
|
|
|
|
|
def _init_requires_grad(self, module: nn.Module) -> None: |
|
|
for p in module.parameters(): |
|
|
p.requires_grad = False |
|
|
|
|
|
def forward( |
|
|
self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs: Any |
|
|
) -> BaseModelOutput: |
|
|
hidden_states = self.encoder(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state |
|
|
if not self.config.encoder_only: |
|
|
emb = self.head(hidden_states) |
|
|
return BaseModelOutput(last_hidden_state=emb) |
|
|
|
|
|
|
|
|
EmbedderModel.register_for_auto_class() |
|
|
|