|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from torch import nn, Tensor, einsum, IntTensor, FloatTensor, BoolTensor |
|
|
from random import random |
|
|
from einops.layers.torch import Rearrange |
|
|
from einops import rearrange, repeat, reduce, pack, unpack |
|
|
|
|
|
|
|
|
def exists(val): |
|
|
return val is not None |
|
|
|
|
|
def identity(t): |
|
|
return t |
|
|
|
|
|
def default(val, d): |
|
|
return val if exists(val) else d |
|
|
|
|
|
def divisible_by(num, den): |
|
|
return (num % den) == 0 |
|
|
|
|
|
def is_odd(n): |
|
|
return not divisible_by(n, 2) |
|
|
|
|
|
def coin_flip(): |
|
|
return random() < 0.5 |
|
|
|
|
|
def pack_one(t, pattern): |
|
|
return pack([t], pattern) |
|
|
|
|
|
def unpack_one(t, ps, pattern): |
|
|
return unpack(t, ps, pattern)[0] |
|
|
|
|
|
|
|
|
|
|
|
def prob_mask_like(shape, prob, device): |
|
|
if prob == 1: |
|
|
return torch.ones(shape, device = device, dtype = torch.bool) |
|
|
elif prob == 0: |
|
|
return torch.zeros(shape, device = device, dtype = torch.bool) |
|
|
else: |
|
|
return torch.zeros(shape, device = device).float().uniform_(0, 1) < prob |
|
|
|
|
|
def reduce_masks_with_and(*masks): |
|
|
masks = [*filter(exists, masks)] |
|
|
|
|
|
if len(masks) == 0: |
|
|
return None |
|
|
|
|
|
mask, *rest_masks = masks |
|
|
|
|
|
for rest_mask in rest_masks: |
|
|
mask = mask & rest_mask |
|
|
|
|
|
return mask |
|
|
|
|
|
def interpolate_1d(t, length, mode = 'bilinear'): |
|
|
" pytorch does not offer interpolation 1d, so hack by converting to 2d " |
|
|
|
|
|
dtype = t.dtype |
|
|
t = t.float() |
|
|
|
|
|
implicit_one_channel = t.ndim == 2 |
|
|
if implicit_one_channel: |
|
|
t = rearrange(t, 'b n -> b 1 n') |
|
|
|
|
|
t = rearrange(t, 'b d n -> b d n 1') |
|
|
t = F.interpolate(t, (length, 1), mode = mode) |
|
|
t = rearrange(t, 'b d n 1 -> b d n') |
|
|
|
|
|
if implicit_one_channel: |
|
|
t = rearrange(t, 'b 1 n -> b n') |
|
|
|
|
|
t = t.to(dtype) |
|
|
return t |
|
|
|
|
|
def curtail_or_pad(t, target_length): |
|
|
length = t.shape[-2] |
|
|
|
|
|
if length > target_length: |
|
|
t = t[..., :target_length, :] |
|
|
elif length < target_length: |
|
|
t = F.pad(t, (0, 0, 0, target_length - length), value = 0.) |
|
|
|
|
|
return t |
|
|
|
|
|
|
|
|
|
|
|
def mask_from_start_end_indices( |
|
|
seq_len: int, |
|
|
start: Tensor, |
|
|
end: Tensor |
|
|
): |
|
|
assert start.shape == end.shape |
|
|
device = start.device |
|
|
|
|
|
seq = torch.arange(seq_len, device = device, dtype = torch.long) |
|
|
seq = seq.reshape(*((-1,) * start.ndim), seq_len) |
|
|
seq = seq.expand(*start.shape, seq_len) |
|
|
|
|
|
mask = seq >= start[..., None].long() |
|
|
mask &= seq < end[..., None].long() |
|
|
return mask |
|
|
|
|
|
def mask_from_frac_lengths( |
|
|
seq_len: int, |
|
|
frac_lengths: Tensor |
|
|
): |
|
|
device = frac_lengths |
|
|
|
|
|
lengths = (frac_lengths * seq_len).long() |
|
|
max_start = seq_len - lengths |
|
|
|
|
|
rand = torch.zeros_like(frac_lengths).float().uniform_(0, 1) |
|
|
start = (max_start * rand).clamp(min = 0) |
|
|
end = start + lengths |
|
|
|
|
|
return mask_from_start_end_indices(seq_len, start, end) |
|
|
|
|
|
|