File size: 1,335 Bytes
cff9a03
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import torch
import torch.nn as nn
from transformers import PreTrainedModel, PretrainedConfig

class LexicalConfig(PretrainedConfig):
    model_type = "lexical_embedding"

    def __init__(
        self, 
        vocab_size=30522, 
        embed_dim=2048, 
        padding_idx=0, 
        **kwargs
    ):
        super().__init__(**kwargs)
        self.vocab_size = vocab_size
        self.embed_dim = embed_dim
        self.padding_idx = padding_idx

class LexicalHFModel(PreTrainedModel):
    config_class = LexicalConfig

    def __init__(self, config):
        super().__init__(config)
        self.config = config
        
        self.embedding = nn.Embedding(
            config.vocab_size, 
            config.embed_dim, 
            padding_idx=config.padding_idx
        )
        
    def forward(self, input_ids, attention_mask=None, **kwargs):
        embeds = self.embedding(input_ids)
        
        if attention_mask is None:
            attention_mask = torch.ones_like(input_ids)

        mask_expanded = attention_mask.unsqueeze(-1).expand(embeds.size()).float()
        sum_embeddings = torch.sum(embeds * mask_expanded, 1)
        sum_mask = torch.clamp(mask_expanded.sum(1), min=1e-9)
        mean_pooled = sum_embeddings / sum_mask
        
        return torch.nn.functional.normalize(mean_pooled, p=2, dim=1)