Fix import error and add source_files to config
Browse files- config.json +1 -7
- foundation_bert.py +3 -7
- yaml_util.py +24 -0
config.json
CHANGED
|
@@ -7,13 +7,7 @@
|
|
| 7 |
"architectures": [
|
| 8 |
"FoundationBert"
|
| 9 |
],
|
| 10 |
-
|
| 11 |
-
"foundation_bert.py",
|
| 12 |
-
"utils/__init__.py",
|
| 13 |
-
"utils/masked_data_modeling_loss.py",
|
| 14 |
-
"utils/yaml_util.py",
|
| 15 |
-
"train_config.yaml"
|
| 16 |
-
],
|
| 17 |
"attention_probs_dropout_prob": 0.1,
|
| 18 |
"classifier_dropout": null,
|
| 19 |
"hidden_act": "gelu",
|
|
|
|
| 7 |
"architectures": [
|
| 8 |
"FoundationBert"
|
| 9 |
],
|
| 10 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
"attention_probs_dropout_prob": 0.1,
|
| 12 |
"classifier_dropout": null,
|
| 13 |
"hidden_act": "gelu",
|
foundation_bert.py
CHANGED
|
@@ -2,15 +2,11 @@ import sys
|
|
| 2 |
import os
|
| 3 |
from pathlib import Path
|
| 4 |
|
| 5 |
-
current_dir = os.path.dirname(os.path.abspath(__file__))
|
| 6 |
-
if current_dir not in sys.path:
|
| 7 |
-
sys.path.append(current_dir)
|
| 8 |
-
|
| 9 |
import torch
|
| 10 |
import yaml
|
| 11 |
-
from
|
| 12 |
# from ..utils.contrastive_loss import ContrastiveLoss
|
| 13 |
-
from
|
| 14 |
from dataclasses import dataclass
|
| 15 |
from transformers import BertModel, BertConfig, PretrainedConfig
|
| 16 |
from typing import Optional, Union
|
|
@@ -128,7 +124,7 @@ class FoundationBert(BertModel):
|
|
| 128 |
|
| 129 |
self.lm_head = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False) # isn't used currently
|
| 130 |
self.xval_loss = torch.nn.MSELoss(reduction='none') # isn't used currently
|
| 131 |
-
self.mlm_loss = MaskedDataLossWithSoftmax(ignore=-100, reduction='none') # isn't used currently
|
| 132 |
self.distributed_loss = False
|
| 133 |
|
| 134 |
@classmethod
|
|
|
|
| 2 |
import os
|
| 3 |
from pathlib import Path
|
| 4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
import torch
|
| 6 |
import yaml
|
| 7 |
+
# from masked_data_modeling_loss import MaskedDataLossWithSoftmax
|
| 8 |
# from ..utils.contrastive_loss import ContrastiveLoss
|
| 9 |
+
from yaml_util import MyLoader
|
| 10 |
from dataclasses import dataclass
|
| 11 |
from transformers import BertModel, BertConfig, PretrainedConfig
|
| 12 |
from typing import Optional, Union
|
|
|
|
| 124 |
|
| 125 |
self.lm_head = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False) # isn't used currently
|
| 126 |
self.xval_loss = torch.nn.MSELoss(reduction='none') # isn't used currently
|
| 127 |
+
#self.mlm_loss = MaskedDataLossWithSoftmax(ignore=-100, reduction='none') # isn't used currently
|
| 128 |
self.distributed_loss = False
|
| 129 |
|
| 130 |
@classmethod
|
yaml_util.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import yaml
|
| 2 |
+
class MyLoader(yaml.SafeLoader):
|
| 3 |
+
# returns
|
| 4 |
+
def construct_mapping(self, *args, **kwargs):
|
| 5 |
+
super().add_constructor(None, construct_undefined)
|
| 6 |
+
# when loading we want to skip keys that require construction,
|
| 7 |
+
mapping = super().construct_mapping(*args, **kwargs)
|
| 8 |
+
|
| 9 |
+
return mapping
|
| 10 |
+
import typing
|
| 11 |
+
class Tagged(typing.NamedTuple):
|
| 12 |
+
tag: str
|
| 13 |
+
value: object
|
| 14 |
+
|
| 15 |
+
def construct_undefined(self, node):
|
| 16 |
+
if isinstance(node, yaml.nodes.ScalarNode):
|
| 17 |
+
value = self.construct_scalar(node)
|
| 18 |
+
elif isinstance(node, yaml.nodes.SequenceNode):
|
| 19 |
+
value = self.construct_sequence(node)
|
| 20 |
+
elif isinstance(node, yaml.nodes.MappingNode):
|
| 21 |
+
value = self.construct_mapping(node)
|
| 22 |
+
else:
|
| 23 |
+
assert False, f"unexpected node: {node!r}"
|
| 24 |
+
return Tagged(node.tag, value)
|