| | |
| | |
| | |
| | import copy |
| | from collections import OrderedDict |
| | import numpy as np |
| | import torch |
| | import torch.distributed as dist |
| | import torch.nn as nn |
| | import torch.nn.functional as F |
| | from .us import normalize |
| | from einops import rearrange, repeat |
| |
|
| | |
| | from .modules import FeatureEncoder |
| | from omegaconf import OmegaConf |
| |
|
| |
|
| | def build_model(config): |
| | model = OmegaConf.to_container(config, resolve=True) |
| | return model |
| |
|
| | class Sim2Mask(nn.Module): |
| | def __init__(self, init_w=1.0, init_b=0.0, gumbel_tau=1.0, learnable=True): |
| | super().__init__() |
| | self.init_w = init_w |
| | self.init_b = init_b |
| | self.gumbel_tau = gumbel_tau |
| | self.learnable = learnable |
| |
|
| | assert not ((init_w is None) ^ (init_b is None)) |
| | if learnable: |
| | self.w = nn.Parameter(torch.full([], float(init_w))) |
| | self.b = nn.Parameter(torch.full([], float(init_b))) |
| | else: |
| | self.w = init_w |
| | self.b = init_b |
| |
|
| | def forward(self, x, deterministic=False): |
| | logits = x * self.w + self.b |
| |
|
| | soft_mask = torch.sigmoid(logits) |
| | if deterministic: |
| | hard_mask = soft_mask.gt(0.5).type(logits.dtype) |
| | else: |
| | hard_mask = gumbel_sigmoid(logits, hard=True, tau=self.gumbel_tau) |
| |
|
| | return hard_mask, soft_mask |
| |
|
| | def extra_repr(self): |
| | return f'init_w={self.init_w}, init_b={self.init_b}, learnable={self.learnable}, gumbel_tau={self.gumbel_tau}' |
| |
|
| |
|
| | class MaskerBackbone(nn.Module): |
| | """Masker image encoder backbone. |
| | """ |
| | def __init__(self, clip_visual, freeze_idx): |
| | super().__init__() |
| | self.transformer = copy.deepcopy(clip_visual.transformer) |
| | self.transformer.resblocks = self.transformer.resblocks[freeze_idx:] |
| |
|
| | for block in self.transformer.resblocks: |
| | if hasattr(block, "hook_handler"): |
| | block.hook_handler.remove() |
| |
|
| | self.ln_post = copy.deepcopy(clip_visual.ln_post) |
| | self.proj = copy.deepcopy(clip_visual.proj) |
| |
|
| | self.layers = len(self.transformer.resblocks) |
| | self.patch_size = clip_visual.patch_size |
| |
|
| | self.output_dim = clip_visual.output_dim if self.proj is not None else clip_visual.width |
| |
|
| | def forward(self, x, spatial=True, ignore_last_attn=True): |
| | if self.layers: |
| | x = self.transformer(x, ignore_last_attn=ignore_last_attn) |
| |
|
| | x = x.permute(1, 0, 2) |
| |
|
| | if spatial: |
| | x = self.ln_post(x) |
| | else: |
| | x = self.ln_post(x[:, 0, :]) |
| |
|
| | if self.proj is not None: |
| | x = x @ self.proj |
| |
|
| | return x |
| |
|
| | class MaskerImageFeatureEncoder(FeatureEncoder): |
| | def __init__(self, backbone: nn.Module, decoder: nn.Module, ignore_last_attn: bool = True): |
| | super().__init__() |
| | self.ignore_last_attn = ignore_last_attn |
| | self.patch_size = backbone.patch_size |
| | self.backbone = backbone |
| | self.decoder = decoder |
| |
|
| | for resblock in self.backbone.transformer.resblocks: |
| | resblock.hook_handler = resblock.register_forward_hook(self.hook) |
| |
|
| | def _encode(self, image, image_feat): |
| | H, W = image.shape[-2:] |
| | h = H // self.patch_size |
| | w = W // self.patch_size |
| |
|
| | x = self.backbone(image_feat, spatial=True, ignore_last_attn=self.ignore_last_attn) |
| | x = rearrange(x[:, 1:], "B (H W) C -> B C H W", H=h, W=w) |
| | x = self.decoder(x) |
| |
|
| | return x |
| |
|
| | class Masker(nn.Module): |
| | def __init__(self, backbone, decoder, image_proj, sim2mask, ignore_last_attn, **kwargs): |
| | super().__init__() |
| | self.ignore_last_attn = ignore_last_attn |
| |
|
| | decoder["C"] = backbone.output_dim |
| | decoder = MODELS.build(decoder) |
| | decoder = nn.Sequential(OrderedDict([ |
| | ("decoder", decoder), |
| | ("image_proj", image_proj) |
| | ])) |
| |
|
| | self.image_encoder = MaskerImageFeatureEncoder(backbone, decoder, ignore_last_attn=ignore_last_attn) |
| |
|
| | self.sim2mask = Sim2Mask(**sim2mask) |
| |
|
| | def forward(self, image, image_feat, text_emb, deterministic=False): |
| | B = image.size(0) |
| | image_emb, feats = self.image_encoder(image, image_feat, ret_feats=True) |
| |
|
| | image_emb_norm = normalize(image_emb, dim=1) |
| | text_emb_norm = normalize(text_emb, dim=-1) |
| |
|
| | H, W = image_emb.shape[2:] |
| | D = dist.get_world_size() |
| |
|
| | |
| | all_text_emb_norm = gather_cat(text_emb_norm, grad=True, contiguous_grad=True) |
| | simmap = torch.einsum("bchw,nc->bnhw", image_emb_norm, all_text_emb_norm) |
| | mask, soft_mask = self.sim2mask(simmap, deterministic=deterministic) |
| |
|
| | |
| | |
| | pos_indices = torch.arange(B, dtype=torch.long, device=image_emb.device) + B * dist.get_rank() |
| | pos_mask = mask[torch.arange(B), pos_indices].unsqueeze(1) |
| |
|
| | offdiag = torch.ones(B, B*D, dtype=torch.bool, device=mask.device) |
| | offdiag[torch.arange(B), pos_indices] = False |
| |
|
| | soft_pos_mask = soft_mask[torch.arange(B), pos_indices].unsqueeze(1) |
| | soft_neg_mask = soft_mask.masked_select(offdiag[..., None, None]).view(B, B*D-1, H, W) |
| |
|
| | masks = { |
| | "pos": pos_mask, |
| |
|
| | "soft_pos": soft_pos_mask, |
| | "soft_neg": soft_neg_mask, |
| | "soft_all": soft_mask, |
| | } |
| |
|
| | return masks, image_emb, text_emb, feats |
| |
|
| | @torch.no_grad() |
| | def forward_seg(self, image, image_feat, text_emb, deterministic=True, hard=False): |
| | """Make mask by 1:N matching |
| | |
| | Args: |
| | image [B, 3, H, W] |
| | image_feat [L, B, C]: CLIP features |
| | text_emb [N, C] |
| | deterministic (bool): deterministic inference flag for gumbel noise |
| | hard (bool): decide hard or soft returning segmentation mask. |
| | Note that soft mask is required for proper evaluation |
| | |
| | Return: |
| | mask [B, N, H', W'] (H' and W' are downsampled H/W) |
| | """ |
| | image_emb = self.image_encoder(image, image_feat) |
| |
|
| | image_emb = normalize(image_emb, dim=1) |
| | text_emb = normalize(text_emb, dim=-1) |
| |
|
| | simmap = torch.einsum("b c h w, n c -> b n h w", image_emb, text_emb) |
| |
|
| | hard_mask, soft_mask = self.sim2mask(simmap, deterministic=deterministic) |
| | mask = hard_mask if hard else soft_mask |
| |
|
| | return mask, simmap |
| |
|
| | class DINOTextMasker(nn.Module): |
| | def __init__(self, similarity_type="cosine"): |
| | super().__init__() |
| | self.sim2mask = DINOTextSim2Mask() |
| | self.sim2mask = self.sim2mask.eval() |
| | self.similarity_type = similarity_type |
| |
|
| | def forward(self, image, image_feat, text_emb, deterministic=False): |
| | pass |
| |
|
| | @torch.no_grad() |
| | def forward_seg(self, image_feat, text_emb, deterministic=True, hard=False): |
| | """Make mask by 1:N matching |
| | |
| | Args: |
| | image [B, 3, H, W] |
| | image_feat [L, B, C]: CLIP features |
| | text_emb [N, K, C] |
| | deterministic (bool): deterministic inference flag for gumbel noise |
| | hard (bool): decide hard or soft returning segmentation mask. |
| | Note that soft mask is required for proper evaluation |
| | use_k_nn (bool): use kNN to segment |
| | k_nn (int): number of nearest neighbors for kNN segmentation |
| | |
| | Return: |
| | mask [B, N, H', W'] (H' and W' are downsampled H/W) |
| | """ |
| | b, c, h, w = image_feat.shape |
| | n, c = text_emb.shape |
| |
|
| | if self.similarity_type == "cosine": |
| | image_feat = normalize(image_feat, dim=1) |
| | |
| | simmap = torch.einsum("b c h w, n c -> b n h w", image_feat, text_emb) |
| | else: |
| | raise NotImplementedError("similarity type {} not implemented".format(self.similarity_type)) |
| |
|
| | hard_mask, soft_mask = self.sim2mask(simmap, deterministic=deterministic) |
| | mask = hard_mask if hard else soft_mask |
| |
|
| | return mask, simmap |
| |
|
| |
|
| | class DINOTextSim2Mask(nn.Module): |
| | def __init__(self, gumbel_tau=1.0): |
| | super().__init__() |
| | self.gumbel_tau = gumbel_tau |
| |
|
| | def forward(self, x, deterministic=False): |
| | soft_mask = torch.sigmoid(x) |
| | if deterministic: |
| | hard_mask = soft_mask.gt(0.5).type(x.dtype) |
| | else: |
| | hard_mask = gumbel_sigmoid(x, hard=True, tau=self.gumbel_tau) |
| |
|
| | return hard_mask, soft_mask |