from __future__ import annotations from transformers import ( BertModel, BertConfig, PreTrainedModel, ) from transformers.tokenization_utils_base import BatchEncoding import torch, torch.nn as nn, torch.nn.functional as F class SimCSEInferenceModel(PreTrainedModel): config_class = BertConfig # 推論時は BERT Config と合わせる def __init__(self, config): super().__init__(config) # 追加ダウンロードを避けるため from_config で空モデルを組み立てる base_cfg = BertConfig(**config.to_dict()) self.encoder_input = BertModel(base_cfg) self.encoder_output = BertModel(base_cfg) hidden = self.encoder_input.config.hidden_size self.dense_input = nn.Linear(hidden, hidden) self.dense_output = nn.Linear(hidden, hidden) self.activation = nn.Tanh() self.temperature = getattr(config, "simcse_temperature", 0.05) @torch.no_grad() def encode_input(self, tok: BatchEncoding) -> torch.Tensor: h = self.encoder_input(**tok).last_hidden_state[:, 0] return self.activation(self.dense_input(h)) @torch.no_grad() def encode_output(self, tok: BatchEncoding) -> torch.Tensor: h = self.encoder_output(**tok).last_hidden_state[:, 0] return self.activation(self.dense_output(h)) def forward( self, tokenized_texts_1: BatchEncoding, tokenized_texts_2: BatchEncoding, labels: torch.Tensor, **_ ): device = next(self.parameters()).device z1 = F.normalize(self.encode_input(tokenized_texts_1.to(device)), dim=-1) z2 = F.normalize(self.encode_output(tokenized_texts_2.to(device)), dim=-1) sim = torch.matmul(z1, z2.T) loss = F.cross_entropy(sim / self.temperature, labels.to(device)) return {"loss": loss, "logits": sim}