| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | from itertools import chain |
| | from typing import Callable, Dict, List, Tuple |
| | import einops |
| | import torch |
| |
|
| |
|
| | def flatten( |
| | hid: List[torch.FloatTensor], |
| | ) -> Tuple[ |
| | torch.FloatTensor, |
| | torch.LongTensor, |
| | ]: |
| | assert len(hid) > 0 |
| | shape = torch.stack([torch.tensor(x.shape[:-1], device=hid[0].device) for x in hid]) |
| | hid = torch.cat([x.flatten(0, -2) for x in hid]) |
| | return hid, shape |
| |
|
| |
|
| | def unflatten( |
| | hid: torch.FloatTensor, |
| | hid_shape: torch.LongTensor, |
| | ) -> List[torch.Tensor]: |
| | hid_len = hid_shape.prod(-1) |
| | hid = hid.split(hid_len.tolist()) |
| | hid = [x.unflatten(0, s.tolist()) for x, s in zip(hid, hid_shape)] |
| | return hid |
| |
|
| |
|
| | def concat( |
| | vid: torch.FloatTensor, |
| | txt: torch.FloatTensor, |
| | vid_len: torch.LongTensor, |
| | txt_len: torch.LongTensor, |
| | ) -> torch.FloatTensor: |
| | vid = torch.split(vid, vid_len.tolist()) |
| | txt = torch.split(txt, txt_len.tolist()) |
| | return torch.cat(list(chain(*zip(vid, txt)))) |
| |
|
| |
|
| | def concat_idx( |
| | vid_len: torch.LongTensor, |
| | txt_len: torch.LongTensor, |
| | ) -> Tuple[ |
| | Callable, |
| | Callable, |
| | ]: |
| | device = vid_len.device |
| | vid_idx = torch.arange(vid_len.sum(), device=device) |
| | txt_idx = torch.arange(len(vid_idx), len(vid_idx) + txt_len.sum(), device=device) |
| | tgt_idx = concat(vid_idx, txt_idx, vid_len, txt_len) |
| | src_idx = torch.argsort(tgt_idx) |
| | return ( |
| | lambda vid, txt: torch.index_select(torch.cat([vid, txt]), 0, tgt_idx), |
| | lambda all: torch.index_select(all, 0, src_idx).split([len(vid_idx), len(txt_idx)]), |
| | ) |
| |
|
| |
|
| | def unconcat( |
| | all: torch.FloatTensor, |
| | vid_len: torch.LongTensor, |
| | txt_len: torch.LongTensor, |
| | ) -> Tuple[ |
| | torch.FloatTensor, |
| | torch.FloatTensor, |
| | ]: |
| | interleave_len = list(chain(*zip(vid_len.tolist(), txt_len.tolist()))) |
| | all = all.split(interleave_len) |
| | vid = torch.cat(all[0::2]) |
| | txt = torch.cat(all[1::2]) |
| | return vid, txt |
| |
|
| |
|
| | def repeat_concat( |
| | vid: torch.FloatTensor, |
| | txt: torch.FloatTensor, |
| | vid_len: torch.LongTensor, |
| | txt_len: torch.LongTensor, |
| | txt_repeat: List, |
| | ) -> torch.FloatTensor: |
| | vid = torch.split(vid, vid_len.tolist()) |
| | txt = torch.split(txt, txt_len.tolist()) |
| | txt = [[x] * n for x, n in zip(txt, txt_repeat)] |
| | txt = list(chain(*txt)) |
| | return torch.cat(list(chain(*zip(vid, txt)))) |
| |
|
| |
|
| | def repeat_concat_idx( |
| | vid_len: torch.LongTensor, |
| | txt_len: torch.LongTensor, |
| | txt_repeat: torch.LongTensor, |
| | ) -> Tuple[ |
| | Callable, |
| | Callable, |
| | ]: |
| | device = vid_len.device |
| | vid_idx = torch.arange(vid_len.sum(), device=device) |
| | txt_idx = torch.arange(len(vid_idx), len(vid_idx) + txt_len.sum(), device=device) |
| | txt_repeat_list = txt_repeat.tolist() |
| | tgt_idx = repeat_concat(vid_idx, txt_idx, vid_len, txt_len, txt_repeat) |
| | src_idx = torch.argsort(tgt_idx) |
| | txt_idx_len = len(tgt_idx) - len(vid_idx) |
| | repeat_txt_len = (txt_len * txt_repeat).tolist() |
| |
|
| | def unconcat_coalesce(all): |
| | """ |
| | Un-concat vid & txt, and coalesce the repeated txt. |
| | e.g. vid [0 1 2 3 4 5 6 7 8] -> 3 splits -> [0 1 2] [3 4 5] [6 7 8] |
| | txt [9 10] |
| | repeat_concat ==> [0 1 2 9 10 3 4 5 9 10 6 7 8 9 10] |
| | 1. argsort re-index ==> [0 1 2 3 4 5 6 7 8 9 9 9 10 10 10] |
| | split ==> vid_out [0 1 2 3 4 5 6 7 8] txt_out [9 9 9 10 10 10] |
| | 2. reshape & mean for each sample to coalesce the repeated txt. |
| | """ |
| | vid_out, txt_out = all[src_idx].split([len(vid_idx), txt_idx_len]) |
| | txt_out_coalesced = [] |
| | for txt, repeat_time in zip(txt_out.split(repeat_txt_len), txt_repeat_list): |
| | txt = txt.reshape(-1, repeat_time, *txt.shape[1:]).mean(1) |
| | txt_out_coalesced.append(txt) |
| | return vid_out, torch.cat(txt_out_coalesced) |
| |
|
| | |
| | |
| | return ( |
| | lambda vid, txt: torch.cat([vid, txt])[tgt_idx], |
| | lambda all: unconcat_coalesce(all), |
| | ) |
| |
|
| |
|
| | def rearrange( |
| | hid: torch.FloatTensor, |
| | hid_shape: torch.LongTensor, |
| | pattern: str, |
| | **kwargs: Dict[str, int], |
| | ) -> Tuple[ |
| | torch.FloatTensor, |
| | torch.LongTensor, |
| | ]: |
| | return flatten([einops.rearrange(h, pattern, **kwargs) for h in unflatten(hid, hid_shape)]) |
| |
|
| |
|
| | def rearrange_idx( |
| | hid_shape: torch.LongTensor, |
| | pattern: str, |
| | **kwargs: Dict[str, int], |
| | ) -> Tuple[Callable, Callable, torch.LongTensor]: |
| | hid_idx = torch.arange(hid_shape.prod(-1).sum(), device=hid_shape.device).unsqueeze(-1) |
| | tgt_idx, tgt_shape = rearrange(hid_idx, hid_shape, pattern, **kwargs) |
| | tgt_idx = tgt_idx.squeeze(-1) |
| | src_idx = torch.argsort(tgt_idx) |
| | return ( |
| | lambda hid: torch.index_select(hid, 0, tgt_idx), |
| | lambda hid: torch.index_select(hid, 0, src_idx), |
| | tgt_shape, |
| | ) |
| |
|
| |
|
| | def repeat( |
| | hid: torch.FloatTensor, |
| | hid_shape: torch.LongTensor, |
| | pattern: str, |
| | **kwargs: Dict[str, torch.LongTensor], |
| | ) -> Tuple[ |
| | torch.FloatTensor, |
| | torch.LongTensor, |
| | ]: |
| | hid = unflatten(hid, hid_shape) |
| | kwargs = [{k: v[i].item() for k, v in kwargs.items()} for i in range(len(hid))] |
| | return flatten([einops.repeat(h, pattern, **a) for h, a in zip(hid, kwargs)]) |
| |
|
| |
|
| | def pack( |
| | samples: List[torch.Tensor], |
| | ) -> Tuple[ |
| | List[torch.Tensor], |
| | List[List[int]], |
| | ]: |
| | batches = {} |
| | indices = {} |
| | for i, sample in enumerate(samples): |
| | shape = sample.shape |
| | batches[shape] = batches.get(shape, []) |
| | indices[shape] = indices.get(shape, []) |
| | batches[shape].append(sample) |
| | indices[shape].append(i) |
| |
|
| | batches = list(map(torch.stack, batches.values())) |
| | indices = list(indices.values()) |
| | return batches, indices |
| |
|
| |
|
| | def unpack( |
| | batches: List[torch.Tensor], |
| | indices: List[List[int]], |
| | ) -> List[torch.Tensor]: |
| | samples = [None] * (max(chain(*indices)) + 1) |
| | for batch, index in zip(batches, indices): |
| | for sample, i in zip(batch.unbind(), index): |
| | samples[i] = sample |
| | return samples |
| |
|
| |
|
| | def window( |
| | hid: torch.FloatTensor, |
| | hid_shape: torch.LongTensor, |
| | window_fn: Callable[[torch.Tensor], List[torch.Tensor]], |
| | ): |
| | hid = unflatten(hid, hid_shape) |
| | hid = list(map(window_fn, hid)) |
| | hid_windows = torch.tensor(list(map(len, hid)), device=hid_shape.device) |
| | hid, hid_shape = flatten(list(chain(*hid))) |
| | return hid, hid_shape, hid_windows |
| |
|
| |
|
| | def window_idx( |
| | hid_shape: torch.LongTensor, |
| | window_fn: Callable[[torch.Tensor], List[torch.Tensor]], |
| | ): |
| | hid_idx = torch.arange(hid_shape.prod(-1).sum(), device=hid_shape.device).unsqueeze(-1) |
| | tgt_idx, tgt_shape, tgt_windows = window(hid_idx, hid_shape, window_fn) |
| | tgt_idx = tgt_idx.squeeze(-1) |
| | src_idx = torch.argsort(tgt_idx) |
| | return ( |
| | lambda hid: torch.index_select(hid, 0, tgt_idx), |
| | lambda hid: torch.index_select(hid, 0, src_idx), |
| | tgt_shape, |
| | tgt_windows, |
| | ) |
| |
|