| |
|
|
| |
|
|
| |
| |
| |
| |
| |
| |
|
|
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
|
|
| |
|
|
| import torch |
| import torch.nn as nn |
|
|
|
|
| @torch.jit.script |
| def exp_attractor(dx, alpha: float = 300, gamma: int = 2): |
| """Exponential attractor: dc = exp(-alpha*|dx|^gamma) * dx , where dx = a - c, a = attractor point, c = bin center, dc = shift in bin centermmary for exp_attractor |
| |
| Args: |
| dx (torch.Tensor): The difference tensor dx = Ai - Cj, where Ai is the attractor point and Cj is the bin center. |
| alpha (float, optional): Proportional Attractor strength. Determines the absolute strength. Lower alpha = greater attraction. Defaults to 300. |
| gamma (int, optional): Exponential Attractor strength. Determines the "region of influence" and indirectly number of bin centers affected. Lower gamma = farther reach. Defaults to 2. |
| |
| Returns: |
| torch.Tensor : Delta shifts - dc; New bin centers = Old bin centers + dc |
| """ |
| return torch.exp(-alpha*(torch.abs(dx)**gamma)) * (dx) |
|
|
|
|
| @torch.jit.script |
| def inv_attractor(dx, alpha: float = 300, gamma: int = 2): |
| """Inverse attractor: dc = dx / (1 + alpha*dx^gamma), where dx = a - c, a = attractor point, c = bin center, dc = shift in bin center |
| This is the default one according to the accompanying paper. |
| |
| Args: |
| dx (torch.Tensor): The difference tensor dx = Ai - Cj, where Ai is the attractor point and Cj is the bin center. |
| alpha (float, optional): Proportional Attractor strength. Determines the absolute strength. Lower alpha = greater attraction. Defaults to 300. |
| gamma (int, optional): Exponential Attractor strength. Determines the "region of influence" and indirectly number of bin centers affected. Lower gamma = farther reach. Defaults to 2. |
| |
| Returns: |
| torch.Tensor: Delta shifts - dc; New bin centers = Old bin centers + dc |
| """ |
| return dx.div(1+alpha*dx.pow(gamma)) |
|
|
|
|
| class AttractorLayer(nn.Module): |
| def __init__(self, in_features, n_bins, n_attractors=16, mlp_dim=128, min_depth=1e-3, max_depth=10, |
| alpha=300, gamma=2, kind='sum', attractor_type='exp', memory_efficient=False): |
| """ |
| Attractor layer for bin centers. Bin centers are bounded on the interval (min_depth, max_depth) |
| """ |
| super().__init__() |
|
|
| self.n_attractors = n_attractors |
| self.n_bins = n_bins |
| self.min_depth = min_depth |
| self.max_depth = max_depth |
| self.alpha = alpha |
| self.gamma = gamma |
| self.kind = kind |
| self.attractor_type = attractor_type |
| self.memory_efficient = memory_efficient |
|
|
| self._net = nn.Sequential( |
| nn.Conv2d(in_features, mlp_dim, 1, 1, 0), |
| nn.ReLU(inplace=True), |
| nn.Conv2d(mlp_dim, n_attractors*2, 1, 1, 0), |
| nn.ReLU(inplace=True) |
| ) |
|
|
| def forward(self, x, b_prev, prev_b_embedding=None, interpolate=True, is_for_query=False): |
| """ |
| Args: |
| x (torch.Tensor) : feature block; shape - n, c, h, w |
| b_prev (torch.Tensor) : previous bin centers normed; shape - n, prev_nbins, h, w |
| |
| Returns: |
| tuple(torch.Tensor,torch.Tensor) : new bin centers normed and scaled; shape - n, nbins, h, w |
| """ |
| if prev_b_embedding is not None: |
| if interpolate: |
| prev_b_embedding = nn.functional.interpolate( |
| prev_b_embedding, x.shape[-2:], mode='bilinear', align_corners=True) |
| x = x + prev_b_embedding |
|
|
| A = self._net(x) |
| eps = 1e-3 |
| A = A + eps |
| n, c, h, w = A.shape |
| A = A.view(n, self.n_attractors, 2, h, w) |
| A_normed = A / A.sum(dim=2, keepdim=True) |
| A_normed = A[:, :, 0, ...] |
|
|
| b_prev = nn.functional.interpolate( |
| b_prev, (h, w), mode='bilinear', align_corners=True) |
| b_centers = b_prev |
|
|
| if self.attractor_type == 'exp': |
| dist = exp_attractor |
| else: |
| dist = inv_attractor |
|
|
| if not self.memory_efficient: |
| func = {'mean': torch.mean, 'sum': torch.sum}[self.kind] |
| |
| delta_c = func(dist(A_normed.unsqueeze( |
| 2) - b_centers.unsqueeze(1)), dim=1) |
| else: |
| delta_c = torch.zeros_like(b_centers, device=b_centers.device) |
| for i in range(self.n_attractors): |
| |
| delta_c += dist(A_normed[:, i, ...].unsqueeze(1) - b_centers) |
|
|
| if self.kind == 'mean': |
| delta_c = delta_c / self.n_attractors |
|
|
| b_new_centers = b_centers + delta_c |
| B_centers = (self.max_depth - self.min_depth) * \ |
| b_new_centers + self.min_depth |
| B_centers, _ = torch.sort(B_centers, dim=1) |
| B_centers = torch.clip(B_centers, self.min_depth, self.max_depth) |
| return b_new_centers, B_centers |
|
|
|
|
| class AttractorLayerUnnormed(nn.Module): |
| def __init__(self, in_features, n_bins, n_attractors=16, mlp_dim=128, min_depth=1e-3, max_depth=10, |
| alpha=300, gamma=2, kind='sum', attractor_type='exp', memory_efficient=False): |
| """ |
| Attractor layer for bin centers. Bin centers are unbounded |
| """ |
| super().__init__() |
|
|
| self.n_attractors = n_attractors |
| self.n_bins = n_bins |
| self.min_depth = min_depth |
| self.max_depth = max_depth |
| self.alpha = alpha |
| self.gamma = gamma |
| self.kind = kind |
| self.attractor_type = attractor_type |
| self.memory_efficient = memory_efficient |
|
|
| self._net = nn.Sequential( |
| nn.Conv2d(in_features, mlp_dim, 1, 1, 0), |
| nn.ReLU(inplace=True), |
| nn.Conv2d(mlp_dim, n_attractors, 1, 1, 0), |
| nn.Softplus() |
| ) |
|
|
| def forward(self, x, b_prev, prev_b_embedding=None, interpolate=True, is_for_query=False): |
| """ |
| Args: |
| x (torch.Tensor) : feature block; shape - n, c, h, w |
| b_prev (torch.Tensor) : previous bin centers normed; shape - n, prev_nbins, h, w |
| |
| Returns: |
| tuple(torch.Tensor,torch.Tensor) : new bin centers unbounded; shape - n, nbins, h, w. Two outputs just to keep the API consistent with the normed version |
| """ |
| if prev_b_embedding is not None: |
| if interpolate: |
| prev_b_embedding = nn.functional.interpolate( |
| prev_b_embedding, x.shape[-2:], mode='bilinear', align_corners=True) |
| x = x + prev_b_embedding |
|
|
| A = self._net(x) |
| n, c, h, w = A.shape |
|
|
| b_prev = nn.functional.interpolate( |
| b_prev, (h, w), mode='bilinear', align_corners=True) |
| b_centers = b_prev |
|
|
| if self.attractor_type == 'exp': |
| dist = exp_attractor |
| else: |
| dist = inv_attractor |
|
|
| if not self.memory_efficient: |
| func = {'mean': torch.mean, 'sum': torch.sum}[self.kind] |
| |
| delta_c = func( |
| dist(A.unsqueeze(2) - b_centers.unsqueeze(1)), dim=1) |
| else: |
| delta_c = torch.zeros_like(b_centers, device=b_centers.device) |
| for i in range(self.n_attractors): |
| delta_c += dist(A[:, i, ...].unsqueeze(1) - |
| b_centers) |
|
|
| if self.kind == 'mean': |
| delta_c = delta_c / self.n_attractors |
|
|
| b_new_centers = b_centers + delta_c |
| B_centers = b_new_centers |
|
|
| return b_new_centers, B_centers |
|
|