from typing import Tuple import torch from torch import nn from models.regression_head import UpsamplingLayer from models.transformer import PrototypeAttentionBlock from models.ops.modules.ms_deform_attn import MSDeformAttn class C_base(nn.Module): def __init__( self, *, transformer_dim: int, num_prototype_attn_steps: int, num_image_attn_steps: int, ) -> None: super().__init__() self.transformer_dim = transformer_dim # Attention blocks self.image_attention = nn.ModuleList() self.image_attention_l1 = nn.ModuleList() self.image_attention_l2 = nn.ModuleList() self.prototype_attention = nn.ModuleList() self.prototype_attention_l1 = nn.ModuleList() self.prototype_attention_l2 = nn.ModuleList() for _ in range(num_prototype_attn_steps): self.prototype_attention.append( PrototypeAttentionBlock(transformer_dim, num_heads=8) ) self.prototype_attention_l1.append( PrototypeAttentionBlock(transformer_dim, num_heads=8) ) self.prototype_attention_l2.append( PrototypeAttentionBlock(transformer_dim, num_heads=8) ) for _ in range(num_image_attn_steps): self.image_attention.append( MSDeformAttn(d_model=256, n_levels=1, n_heads=8, n_points=8) ) self.image_attention_l1.append( MSDeformAttn(d_model=256, n_levels=1, n_heads=8, n_points=8) ) self.image_attention_l2.append( MSDeformAttn(d_model=256, n_levels=1, n_heads=8, n_points=8) ) # Upsampling self.up1 = UpsamplingLayer(transformer_dim, transformer_dim) self.up2 = UpsamplingLayer(transformer_dim, transformer_dim) self.up3 = UpsamplingLayer(transformer_dim, transformer_dim) self.up_aux = UpsamplingLayer(transformer_dim, transformer_dim) # Shapes h, w = 64, 64 self.spatial_shapes = torch.tensor([[h, w]]) self.valid_ratios = torch.tensor([[1.0, 1.0]]) self.level_start_index = torch.tensor([[0]]) self.spatial_shapes2 = torch.tensor([[h * 2, w * 2]]) self.valid_ratios2 = torch.tensor([[1.0, 1.0]]) self.level_start_index2 = torch.tensor([[0]]) self.spatial_shapes1 = torch.tensor([[h * 4, w * 4]]) self.valid_ratios1 = torch.tensor([[1.0, 1.0]]) self.level_start_index1 = torch.tensor([[0]]) @staticmethod def get_reference_points(spatial_shapes, valid_ratios, device="cpu"): reference_points_list = [] for lvl, (H_, W_) in enumerate(spatial_shapes): ref_y, ref_x = torch.meshgrid( torch.linspace(0.5, H_ - 0.5, H_, device=device), torch.linspace(0.5, W_ - 0.5, W_, device=device), indexing="ij", ) ref_y = ref_y.reshape(-1)[None] / (valid_ratios[lvl, 1] * H_) ref_x = ref_x.reshape(-1)[None] / (valid_ratios[lvl, 0] * W_) ref = torch.stack((ref_x, ref_y), -1) reference_points_list.append(ref) reference_points = torch.cat(reference_points_list, 1) reference_points = reference_points[:, :, None] * valid_ratios[:, None] return reference_points def forward( self, image_embeddings: torch.Tensor, image_pe: torch.Tensor, prototype_embeddings: torch.Tensor, hq_features: torch.Tensor, hq_prototypes: torch.Tensor, hq_pos: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: device = image_embeddings.device # Move tensors self.spatial_shapes = self.spatial_shapes.to(device) self.spatial_shapes1 = self.spatial_shapes1.to(device) self.spatial_shapes2 = self.spatial_shapes2.to(device) self.level_start_index = self.level_start_index.to(device) self.level_start_index1 = self.level_start_index1.to(device) self.level_start_index2 = self.level_start_index2.to(device) self.valid_ratios = self.valid_ratios.to(device) self.valid_ratios1 = self.valid_ratios1.to(device) self.valid_ratios2 = self.valid_ratios2.to(device) # 🔥 Always compute reference points self.reference_points = self.get_reference_points( self.spatial_shapes, self.valid_ratios, device=device ) self.reference_points1 = self.get_reference_points( self.spatial_shapes1, self.valid_ratios1, device=device ) self.reference_points2 = self.get_reference_points( self.spatial_shapes2, self.valid_ratios2, device=device ) b, c, h, w = image_embeddings.shape image_pe = torch.repeat_interleave(image_pe, b, dim=0) image_embeddings = image_embeddings.flatten(2).permute(0, 2, 1) image_pe = image_pe.flatten(2).permute(0, 2, 1) src = image_embeddings hq_features_l1_pos = hq_pos[0].flatten(2).permute(0, 2, 1) hq_features_l2_pos = hq_pos[1].flatten(2).permute(0, 2, 1) hq_features_l1 = hq_features[0].flatten(2).permute(0, 2, 1) hq_features_l2 = hq_features[1].flatten(2).permute(0, 2, 1) # Prototype attention for layer in self.prototype_attention: src = layer(image_f=src, prototypes=prototype_embeddings) for layer in self.prototype_attention_l1: hq_features_l1 = layer(image_f=hq_features_l1, prototypes=hq_prototypes[0]) for layer in self.prototype_attention_l2: hq_features_l2 = layer(image_f=hq_features_l2, prototypes=hq_prototypes[1]) # Image attention for layer in self.image_attention: src = layer(src + image_pe, self.reference_points, src, self.spatial_shapes, self.level_start_index) for layer in self.image_attention_l1: hq_features_l1 = layer( hq_features_l1 + hq_features_l1_pos, self.reference_points1, hq_features_l1, self.spatial_shapes1, self.level_start_index1, ) for layer in self.image_attention_l2: hq_features_l2 = layer( hq_features_l2 + hq_features_l2_pos, self.reference_points2, hq_features_l2, self.spatial_shapes2, self.level_start_index2, ) # Reshape src = src.transpose(1, 2).reshape(b, c, h, w) hq_features_l2 = hq_features_l2.transpose(1, 2).view(b, c, h * 2, w * 2) hq_features_l1 = hq_features_l1.transpose(1, 2).view(b, c, h * 4, w * 4) # Upsample src = self.up1(src) + hq_features_l2 src = self.up2(src) + hq_features_l1 src = self.up3(src) src_aux = self.up_aux(hq_features_l1) return src, src_aux