| from typing import Any, Optional, cast |
|
|
| import torch |
| import torch.nn as nn |
| from torch import nn |
| from transformers import AutoModel, PreTrainedModel |
| from transformers.modeling_outputs import BaseModelOutput |
|
|
| from .configuration_embedder import EmbedderConfig |
|
|
|
|
| class EncoderBlock(nn.Module): |
| def __init__(self, dim: int, hidden_dim: int, dropout: float): |
| super().__init__() |
| self.net = nn.Sequential( |
| nn.Linear(dim, hidden_dim), |
| nn.ReLU(), |
| ) |
| self.norm = nn.LayerNorm(dim) |
| self.dropout = nn.Dropout(dropout) |
| self.proj = nn.Linear(hidden_dim, dim) |
| self.relu = nn.ReLU() |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| residual = x |
| x = self.net(x) |
| x = self.dropout(x) |
| x = self.relu(self.proj(x)) |
| return cast(torch.Tensor, self.norm(x + residual)) |
|
|
|
|
| class Head(nn.Module): |
| def __init__(self, dim: int, num_blocks: int = 1, dropout: float = 0): |
| super().__init__() |
| self.blocks = nn.Sequential( |
| *[EncoderBlock(dim=dim, hidden_dim=dim, dropout=dropout) for _ in range(num_blocks)] |
| ) |
| self.proj = nn.Linear(dim, dim) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| x = self.blocks(x) |
| x = self.proj(x) |
| return x |
|
|
|
|
| class EmbedderModel(PreTrainedModel): |
| config_class = EmbedderConfig |
| base_model_prefix = "model" |
| _supports_attention_backend = True |
|
|
| def __init__(self, config: EmbedderConfig): |
| super().__init__(config) |
| self.encoder = AutoModel.from_config( |
| config.encoder_config, |
| trust_remote_code=True, |
| ) |
| self._init_requires_grad(self.encoder) |
| self.head = Head( |
| dim=self.encoder.embeddings.word_embeddings.embedding_dim, |
| num_blocks=config.num_blocks, |
| dropout=config.dropout, |
| ) |
| self.post_init() |
|
|
| def _init_requires_grad(self, module: nn.Module) -> None: |
| for p in module.parameters(): |
| p.requires_grad = False |
|
|
| def forward( |
| self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs: Any |
| ) -> BaseModelOutput: |
| hidden_states = self.encoder(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state |
| if not self.config.encoder_only: |
| emb = self.head(hidden_states) |
| return BaseModelOutput(last_hidden_state=emb) |
|
|
|
|
| EmbedderModel.register_for_auto_class() |
|
|