| | |
| |
|
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| |
|
| | import torch |
| | import torch.nn as nn |
| |
|
| |
|
| | @torch.jit.script |
| | def exp_attractor(dx, alpha: float = 300, gamma: int = 2): |
| | """Exponential attractor: dc = exp(-alpha*|dx|^gamma) * dx , where dx = a - c, a = attractor point, c = bin center, dc = shift in bin centermmary for exp_attractor |
| | |
| | Args: |
| | dx (torch.Tensor): The difference tensor dx = Ai - Cj, where Ai is the attractor point and Cj is the bin center. |
| | alpha (float, optional): Proportional Attractor strength. Determines the absolute strength. Lower alpha = greater attraction. Defaults to 300. |
| | gamma (int, optional): Exponential Attractor strength. Determines the "region of influence" and indirectly number of bin centers affected. Lower gamma = farther reach. Defaults to 2. |
| | |
| | Returns: |
| | torch.Tensor : Delta shifts - dc; New bin centers = Old bin centers + dc |
| | """ |
| | return torch.exp(-alpha*(torch.abs(dx)**gamma)) * (dx) |
| |
|
| |
|
| | @torch.jit.script |
| | def inv_attractor(dx, alpha: float = 300, gamma: int = 2): |
| | """Inverse attractor: dc = dx / (1 + alpha*dx^gamma), where dx = a - c, a = attractor point, c = bin center, dc = shift in bin center |
| | This is the default one according to the accompanying paper. |
| | |
| | Args: |
| | dx (torch.Tensor): The difference tensor dx = Ai - Cj, where Ai is the attractor point and Cj is the bin center. |
| | alpha (float, optional): Proportional Attractor strength. Determines the absolute strength. Lower alpha = greater attraction. Defaults to 300. |
| | gamma (int, optional): Exponential Attractor strength. Determines the "region of influence" and indirectly number of bin centers affected. Lower gamma = farther reach. Defaults to 2. |
| | |
| | Returns: |
| | torch.Tensor: Delta shifts - dc; New bin centers = Old bin centers + dc |
| | """ |
| | return dx.div(1+alpha*dx.pow(gamma)) |
| |
|
| |
|
| | class AttractorLayer(nn.Module): |
| | def __init__(self, in_features, n_bins, n_attractors=16, mlp_dim=128, min_depth=1e-3, max_depth=10, |
| | alpha=300, gamma=2, kind='sum', attractor_type='exp', memory_efficient=False): |
| | """ |
| | Attractor layer for bin centers. Bin centers are bounded on the interval (min_depth, max_depth) |
| | """ |
| | super().__init__() |
| |
|
| | self.n_attractors = n_attractors |
| | self.n_bins = n_bins |
| | self.min_depth = min_depth |
| | self.max_depth = max_depth |
| | self.alpha = alpha |
| | self.gamma = gamma |
| | self.kind = kind |
| | self.attractor_type = attractor_type |
| | self.memory_efficient = memory_efficient |
| |
|
| | self._net = nn.Sequential( |
| | nn.Conv2d(in_features, mlp_dim, 1, 1, 0), |
| | nn.ReLU(inplace=True), |
| | nn.Conv2d(mlp_dim, n_attractors*2, 1, 1, 0), |
| | nn.ReLU(inplace=True) |
| | ) |
| |
|
| | def forward(self, x, b_prev, prev_b_embedding=None, interpolate=True, is_for_query=False): |
| | """ |
| | Args: |
| | x (torch.Tensor) : feature block; shape - n, c, h, w |
| | b_prev (torch.Tensor) : previous bin centers normed; shape - n, prev_nbins, h, w |
| | |
| | Returns: |
| | tuple(torch.Tensor,torch.Tensor) : new bin centers normed and scaled; shape - n, nbins, h, w |
| | """ |
| | if prev_b_embedding is not None: |
| | if interpolate: |
| | prev_b_embedding = nn.functional.interpolate( |
| | prev_b_embedding, x.shape[-2:], mode='bilinear', align_corners=True) |
| | x = x + prev_b_embedding |
| |
|
| | A = self._net(x) |
| | eps = 1e-3 |
| | A = A + eps |
| | n, c, h, w = A.shape |
| | A = A.view(n, self.n_attractors, 2, h, w) |
| | A_normed = A / A.sum(dim=2, keepdim=True) |
| | A_normed = A[:, :, 0, ...] |
| |
|
| | b_prev = nn.functional.interpolate( |
| | b_prev, (h, w), mode='bilinear', align_corners=True) |
| | b_centers = b_prev |
| |
|
| | if self.attractor_type == 'exp': |
| | dist = exp_attractor |
| | else: |
| | dist = inv_attractor |
| |
|
| | if not self.memory_efficient: |
| | func = {'mean': torch.mean, 'sum': torch.sum}[self.kind] |
| | |
| | delta_c = func(dist(A_normed.unsqueeze( |
| | 2) - b_centers.unsqueeze(1)), dim=1) |
| | else: |
| | delta_c = torch.zeros_like(b_centers, device=b_centers.device) |
| | for i in range(self.n_attractors): |
| | |
| | delta_c += dist(A_normed[:, i, ...].unsqueeze(1) - b_centers) |
| |
|
| | if self.kind == 'mean': |
| | delta_c = delta_c / self.n_attractors |
| |
|
| | b_new_centers = b_centers + delta_c |
| | B_centers = (self.max_depth - self.min_depth) * \ |
| | b_new_centers + self.min_depth |
| | B_centers, _ = torch.sort(B_centers, dim=1) |
| | B_centers = torch.clip(B_centers, self.min_depth, self.max_depth) |
| | return b_new_centers, B_centers |
| |
|
| |
|
| | class AttractorLayerUnnormed(nn.Module): |
| | def __init__(self, in_features, n_bins, n_attractors=16, mlp_dim=128, min_depth=1e-3, max_depth=10, |
| | alpha=300, gamma=2, kind='sum', attractor_type='exp', memory_efficient=False): |
| | """ |
| | Attractor layer for bin centers. Bin centers are unbounded |
| | """ |
| | super().__init__() |
| |
|
| | self.n_attractors = n_attractors |
| | self.n_bins = n_bins |
| | self.min_depth = min_depth |
| | self.max_depth = max_depth |
| | self.alpha = alpha |
| | self.gamma = gamma |
| | self.kind = kind |
| | self.attractor_type = attractor_type |
| | self.memory_efficient = memory_efficient |
| |
|
| | self._net = nn.Sequential( |
| | nn.Conv2d(in_features, mlp_dim, 1, 1, 0), |
| | nn.ReLU(inplace=True), |
| | nn.Conv2d(mlp_dim, n_attractors, 1, 1, 0), |
| | nn.Softplus() |
| | ) |
| |
|
| | def forward(self, x, b_prev, prev_b_embedding=None, interpolate=True, is_for_query=False): |
| | """ |
| | Args: |
| | x (torch.Tensor) : feature block; shape - n, c, h, w |
| | b_prev (torch.Tensor) : previous bin centers normed; shape - n, prev_nbins, h, w |
| | |
| | Returns: |
| | tuple(torch.Tensor,torch.Tensor) : new bin centers unbounded; shape - n, nbins, h, w. Two outputs just to keep the API consistent with the normed version |
| | """ |
| | if prev_b_embedding is not None: |
| | if interpolate: |
| | prev_b_embedding = nn.functional.interpolate( |
| | prev_b_embedding, x.shape[-2:], mode='bilinear', align_corners=True) |
| | x = x + prev_b_embedding |
| |
|
| | A = self._net(x) |
| | n, c, h, w = A.shape |
| |
|
| | b_prev = nn.functional.interpolate( |
| | b_prev, (h, w), mode='bilinear', align_corners=True) |
| | b_centers = b_prev |
| |
|
| | if self.attractor_type == 'exp': |
| | dist = exp_attractor |
| | else: |
| | dist = inv_attractor |
| |
|
| | if not self.memory_efficient: |
| | func = {'mean': torch.mean, 'sum': torch.sum}[self.kind] |
| | |
| | delta_c = func( |
| | dist(A.unsqueeze(2) - b_centers.unsqueeze(1)), dim=1) |
| | else: |
| | delta_c = torch.zeros_like(b_centers, device=b_centers.device) |
| | for i in range(self.n_attractors): |
| | delta_c += dist(A[:, i, ...].unsqueeze(1) - |
| | b_centers) |
| |
|
| | if self.kind == 'mean': |
| | delta_c = delta_c / self.n_attractors |
| |
|
| | b_new_centers = b_centers + delta_c |
| | B_centers = b_new_centers |
| |
|
| | return b_new_centers, B_centers |
| |
|