| | |
| | |
| | |
| | |
| |
|
| | from typing import List, Optional |
| |
|
| | import torch |
| | from torch import Tensor |
| |
|
| |
|
| | @torch.jit.script |
| | def script_skip_tensor_list(x: List[Tensor], mask): |
| | res = [xi[mask] if xi.size(0) == mask.size(0) else xi[:, mask] for xi in x] |
| | outputs = [] |
| | for i, t in enumerate(res): |
| | if t.numel() != 0: |
| | outputs.append(t) |
| | else: |
| | outputs.append(x[i]) |
| | return outputs |
| |
|
| |
|
| | @torch.jit.script |
| | def script_skip_tensor(x: Tensor, mask): |
| | |
| | if x.size(0) == 0: |
| | return x |
| | res = x[mask] if x.size(0) == mask.size(0) else x[:, mask] |
| | if res.numel() == 0: |
| | return x |
| | else: |
| | return res |
| |
|
| |
|
| | @torch.jit.script |
| | def expand_2d_or_3d_tensor(x, trg_dim: int, padding_idx: int): |
| | """ |
| | Expand 2D/3D tensor on dim=1 |
| | """ |
| | if x is None: |
| | return None |
| |
|
| | assert x.dim() == 2 or x.dim() == 3 |
| | assert trg_dim >= x.size(1), (trg_dim, x.size()) |
| | if trg_dim == x.size(1): |
| | return x |
| |
|
| | dims = [x.size(0), trg_dim - x.size(1)] |
| | if x.dim() == 3: |
| | dims.append(x.size(2)) |
| | x = torch.cat([x, torch.zeros(dims).to(x).fill_(padding_idx)], 1) |
| |
|
| | return x |
| |
|
| |
|
| | @torch.jit.script |
| | def coalesce(x: Optional[Tensor], y: Tensor) -> Tensor: |
| | return x if x is not None else y |
| |
|
| |
|
| | @torch.jit.script |
| | def fill_tensors( |
| | x: Optional[Tensor], mask, y: Optional[Tensor], padding_idx: int |
| | ) -> Optional[Tensor]: |
| | """ |
| | Filling tensor x with y at masked positions (dim=0). |
| | """ |
| | if x is None or x.size()[0] == 0 or y is None: |
| | return x |
| | assert x.dim() == y.dim() and mask.size(0) == x.size(0) |
| | assert x.dim() == 2 or (x.dim() == 3 and x.size(2) == y.size(2)) |
| |
|
| | n_selected = mask.sum() |
| | if n_selected == 0: |
| | return x |
| | assert n_selected == y.size(0) |
| | if n_selected == x.size(0): |
| | return y |
| |
|
| | if x.size(1) < y.size(1): |
| | x = expand_2d_or_3d_tensor(x, y.size(1), padding_idx) |
| | x[mask] = y |
| | elif x.size(1) > y.size(1): |
| | x[mask] = torch.tensor(padding_idx).type_as(x) |
| | if x.dim() == 2: |
| | x[mask, : y.size(1)] = y |
| | else: |
| | x[mask, : y.size(1), :] = y |
| | else: |
| | x[mask] = y |
| | return x |
| |
|