wojood-api / Nested /utils /helpers.py
TymaaHammouda's picture
Update
ebfe92b
raw
history blame
3.8 kB
import os
import sys
import logging
import importlib
import shutil
import torch
import pickle
import json
import random
import numpy as np
from argparse import Namespace
from huggingface_hub import hf_hub_download, snapshot_download
def logging_config(log_file=None):
"""
Initialize custom logger
:param log_file: str - path to log file, full path
:return: None
"""
handlers = [logging.StreamHandler(sys.stdout)]
if log_file:
handlers.append(logging.FileHandler(log_file, "w", "utf-8"))
print("Logging to {}".format(log_file))
logging.basicConfig(
level=logging.INFO,
handlers=handlers,
format="%(levelname)s\t%(name)s\t%(asctime)s\t%(message)s",
datefmt="%a, %d %b %Y %H:%M:%S",
force=True
)
def load_object(name, kwargs):
"""
Load objects dynamically given the object name and its arguments
:param name: str - object name, class name or function name
:param kwargs: dict - keyword arguments
:return: object
"""
object_module, object_name = name.rsplit(".", 1)
object_module = importlib.import_module(object_module)
fn = getattr(object_module, object_name)(**kwargs)
return fn
def make_output_dirs(path, subdirs=[], overwrite=True):
"""
Create root directory and any other sub-directories
:param path: str - root directory
:param subdirs: List[str] - list of sub-directories
:param overwrite: boolean - to overwrite the directory or not
:return: None
"""
if overwrite:
shutil.rmtree(path, ignore_errors=True)
os.makedirs(path)
for subdir in subdirs:
os.makedirs(os.path.join(path, subdir))
def load_checkpoint(model_path):
"""
Load model given the model path
:param model_path: str - path to model
:return: tagger - Nested.trainers.BaseTrainer - the tagger model
vocab - arabicner.utils.data.Vocab - indexed tags
train_config - argparse.Namespace - training configurations
"""
with open("Nested/utils/tag_vocab.pkl", "rb") as fh:
tag_vocab = pickle.load(fh)
# Load train configurations from checkpoint
train_config = Namespace()
args_path = hf_hub_download(repo_id="SinaLab/Nested", filename="args.json")
with open(args_path, "r") as fh:
train_config.__dict__ = json.load(fh)
# Initialize the loss function, not used for inference, but evaluation
loss = load_object(train_config.loss["fn"], train_config.loss["kwargs"])
# Load BERT tagger
model = load_object(train_config.network_config["fn"], train_config.network_config["kwargs"])
model = torch.nn.DataParallel(model)
if torch.cuda.is_available():
model = model.cuda()
# Update arguments for the tagger
# Attach the model, loss (used for evaluations cases)
train_config.trainer_config["kwargs"]["model"] = model
train_config.trainer_config["kwargs"]["loss"] = loss
tagger = load_object(train_config.trainer_config["fn"], train_config.trainer_config["kwargs"])
# checkpoint_path = hf_hub_download(repo_id="SinaLab/Nested", local_dir="checkpoints")
checkpoint_path = snapshot_download(repo_id="SinaLab/Nested", repo_type="checkpoints")
tagger.load(checkpoint_path)
return tagger, tag_vocab, train_config
def set_seed(seed):
"""
Set the seed for random intialization and set
CUDANN parameters to ensure determmihstic results across
multiple runs with the same seed
:param seed: int
"""
np.random.seed(seed)
random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.enabled = False