rio11user commited on
Commit
6664fab
·
verified ·
1 Parent(s): 81bd0eb

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. config.json +5 -2
  2. model.safetensors +1 -1
  3. modeling_simcse.py +37 -0
config.json CHANGED
@@ -23,5 +23,8 @@
23
  "transformers_version": "4.51.3",
24
  "type_vocab_size": 2,
25
  "use_cache": true,
26
- "vocab_size": 32768
27
- }
 
 
 
 
23
  "transformers_version": "4.51.3",
24
  "type_vocab_size": 2,
25
  "use_cache": true,
26
+ "vocab_size": 32768,
27
+ "auto_map": {
28
+ "AutoModel": "modeling_simcse.SimCSEInferenceModel"
29
+ }
30
+ }
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:d17f22fdf6834dd2fae0baad70d8a15bc36644056ba215f582a5a4b5c4012b4c
3
  size 894432952
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c89b84d9d78f111b2c00bb3bda2063b07ffa943fc59581b095c6cab6fd4b181b
3
  size 894432952
modeling_simcse.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+ from transformers import BertModel, BertConfig, PreTrainedModel
3
+ from transformers.tokenization_utils_base import BatchEncoding
4
+ import torch, torch.nn as nn, torch.nn.functional as F
5
+
6
+ class SimCSEInferenceModel(PreTrainedModel):
7
+ config_class = BertConfig
8
+
9
+ def __init__(self, config):
10
+ super().__init__(config)
11
+ base_cfg = BertConfig(**config.to_dict())
12
+ self.encoder_input = BertModel(base_cfg)
13
+ self.encoder_output = BertModel(base_cfg)
14
+
15
+ hidden = self.encoder_input.config.hidden_size
16
+ self.dense_input = nn.Linear(hidden, hidden)
17
+ self.dense_output = nn.Linear(hidden, hidden)
18
+ self.activation = nn.Tanh()
19
+ self.temperature = getattr(config, "simcse_temperature", 0.05)
20
+
21
+ @torch.no_grad()
22
+ def encode_input(self, tok: BatchEncoding) -> torch.Tensor:
23
+ h = self.encoder_input(**tok).last_hidden_state[:, 0]
24
+ return self.activation(self.dense_input(h))
25
+
26
+ @torch.no_grad()
27
+ def encode_output(self, tok: BatchEncoding) -> torch.Tensor:
28
+ h = self.encoder_output(**tok).last_hidden_state[:, 0]
29
+ return self.activation(self.dense_output(h))
30
+
31
+ def forward(self, tokenized_texts_1: BatchEncoding, tokenized_texts_2: BatchEncoding, labels: torch.Tensor, **_):
32
+ device = next(self.parameters()).device
33
+ z1 = F.normalize(self.encode_input(tokenized_texts_1.to(device)), dim=-1)
34
+ z2 = F.normalize(self.encode_output(tokenized_texts_2.to(device)), dim=-1)
35
+ sim = torch.matmul(z1, z2.T)
36
+ loss = F.cross_entropy(sim / self.temperature, labels.to(device))
37
+ return {"loss": loss, "logits": sim}