add missing parts
Browse files- automodel.py +19 -0
automodel.py
CHANGED
|
@@ -5,6 +5,7 @@ import torch
|
|
| 5 |
import torch.nn as nn
|
| 6 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
| 7 |
from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present
|
|
|
|
| 8 |
from transformers import BertPreTrainedModel
|
| 9 |
from transformers.modeling_outputs import (MaskedLMOutput,
|
| 10 |
SequenceClassifierOutput)
|
|
@@ -233,6 +234,24 @@ class ClinicalMosaicForSequenceClassification(BertPreTrainedModel):
|
|
| 233 |
attentions=None,
|
| 234 |
)
|
| 235 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 236 |
class BertLMPredictionHead(nn.Module):
|
| 237 |
|
| 238 |
def __init__(self, config, bert_model_embedding_weights):
|
|
|
|
| 5 |
import torch.nn as nn
|
| 6 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
| 7 |
from torch.nn.modules.utils import consume_prefix_in_state_dict_if_present
|
| 8 |
+
from transformers.activations import ACT2FN
|
| 9 |
from transformers import BertPreTrainedModel
|
| 10 |
from transformers.modeling_outputs import (MaskedLMOutput,
|
| 11 |
SequenceClassifierOutput)
|
|
|
|
| 234 |
attentions=None,
|
| 235 |
)
|
| 236 |
|
| 237 |
+
|
| 238 |
+
class BertPredictionHeadTransform(nn.Module):
|
| 239 |
+
|
| 240 |
+
def __init__(self, config):
|
| 241 |
+
super().__init__()
|
| 242 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
| 243 |
+
if isinstance(config.hidden_act, str):
|
| 244 |
+
self.transform_act_fn = ACT2FN[config.hidden_act]
|
| 245 |
+
else:
|
| 246 |
+
self.transform_act_fn = config.hidden_act
|
| 247 |
+
self.LayerNorm = torch.nn.LayerNorm(config.hidden_size, eps=1e-12)
|
| 248 |
+
|
| 249 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
| 250 |
+
hidden_states = self.dense(hidden_states)
|
| 251 |
+
hidden_states = self.transform_act_fn(hidden_states)
|
| 252 |
+
hidden_states = self.LayerNorm(hidden_states)
|
| 253 |
+
return hidden_states
|
| 254 |
+
|
| 255 |
class BertLMPredictionHead(nn.Module):
|
| 256 |
|
| 257 |
def __init__(self, config, bert_model_embedding_weights):
|