File size: 3,807 Bytes
f316449
 
 
 
 
 
 
 
 
 
 
fec514c
f316449
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e091f09
f316449
 
 
 
e091f09
 
 
f316449
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fec514c
07296dc
e091f09
 
f316449
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import os
import sys
import logging
import importlib
import shutil
import torch
import pickle
import json
import random
import numpy as np
from argparse import Namespace
from huggingface_hub import hf_hub_download, snapshot_download


def logging_config(log_file=None):
    """
    Initialize custom logger
    :param log_file: str - path to log file, full path
    :return: None
    """
    handlers = [logging.StreamHandler(sys.stdout)]

    if log_file:
        handlers.append(logging.FileHandler(log_file, "w", "utf-8"))
        print("Logging to {}".format(log_file))

    logging.basicConfig(
        level=logging.INFO,
        handlers=handlers,
        format="%(levelname)s\t%(name)s\t%(asctime)s\t%(message)s",
        datefmt="%a, %d %b %Y %H:%M:%S",
        force=True
    )


def load_object(name, kwargs):
    """
    Load objects dynamically given the object name and its arguments
    :param name: str - object name, class name or function name
    :param kwargs: dict - keyword arguments
    :return: object
    """
    object_module, object_name = name.rsplit(".", 1)
    object_module = importlib.import_module(object_module)
    fn = getattr(object_module, object_name)(**kwargs)
    return fn


def make_output_dirs(path, subdirs=[], overwrite=True):
    """
    Create root directory and any other sub-directories
    :param path: str - root directory
    :param subdirs: List[str] - list of sub-directories
    :param overwrite: boolean - to overwrite the directory or not
    :return: None
    """
    if overwrite:
        shutil.rmtree(path, ignore_errors=True)

    os.makedirs(path)

    for subdir in subdirs:
        os.makedirs(os.path.join(path, subdir))


def load_checkpoint(model_path):
    """
    Load model given the model path
    :param model_path: str - path to model
    :return: tagger - Nested.trainers.BaseTrainer - the tagger model
             vocab - arabicner.utils.data.Vocab - indexed tags
             train_config - argparse.Namespace - training configurations
    """
    with open("Nested/utils/tag_vocab.pkl", "rb") as fh:
        tag_vocab = pickle.load(fh)

    # Load train configurations from checkpoint
    train_config = Namespace()
    args_path = hf_hub_download(repo_id="SinaLab/Nested", filename="args.json")
    
    with open(args_path, "r") as fh:
        train_config.__dict__ = json.load(fh)

    # Initialize the loss function, not used for inference, but evaluation
    loss = load_object(train_config.loss["fn"], train_config.loss["kwargs"])

    # Load BERT tagger
    model = load_object(train_config.network_config["fn"], train_config.network_config["kwargs"])
    model = torch.nn.DataParallel(model)

    if torch.cuda.is_available():
        model = model.cuda()

    # Update arguments for the tagger
    # Attach the model, loss (used for evaluations cases)
    train_config.trainer_config["kwargs"]["model"] = model
    train_config.trainer_config["kwargs"]["loss"] = loss

    tagger = load_object(train_config.trainer_config["fn"], train_config.trainer_config["kwargs"])
    # checkpoint_path = hf_hub_download(repo_id="SinaLab/Nested", local_dir="checkpoints")
    checkpoint_path = snapshot_download(repo_id="SinaLab/Nested", allow_patterns="checkpoints/")

    tagger.load(checkpoint_path)
    return tagger, tag_vocab, train_config


def set_seed(seed):
    """
    Set the seed for random intialization and set
    CUDANN parameters to ensure determmihstic results across
    multiple runs with the same seed

    :param seed: int
    """
    np.random.seed(seed)
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.enabled = False