Spaces:
Running
Running
Commit
·
7f85e18
1
Parent(s):
0b8923b
Update checkpoint_path
Browse files- Nested/trainers/BaseTrainer.py +2 -1
- Nested/utils/helpers.py +1 -1
- app.py +1 -0
Nested/trainers/BaseTrainer.py
CHANGED
|
@@ -107,7 +107,8 @@ class BaseTrainer:
|
|
| 107 |
:param checkpoint_path: str - path/to/checkpoints
|
| 108 |
:return: None
|
| 109 |
"""
|
| 110 |
-
checkpoint_path = natsort.natsorted("
|
|
|
|
| 111 |
# checkpoint_path = checkpoint_path[-1]
|
| 112 |
|
| 113 |
logger.info("Loading checkpoint %s", checkpoint_path)
|
|
|
|
| 107 |
:param checkpoint_path: str - path/to/checkpoints
|
| 108 |
:return: None
|
| 109 |
"""
|
| 110 |
+
# checkpoint_path = natsort.natsorted(glob.glob(f"{checkpoint_path}/checkpoint_*.pt"))
|
| 111 |
+
checkpoint_path = natsort.natsorted(checkpoint_path)
|
| 112 |
# checkpoint_path = checkpoint_path[-1]
|
| 113 |
|
| 114 |
logger.info("Loading checkpoint %s", checkpoint_path)
|
Nested/utils/helpers.py
CHANGED
|
@@ -97,7 +97,7 @@ def load_checkpoint(model_path):
|
|
| 97 |
train_config.trainer_config["kwargs"]["loss"] = loss
|
| 98 |
|
| 99 |
tagger = load_object(train_config.trainer_config["fn"], train_config.trainer_config["kwargs"])
|
| 100 |
-
checkpoint_path = hf_hub_download(repo_id="SinaLab/Nested", filename="checkpoints
|
| 101 |
|
| 102 |
tagger.load(checkpoint_path)
|
| 103 |
return tagger, tag_vocab, train_config
|
|
|
|
| 97 |
train_config.trainer_config["kwargs"]["loss"] = loss
|
| 98 |
|
| 99 |
tagger = load_object(train_config.trainer_config["fn"], train_config.trainer_config["kwargs"])
|
| 100 |
+
checkpoint_path = hf_hub_download(repo_id="SinaLab/Nested", filename="checkpoints")
|
| 101 |
|
| 102 |
tagger.load(checkpoint_path)
|
| 103 |
return tagger, tag_vocab, train_config
|
app.py
CHANGED
|
@@ -27,6 +27,7 @@ checkpoint_path = hf_hub_download(
|
|
| 27 |
repo_id="SinaLab/Nested",
|
| 28 |
filename="checkpoints/checkpoint_2.pt"
|
| 29 |
)
|
|
|
|
| 30 |
|
| 31 |
args_path = hf_hub_download(
|
| 32 |
repo_id="SinaLab/Nested",
|
|
|
|
| 27 |
repo_id="SinaLab/Nested",
|
| 28 |
filename="checkpoints/checkpoint_2.pt"
|
| 29 |
)
|
| 30 |
+
print("checkpoint_path : ", checkpoint_path)
|
| 31 |
|
| 32 |
args_path = hf_hub_download(
|
| 33 |
repo_id="SinaLab/Nested",
|