File size: 2,537 Bytes
c3443ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d12d1be
c3443ec
 
 
 
 
d12d1be
c3443ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d12d1be
c3443ec
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
from typing import Any, Optional, cast

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import nn
from transformers import AutoModel, PreTrainedModel
from transformers.modeling_outputs import BaseModelOutput

from .configuration_embedder import EmbedderConfig


class EncoderBlock(nn.Module):
    def __init__(self, dim: int, hidden_dim: int, dropout: float):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.ReLU(),
        )
        self.norm = nn.LayerNorm(dim)
        self.dropout = nn.Dropout(dropout)
        self.proj = nn.Linear(hidden_dim, dim)
        self.relu = nn.ReLU()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        residual = x
        x = self.net(x)
        x = self.dropout(x)
        x = self.relu(self.proj(x))
        return cast(torch.Tensor, self.norm(x + residual))


class Head(nn.Module):
    def __init__(self, dim: int, num_blocks: int = 1, dropout: float = 0):
        super().__init__()
        self.blocks = nn.Sequential(
            *[EncoderBlock(dim=dim, hidden_dim=dim, dropout=dropout) for _ in range(num_blocks)]
        )
        self.proj = nn.Linear(dim, dim)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.blocks(x)
        x = self.proj(x)
        return x


class EmbedderModel(PreTrainedModel):
    config_class = EmbedderConfig  # type: ignore[assignment]
    base_model_prefix = "model"
    _supports_attention_backend = True

    def __init__(self, config: EmbedderConfig):
        super().__init__(config)
        self.encoder = AutoModel.from_config(
            config.encoder_config,
            trust_remote_code=True,
        )
        self._init_requires_grad(self.encoder)
        self.head = Head(
            dim=self.encoder.embeddings.word_embeddings.embedding_dim,
            num_blocks=config.num_blocks,
            dropout=config.dropout,
        )

    def _init_requires_grad(self, module: nn.Module) -> None:
        for p in module.parameters():
            p.requires_grad = False

    def forward(
        self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, **kwargs: Any
    ) -> BaseModelOutput:
        hidden_states = self.encoder(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state
        if not self.config.encoder_only:
            emb = self.head(hidden_states)  # B, T, D
        return BaseModelOutput(last_hidden_state=emb)


EmbedderModel.register_for_auto_class()