File size: 1,888 Bytes
4a106f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
from __future__ import annotations
from transformers import (
    BertModel,
    BertConfig,
    PreTrainedModel,
)
from transformers.tokenization_utils_base import BatchEncoding
import torch, torch.nn as nn, torch.nn.functional as F

class SimCSEInferenceModel(PreTrainedModel):
    config_class = BertConfig          # 推論時は BERT Config と合わせる

    def __init__(self, config):
        super().__init__(config)
        # 追加ダウンロードを避けるため from_config で空モデルを組み立てる
        base_cfg = BertConfig(**config.to_dict())
        self.encoder_input  = BertModel(base_cfg)
        self.encoder_output = BertModel(base_cfg)

        hidden = self.encoder_input.config.hidden_size
        self.dense_input  = nn.Linear(hidden, hidden)
        self.dense_output = nn.Linear(hidden, hidden)
        self.activation   = nn.Tanh()
        self.temperature  = getattr(config, "simcse_temperature", 0.05)

    @torch.no_grad()
    def encode_input(self, tok: BatchEncoding) -> torch.Tensor:
        h = self.encoder_input(**tok).last_hidden_state[:, 0]
        return self.activation(self.dense_input(h))

    @torch.no_grad()
    def encode_output(self, tok: BatchEncoding) -> torch.Tensor:
        h = self.encoder_output(**tok).last_hidden_state[:, 0]
        return self.activation(self.dense_output(h))

    def forward(
        self,
        tokenized_texts_1: BatchEncoding,
        tokenized_texts_2: BatchEncoding,
        labels: torch.Tensor,
        **_
    ):
        device = next(self.parameters()).device
        z1 = F.normalize(self.encode_input(tokenized_texts_1.to(device)), dim=-1)
        z2 = F.normalize(self.encode_output(tokenized_texts_2.to(device)), dim=-1)
        sim = torch.matmul(z1, z2.T)
        loss = F.cross_entropy(sim / self.temperature, labels.to(device))
        return {"loss": loss, "logits": sim}