TymaaHammouda commited on
Commit
fec514c
·
1 Parent(s): af2fbec

Add checkpoints dir load

Browse files
Files changed (1) hide show
  1. Nested/utils/helpers.py +4 -2
Nested/utils/helpers.py CHANGED
@@ -9,7 +9,7 @@ import json
9
  import random
10
  import numpy as np
11
  from argparse import Namespace
12
- from huggingface_hub import hf_hub_download
13
 
14
 
15
  def logging_config(log_file=None):
@@ -97,7 +97,9 @@ def load_checkpoint(model_path):
97
  train_config.trainer_config["kwargs"]["loss"] = loss
98
 
99
  tagger = load_object(train_config.trainer_config["fn"], train_config.trainer_config["kwargs"])
100
- checkpoint_path = hf_hub_download(repo_id="SinaLab/Nested", local_dir="checkpoints")
 
 
101
 
102
  tagger.load(checkpoint_path)
103
  return tagger, tag_vocab, train_config
 
9
  import random
10
  import numpy as np
11
  from argparse import Namespace
12
+ from huggingface_hub import hf_hub_download, snapshot_download
13
 
14
 
15
  def logging_config(log_file=None):
 
97
  train_config.trainer_config["kwargs"]["loss"] = loss
98
 
99
  tagger = load_object(train_config.trainer_config["fn"], train_config.trainer_config["kwargs"])
100
+ # checkpoint_path = hf_hub_download(repo_id="SinaLab/Nested", local_dir="checkpoints")
101
+ checkpoint_path = snapshot_download("SinaLab/Nested/checkpoints")
102
+
103
 
104
  tagger.load(checkpoint_path)
105
  return tagger, tag_vocab, train_config