File size: 3,930 Bytes
2551344
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import os
import sys
import logging
import importlib
import shutil
import torch
import pickle
import json
import random
import numpy as np
from argparse import Namespace
from huggingface_hub import hf_hub_download, snapshot_download


def logging_config(log_file=None):
    """

    Initialize custom logger

    :param log_file: str - path to log file, full path

    :return: None

    """
    handlers = [logging.StreamHandler(sys.stdout)]

    if log_file:
        handlers.append(logging.FileHandler(log_file, "w", "utf-8"))
        print("Logging to {}".format(log_file))

    logging.basicConfig(
        level=logging.INFO,
        handlers=handlers,
        format="%(levelname)s\t%(name)s\t%(asctime)s\t%(message)s",
        datefmt="%a, %d %b %Y %H:%M:%S",
        force=True
    )


def load_object(name, kwargs):
    """

    Load objects dynamically given the object name and its arguments

    :param name: str - object name, class name or function name

    :param kwargs: dict - keyword arguments

    :return: object

    """
    object_module, object_name = name.rsplit(".", 1)
    object_module = importlib.import_module(object_module)
    fn = getattr(object_module, object_name)(**kwargs)
    return fn


def make_output_dirs(path, subdirs=[], overwrite=True):
    """

    Create root directory and any other sub-directories

    :param path: str - root directory

    :param subdirs: List[str] - list of sub-directories

    :param overwrite: boolean - to overwrite the directory or not

    :return: None

    """
    if overwrite:
        shutil.rmtree(path, ignore_errors=True)

    os.makedirs(path)

    for subdir in subdirs:
        os.makedirs(os.path.join(path, subdir))


def load_checkpoint(model_path):
    """

    Load model given the model path

    :param model_path: str - path to model

    :return: tagger - Nested.trainers.BaseTrainer - the tagger model

             vocab - arabicner.utils.data.Vocab - indexed tags

             train_config - argparse.Namespace - training configurations

    """
    with open("Nested/utils/tag_vocab.pkl", "rb") as fh:
        tag_vocab = pickle.load(fh)

    # Load train configurations from checkpoint
    train_config = Namespace()
    args_path = hf_hub_download(repo_id="SinaLab/Nested", filename="args.json")
    
    with open(args_path, "r") as fh:
        train_config.__dict__ = json.load(fh)

    # Initialize the loss function, not used for inference, but evaluation
    loss = load_object(train_config.loss["fn"], train_config.loss["kwargs"])

    # Load BERT tagger
    model = load_object(train_config.network_config["fn"], train_config.network_config["kwargs"])
    model = torch.nn.DataParallel(model)

    if torch.cuda.is_available():
        model = model.cuda()

    # Update arguments for the tagger
    # Attach the model, loss (used for evaluations cases)
    train_config.trainer_config["kwargs"]["model"] = model
    train_config.trainer_config["kwargs"]["loss"] = loss

    tagger = load_object(train_config.trainer_config["fn"], train_config.trainer_config["kwargs"])
    # checkpoint_path = hf_hub_download(repo_id="SinaLab/Nested", local_dir="checkpoints")
    checkpoint_path = snapshot_download(repo_id="SinaLab/Nested", allow_patterns="checkpoints/")

    tagger.load(checkpoint_path)
    return tagger, tag_vocab, train_config


def set_seed(seed):
    """

    Set the seed for random intialization and set

    CUDANN parameters to ensure determmihstic results across

    multiple runs with the same seed



    :param seed: int

    """
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.enabled = False