Nick Sorros
commited on
Commit
·
b4da537
1
Parent(s):
9c3f2b9
Update model
Browse files
model.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
from transformers import AutoModel
|
| 2 |
import torch
|
| 3 |
|
| 4 |
|
|
@@ -15,7 +15,7 @@ class MultiLabelAttention(torch.nn.Module):
|
|
| 15 |
return torch.matmul(torch.transpose(attention_weights, 2, 1), x)
|
| 16 |
|
| 17 |
|
| 18 |
-
class BertMesh(
|
| 19 |
def __init__(
|
| 20 |
self,
|
| 21 |
pretrained_model,
|
|
@@ -24,7 +24,8 @@ class BertMesh(torch.nn.Module):
|
|
| 24 |
dropout=0,
|
| 25 |
multilabel_attention=False,
|
| 26 |
):
|
| 27 |
-
super().__init__()
|
|
|
|
| 28 |
self.pretrained_model = pretrained_model
|
| 29 |
self.num_labels = num_labels
|
| 30 |
self.hidden_size = hidden_size
|
|
|
|
| 1 |
+
from transformers import AutoModel, AutoConfig, PreTrainedModel
|
| 2 |
import torch
|
| 3 |
|
| 4 |
|
|
|
|
| 15 |
return torch.matmul(torch.transpose(attention_weights, 2, 1), x)
|
| 16 |
|
| 17 |
|
| 18 |
+
class BertMesh(PreTrainedModel):
|
| 19 |
def __init__(
|
| 20 |
self,
|
| 21 |
pretrained_model,
|
|
|
|
| 24 |
dropout=0,
|
| 25 |
multilabel_attention=False,
|
| 26 |
):
|
| 27 |
+
super().__init__(config=AutoConfig.from_pretrained(pretrained_model))
|
| 28 |
+
self.config.auto_map = {"AutoModel": "transformers_model.BertMesh"}
|
| 29 |
self.pretrained_model = pretrained_model
|
| 30 |
self.num_labels = num_labels
|
| 31 |
self.hidden_size = hidden_size
|