| | |
| | |
| | |
| | |
| | |
| |
|
| | import errno |
| | import functools |
| | import hashlib |
| | import inspect |
| | import io |
| | import os |
| | import random |
| | import socket |
| | import tempfile |
| | import warnings |
| | import zlib |
| | from contextlib import contextmanager |
| |
|
| | from diffq import UniformQuantizer, DiffQuantizer |
| | import torch as th |
| | import tqdm |
| | from torch import distributed |
| | from torch.nn import functional as F |
| |
|
| |
|
| | def center_trim(tensor, reference): |
| | """ |
| | Center trim `tensor` with respect to `reference`, along the last dimension. |
| | `reference` can also be a number, representing the length to trim to. |
| | If the size difference != 0 mod 2, the extra sample is removed on the right side. |
| | """ |
| | if hasattr(reference, "size"): |
| | reference = reference.size(-1) |
| | delta = tensor.size(-1) - reference |
| | if delta < 0: |
| | raise ValueError("tensor must be larger than reference. " f"Delta is {delta}.") |
| | if delta: |
| | tensor = tensor[..., delta // 2:-(delta - delta // 2)] |
| | return tensor |
| |
|
| |
|
| | def average_metric(metric, count=1.): |
| | """ |
| | Average `metric` which should be a float across all hosts. `count` should be |
| | the weight for this particular host (i.e. number of examples). |
| | """ |
| | metric = th.tensor([count, count * metric], dtype=th.float32, device='cuda') |
| | distributed.all_reduce(metric, op=distributed.ReduceOp.SUM) |
| | return metric[1].item() / metric[0].item() |
| |
|
| |
|
| | def free_port(host='', low=20000, high=40000): |
| | """ |
| | Return a port number that is most likely free. |
| | This could suffer from a race condition although |
| | it should be quite rare. |
| | """ |
| | sock = socket.socket() |
| | while True: |
| | port = random.randint(low, high) |
| | try: |
| | sock.bind((host, port)) |
| | except OSError as error: |
| | if error.errno == errno.EADDRINUSE: |
| | continue |
| | raise |
| | return port |
| |
|
| |
|
| | def sizeof_fmt(num, suffix='B'): |
| | """ |
| | Given `num` bytes, return human readable size. |
| | Taken from https://stackoverflow.com/a/1094933 |
| | """ |
| | for unit in ['', 'Ki', 'Mi', 'Gi', 'Ti', 'Pi', 'Ei', 'Zi']: |
| | if abs(num) < 1024.0: |
| | return "%3.1f%s%s" % (num, unit, suffix) |
| | num /= 1024.0 |
| | return "%.1f%s%s" % (num, 'Yi', suffix) |
| |
|
| |
|
| | def human_seconds(seconds, display='.2f'): |
| | """ |
| | Given `seconds` seconds, return human readable duration. |
| | """ |
| | value = seconds * 1e6 |
| | ratios = [1e3, 1e3, 60, 60, 24] |
| | names = ['us', 'ms', 's', 'min', 'hrs', 'days'] |
| | last = names.pop(0) |
| | for name, ratio in zip(names, ratios): |
| | if value / ratio < 0.3: |
| | break |
| | value /= ratio |
| | last = name |
| | return f"{format(value, display)} {last}" |
| |
|
| |
|
| | class TensorChunk: |
| | def __init__(self, tensor, offset=0, length=None): |
| | total_length = tensor.shape[-1] |
| | assert offset >= 0 |
| | assert offset < total_length |
| |
|
| | if length is None: |
| | length = total_length - offset |
| | else: |
| | length = min(total_length - offset, length) |
| |
|
| | self.tensor = tensor |
| | self.offset = offset |
| | self.length = length |
| | self.device = tensor.device |
| |
|
| | @property |
| | def shape(self): |
| | shape = list(self.tensor.shape) |
| | shape[-1] = self.length |
| | return shape |
| |
|
| | def padded(self, target_length): |
| | delta = target_length - self.length |
| | total_length = self.tensor.shape[-1] |
| | assert delta >= 0 |
| |
|
| | start = self.offset - delta // 2 |
| | end = start + target_length |
| |
|
| | correct_start = max(0, start) |
| | correct_end = min(total_length, end) |
| |
|
| | pad_left = correct_start - start |
| | pad_right = end - correct_end |
| |
|
| | out = F.pad(self.tensor[..., correct_start:correct_end], (pad_left, pad_right)) |
| | assert out.shape[-1] == target_length |
| | return out |
| |
|
| |
|
| | def tensor_chunk(tensor_or_chunk): |
| | if isinstance(tensor_or_chunk, TensorChunk): |
| | return tensor_or_chunk |
| | else: |
| | assert isinstance(tensor_or_chunk, th.Tensor) |
| | return TensorChunk(tensor_or_chunk) |
| |
|
| |
|
| | def apply_model(model, mix, shifts=None, split=False, |
| | overlap=0.25, transition_power=1., progress=False): |
| | """ |
| | Apply model to a given mixture. |
| | |
| | Args: |
| | shifts (int): if > 0, will shift in time `mix` by a random amount between 0 and 0.5 sec |
| | and apply the oppositve shift to the output. This is repeated `shifts` time and |
| | all predictions are averaged. This effectively makes the model time equivariant |
| | and improves SDR by up to 0.2 points. |
| | split (bool): if True, the input will be broken down in 8 seconds extracts |
| | and predictions will be performed individually on each and concatenated. |
| | Useful for model with large memory footprint like Tasnet. |
| | progress (bool): if True, show a progress bar (requires split=True) |
| | """ |
| | assert transition_power >= 1, "transition_power < 1 leads to weird behavior." |
| | device = mix.device |
| | channels, length = mix.shape |
| | if split: |
| | out = th.zeros(len(model.sources), channels, length, device=device) |
| | sum_weight = th.zeros(length, device=device) |
| | segment = model.segment_length |
| | stride = int((1 - overlap) * segment) |
| | offsets = range(0, length, stride) |
| | scale = stride / model.samplerate |
| | if progress: |
| | offsets = tqdm.tqdm(offsets, unit_scale=scale, ncols=120, unit='seconds') |
| | |
| | |
| | |
| | weight = th.cat([th.arange(1, segment // 2 + 1), |
| | th.arange(segment - segment // 2, 0, -1)]).to(device) |
| | assert len(weight) == segment |
| | |
| | |
| | weight = (weight / weight.max())**transition_power |
| | for offset in offsets: |
| | chunk = TensorChunk(mix, offset, segment) |
| | chunk_out = apply_model(model, chunk, shifts=shifts) |
| | chunk_length = chunk_out.shape[-1] |
| | out[..., offset:offset + segment] += weight[:chunk_length] * chunk_out |
| | sum_weight[offset:offset + segment] += weight[:chunk_length] |
| | offset += segment |
| | assert sum_weight.min() > 0 |
| | out /= sum_weight |
| | return out |
| | elif shifts: |
| | max_shift = int(0.5 * model.samplerate) |
| | mix = tensor_chunk(mix) |
| | padded_mix = mix.padded(length + 2 * max_shift) |
| | out = 0 |
| | for _ in range(shifts): |
| | offset = random.randint(0, max_shift) |
| | shifted = TensorChunk(padded_mix, offset, length + max_shift - offset) |
| | shifted_out = apply_model(model, shifted) |
| | out += shifted_out[..., max_shift - offset:] |
| | out /= shifts |
| | return out |
| | else: |
| | valid_length = model.valid_length(length) |
| | mix = tensor_chunk(mix) |
| | padded_mix = mix.padded(valid_length) |
| | with th.no_grad(): |
| | out = model(padded_mix.unsqueeze(0))[0] |
| | return center_trim(out, length) |
| |
|
| |
|
| | @contextmanager |
| | def temp_filenames(count, delete=True): |
| | names = [] |
| | try: |
| | for _ in range(count): |
| | names.append(tempfile.NamedTemporaryFile(delete=False).name) |
| | yield names |
| | finally: |
| | if delete: |
| | for name in names: |
| | os.unlink(name) |
| |
|
| |
|
| | def get_quantizer(model, args, optimizer=None): |
| | quantizer = None |
| | if args.diffq: |
| | quantizer = DiffQuantizer( |
| | model, min_size=args.q_min_size, group_size=8) |
| | if optimizer is not None: |
| | quantizer.setup_optimizer(optimizer) |
| | elif args.qat: |
| | quantizer = UniformQuantizer( |
| | model, bits=args.qat, min_size=args.q_min_size) |
| | return quantizer |
| |
|
| |
|
| | def load_model(path, strict=False): |
| | with warnings.catch_warnings(): |
| | warnings.simplefilter("ignore") |
| | load_from = path |
| | package = th.load(load_from, 'cpu') |
| |
|
| | klass = package["klass"] |
| | args = package["args"] |
| | kwargs = package["kwargs"] |
| |
|
| | if strict: |
| | model = klass(*args, **kwargs) |
| | else: |
| | sig = inspect.signature(klass) |
| | for key in list(kwargs): |
| | if key not in sig.parameters: |
| | warnings.warn("Dropping inexistant parameter " + key) |
| | del kwargs[key] |
| | model = klass(*args, **kwargs) |
| |
|
| | state = package["state"] |
| | training_args = package["training_args"] |
| | quantizer = get_quantizer(model, training_args) |
| |
|
| | set_state(model, quantizer, state) |
| | return model |
| |
|
| |
|
| | def get_state(model, quantizer): |
| | if quantizer is None: |
| | state = {k: p.data.to('cpu') for k, p in model.state_dict().items()} |
| | else: |
| | state = quantizer.get_quantized_state() |
| | buf = io.BytesIO() |
| | th.save(state, buf) |
| | state = {'compressed': zlib.compress(buf.getvalue())} |
| | return state |
| |
|
| |
|
| | def set_state(model, quantizer, state): |
| | if quantizer is None: |
| | model.load_state_dict(state) |
| | else: |
| | buf = io.BytesIO(zlib.decompress(state["compressed"])) |
| | state = th.load(buf, "cpu") |
| | quantizer.restore_quantized_state(state) |
| |
|
| | return state |
| |
|
| |
|
| | def save_state(state, path): |
| | buf = io.BytesIO() |
| | th.save(state, buf) |
| | sig = hashlib.sha256(buf.getvalue()).hexdigest()[:8] |
| |
|
| | path = path.parent / (path.stem + "-" + sig + path.suffix) |
| | path.write_bytes(buf.getvalue()) |
| |
|
| |
|
| | def save_model(model, quantizer, training_args, path): |
| | args, kwargs = model._init_args_kwargs |
| | klass = model.__class__ |
| |
|
| | state = get_state(model, quantizer) |
| |
|
| | save_to = path |
| | package = { |
| | 'klass': klass, |
| | 'args': args, |
| | 'kwargs': kwargs, |
| | 'state': state, |
| | 'training_args': training_args, |
| | } |
| | th.save(package, save_to) |
| |
|
| |
|
| | def capture_init(init): |
| | @functools.wraps(init) |
| | def __init__(self, *args, **kwargs): |
| | self._init_args_kwargs = (args, kwargs) |
| | init(self, *args, **kwargs) |
| |
|
| | return __init__ |
| |
|