Spaces:
Runtime error
Runtime error
Update Utils/PLBERT/util.py
Browse files- Utils/PLBERT/util.py +19 -19
Utils/PLBERT/util.py
CHANGED
|
@@ -3,7 +3,6 @@ import yaml
|
|
| 3 |
import torch
|
| 4 |
from transformers import AlbertConfig, AlbertModel
|
| 5 |
|
| 6 |
-
|
| 7 |
class CustomAlbert(AlbertModel):
|
| 8 |
def forward(self, *args, **kwargs):
|
| 9 |
# Call the original forward method
|
|
@@ -16,34 +15,35 @@ class CustomAlbert(AlbertModel):
|
|
| 16 |
def load_plbert(log_dir):
|
| 17 |
config_path = os.path.join(log_dir, "config.yml")
|
| 18 |
plbert_config = yaml.safe_load(open(config_path))
|
| 19 |
-
|
| 20 |
-
albert_base_configuration = AlbertConfig(**plbert_config[
|
| 21 |
bert = CustomAlbert(albert_base_configuration)
|
| 22 |
|
| 23 |
files = os.listdir(log_dir)
|
| 24 |
ckpts = []
|
| 25 |
for f in os.listdir(log_dir):
|
| 26 |
-
if f.startswith("step_"):
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
iters = [
|
| 30 |
-
int(f.split("_")[-1].split(".")[0])
|
| 31 |
-
for f in ckpts
|
| 32 |
-
if os.path.isfile(os.path.join(log_dir, f))
|
| 33 |
-
]
|
| 34 |
iters = sorted(iters)[-1]
|
| 35 |
|
| 36 |
-
checkpoint = torch.load(log_dir + "/step_" + str(iters) + ".t7", map_location=
|
| 37 |
-
state_dict = checkpoint[
|
| 38 |
from collections import OrderedDict
|
| 39 |
-
|
| 40 |
new_state_dict = OrderedDict()
|
| 41 |
for k, v in state_dict.items():
|
| 42 |
-
name = k[7:]
|
| 43 |
-
if name.startswith(
|
| 44 |
-
name = name[8:]
|
| 45 |
new_state_dict[name] = v
|
| 46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
bert.load_state_dict(new_state_dict, strict=False)
|
| 48 |
-
|
| 49 |
return bert
|
|
|
|
|
|
|
|
|
| 3 |
import torch
|
| 4 |
from transformers import AlbertConfig, AlbertModel
|
| 5 |
|
|
|
|
| 6 |
class CustomAlbert(AlbertModel):
|
| 7 |
def forward(self, *args, **kwargs):
|
| 8 |
# Call the original forward method
|
|
|
|
| 15 |
def load_plbert(log_dir):
|
| 16 |
config_path = os.path.join(log_dir, "config.yml")
|
| 17 |
plbert_config = yaml.safe_load(open(config_path))
|
| 18 |
+
|
| 19 |
+
albert_base_configuration = AlbertConfig(**plbert_config['model_params'])
|
| 20 |
bert = CustomAlbert(albert_base_configuration)
|
| 21 |
|
| 22 |
files = os.listdir(log_dir)
|
| 23 |
ckpts = []
|
| 24 |
for f in os.listdir(log_dir):
|
| 25 |
+
if f.startswith("step_"): ckpts.append(f)
|
| 26 |
+
|
| 27 |
+
iters = [int(f.split('_')[-1].split('.')[0]) for f in ckpts if os.path.isfile(os.path.join(log_dir, f))]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
iters = sorted(iters)[-1]
|
| 29 |
|
| 30 |
+
checkpoint = torch.load(log_dir + "/step_" + str(iters) + ".t7", map_location='cpu')
|
| 31 |
+
state_dict = checkpoint['net']
|
| 32 |
from collections import OrderedDict
|
|
|
|
| 33 |
new_state_dict = OrderedDict()
|
| 34 |
for k, v in state_dict.items():
|
| 35 |
+
name = k[7:] # remove `module.`
|
| 36 |
+
if name.startswith('encoder.'):
|
| 37 |
+
name = name[8:] # remove `encoder.`
|
| 38 |
new_state_dict[name] = v
|
| 39 |
+
|
| 40 |
+
# Check if 'embeddings.position_ids' exists before attempting to delete it
|
| 41 |
+
if not hasattr(bert.embeddings, 'position_ids'):
|
| 42 |
+
del new_state_dict["embeddings.position_ids"]
|
| 43 |
+
|
| 44 |
+
|
| 45 |
bert.load_state_dict(new_state_dict, strict=False)
|
| 46 |
+
|
| 47 |
return bert
|
| 48 |
+
|
| 49 |
+
|