import torch import torch.nn as nn from transformers import AutoModel class LabelEmbCls(nn.Module): def __init__(self, base: AutoModel, lbl_emb: torch.Tensor): super().__init__() self.bert = base self.lbl_E = nn.Parameter(lbl_emb, requires_grad=False) self.tau = nn.Parameter(torch.tensor(1.0)) def forward(self, input_ids, attention_mask, token_type_ids=None): # get the [CLS] token embedding cls = self.bert( input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids ).last_hidden_state[:, 0] # shape [batch, 768] # compute dot-product / tau return torch.matmul(cls, self.lbl_E.T) / self.tau