Spaces:
Running
Running
| from torch import nn | |
| from torch.nn import functional as F | |
| class SpatialAttention(nn.Module): | |
| def __init__(self, in_channels): | |
| super(SpatialAttention, self).__init__() | |
| self.conv1 = nn.Conv2d(in_channels, 1, kernel_size=1, stride=1, padding=0) | |
| def forward(self, x): | |
| # Calculate attention scores | |
| attention_scores = self.conv1(x) | |
| attention_scores = F.softmax(attention_scores, dim=2) | |
| # Apply attention to input features | |
| attended_features = x * attention_scores | |
| return attended_features | |
| class DCT_Attention_Fusion_Conv(nn.Module): | |
| def __init__(self, channels): | |
| super(DCT_Attention_Fusion_Conv, self).__init__() | |
| self.rgb_attention = SpatialAttention(channels) | |
| self.depth_attention = SpatialAttention(channels) | |
| self.rgb_pooling = nn.AdaptiveAvgPool2d(1) | |
| self.depth_pooling = nn.AdaptiveAvgPool2d(1) | |
| def forward(self, rgb_features, DCT_features): | |
| # Spatial attention for both modalities | |
| rgb_attended_features = self.rgb_attention(rgb_features) | |
| depth_attended_features = self.depth_attention(DCT_features) | |
| # Adaptive pooling for both modalities | |
| rgb_pooled = self.rgb_pooling(rgb_attended_features) | |
| depth_pooled = self.depth_pooling(depth_attended_features) | |
| # Upsample attended and pooled features to the original size | |
| rgb_upsampled = F.interpolate(rgb_pooled, size=rgb_features.size()[2:], mode='bilinear', align_corners=False) | |
| depth_upsampled = F.interpolate(depth_pooled, size=DCT_features.size()[2:], mode='bilinear', align_corners=False) | |
| # Concatenate the upsampled features | |
| fused_features = F.relu(rgb_upsampled+depth_upsampled) | |
| # fused_features = fused_features.sum(dim=1) | |
| return fused_features | |