Uploading model.pt
Browse files
model.py
CHANGED
|
@@ -1,6 +1,13 @@
|
|
| 1 |
import torch
|
| 2 |
from torch import nn
|
| 3 |
-
from transformers import PreTrainedModel,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
|
| 6 |
# Expert class using pre-trained BERT
|
|
@@ -69,7 +76,7 @@ class GatingNetwork(nn.Module):
|
|
| 69 |
|
| 70 |
# Mixture of Experts for sentence embeddings using BERT
|
| 71 |
class EmbeddingMoE(PreTrainedModel):
|
| 72 |
-
config_class =
|
| 73 |
|
| 74 |
def __init__(self, config):
|
| 75 |
super().__init__(config)
|
|
|
|
| 1 |
import torch
|
| 2 |
from torch import nn
|
| 3 |
+
from transformers import PreTrainedModel, PretrainedConfig , AutoModel
|
| 4 |
+
|
| 5 |
+
class EmbeddingMoEConfig(PretrainedConfig):
|
| 6 |
+
def __init__(self, output_dim=128, num_experts=2, dropout_rate=0.1, **kwargs):
|
| 7 |
+
super().__init__(**kwargs)
|
| 8 |
+
self.output_dim = output_dim
|
| 9 |
+
self.num_experts = num_experts
|
| 10 |
+
self.dropout_rate = dropout_rate
|
| 11 |
|
| 12 |
|
| 13 |
# Expert class using pre-trained BERT
|
|
|
|
| 76 |
|
| 77 |
# Mixture of Experts for sentence embeddings using BERT
|
| 78 |
class EmbeddingMoE(PreTrainedModel):
|
| 79 |
+
config_class = EmbeddingMoEConfig
|
| 80 |
|
| 81 |
def __init__(self, config):
|
| 82 |
super().__init__(config)
|