| import os
|
| import yaml
|
| import torch
|
| from transformers import AlbertConfig, AlbertModel
|
|
|
| class CustomAlbert(AlbertModel):
|
| def forward(self, *args, **kwargs):
|
|
|
| outputs = super().forward(*args, **kwargs)
|
|
|
|
|
| return outputs.last_hidden_state
|
|
|
|
|
| def load_plbert(log_dir):
|
| config_path = os.path.join(log_dir, "config.yml")
|
| plbert_config = yaml.safe_load(open(config_path))
|
|
|
| albert_base_configuration = AlbertConfig(**plbert_config['model_params'])
|
| bert = CustomAlbert(albert_base_configuration)
|
|
|
| files = os.listdir(log_dir)
|
| ckpts = []
|
| for f in os.listdir(log_dir):
|
| if f.startswith("step_"): ckpts.append(f)
|
|
|
| iters = [int(f.split('_')[-1].split('.')[0]) for f in ckpts if os.path.isfile(os.path.join(log_dir, f))]
|
| iters = sorted(iters)[-1]
|
|
|
| checkpoint = torch.load(log_dir + "/step_" + str(iters) + ".t7", map_location='cpu')
|
| state_dict = checkpoint['net']
|
| from collections import OrderedDict
|
| new_state_dict = OrderedDict()
|
| for k, v in state_dict.items():
|
| name = k[7:]
|
| if name.startswith('encoder.'):
|
| name = name[8:]
|
| new_state_dict[name] = v
|
|
|
|
|
| if not hasattr(bert.embeddings, 'position_ids'):
|
| del new_state_dict["embeddings.position_ids"]
|
|
|
|
|
| bert.load_state_dict(new_state_dict, strict=False)
|
|
|
| return bert
|
|
|