File size: 4,879 Bytes
c668e80 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
# -*- coding: utf-8 -*-
import torch
import random
from contextlib import contextmanager
import inspect
import numpy as np
import os
from copy import deepcopy
class RandomShuffler(object):
"""Use random functions while keeping track of the random state to make it
reproducible and deterministic.
taken from the torchtext Library"""
def __init__(self, random_state=None):
self._random_state = random_state
if self._random_state is None:
self._random_state = random.getstate()
@contextmanager
def use_internal_state(self):
"""Use a specific RNG state."""
old_state = random.getstate()
random.setstate(self._random_state)
yield
self._random_state = random.getstate()
random.setstate(old_state)
@property
def random_state(self):
return deepcopy(self._random_state)
def __call__(self, data):
"""Shuffle and return a new list."""
with self.use_internal_state():
return random.sample(data, len(data))
def check_path(path, exist_ok=False, log=print):
"""Check if `path` exists, makedirs if not else warning/IOError."""
if os.path.exists(path):
if exist_ok:
log(f"path {path} exists, may overwrite...")
else:
raise IOError(f"path {path} exists, stop.")
else:
if os.path.dirname(path) != "":
os.makedirs(os.path.dirname(path), exist_ok=True)
def sequence_mask(lengths, max_len=None):
"""
Creates a boolean mask from sequence lengths.
"""
batch_size = lengths.numel()
max_len = max_len or lengths.max()
return (
torch.arange(0, max_len, device=lengths.device)
.type_as(lengths)
.repeat(batch_size, 1)
.lt(lengths.unsqueeze(1))
)
def tile(x, count, dim=0):
"""
Tiles x on dimension dim count times.
"""
perm = list(range(len(x.size())))
if dim != 0:
perm[0], perm[dim] = perm[dim], perm[0]
x = x.permute(perm)
out_size = list(x.size())
out_size[0] *= count
batch = x.size(0)
x = (
x.contiguous()
.view(batch, -1)
.transpose(0, 1)
.repeat(count, 1)
.transpose(0, 1)
.contiguous()
.view(*out_size)
)
if dim != 0:
x = x.permute(perm).contiguous()
return x
def use_gpu(opt):
"""
Creates a boolean if gpu used
"""
return (hasattr(opt, "gpu_ranks") and len(opt.gpu_ranks) > 0) or (
hasattr(opt, "gpu") and opt.gpu > -1
)
def set_random_seed(seed, is_cuda):
"""Sets the random seed."""
if seed > 0:
torch.manual_seed(seed)
# this one is needed for Random Shuffler of batches
# in multi gpu it ensures datasets are read in the same order
random.seed(seed)
# some cudnn methods can be random even after fixing the seed
# unless you tell it to be deterministic
torch.backends.cudnn.deterministic = True
# This one is needed for various tranfroms
np.random.seed(seed)
if is_cuda and seed > 0:
# These ensure same initialization in multi gpu mode
torch.cuda.manual_seed(seed)
def fn_args(fun):
"""Returns the list of function arguments name."""
return inspect.getfullargspec(fun).args
def report_matrix(row_label, column_label, matrix):
header_format = "{:>10.10} " + "{:>10.7} " * len(row_label)
row_format = "{:>10.10} " + "{:>10.7f} " * len(row_label)
output = header_format.format("", *row_label) + "\n"
for word, row in zip(column_label, matrix):
max_index = row.index(max(row))
row_format = row_format.replace("{:>10.7f} ", "{:*>10.7f} ", max_index + 1)
row_format = row_format.replace("{:*>10.7f} ", "{:>10.7f} ", max_index)
output += row_format.format(word, *row) + "\n"
row_format = "{:>10.10} " + "{:>10.7f} " * len(row_label)
return output
def check_model_config(model_config, root):
# we need to check the model path + any tokenizer path
for model in model_config["models"]:
model_path = os.path.join(root, model)
if not os.path.exists(model_path):
raise FileNotFoundError(
"{} from model {} does not exist".format(model_path, model_config["id"])
)
if "tokenizer" in model_config.keys():
if "params" in model_config["tokenizer"].keys():
for k, v in model_config["tokenizer"]["params"].items():
if k.endswith("path"):
tok_path = os.path.join(root, v)
if not os.path.exists(tok_path):
raise FileNotFoundError(
"{} from model {} does not exist".format(
tok_path, model_config["id"]
)
)
|