Spaces:
Runtime error
Runtime error
| import torch | |
| def safe_log(z): | |
| return torch.log(z + 1e-7) | |
| def log_sum_exp(value, dim=None, keepdim=False): | |
| """Numerically stable implementation of the operation | |
| value.exp().sum(dim, keepdim).log() | |
| """ | |
| if dim is not None: | |
| m, _ = torch.max(value, dim=dim, keepdim=True) | |
| value0 = value - m | |
| if keepdim is False: | |
| m = m.squeeze(dim) | |
| return m + torch.log(torch.sum(torch.exp(value0), dim=dim, keepdim=keepdim)) | |
| else: | |
| m = torch.max(value) | |
| sum_exp = torch.sum(torch.exp(value - m)) | |
| return m + torch.log(sum_exp) | |
| def generate_grid(zmin, zmax, dz, device, ndim=2): | |
| """generate a 1- or 2-dimensional grid | |
| Returns: Tensor, int | |
| Tensor: The grid tensor with shape (k^2, 2), | |
| where k=(zmax - zmin)/dz | |
| int: k | |
| """ | |
| if ndim == 2: | |
| x = torch.arange(zmin, zmax, dz) | |
| k = x.size(0) | |
| x1 = x.unsqueeze(1).repeat(1, k).view(-1) | |
| x2 = x.repeat(k) | |
| return torch.cat((x1.unsqueeze(-1), x2.unsqueeze(-1)), dim=-1).to(device), k | |
| elif ndim == 1: | |
| return torch.arange(zmin, zmax, dz).unsqueeze(1).to(device) |