lv12 commited on
Commit
e24665a
·
verified ·
1 Parent(s): 287885e

Uploading model.pt

Browse files
Files changed (1) hide show
  1. model.py +9 -2
model.py CHANGED
@@ -1,6 +1,13 @@
1
  import torch
2
  from torch import nn
3
- from transformers import PreTrainedModel, AutoConfig, AutoModel
 
 
 
 
 
 
 
4
 
5
 
6
  # Expert class using pre-trained BERT
@@ -69,7 +76,7 @@ class GatingNetwork(nn.Module):
69
 
70
  # Mixture of Experts for sentence embeddings using BERT
71
  class EmbeddingMoE(PreTrainedModel):
72
- config_class = AutoConfig
73
 
74
  def __init__(self, config):
75
  super().__init__(config)
 
1
  import torch
2
  from torch import nn
3
+ from transformers import PreTrainedModel, PretrainedConfig , AutoModel
4
+
5
+ class EmbeddingMoEConfig(PretrainedConfig):
6
+ def __init__(self, output_dim=128, num_experts=2, dropout_rate=0.1, **kwargs):
7
+ super().__init__(**kwargs)
8
+ self.output_dim = output_dim
9
+ self.num_experts = num_experts
10
+ self.dropout_rate = dropout_rate
11
 
12
 
13
  # Expert class using pre-trained BERT
 
76
 
77
  # Mixture of Experts for sentence embeddings using BERT
78
  class EmbeddingMoE(PreTrainedModel):
79
+ config_class = EmbeddingMoEConfig
80
 
81
  def __init__(self, config):
82
  super().__init__(config)