aaljabari commited on
Commit
f26a4b0
·
verified ·
1 Parent(s): ef877e8

Create helpers.py

Browse files
Files changed (1) hide show
  1. Nested/utils/helpers.py +123 -0
Nested/utils/helpers.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import logging
4
+ import importlib
5
+ import shutil
6
+ import torch
7
+ import pickle
8
+ import json
9
+ import random
10
+ import numpy as np
11
+ from argparse import Namespace
12
+ from huggingface_hub import hf_hub_download, snapshot_download
13
+
14
+
15
+ def logging_config(log_file=None):
16
+ """
17
+ Initialize custom logger
18
+ :param log_file: str - path to log file, full path
19
+ :return: None
20
+ """
21
+ handlers = [logging.StreamHandler(sys.stdout)]
22
+
23
+ if log_file:
24
+ handlers.append(logging.FileHandler(log_file, "w", "utf-8"))
25
+ print("Logging to {}".format(log_file))
26
+
27
+ logging.basicConfig(
28
+ level=logging.INFO,
29
+ handlers=handlers,
30
+ format="%(levelname)s\t%(name)s\t%(asctime)s\t%(message)s",
31
+ datefmt="%a, %d %b %Y %H:%M:%S",
32
+ force=True
33
+ )
34
+
35
+
36
+ def load_object(name, kwargs):
37
+ """
38
+ Load objects dynamically given the object name and its arguments
39
+ :param name: str - object name, class name or function name
40
+ :param kwargs: dict - keyword arguments
41
+ :return: object
42
+ """
43
+ object_module, object_name = name.rsplit(".", 1)
44
+ object_module = importlib.import_module(object_module)
45
+ fn = getattr(object_module, object_name)(**kwargs)
46
+ return fn
47
+
48
+
49
+ def make_output_dirs(path, subdirs=[], overwrite=True):
50
+ """
51
+ Create root directory and any other sub-directories
52
+ :param path: str - root directory
53
+ :param subdirs: List[str] - list of sub-directories
54
+ :param overwrite: boolean - to overwrite the directory or not
55
+ :return: None
56
+ """
57
+ if overwrite:
58
+ shutil.rmtree(path, ignore_errors=True)
59
+
60
+ os.makedirs(path)
61
+
62
+ for subdir in subdirs:
63
+ os.makedirs(os.path.join(path, subdir))
64
+
65
+
66
+ def load_checkpoint(model_path):
67
+ """
68
+ Load model given the model path
69
+ :param model_path: str - path to model
70
+ :return: tagger - Nested.trainers.BaseTrainer - the tagger model
71
+ vocab - arabicner.utils.data.Vocab - indexed tags
72
+ train_config - argparse.Namespace - training configurations
73
+ """
74
+ with open("Nested/utils/tag_vocab.pkl", "rb") as fh:
75
+ tag_vocab = pickle.load(fh)
76
+
77
+ # Load train configurations from checkpoint
78
+ train_config = Namespace()
79
+ args_path = hf_hub_download(repo_id="SinaLab/Nested", filename="args.json")
80
+
81
+ with open(args_path, "r") as fh:
82
+ train_config.__dict__ = json.load(fh)
83
+
84
+ # Initialize the loss function, not used for inference, but evaluation
85
+ loss = load_object(train_config.loss["fn"], train_config.loss["kwargs"])
86
+
87
+ # Load BERT tagger
88
+ model = load_object(train_config.network_config["fn"], train_config.network_config["kwargs"])
89
+ model = torch.nn.DataParallel(model)
90
+
91
+ if torch.cuda.is_available():
92
+ model = model.cuda()
93
+
94
+ # Update arguments for the tagger
95
+ # Attach the model, loss (used for evaluations cases)
96
+ train_config.trainer_config["kwargs"]["model"] = model
97
+ train_config.trainer_config["kwargs"]["loss"] = loss
98
+
99
+ tagger = load_object(train_config.trainer_config["fn"], train_config.trainer_config["kwargs"])
100
+ # checkpoint_path = hf_hub_download(repo_id="SinaLab/Nested", local_dir="checkpoints")
101
+ checkpoint_path = snapshot_download(repo_id="SinaLab/Nested", allow_patterns="checkpoints/")
102
+
103
+ tagger.load(checkpoint_path)
104
+ return tagger, tag_vocab, train_config
105
+
106
+
107
+ def set_seed(seed):
108
+ """
109
+ Set the seed for random intialization and set
110
+ CUDANN parameters to ensure determmihstic results across
111
+ multiple runs with the same seed
112
+
113
+ :param seed: int
114
+ """
115
+ np.random.seed(seed)
116
+ random.seed(seed)
117
+ torch.manual_seed(seed)
118
+ torch.cuda.manual_seed(seed)
119
+ torch.cuda.manual_seed_all(seed)
120
+
121
+ torch.backends.cudnn.deterministic = True
122
+ torch.backends.cudnn.benchmark = False
123
+ torch.backends.cudnn.enabled = False