| """Light HamHead Decoder. |
| |
| Adapted from: |
| https://github.com/Visual-Attention-Network/SegNeXt/blob/main/mmseg/models/decode_heads/ham_head.py |
| """ |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from siclib.models import BaseModel |
| from siclib.models.utils.modules import ConvModule, FeatureFusionBlock |
|
|
| |
| |
|
|
|
|
| class _MatrixDecomposition2DBase(nn.Module): |
| def __init__(self): |
| super().__init__() |
|
|
| self.spatial = True |
|
|
| self.S = 1 |
| self.D = 512 |
| self.R = 64 |
|
|
| self.train_steps = 6 |
| self.eval_steps = 7 |
|
|
| self.inv_t = 100 |
| self.eta = 0.9 |
|
|
| self.rand_init = True |
|
|
| def _build_bases(self, B, S, D, R, device="cpu"): |
| raise NotImplementedError |
|
|
| def local_step(self, x, bases, coef): |
| raise NotImplementedError |
|
|
| |
| def local_inference(self, x, bases): |
| |
| coef = torch.bmm(x.transpose(1, 2), bases) |
| coef = F.softmax(self.inv_t * coef, dim=-1) |
|
|
| steps = self.train_steps if self.training else self.eval_steps |
| for _ in range(steps): |
| bases, coef = self.local_step(x, bases, coef) |
|
|
| return bases, coef |
|
|
| def compute_coef(self, x, bases, coef): |
| raise NotImplementedError |
|
|
| def forward(self, x, return_bases=False): |
| B, C, H, W = x.shape |
|
|
| |
| if self.spatial: |
| D = C // self.S |
| N = H * W |
| x = x.view(B * self.S, D, N) |
| else: |
| D = H * W |
| N = C // self.S |
| x = x.view(B * self.S, N, D).transpose(1, 2) |
|
|
| if not self.rand_init and not hasattr(self, "bases"): |
| bases = self._build_bases(1, self.S, D, self.R, device=x.device) |
| self.register_buffer("bases", bases) |
|
|
| |
| if self.rand_init: |
| bases = self._build_bases(B, self.S, D, self.R, device=x.device) |
| else: |
| bases = self.bases.repeat(B, 1, 1) |
|
|
| bases, coef = self.local_inference(x, bases) |
|
|
| |
| coef = self.compute_coef(x, bases, coef) |
|
|
| |
| x = torch.bmm(bases, coef.transpose(1, 2)) |
|
|
| |
| x = x.view(B, C, H, W) if self.spatial else x.transpose(1, 2).view(B, C, H, W) |
| |
| bases = bases.view(B, self.S, D, self.R) |
|
|
| return x |
|
|
|
|
| class NMF2D(_MatrixDecomposition2DBase): |
| def __init__(self): |
| super().__init__() |
|
|
| self.inv_t = 1 |
|
|
| def _build_bases(self, B, S, D, R, device="cpu"): |
| bases = torch.rand((B * S, D, R)).to(device) |
| bases = F.normalize(bases, dim=1) |
|
|
| return bases |
|
|
| |
| def local_step(self, x, bases, coef): |
| |
| numerator = torch.bmm(x.transpose(1, 2), bases) |
| |
| denominator = coef.bmm(bases.transpose(1, 2).bmm(bases)) |
| |
| coef = coef * numerator / (denominator + 1e-6) |
|
|
| |
| numerator = torch.bmm(x, coef) |
| |
| denominator = bases.bmm(coef.transpose(1, 2).bmm(coef)) |
| |
| bases = bases * numerator / (denominator + 1e-6) |
|
|
| return bases, coef |
|
|
| def compute_coef(self, x, bases, coef): |
| |
| numerator = torch.bmm(x.transpose(1, 2), bases) |
| |
| denominator = coef.bmm(bases.transpose(1, 2).bmm(bases)) |
| |
| coef = coef * numerator / (denominator + 1e-6) |
|
|
| return coef |
|
|
|
|
| class Hamburger(nn.Module): |
| def __init__(self, ham_channels=512, norm_cfg=None, **kwargs): |
| super().__init__() |
|
|
| self.ham_in = ConvModule(ham_channels, ham_channels, 1) |
|
|
| self.ham = NMF2D() |
|
|
| self.ham_out = ConvModule(ham_channels, ham_channels, 1) |
|
|
| def forward(self, x): |
| enjoy = self.ham_in(x) |
| enjoy = F.relu(enjoy, inplace=False) |
| enjoy = self.ham(enjoy) |
| enjoy = self.ham_out(enjoy) |
| ham = F.relu(x + enjoy, inplace=False) |
|
|
| return ham |
|
|
|
|
| class LightHamHead(BaseModel): |
| """Is Attention Better Than Matrix Decomposition? |
| This head is the implementation of `HamNet |
| <https://arxiv.org/abs/2109.04553>`_. |
| |
| Args: |
| ham_channels (int): input channels for Hamburger. |
| ham_kwargs (int): kwagrs for Ham. |
| """ |
|
|
| default_conf = { |
| "predict_uncertainty": True, |
| "out_channels": 64, |
| "in_channels": [64, 128, 320, 512], |
| "in_index": [0, 1, 2, 3], |
| "ham_channels": 512, |
| "with_low_level": True, |
| } |
|
|
| def _init(self, conf): |
| self.in_index = conf.in_index |
| self.in_channels = conf.in_channels |
| self.out_channels = conf.out_channels |
| self.ham_channels = conf.ham_channels |
| self.align_corners = False |
| self.predict_uncertainty = conf.predict_uncertainty |
|
|
| self.squeeze = ConvModule(sum(self.in_channels), self.ham_channels, 1) |
|
|
| self.hamburger = Hamburger(self.ham_channels) |
|
|
| self.align = ConvModule(self.ham_channels, self.out_channels, 1) |
|
|
| if self.predict_uncertainty: |
| self.linear_pred_uncertainty = nn.Sequential( |
| ConvModule( |
| in_channels=self.out_channels, |
| out_channels=self.out_channels, |
| kernel_size=3, |
| padding=1, |
| bias=False, |
| ), |
| nn.Conv2d(in_channels=self.out_channels, out_channels=1, kernel_size=1), |
| ) |
|
|
| self.with_ll = conf.with_low_level |
| if self.with_ll: |
| self.out_conv = ConvModule( |
| self.out_channels, self.out_channels, 3, padding=1, bias=False |
| ) |
| self.ll_fusion = FeatureFusionBlock(self.out_channels, upsample=False) |
|
|
| def _forward(self, features): |
| """Forward function.""" |
| |
| inputs = [features["hl"][i] for i in self.in_index] |
|
|
| inputs = [ |
| F.interpolate( |
| level, size=inputs[0].shape[2:], mode="bilinear", align_corners=self.align_corners |
| ) |
| for level in inputs |
| ] |
|
|
| inputs = torch.cat(inputs, dim=1) |
| x = self.squeeze(inputs) |
|
|
| x = self.hamburger(x) |
|
|
| feats = self.align(x) |
|
|
| if self.with_ll: |
| assert "ll" in features, "Low-level features are required for this model" |
| feats = F.interpolate(feats, scale_factor=2, mode="bilinear", align_corners=False) |
| feats = self.out_conv(feats) |
| feats = F.interpolate(feats, scale_factor=2, mode="bilinear", align_corners=False) |
| feats_ll = features["ll"].clone() |
| feats = self.ll_fusion(feats, feats_ll) |
|
|
| uncertainty = ( |
| self.linear_pred_uncertainty(feats).squeeze(1) if self.predict_uncertainty else None |
| ) |
|
|
| return feats, uncertainty |
|
|
| def loss(self, pred, data): |
| raise NotImplementedError |
|
|