File size: 738 Bytes
638223b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import torch
import torch.nn as nn
from transformers import AutoModel

class LabelEmbCls(nn.Module):
    def __init__(self, base: AutoModel, lbl_emb: torch.Tensor):
        super().__init__()
        self.bert  = base
        self.lbl_E = nn.Parameter(lbl_emb, requires_grad=False)
        self.tau   = nn.Parameter(torch.tensor(1.0))
    def forward(self, input_ids, attention_mask, token_type_ids=None):
        # get the [CLS] token embedding
        cls = self.bert(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids
        ).last_hidden_state[:, 0]  # shape [batch, 768]
        # compute dot-product / tau
        return torch.matmul(cls, self.lbl_E.T) / self.tau