TymaaHammouda commited on
Commit
e091f09
·
1 Parent(s): 39af8fe

Update paths

Browse files
Files changed (1) hide show
  1. Nested/utils/helpers.py +8 -3
Nested/utils/helpers.py CHANGED
@@ -9,6 +9,7 @@ import json
9
  import random
10
  import numpy as np
11
  from argparse import Namespace
 
12
 
13
 
14
  def logging_config(log_file=None):
@@ -70,12 +71,14 @@ def load_checkpoint(model_path):
70
  vocab - arabicner.utils.data.Vocab - indexed tags
71
  train_config - argparse.Namespace - training configurations
72
  """
73
- with open(os.path.join(model_path, "tag_vocab.pkl"), "rb") as fh:
74
  tag_vocab = pickle.load(fh)
75
 
76
  # Load train configurations from checkpoint
77
  train_config = Namespace()
78
- with open(os.path.join(model_path, "args.json"), "r") as fh:
 
 
79
  train_config.__dict__ = json.load(fh)
80
 
81
  # Initialize the loss function, not used for inference, but evaluation
@@ -94,7 +97,9 @@ def load_checkpoint(model_path):
94
  train_config.trainer_config["kwargs"]["loss"] = loss
95
 
96
  tagger = load_object(train_config.trainer_config["fn"], train_config.trainer_config["kwargs"])
97
- tagger.load(os.path.join(model_path, "checkpoints"))
 
 
98
  return tagger, tag_vocab, train_config
99
 
100
 
 
9
  import random
10
  import numpy as np
11
  from argparse import Namespace
12
+ from huggingface_hub import hf_hub_download
13
 
14
 
15
  def logging_config(log_file=None):
 
71
  vocab - arabicner.utils.data.Vocab - indexed tags
72
  train_config - argparse.Namespace - training configurations
73
  """
74
+ with open("Nested/utils/tag_vocab.pkl", "rb") as fh:
75
  tag_vocab = pickle.load(fh)
76
 
77
  # Load train configurations from checkpoint
78
  train_config = Namespace()
79
+ args_path = hf_hub_download(repo_id="SinaLab/Nested", filename="args.json")
80
+
81
+ with open(args_path, "r") as fh:
82
  train_config.__dict__ = json.load(fh)
83
 
84
  # Initialize the loss function, not used for inference, but evaluation
 
97
  train_config.trainer_config["kwargs"]["loss"] = loss
98
 
99
  tagger = load_object(train_config.trainer_config["fn"], train_config.trainer_config["kwargs"])
100
+ checkpoint_path = hf_hub_download(repo_id="SinaLab/Nested", filename="checkpoints/checkpoint_2.pt")
101
+
102
+ tagger.load(checkpoint_path)
103
  return tagger, tag_vocab, train_config
104
 
105