lv12 commited on
Commit
a42c485
·
verified ·
1 Parent(s): b003d9b

Uploading model.pt

Browse files
Files changed (2) hide show
  1. config.json +8 -3
  2. model.py +108 -0
config.json CHANGED
@@ -1,9 +1,14 @@
1
  {
 
 
 
2
  "model_type": "EmbeddingMoE",
3
  "base_model": "bert-base-uncased",
4
  "output_dim": 128,
5
  "dropout_rate": 0.1,
6
  "num_experts": 2,
7
- "hidden_dim": 256
8
- }
9
-
 
 
 
1
  {
2
+ "architectures": [
3
+ "EmbeddingMoE"
4
+ ],
5
  "model_type": "EmbeddingMoE",
6
  "base_model": "bert-base-uncased",
7
  "output_dim": 128,
8
  "dropout_rate": 0.1,
9
  "num_experts": 2,
10
+ "hidden_dim": 256,
11
+ "auto_map": {
12
+ "AutoModel": "modeling_embedding_moe.py"
13
+ }
14
+ }
model.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from transformers import PreTrainedModel, AutoConfig, AutoTokenizer, AutoModel
4
+
5
+
6
+ # Expert class using pre-trained BERT
7
+ class EmbeddingExpert(nn.Module):
8
+ def __init__(self, model_name, output_dim, dropout_rate=0.1):
9
+ super().__init__()
10
+ self.base = AutoModel.from_pretrained(model_name)
11
+ self.layer_norm = nn.LayerNorm(self.base.config.hidden_size)
12
+ self.dropout = nn.Dropout(dropout_rate)
13
+ for param in self.base.parameters():
14
+ param.requires_grad = False
15
+
16
+ # Projection layer to get the final embedding
17
+ self.projection = nn.Linear(self.base.config.hidden_size, output_dim)
18
+ nn.init.xavier_uniform_(self.projection.weight)
19
+ nn.init.zeros_(self.projection.bias)
20
+
21
+ def mean_pooling(self, model_output, attention_mask):
22
+ # Mean pooling - take attention mask into account for averaging
23
+ token_embeddings = model_output.last_hidden_state
24
+ input_mask_expanded = (
25
+ attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
26
+ )
27
+ return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
28
+ input_mask_expanded.sum(1), min=1e-9
29
+ )
30
+
31
+ def forward(self, input_ids, attention_mask):
32
+ outputs = self.base(input_ids=input_ids, attention_mask=attention_mask)
33
+ pooled_output = self.mean_pooling(outputs, attention_mask)
34
+ pooled_output = self.layer_norm(pooled_output)
35
+ pooled_output = self.dropout(pooled_output)
36
+ embedding = self.projection(pooled_output)
37
+ embedding = F.normalize(embedding, p=2, dim=1)
38
+
39
+ return embedding
40
+
41
+
42
+ # Gating Network
43
+ class GatingNetwork(nn.Module):
44
+ def __init__(self, input_dim, hidden_dim, num_experts, dropout_rate=0.1):
45
+ super().__init__()
46
+ self.layer_norm = nn.LayerNorm(input_dim)
47
+ self.dropout = nn.Dropout(dropout_rate)
48
+ self.linear1 = nn.Linear(input_dim, hidden_dim)
49
+ self.relu = nn.ReLU()
50
+ self.linear2 = nn.Linear(hidden_dim, num_experts)
51
+ self.softmax = nn.Softmax(dim=-1)
52
+
53
+ nn.init.xavier_uniform_(self.linear1.weight)
54
+ nn.init.zeros_(self.linear1.bias)
55
+ nn.init.xavier_uniform_(self.linear2.weight)
56
+ nn.init.zeros_(self.linear2.bias)
57
+
58
+ def forward(self, x):
59
+ x = self.layer_norm(x)
60
+ x = self.dropout(x)
61
+
62
+ x = self.linear1(x)
63
+ x = self.relu(x)
64
+ x = self.linear2(x)
65
+ x = torch.clamp(x, min=-10, max=10)
66
+ x = self.softmax(x)
67
+ return x
68
+
69
+
70
+ # Mixture of Experts for sentence embeddings using BERT
71
+ class EmbeddingMoE(PreTrainedModel):
72
+ config_class = AutoConfig
73
+
74
+ def __init__(self, config):
75
+ super().__init__(config)
76
+ output_dim = config.output_dim if hasattr(config, "output_dim") else 128
77
+ num_experts = config.num_experts if hasattr(config, "num_experts") else 2
78
+
79
+ self.expert1 = EmbeddingExpert("bert-base-uncased", output_dim)
80
+ self.expert2 = EmbeddingExpert("bert-base-uncased", output_dim)
81
+ self.gating = GatingNetwork(output_dim, 256, num_experts)
82
+
83
+ def forward(self, input_ids, attention_mask):
84
+ # Get embeddings from both experts
85
+ expert1_output = self.expert1(input_ids, attention_mask)
86
+ expert2_output = self.expert2(input_ids, attention_mask)
87
+
88
+ # Average the output as input to gating
89
+ gating_input = (expert1_output + expert2_output) / 2
90
+
91
+ # Get gating weights
92
+ gating_output = self.gating(gating_input)
93
+
94
+ # Combine expert outputs
95
+ mixed_output = (
96
+ gating_output[:, 0].unsqueeze(1) * expert1_output
97
+ + gating_output[:, 1].unsqueeze(1) * expert2_output
98
+ )
99
+
100
+ # Normalize the embedding to unit length
101
+ embedding = torch.nn.functional.normalize(mixed_output, p=2, dim=1)
102
+
103
+ return embedding
104
+
105
+ def encode_sentence(self, input_ids, attention_mask):
106
+ """Helper method to get the embedding for a single sentence"""
107
+ with torch.no_grad():
108
+ return self.forward(input_ids, attention_mask)