Added support adapterhub
Browse files- modeling_m5_encoder.py +8 -1
modeling_m5_encoder.py
CHANGED
|
@@ -60,7 +60,14 @@ class M5Encoder(PreTrainedModel):
|
|
| 60 |
def __init__(self, config):
|
| 61 |
super().__init__(config)
|
| 62 |
self.model = M5EncoderModel(config)
|
| 63 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
def forward(self, input_ids, attention_mask=None, relative_position=None, **kwargs):
|
| 65 |
return self.model(input_ids=input_ids,
|
| 66 |
attention_mask=attention_mask,
|
|
|
|
| 60 |
def __init__(self, config):
|
| 61 |
super().__init__(config)
|
| 62 |
self.model = M5EncoderModel(config)
|
| 63 |
+
|
| 64 |
+
def get_input_embeddings(self):
|
| 65 |
+
return self.model.shared
|
| 66 |
+
|
| 67 |
+
def set_input_embeddings(self, new_embeddings):
|
| 68 |
+
self.model.shared = new_embeddings
|
| 69 |
+
self.model.encoder.embed_tokens = new_embeddings # keep encoder in sync
|
| 70 |
+
|
| 71 |
def forward(self, input_ids, attention_mask=None, relative_position=None, **kwargs):
|
| 72 |
return self.model(input_ids=input_ids,
|
| 73 |
attention_mask=attention_mask,
|