| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| import numpy as np |
| import torch |
| from torch import nn |
|
|
|
|
| def sum_tensor(inp, axes, keepdim=False): |
| axes = np.unique(axes).astype(int) |
| if keepdim: |
| for ax in axes: |
| inp = inp.sum(int(ax), keepdim=True) |
| else: |
| for ax in sorted(axes, reverse=True): |
| inp = inp.sum(int(ax)) |
| return inp |
|
|
|
|
| def mean_tensor(inp, axes, keepdim=False): |
| axes = np.unique(axes).astype(int) |
| if keepdim: |
| for ax in axes: |
| inp = inp.mean(int(ax), keepdim=True) |
| else: |
| for ax in sorted(axes, reverse=True): |
| inp = inp.mean(int(ax)) |
| return inp |
|
|
|
|
| def flip(x, dim): |
| """ |
| flips the tensor at dimension dim (mirroring!) |
| :param x: |
| :param dim: |
| :return: |
| """ |
| indices = [slice(None)] * x.dim() |
| indices[dim] = torch.arange(x.size(dim) - 1, -1, -1, |
| dtype=torch.long, device=x.device) |
| return x[tuple(indices)] |
|
|
|
|
|
|