Spaces:
Runtime error
Runtime error
| import numpy as np | |
| import torch | |
| from torch import nn as nn | |
| import isegm.model.initializer as initializer | |
| def select_activation_function(activation): | |
| if isinstance(activation, str): | |
| if activation.lower() == "relu": | |
| return nn.ReLU | |
| elif activation.lower() == "softplus": | |
| return nn.Softplus | |
| else: | |
| raise ValueError(f"Unknown activation type {activation}") | |
| elif isinstance(activation, nn.Module): | |
| return activation | |
| else: | |
| raise ValueError(f"Unknown activation type {activation}") | |
| class BilinearConvTranspose2d(nn.ConvTranspose2d): | |
| def __init__(self, in_channels, out_channels, scale, groups=1): | |
| kernel_size = 2 * scale - scale % 2 | |
| self.scale = scale | |
| super().__init__( | |
| in_channels, | |
| out_channels, | |
| kernel_size=kernel_size, | |
| stride=scale, | |
| padding=1, | |
| groups=groups, | |
| bias=False, | |
| ) | |
| self.apply( | |
| initializer.Bilinear(scale=scale, in_channels=in_channels, groups=groups) | |
| ) | |
| class DistMaps(nn.Module): | |
| def __init__(self, norm_radius, spatial_scale=1.0, cpu_mode=False, use_disks=False): | |
| super(DistMaps, self).__init__() | |
| self.spatial_scale = spatial_scale | |
| self.norm_radius = norm_radius | |
| self.cpu_mode = cpu_mode | |
| self.use_disks = use_disks | |
| if self.cpu_mode: | |
| from isegm.utils.cython import get_dist_maps | |
| self._get_dist_maps = get_dist_maps | |
| def get_coord_features(self, points, batchsize, rows, cols): | |
| if self.cpu_mode: | |
| coords = [] | |
| for i in range(batchsize): | |
| norm_delimeter = ( | |
| 1.0 if self.use_disks else self.spatial_scale * self.norm_radius | |
| ) | |
| coords.append( | |
| self._get_dist_maps( | |
| points[i].cpu().float().numpy(), rows, cols, norm_delimeter | |
| ) | |
| ) | |
| coords = ( | |
| torch.from_numpy(np.stack(coords, axis=0)).to(points.device).float() | |
| ) | |
| else: | |
| num_points = points.shape[1] // 2 | |
| points = points.view(-1, points.size(2)) | |
| points, points_order = torch.split(points, [2, 1], dim=1) | |
| invalid_points = torch.max(points, dim=1, keepdim=False)[0] < 0 | |
| row_array = torch.arange( | |
| start=0, end=rows, step=1, dtype=torch.float32, device=points.device | |
| ) | |
| col_array = torch.arange( | |
| start=0, end=cols, step=1, dtype=torch.float32, device=points.device | |
| ) | |
| coord_rows, coord_cols = torch.meshgrid(row_array, col_array) | |
| coords = ( | |
| torch.stack((coord_rows, coord_cols), dim=0) | |
| .unsqueeze(0) | |
| .repeat(points.size(0), 1, 1, 1) | |
| ) | |
| add_xy = (points * self.spatial_scale).view( | |
| points.size(0), points.size(1), 1, 1 | |
| ) | |
| coords.add_(-add_xy) | |
| if not self.use_disks: | |
| coords.div_(self.norm_radius * self.spatial_scale) | |
| coords.mul_(coords) | |
| coords[:, 0] += coords[:, 1] | |
| coords = coords[:, :1] | |
| coords[invalid_points, :, :, :] = 1e6 | |
| coords = coords.view(-1, num_points, 1, rows, cols) | |
| coords = coords.min(dim=1)[0] # -> (bs * num_masks * 2) x 1 x h x w | |
| coords = coords.view(-1, 2, rows, cols) | |
| if self.use_disks: | |
| coords = (coords <= (self.norm_radius * self.spatial_scale) ** 2).float() | |
| else: | |
| coords.sqrt_().mul_(2).tanh_() | |
| return coords | |
| def forward(self, x, coords): | |
| return self.get_coord_features(coords, x.shape[0], x.shape[2], x.shape[3]) | |
| class ScaleLayer(nn.Module): | |
| def __init__(self, init_value=1.0, lr_mult=1): | |
| super().__init__() | |
| self.lr_mult = lr_mult | |
| self.scale = nn.Parameter( | |
| torch.full((1,), init_value / lr_mult, dtype=torch.float32) | |
| ) | |
| def forward(self, x): | |
| scale = torch.abs(self.scale * self.lr_mult) | |
| return x * scale | |
| class BatchImageNormalize: | |
| def __init__(self, mean, std, dtype=torch.float): | |
| self.mean = torch.as_tensor(mean, dtype=dtype)[None, :, None, None] | |
| self.std = torch.as_tensor(std, dtype=dtype)[None, :, None, None] | |
| def __call__(self, tensor): | |
| tensor = tensor.clone() | |
| tensor.sub_(self.mean.to(tensor.device)).div_(self.std.to(tensor.device)) | |
| return tensor | |