File size: 4,879 Bytes
c668e80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
# -*- coding: utf-8 -*-

import torch
import random
from contextlib import contextmanager
import inspect
import numpy as np
import os
from copy import deepcopy


class RandomShuffler(object):
    """Use random functions while keeping track of the random state to make it
    reproducible and deterministic.
    taken from the torchtext Library"""

    def __init__(self, random_state=None):
        self._random_state = random_state
        if self._random_state is None:
            self._random_state = random.getstate()

    @contextmanager
    def use_internal_state(self):
        """Use a specific RNG state."""
        old_state = random.getstate()
        random.setstate(self._random_state)
        yield
        self._random_state = random.getstate()
        random.setstate(old_state)

    @property
    def random_state(self):
        return deepcopy(self._random_state)

    def __call__(self, data):
        """Shuffle and return a new list."""
        with self.use_internal_state():
            return random.sample(data, len(data))


def check_path(path, exist_ok=False, log=print):
    """Check if `path` exists, makedirs if not else warning/IOError."""
    if os.path.exists(path):
        if exist_ok:
            log(f"path {path} exists, may overwrite...")
        else:
            raise IOError(f"path {path} exists, stop.")
    else:
        if os.path.dirname(path) != "":
            os.makedirs(os.path.dirname(path), exist_ok=True)


def sequence_mask(lengths, max_len=None):
    """
    Creates a boolean mask from sequence lengths.
    """
    batch_size = lengths.numel()
    max_len = max_len or lengths.max()
    return (
        torch.arange(0, max_len, device=lengths.device)
        .type_as(lengths)
        .repeat(batch_size, 1)
        .lt(lengths.unsqueeze(1))
    )


def tile(x, count, dim=0):
    """
    Tiles x on dimension dim count times.
    """
    perm = list(range(len(x.size())))
    if dim != 0:
        perm[0], perm[dim] = perm[dim], perm[0]
        x = x.permute(perm)
    out_size = list(x.size())
    out_size[0] *= count
    batch = x.size(0)
    x = (
        x.contiguous()
        .view(batch, -1)
        .transpose(0, 1)
        .repeat(count, 1)
        .transpose(0, 1)
        .contiguous()
        .view(*out_size)
    )
    if dim != 0:
        x = x.permute(perm).contiguous()
    return x


def use_gpu(opt):
    """
    Creates a boolean if gpu used
    """
    return (hasattr(opt, "gpu_ranks") and len(opt.gpu_ranks) > 0) or (
        hasattr(opt, "gpu") and opt.gpu > -1
    )


def set_random_seed(seed, is_cuda):
    """Sets the random seed."""
    if seed > 0:
        torch.manual_seed(seed)
        # this one is needed for Random Shuffler of batches
        # in multi gpu it ensures datasets are read in the same order
        random.seed(seed)
        # some cudnn methods can be random even after fixing the seed
        # unless you tell it to be deterministic
        torch.backends.cudnn.deterministic = True
        # This one is needed for various tranfroms
        np.random.seed(seed)

    if is_cuda and seed > 0:
        # These ensure same initialization in multi gpu mode
        torch.cuda.manual_seed(seed)


def fn_args(fun):
    """Returns the list of function arguments name."""
    return inspect.getfullargspec(fun).args


def report_matrix(row_label, column_label, matrix):
    header_format = "{:>10.10} " + "{:>10.7} " * len(row_label)
    row_format = "{:>10.10} " + "{:>10.7f} " * len(row_label)
    output = header_format.format("", *row_label) + "\n"
    for word, row in zip(column_label, matrix):
        max_index = row.index(max(row))
        row_format = row_format.replace("{:>10.7f} ", "{:*>10.7f} ", max_index + 1)
        row_format = row_format.replace("{:*>10.7f} ", "{:>10.7f} ", max_index)
        output += row_format.format(word, *row) + "\n"
        row_format = "{:>10.10} " + "{:>10.7f} " * len(row_label)
    return output


def check_model_config(model_config, root):
    # we need to check the model path + any tokenizer path
    for model in model_config["models"]:
        model_path = os.path.join(root, model)
        if not os.path.exists(model_path):
            raise FileNotFoundError(
                "{} from model {} does not exist".format(model_path, model_config["id"])
            )
    if "tokenizer" in model_config.keys():
        if "params" in model_config["tokenizer"].keys():
            for k, v in model_config["tokenizer"]["params"].items():
                if k.endswith("path"):
                    tok_path = os.path.join(root, v)
                    if not os.path.exists(tok_path):
                        raise FileNotFoundError(
                            "{} from model {} does not exist".format(
                                tok_path, model_config["id"]
                            )
                        )