rjzevallos commited on
Commit
e170dc2
·
verified ·
1 Parent(s): e841861

Upload util.py

Browse files
Files changed (1) hide show
  1. util.py +45 -0
util.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import yaml
3
+ import torch
4
+ from transformers import AlbertConfig, AlbertModel
5
+
6
+ class CustomAlbert(AlbertModel):
7
+ def forward(self, *args, **kwargs):
8
+ # Call the original forward method
9
+ outputs = super().forward(*args, **kwargs)
10
+
11
+ # Only return the last_hidden_state
12
+ return outputs.last_hidden_state
13
+
14
+
15
+ def load_plbert(log_dir):
16
+ config_path = os.path.join(log_dir, "config.yml")
17
+ plbert_config = yaml.safe_load(open(config_path))
18
+
19
+ albert_base_configuration = AlbertConfig(**plbert_config['model_params'])
20
+ bert = CustomAlbert(albert_base_configuration)
21
+
22
+ files = os.listdir(log_dir)
23
+ ckpts = []
24
+ for f in os.listdir(log_dir):
25
+ if f.startswith("step_"): ckpts.append(f)
26
+
27
+ iters = [int(f.split('_')[-1].split('.')[0]) for f in ckpts if os.path.isfile(os.path.join(log_dir, f))]
28
+ iters = sorted(iters)[-1]
29
+
30
+ checkpoint = torch.load(log_dir + "/step_" + str(iters) + ".t7", map_location='cpu')
31
+ state_dict = checkpoint['net']
32
+ from collections import OrderedDict
33
+ new_state_dict = OrderedDict()
34
+ for k, v in state_dict.items():
35
+ name = k[7:] # remove `module.`
36
+ if name.startswith('encoder.'):
37
+ name = name[8:] # remove `encoder.`
38
+ new_state_dict[name] = v
39
+ try:
40
+ del new_state_dict["embeddings.position_ids"]
41
+ except KeyError:
42
+ pass
43
+ bert.load_state_dict(new_state_dict, strict=False)
44
+
45
+ return bert