File size: 738 Bytes
638223b | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 | import torch
import torch.nn as nn
from transformers import AutoModel
class LabelEmbCls(nn.Module):
def __init__(self, base: AutoModel, lbl_emb: torch.Tensor):
super().__init__()
self.bert = base
self.lbl_E = nn.Parameter(lbl_emb, requires_grad=False)
self.tau = nn.Parameter(torch.tensor(1.0))
def forward(self, input_ids, attention_mask, token_type_ids=None):
# get the [CLS] token embedding
cls = self.bert(
input_ids=input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids
).last_hidden_state[:, 0] # shape [batch, 768]
# compute dot-product / tau
return torch.matmul(cls, self.lbl_E.T) / self.tau |