| |
| |
| |
| |
|
|
| from typing import List, Optional |
|
|
| import torch |
| from torch import Tensor |
|
|
|
|
| @torch.jit.script |
| def script_skip_tensor_list(x: List[Tensor], mask): |
| res = [xi[mask] if xi.size(0) == mask.size(0) else xi[:, mask] for xi in x] |
| outputs = [] |
| for i, t in enumerate(res): |
| if t.numel() != 0: |
| outputs.append(t) |
| else: |
| outputs.append(x[i]) |
| return outputs |
|
|
|
|
| @torch.jit.script |
| def script_skip_tensor(x: Tensor, mask): |
| |
| if x.size(0) == 0: |
| return x |
| res = x[mask] if x.size(0) == mask.size(0) else x[:, mask] |
| if res.numel() == 0: |
| return x |
| else: |
| return res |
|
|
|
|
| @torch.jit.script |
| def expand_2d_or_3d_tensor(x, trg_dim: int, padding_idx: int): |
| """ |
| Expand 2D/3D tensor on dim=1 |
| """ |
| if x is None: |
| return None |
|
|
| assert x.dim() == 2 or x.dim() == 3 |
| assert trg_dim >= x.size(1), (trg_dim, x.size()) |
| if trg_dim == x.size(1): |
| return x |
|
|
| dims = [x.size(0), trg_dim - x.size(1)] |
| if x.dim() == 3: |
| dims.append(x.size(2)) |
| x = torch.cat([x, torch.zeros(dims).to(x).fill_(padding_idx)], 1) |
|
|
| return x |
|
|
|
|
| @torch.jit.script |
| def coalesce(x: Optional[Tensor], y: Tensor) -> Tensor: |
| return x if x is not None else y |
|
|
|
|
| @torch.jit.script |
| def fill_tensors( |
| x: Optional[Tensor], mask, y: Optional[Tensor], padding_idx: int |
| ) -> Optional[Tensor]: |
| """ |
| Filling tensor x with y at masked positions (dim=0). |
| """ |
| if x is None or x.size()[0] == 0 or y is None: |
| return x |
| assert x.dim() == y.dim() and mask.size(0) == x.size(0) |
| assert x.dim() == 2 or (x.dim() == 3 and x.size(2) == y.size(2)) |
|
|
| n_selected = mask.sum() |
| if n_selected == 0: |
| return x |
| assert n_selected == y.size(0) |
| if n_selected == x.size(0): |
| return y |
|
|
| if x.size(1) < y.size(1): |
| x = expand_2d_or_3d_tensor(x, y.size(1), padding_idx) |
| x[mask] = y |
| elif x.size(1) > y.size(1): |
| x[mask] = torch.tensor(padding_idx).type_as(x) |
| if x.dim() == 2: |
| x[mask, : y.size(1)] = y |
| else: |
| x[mask, : y.size(1), :] = y |
| else: |
| x[mask] = y |
| return x |
|
|