Xsmos commited on
Commit
983234c
·
verified ·
1 Parent(s): dd85898

Fix import error and add source_files to config

Browse files
Files changed (3) hide show
  1. config.json +1 -7
  2. foundation_bert.py +3 -7
  3. yaml_util.py +24 -0
config.json CHANGED
@@ -7,13 +7,7 @@
7
  "architectures": [
8
  "FoundationBert"
9
  ],
10
- "source_files": [
11
- "foundation_bert.py",
12
- "utils/__init__.py",
13
- "utils/masked_data_modeling_loss.py",
14
- "utils/yaml_util.py",
15
- "train_config.yaml"
16
- ],
17
  "attention_probs_dropout_prob": 0.1,
18
  "classifier_dropout": null,
19
  "hidden_act": "gelu",
 
7
  "architectures": [
8
  "FoundationBert"
9
  ],
10
+
 
 
 
 
 
 
11
  "attention_probs_dropout_prob": 0.1,
12
  "classifier_dropout": null,
13
  "hidden_act": "gelu",
foundation_bert.py CHANGED
@@ -2,15 +2,11 @@ import sys
2
  import os
3
  from pathlib import Path
4
 
5
- current_dir = os.path.dirname(os.path.abspath(__file__))
6
- if current_dir not in sys.path:
7
- sys.path.append(current_dir)
8
-
9
  import torch
10
  import yaml
11
- from .utils.masked_data_modeling_loss import MaskedDataLossWithSoftmax
12
  # from ..utils.contrastive_loss import ContrastiveLoss
13
- from .utils.yaml_util import MyLoader
14
  from dataclasses import dataclass
15
  from transformers import BertModel, BertConfig, PretrainedConfig
16
  from typing import Optional, Union
@@ -128,7 +124,7 @@ class FoundationBert(BertModel):
128
 
129
  self.lm_head = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False) # isn't used currently
130
  self.xval_loss = torch.nn.MSELoss(reduction='none') # isn't used currently
131
- self.mlm_loss = MaskedDataLossWithSoftmax(ignore=-100, reduction='none') # isn't used currently
132
  self.distributed_loss = False
133
 
134
  @classmethod
 
2
  import os
3
  from pathlib import Path
4
 
 
 
 
 
5
  import torch
6
  import yaml
7
+ # from masked_data_modeling_loss import MaskedDataLossWithSoftmax
8
  # from ..utils.contrastive_loss import ContrastiveLoss
9
+ from yaml_util import MyLoader
10
  from dataclasses import dataclass
11
  from transformers import BertModel, BertConfig, PretrainedConfig
12
  from typing import Optional, Union
 
124
 
125
  self.lm_head = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False) # isn't used currently
126
  self.xval_loss = torch.nn.MSELoss(reduction='none') # isn't used currently
127
+ #self.mlm_loss = MaskedDataLossWithSoftmax(ignore=-100, reduction='none') # isn't used currently
128
  self.distributed_loss = False
129
 
130
  @classmethod
yaml_util.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import yaml
2
+ class MyLoader(yaml.SafeLoader):
3
+ # returns
4
+ def construct_mapping(self, *args, **kwargs):
5
+ super().add_constructor(None, construct_undefined)
6
+ # when loading we want to skip keys that require construction,
7
+ mapping = super().construct_mapping(*args, **kwargs)
8
+
9
+ return mapping
10
+ import typing
11
+ class Tagged(typing.NamedTuple):
12
+ tag: str
13
+ value: object
14
+
15
+ def construct_undefined(self, node):
16
+ if isinstance(node, yaml.nodes.ScalarNode):
17
+ value = self.construct_scalar(node)
18
+ elif isinstance(node, yaml.nodes.SequenceNode):
19
+ value = self.construct_sequence(node)
20
+ elif isinstance(node, yaml.nodes.MappingNode):
21
+ value = self.construct_mapping(node)
22
+ else:
23
+ assert False, f"unexpected node: {node!r}"
24
+ return Tagged(node.tag, value)