rio11user commited on
Commit
e5c1ba1
Β·
verified Β·
1 Parent(s): 43308fd

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. config.json +5 -2
  2. model.safetensors +1 -1
  3. modeling_simcse.py +48 -0
config.json CHANGED
@@ -23,5 +23,8 @@
23
  "transformers_version": "4.51.3",
24
  "type_vocab_size": 2,
25
  "use_cache": true,
26
- "vocab_size": 32768
27
- }
 
 
 
 
23
  "transformers_version": "4.51.3",
24
  "type_vocab_size": 2,
25
  "use_cache": true,
26
+ "vocab_size": 32768,
27
+ "auto_map": {
28
+ "AutoModel": "modeling_simcse.SimCSEInferenceModel"
29
+ }
30
+ }
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:56fa743a6730f3ae52e52f46365c5ef7f6433974240b5f0df3761378b7cafca7
3
  size 894432952
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dc70afc6def7daeb474328a741d0ff0139f7c27291cf7c795c5c401c1f4c5ce4
3
  size 894432952
modeling_simcse.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ from transformers import (
3
+ BertModel,
4
+ BertConfig,
5
+ PreTrainedModel,
6
+ )
7
+ from transformers.tokenization_utils_base import BatchEncoding
8
+ import torch, torch.nn as nn, torch.nn.functional as F
9
+
10
+ class SimCSEInferenceModel(PreTrainedModel):
11
+ config_class = BertConfig # ζŽ¨θ«–ζ™‚γ― BERT Config γ¨εˆγ‚γ›γ‚‹
12
+
13
+ def __init__(self, config):
14
+ super().__init__(config)
15
+ # θΏ½εŠ γƒ€γ‚¦γƒ³γƒ­γƒΌγƒ‰γ‚’ιΏγ‘γ‚‹γŸγ‚ from_config で空ヒデルを硄み立てる
16
+ base_cfg = BertConfig(**config.to_dict())
17
+ self.encoder_input = BertModel(base_cfg)
18
+ self.encoder_output = BertModel(base_cfg)
19
+
20
+ hidden = self.encoder_input.config.hidden_size
21
+ self.dense_input = nn.Linear(hidden, hidden)
22
+ self.dense_output = nn.Linear(hidden, hidden)
23
+ self.activation = nn.Tanh()
24
+ self.temperature = getattr(config, "simcse_temperature", 0.05)
25
+
26
+ @torch.no_grad()
27
+ def encode_input(self, tok: BatchEncoding) -> torch.Tensor:
28
+ h = self.encoder_input(**tok).last_hidden_state[:, 0]
29
+ return self.activation(self.dense_input(h))
30
+
31
+ @torch.no_grad()
32
+ def encode_output(self, tok: BatchEncoding) -> torch.Tensor:
33
+ h = self.encoder_output(**tok).last_hidden_state[:, 0]
34
+ return self.activation(self.dense_output(h))
35
+
36
+ def forward(
37
+ self,
38
+ tokenized_texts_1: BatchEncoding,
39
+ tokenized_texts_2: BatchEncoding,
40
+ labels: torch.Tensor,
41
+ **_
42
+ ):
43
+ device = next(self.parameters()).device
44
+ z1 = F.normalize(self.encode_input(tokenized_texts_1.to(device)), dim=-1)
45
+ z2 = F.normalize(self.encode_output(tokenized_texts_2.to(device)), dim=-1)
46
+ sim = torch.matmul(z1, z2.T)
47
+ loss = F.cross_entropy(sim / self.temperature, labels.to(device))
48
+ return {"loss": loss, "logits": sim}