|
|
from __future__ import annotations |
|
|
from transformers import ( |
|
|
BertModel, |
|
|
BertConfig, |
|
|
PreTrainedModel, |
|
|
) |
|
|
from transformers.tokenization_utils_base import BatchEncoding |
|
|
import torch, torch.nn as nn, torch.nn.functional as F |
|
|
|
|
|
class SimCSEInferenceModel(PreTrainedModel): |
|
|
config_class = BertConfig |
|
|
|
|
|
def __init__(self, config): |
|
|
super().__init__(config) |
|
|
|
|
|
base_cfg = BertConfig(**config.to_dict()) |
|
|
self.encoder_input = BertModel(base_cfg) |
|
|
self.encoder_output = BertModel(base_cfg) |
|
|
|
|
|
hidden = self.encoder_input.config.hidden_size |
|
|
self.dense_input = nn.Linear(hidden, hidden) |
|
|
self.dense_output = nn.Linear(hidden, hidden) |
|
|
self.activation = nn.Tanh() |
|
|
self.temperature = getattr(config, "simcse_temperature", 0.05) |
|
|
|
|
|
@torch.no_grad() |
|
|
def encode_input(self, tok: BatchEncoding) -> torch.Tensor: |
|
|
h = self.encoder_input(**tok).last_hidden_state[:, 0] |
|
|
return self.activation(self.dense_input(h)) |
|
|
|
|
|
@torch.no_grad() |
|
|
def encode_output(self, tok: BatchEncoding) -> torch.Tensor: |
|
|
h = self.encoder_output(**tok).last_hidden_state[:, 0] |
|
|
return self.activation(self.dense_output(h)) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
tokenized_texts_1: BatchEncoding, |
|
|
tokenized_texts_2: BatchEncoding, |
|
|
labels: torch.Tensor, |
|
|
**_ |
|
|
): |
|
|
device = next(self.parameters()).device |
|
|
z1 = F.normalize(self.encode_input(tokenized_texts_1.to(device)), dim=-1) |
|
|
z2 = F.normalize(self.encode_output(tokenized_texts_2.to(device)), dim=-1) |
|
|
sim = torch.matmul(z1, z2.T) |
|
|
loss = F.cross_entropy(sim / self.temperature, labels.to(device)) |
|
|
return {"loss": loss, "logits": sim} |
|
|
|