Spaces:
Sleeping
Sleeping
Create helpers.py
Browse files- Nested/utils/helpers.py +123 -0
Nested/utils/helpers.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import logging
|
| 4 |
+
import importlib
|
| 5 |
+
import shutil
|
| 6 |
+
import torch
|
| 7 |
+
import pickle
|
| 8 |
+
import json
|
| 9 |
+
import random
|
| 10 |
+
import numpy as np
|
| 11 |
+
from argparse import Namespace
|
| 12 |
+
from huggingface_hub import hf_hub_download, snapshot_download
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def logging_config(log_file=None):
|
| 16 |
+
"""
|
| 17 |
+
Initialize custom logger
|
| 18 |
+
:param log_file: str - path to log file, full path
|
| 19 |
+
:return: None
|
| 20 |
+
"""
|
| 21 |
+
handlers = [logging.StreamHandler(sys.stdout)]
|
| 22 |
+
|
| 23 |
+
if log_file:
|
| 24 |
+
handlers.append(logging.FileHandler(log_file, "w", "utf-8"))
|
| 25 |
+
print("Logging to {}".format(log_file))
|
| 26 |
+
|
| 27 |
+
logging.basicConfig(
|
| 28 |
+
level=logging.INFO,
|
| 29 |
+
handlers=handlers,
|
| 30 |
+
format="%(levelname)s\t%(name)s\t%(asctime)s\t%(message)s",
|
| 31 |
+
datefmt="%a, %d %b %Y %H:%M:%S",
|
| 32 |
+
force=True
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def load_object(name, kwargs):
|
| 37 |
+
"""
|
| 38 |
+
Load objects dynamically given the object name and its arguments
|
| 39 |
+
:param name: str - object name, class name or function name
|
| 40 |
+
:param kwargs: dict - keyword arguments
|
| 41 |
+
:return: object
|
| 42 |
+
"""
|
| 43 |
+
object_module, object_name = name.rsplit(".", 1)
|
| 44 |
+
object_module = importlib.import_module(object_module)
|
| 45 |
+
fn = getattr(object_module, object_name)(**kwargs)
|
| 46 |
+
return fn
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def make_output_dirs(path, subdirs=[], overwrite=True):
|
| 50 |
+
"""
|
| 51 |
+
Create root directory and any other sub-directories
|
| 52 |
+
:param path: str - root directory
|
| 53 |
+
:param subdirs: List[str] - list of sub-directories
|
| 54 |
+
:param overwrite: boolean - to overwrite the directory or not
|
| 55 |
+
:return: None
|
| 56 |
+
"""
|
| 57 |
+
if overwrite:
|
| 58 |
+
shutil.rmtree(path, ignore_errors=True)
|
| 59 |
+
|
| 60 |
+
os.makedirs(path)
|
| 61 |
+
|
| 62 |
+
for subdir in subdirs:
|
| 63 |
+
os.makedirs(os.path.join(path, subdir))
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def load_checkpoint(model_path):
|
| 67 |
+
"""
|
| 68 |
+
Load model given the model path
|
| 69 |
+
:param model_path: str - path to model
|
| 70 |
+
:return: tagger - Nested.trainers.BaseTrainer - the tagger model
|
| 71 |
+
vocab - arabicner.utils.data.Vocab - indexed tags
|
| 72 |
+
train_config - argparse.Namespace - training configurations
|
| 73 |
+
"""
|
| 74 |
+
with open("Nested/utils/tag_vocab.pkl", "rb") as fh:
|
| 75 |
+
tag_vocab = pickle.load(fh)
|
| 76 |
+
|
| 77 |
+
# Load train configurations from checkpoint
|
| 78 |
+
train_config = Namespace()
|
| 79 |
+
args_path = hf_hub_download(repo_id="SinaLab/Nested", filename="args.json")
|
| 80 |
+
|
| 81 |
+
with open(args_path, "r") as fh:
|
| 82 |
+
train_config.__dict__ = json.load(fh)
|
| 83 |
+
|
| 84 |
+
# Initialize the loss function, not used for inference, but evaluation
|
| 85 |
+
loss = load_object(train_config.loss["fn"], train_config.loss["kwargs"])
|
| 86 |
+
|
| 87 |
+
# Load BERT tagger
|
| 88 |
+
model = load_object(train_config.network_config["fn"], train_config.network_config["kwargs"])
|
| 89 |
+
model = torch.nn.DataParallel(model)
|
| 90 |
+
|
| 91 |
+
if torch.cuda.is_available():
|
| 92 |
+
model = model.cuda()
|
| 93 |
+
|
| 94 |
+
# Update arguments for the tagger
|
| 95 |
+
# Attach the model, loss (used for evaluations cases)
|
| 96 |
+
train_config.trainer_config["kwargs"]["model"] = model
|
| 97 |
+
train_config.trainer_config["kwargs"]["loss"] = loss
|
| 98 |
+
|
| 99 |
+
tagger = load_object(train_config.trainer_config["fn"], train_config.trainer_config["kwargs"])
|
| 100 |
+
# checkpoint_path = hf_hub_download(repo_id="SinaLab/Nested", local_dir="checkpoints")
|
| 101 |
+
checkpoint_path = snapshot_download(repo_id="SinaLab/Nested", allow_patterns="checkpoints/")
|
| 102 |
+
|
| 103 |
+
tagger.load(checkpoint_path)
|
| 104 |
+
return tagger, tag_vocab, train_config
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def set_seed(seed):
|
| 108 |
+
"""
|
| 109 |
+
Set the seed for random intialization and set
|
| 110 |
+
CUDANN parameters to ensure determmihstic results across
|
| 111 |
+
multiple runs with the same seed
|
| 112 |
+
|
| 113 |
+
:param seed: int
|
| 114 |
+
"""
|
| 115 |
+
np.random.seed(seed)
|
| 116 |
+
random.seed(seed)
|
| 117 |
+
torch.manual_seed(seed)
|
| 118 |
+
torch.cuda.manual_seed(seed)
|
| 119 |
+
torch.cuda.manual_seed_all(seed)
|
| 120 |
+
|
| 121 |
+
torch.backends.cudnn.deterministic = True
|
| 122 |
+
torch.backends.cudnn.benchmark = False
|
| 123 |
+
torch.backends.cudnn.enabled = False
|