| import glob
|
| import logging
|
| import re
|
| import time
|
| from collections import defaultdict
|
| import os
|
| import sys
|
| import shutil
|
| import types
|
| import numpy as np
|
| import torch
|
| import torch.nn.functional as F
|
| import torch.distributed as dist
|
| from torch import nn
|
|
|
|
|
| def tensors_to_scalars(metrics):
|
| new_metrics = {}
|
| for k, v in metrics.items():
|
| if isinstance(v, torch.Tensor):
|
| v = v.item()
|
| if type(v) is dict:
|
| v = tensors_to_scalars(v)
|
| new_metrics[k] = v
|
| return new_metrics
|
|
|
|
|
| class AvgrageMeter(object):
|
|
|
| def __init__(self):
|
| self.reset()
|
|
|
| def reset(self):
|
| self.avg = 0
|
| self.sum = 0
|
| self.cnt = 0
|
|
|
| def update(self, val, n=1):
|
| self.sum += val * n
|
| self.cnt += n
|
| self.avg = self.sum / self.cnt
|
|
|
|
|
| def collate_1d(values, pad_idx=0, left_pad=False, shift_right=False, max_len=None, shift_id=1):
|
| """Convert a list of 1d tensors into a padded 2d tensor."""
|
| size = max(v.size(0) for v in values) if max_len is None else max_len
|
| res = values[0].new(len(values), size).fill_(pad_idx)
|
|
|
| def copy_tensor(src, dst):
|
| assert dst.numel() == src.numel()
|
| if shift_right:
|
| dst[1:] = src[:-1]
|
| dst[0] = shift_id
|
| else:
|
| dst.copy_(src)
|
|
|
| for i, v in enumerate(values):
|
| copy_tensor(v, res[i][size - len(v):] if left_pad else res[i][:len(v)])
|
| return res
|
|
|
|
|
| def collate_2d(values, pad_idx=0, left_pad=False, shift_right=False, max_len=None):
|
| """Convert a list of 2d tensors into a padded 3d tensor."""
|
| size = max(v.size(0) for v in values) if max_len is None else max_len
|
| res = values[0].new(len(values), size, values[0].shape[1]).fill_(pad_idx)
|
|
|
| def copy_tensor(src, dst):
|
| assert dst.numel() == src.numel()
|
| if shift_right:
|
| dst[1:] = src[:-1]
|
| else:
|
| dst.copy_(src)
|
|
|
| for i, v in enumerate(values):
|
| copy_tensor(v, res[i][size - len(v):] if left_pad else res[i][:len(v)])
|
| return res
|
|
|
|
|
| def _is_batch_full(batch, num_tokens, max_tokens, max_sentences):
|
| if len(batch) == 0:
|
| return 0
|
| if len(batch) == max_sentences:
|
| return 1
|
| if num_tokens > max_tokens:
|
| return 1
|
| return 0
|
|
|
|
|
| def batch_by_size(
|
| indices, num_tokens_fn, max_tokens=None, max_sentences=None,
|
| required_batch_size_multiple=1, distributed=False
|
| ):
|
| """
|
| Yield mini-batches of indices bucketed by size. Batches may contain
|
| sequences of different lengths.
|
|
|
| Args:
|
| indices (List[int]): ordered list of dataset indices
|
| num_tokens_fn (callable): function that returns the number of tokens at
|
| a given index
|
| max_tokens (int, optional): max number of tokens in each batch
|
| (default: None).
|
| max_sentences (int, optional): max number of sentences in each
|
| batch (default: None).
|
| required_batch_size_multiple (int, optional): require batch size to
|
| be a multiple of N (default: 1).
|
| """
|
| max_tokens = max_tokens if max_tokens is not None else sys.maxsize
|
| max_sentences = max_sentences if max_sentences is not None else sys.maxsize
|
| bsz_mult = required_batch_size_multiple
|
|
|
| if isinstance(indices, types.GeneratorType):
|
| indices = np.fromiter(indices, dtype=np.int64, count=-1)
|
|
|
| sample_len = 0
|
| sample_lens = []
|
| batch = []
|
| batches = []
|
| for i in range(len(indices)):
|
| idx = indices[i]
|
| num_tokens = num_tokens_fn(idx)
|
| sample_lens.append(num_tokens)
|
| sample_len = max(sample_len, num_tokens)
|
| assert sample_len <= max_tokens, (
|
| "sentence at index {} of size {} exceeds max_tokens "
|
| "limit of {}!".format(idx, sample_len, max_tokens)
|
| )
|
| num_tokens = (len(batch) + 1) * sample_len
|
|
|
| if _is_batch_full(batch, num_tokens, max_tokens, max_sentences):
|
| mod_len = max(
|
| bsz_mult * (len(batch) // bsz_mult),
|
| len(batch) % bsz_mult,
|
| )
|
| batches.append(batch[:mod_len])
|
| batch = batch[mod_len:]
|
| sample_lens = sample_lens[mod_len:]
|
| sample_len = max(sample_lens) if len(sample_lens) > 0 else 0
|
| batch.append(idx)
|
| if len(batch) > 0:
|
| batches.append(batch)
|
| return batches
|
|
|
|
|
| def make_positions(tensor, padding_idx):
|
| """Replace non-padding symbols with their position numbers.
|
|
|
| Position numbers begin at padding_idx+1. Padding symbols are ignored.
|
| """
|
|
|
|
|
|
|
|
|
| mask = tensor.ne(padding_idx).int()
|
| return (
|
| torch.cumsum(mask, dim=1).type_as(mask) * mask
|
| ).long() + padding_idx
|
|
|
|
|
| def softmax(x, dim):
|
| return F.softmax(x, dim=dim, dtype=torch.float32)
|
|
|
|
|
| def unpack_dict_to_list(samples):
|
| samples_ = []
|
| bsz = samples.get('outputs').size(0)
|
| for i in range(bsz):
|
| res = {}
|
| for k, v in samples.items():
|
| try:
|
| res[k] = v[i]
|
| except:
|
| pass
|
| samples_.append(res)
|
| return samples_
|
|
|
|
|
| def load_ckpt(cur_model, ckpt_base_dir, prefix_in_ckpt='model', force=True, strict=True):
|
| if os.path.isfile(ckpt_base_dir):
|
| base_dir = os.path.dirname(ckpt_base_dir)
|
| checkpoint_path = [ckpt_base_dir]
|
| else:
|
| base_dir = ckpt_base_dir
|
| checkpoint_path = sorted(glob.glob(f'{base_dir}/model_ckpt_steps_*.ckpt'), key=
|
| lambda x: int(re.findall(f'{base_dir}/model_ckpt_steps_(\d+).ckpt', x)[0]))
|
| if len(checkpoint_path) > 0:
|
| checkpoint_path = checkpoint_path[-1]
|
| state_dict = torch.load(checkpoint_path, map_location="cpu")["state_dict"]
|
| state_dict = {k[len(prefix_in_ckpt) + 1:]: v for k, v in state_dict.items()
|
| if k.startswith(f'{prefix_in_ckpt}.')}
|
| if not strict:
|
| cur_model_state_dict = cur_model.state_dict()
|
| unmatched_keys = []
|
| for key, param in state_dict.items():
|
| if key in cur_model_state_dict:
|
| new_param = cur_model_state_dict[key]
|
| if new_param.shape != param.shape:
|
| unmatched_keys.append(key)
|
| print("| Unmatched keys: ", key, new_param.shape, param.shape)
|
| for key in unmatched_keys:
|
| del state_dict[key]
|
| cur_model.load_state_dict(state_dict, strict=strict)
|
| print(f"| load '{prefix_in_ckpt}' from '{checkpoint_path}'.")
|
| else:
|
| e_msg = f"| ckpt not found in {base_dir}."
|
| if force:
|
| assert False, e_msg
|
| else:
|
| print(e_msg)
|
|
|
|
|
| def remove_padding(x, padding_idx=0):
|
| if x is None:
|
| return None
|
| assert len(x.shape) in [1, 2]
|
| if len(x.shape) == 2:
|
| return x[np.abs(x).sum(-1) != padding_idx]
|
| elif len(x.shape) == 1:
|
| return x[x != padding_idx]
|
|
|
|
|
| class Timer:
|
| timer_map = {}
|
|
|
| def __init__(self, name, print_time=False):
|
| if name not in Timer.timer_map:
|
| Timer.timer_map[name] = 0
|
| self.name = name
|
| self.print_time = print_time
|
|
|
| def __enter__(self):
|
| self.t = time.time()
|
|
|
| def __exit__(self, exc_type, exc_val, exc_tb):
|
| Timer.timer_map[self.name] += time.time() - self.t
|
| if self.print_time:
|
| print(self.name, Timer.timer_map[self.name])
|
|
|
|
|
| def print_arch(model, model_name='model'):
|
| print(f"| {model_name} Arch: ", model)
|
| num_params(model, model_name=model_name)
|
|
|
|
|
| def num_params(model, print_out=True, model_name="model"):
|
| parameters = filter(lambda p: p.requires_grad, model.parameters())
|
| parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
|
| if print_out:
|
| print(f'| {model_name} Trainable Parameters: %.3fM' % parameters)
|
| return parameters
|
|
|