TymaaHammouda commited on
Commit
7f85e18
·
1 Parent(s): 0b8923b

Update checkpoint_path

Browse files
Nested/trainers/BaseTrainer.py CHANGED
@@ -107,7 +107,8 @@ class BaseTrainer:
107
  :param checkpoint_path: str - path/to/checkpoints
108
  :return: None
109
  """
110
- checkpoint_path = natsort.natsorted("checkpoint_2.pt")
 
111
  # checkpoint_path = checkpoint_path[-1]
112
 
113
  logger.info("Loading checkpoint %s", checkpoint_path)
 
107
  :param checkpoint_path: str - path/to/checkpoints
108
  :return: None
109
  """
110
+ # checkpoint_path = natsort.natsorted(glob.glob(f"{checkpoint_path}/checkpoint_*.pt"))
111
+ checkpoint_path = natsort.natsorted(checkpoint_path)
112
  # checkpoint_path = checkpoint_path[-1]
113
 
114
  logger.info("Loading checkpoint %s", checkpoint_path)
Nested/utils/helpers.py CHANGED
@@ -97,7 +97,7 @@ def load_checkpoint(model_path):
97
  train_config.trainer_config["kwargs"]["loss"] = loss
98
 
99
  tagger = load_object(train_config.trainer_config["fn"], train_config.trainer_config["kwargs"])
100
- checkpoint_path = hf_hub_download(repo_id="SinaLab/Nested", filename="checkpoints/checkpoint_2.pt")
101
 
102
  tagger.load(checkpoint_path)
103
  return tagger, tag_vocab, train_config
 
97
  train_config.trainer_config["kwargs"]["loss"] = loss
98
 
99
  tagger = load_object(train_config.trainer_config["fn"], train_config.trainer_config["kwargs"])
100
+ checkpoint_path = hf_hub_download(repo_id="SinaLab/Nested", filename="checkpoints")
101
 
102
  tagger.load(checkpoint_path)
103
  return tagger, tag_vocab, train_config
app.py CHANGED
@@ -27,6 +27,7 @@ checkpoint_path = hf_hub_download(
27
  repo_id="SinaLab/Nested",
28
  filename="checkpoints/checkpoint_2.pt"
29
  )
 
30
 
31
  args_path = hf_hub_download(
32
  repo_id="SinaLab/Nested",
 
27
  repo_id="SinaLab/Nested",
28
  filename="checkpoints/checkpoint_2.pt"
29
  )
30
+ print("checkpoint_path : ", checkpoint_path)
31
 
32
  args_path = hf_hub_download(
33
  repo_id="SinaLab/Nested",