phone_40_11_v8 / modeling_simcse.py
rio11user's picture
Upload folder using huggingface_hub
a367f18 verified
from __future__ import annotations
from transformers import (
BertModel,
BertConfig,
PreTrainedModel,
)
from transformers.tokenization_utils_base import BatchEncoding
import torch, torch.nn as nn, torch.nn.functional as F
class SimCSEInferenceModel(PreTrainedModel):
config_class = BertConfig # ζŽ¨θ«–ζ™‚γ― BERT Config γ¨εˆγ‚γ›γ‚‹
def __init__(self, config):
super().__init__(config)
# θΏ½εŠ γƒ€γ‚¦γƒ³γƒ­γƒΌγƒ‰γ‚’ιΏγ‘γ‚‹γŸγ‚ from_config で空ヒデルを硄み立てる
base_cfg = BertConfig(**config.to_dict())
self.encoder_input = BertModel(base_cfg)
self.encoder_output = BertModel(base_cfg)
hidden = self.encoder_input.config.hidden_size
self.dense_input = nn.Linear(hidden, hidden)
self.dense_output = nn.Linear(hidden, hidden)
self.activation = nn.Tanh()
self.temperature = getattr(config, "simcse_temperature", 0.05)
@torch.no_grad()
def encode_input(self, tok: BatchEncoding) -> torch.Tensor:
h = self.encoder_input(**tok).last_hidden_state[:, 0]
return self.activation(self.dense_input(h))
@torch.no_grad()
def encode_output(self, tok: BatchEncoding) -> torch.Tensor:
h = self.encoder_output(**tok).last_hidden_state[:, 0]
return self.activation(self.dense_output(h))
def forward(
self,
tokenized_texts_1: BatchEncoding,
tokenized_texts_2: BatchEncoding,
labels: torch.Tensor,
**_
):
device = next(self.parameters()).device
z1 = F.normalize(self.encode_input(tokenized_texts_1.to(device)), dim=-1)
z2 = F.normalize(self.encode_output(tokenized_texts_2.to(device)), dim=-1)
sim = torch.matmul(z1, z2.T)
loss = F.cross_entropy(sim / self.temperature, labels.to(device))
return {"loss": loss, "logits": sim}