Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torchvision import models | |
| # ========================================== | |
| # POSITION ATTENTION MODULE (OPTIMIZED) | |
| # ========================================== | |
| class PositionAttention(nn.Module): | |
| """ | |
| Position Attention Module with best configuration | |
| - reduction_ratio: 16 (from best config) | |
| """ | |
| def __init__(self, in_channels, reduction_ratio=16): | |
| super(PositionAttention, self).__init__() | |
| self.query_conv = nn.Conv2d(in_channels, in_channels // reduction_ratio, 1) | |
| self.key_conv = nn.Conv2d(in_channels, in_channels // reduction_ratio, 1) | |
| self.value_conv = nn.Conv2d(in_channels, in_channels, 1) | |
| self.gamma = nn.Parameter(torch.zeros(1)) | |
| self.softmax = nn.Softmax(dim=-1) | |
| def forward(self, x): | |
| batch_size, C, H, W = x.size() | |
| proj_query = self.query_conv(x).view(batch_size, -1, H * W).permute(0, 2, 1) | |
| proj_key = self.key_conv(x).view(batch_size, -1, H * W) | |
| attention = self.softmax(torch.bmm(proj_query, proj_key)) | |
| proj_value = self.value_conv(x).view(batch_size, C, -1) | |
| out = torch.bmm(proj_value, attention.permute(0, 2, 1)) | |
| out = out.view(batch_size, C, H, W) | |
| out = self.gamma * out + x | |
| return out | |
| # ========================================== | |
| # CHANNEL ATTENTION MODULE (OPTIMIZED) | |
| # ========================================== | |
| class ChannelAttention(nn.Module): | |
| """ | |
| Channel Attention Module with best configuration | |
| """ | |
| def __init__(self, in_channels): | |
| super(ChannelAttention, self).__init__() | |
| self.gamma = nn.Parameter(torch.zeros(1)) | |
| self.softmax = nn.Softmax(dim=-1) | |
| def forward(self, x): | |
| batch_size, C, H, W = x.size() | |
| proj_query = x.view(batch_size, C, -1) | |
| proj_key = x.view(batch_size, C, -1).permute(0, 2, 1) | |
| attention = self.softmax(torch.bmm(proj_query, proj_key)) | |
| proj_value = x.view(batch_size, C, -1) | |
| out = torch.bmm(attention, proj_value) | |
| out = out.view(batch_size, C, H, W) | |
| out = self.gamma * out + x | |
| return out | |
| # ========================================== | |
| # TYPE ATTENTION MODULE (OPTIMIZED) | |
| # ========================================== | |
| class TypeAttention(nn.Module): | |
| """ | |
| Type Attention Module with best configuration: | |
| - scale_channels_ratio: 2 | |
| - num_scales: 3 | |
| - continuity_channels_ratio: 8 | |
| - temperature_scaling: 5.0 | |
| - use_texture_analyzer: True | |
| - use_continuity_detector: False | |
| - use_type_prototypes: False | |
| """ | |
| def __init__(self, in_channels, num_types=7): | |
| super(TypeAttention, self).__init__() | |
| self.in_channels = in_channels | |
| self.num_types = num_types | |
| # Best config parameters | |
| scale_channels_ratio = 2 | |
| num_scales = 3 | |
| continuity_channels_ratio = 8 | |
| self.temperature_scaling = 5.0 | |
| kernel_sizes = [3, 5, 7] | |
| scale_ch = in_channels // scale_channels_ratio | |
| # Multi-scale feature extraction (3 scales) | |
| self.scale_convs = nn.ModuleList() | |
| for k_size in kernel_sizes[:num_scales]: | |
| self.scale_convs.append(nn.Sequential( | |
| nn.Conv2d(in_channels, scale_ch, kernel_size=k_size, padding=k_size//2), | |
| nn.BatchNorm2d(scale_ch), | |
| nn.ReLU(inplace=True) | |
| )) | |
| # Texture analyzer (ENABLED in best config) | |
| self.texture_analyzer = nn.Sequential( | |
| nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, | |
| groups=in_channels), | |
| nn.BatchNorm2d(in_channels), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(in_channels, in_channels, kernel_size=1), | |
| nn.BatchNorm2d(in_channels), | |
| nn.ReLU(inplace=True) | |
| ) | |
| # Continuity detector (DISABLED in best config) | |
| # Not initialized since use_continuity_detector = False | |
| # Scale attention | |
| total_scale_ch = scale_ch * num_scales | |
| hidden_ch = in_channels // 4 # scale_attention_hidden_ratio default | |
| self.scale_attention = nn.Sequential( | |
| nn.AdaptiveAvgPool2d(1), | |
| nn.Conv2d(total_scale_ch, hidden_ch, kernel_size=1), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(hidden_ch, num_scales, kernel_size=1), | |
| nn.Softmax(dim=1) | |
| ) | |
| # Channel projection | |
| self.channel_projection = nn.Conv2d(scale_ch, in_channels, kernel_size=1) | |
| # Type enhancer | |
| self.type_enhancer = nn.Sequential( | |
| nn.Conv2d(in_channels * 2, in_channels, kernel_size=1), | |
| nn.BatchNorm2d(in_channels), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1), | |
| nn.BatchNorm2d(in_channels) | |
| ) | |
| # Type prototypes (DISABLED in best config) | |
| # Not initialized since use_type_prototypes = False | |
| self.gamma = nn.Parameter(torch.zeros(1)) | |
| def forward(self, x): | |
| batch_size, C, H, W = x.size() | |
| # Multi-scale extraction | |
| scale_feats = [conv(x) for conv in self.scale_convs] | |
| multiscale_feat = torch.cat(scale_feats, dim=1) | |
| # Scale attention | |
| scale_weights = self.scale_attention(multiscale_feat) | |
| # Weighted aggregation | |
| aggregated = sum(feat * scale_weights[:, i:i+1, :, :] | |
| for i, feat in enumerate(scale_feats)) | |
| aggregated = self.channel_projection(aggregated) | |
| # Texture analysis (ENABLED) | |
| texture_features = self.texture_analyzer(x) | |
| # Continuity detection (DISABLED - skip this step) | |
| # texture_features remains unchanged | |
| # Fusion | |
| combined = torch.cat([aggregated, texture_features], dim=1) | |
| enhanced = self.type_enhancer(combined) | |
| # Type prototype matching (DISABLED - skip this step) | |
| type_attended = enhanced | |
| type_attention_weights = None | |
| out = self.gamma * type_attended + x | |
| return out, type_attention_weights | |
| # ========================================== | |
| # TRIPLE ATTENTION MODULE (OPTIMIZED) | |
| # ========================================== | |
| class TripleAttention(nn.Module): | |
| """ | |
| Triple Attention Module with best configuration: | |
| - use_position_attention: True | |
| - use_channel_attention: True | |
| - position_reduction_ratio: 16 | |
| - fusion_method: 'equal' | |
| """ | |
| def __init__(self, in_channels, num_types=7): | |
| super(TripleAttention, self).__init__() | |
| # Best config parameters | |
| self.use_position_attention = True | |
| self.use_channel_attention = True | |
| self.use_type_attention = True | |
| self.fusion_method = 'equal' | |
| # Initialize attention modules | |
| self.position_attention = PositionAttention(in_channels, reduction_ratio=16) | |
| self.channel_attention = ChannelAttention(in_channels) | |
| self.type_attention = TypeAttention(in_channels, num_types) | |
| def forward(self, x): | |
| outputs = [] | |
| type_weights = None | |
| # Position attention | |
| outputs.append(self.position_attention(x)) | |
| # Channel attention | |
| outputs.append(self.channel_attention(x)) | |
| # Type attention | |
| tam_out, type_weights = self.type_attention(x) | |
| outputs.append(tam_out) | |
| # Fusion using equal weighting | |
| out = sum(outputs) / len(outputs) | |
| return out, type_weights | |
| # ========================================== | |
| # STOOLNET MODEL (OPTIMIZED) | |
| # ========================================== | |
| class StoolNetTriple(nn.Module): | |
| """ | |
| StoolNet with best configuration: | |
| - backbone: resnet34 | |
| - fc_type_hidden: [512] | |
| - dropout_rates: [0.5, 0.3] | |
| - learning_rate: 0.001 | |
| - batch_size: 16 | |
| - weight_decay: 0.0001 | |
| """ | |
| def __init__(self, num_type_classes=7, num_shape_classes=4, num_color_classes=2): | |
| super(StoolNetTriple, self).__init__() | |
| # Best config: ResNet34 backbone | |
| weights = models.ResNet34_Weights.IMAGENET1K_V1 | |
| resnet = models.resnet34(weights=weights) | |
| feature_dim = 512 | |
| # Backbone layers | |
| self.conv1 = resnet.conv1 | |
| self.bn1 = resnet.bn1 | |
| self.relu = resnet.relu | |
| self.maxpool = resnet.maxpool | |
| self.layer1 = resnet.layer1 | |
| self.layer2 = resnet.layer2 | |
| self.layer3 = resnet.layer3 | |
| self.layer4 = resnet.layer4 | |
| # Triple Attention with best config | |
| self.triple_attention = TripleAttention(feature_dim, num_types=num_type_classes) | |
| self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) | |
| # FC layers with best config | |
| # Type classifier: [512] hidden layer with dropout [0.5, 0.3] | |
| self.fc_type = nn.Sequential( | |
| nn.Linear(feature_dim, 512), | |
| nn.ReLU(), | |
| nn.Dropout(0.5), | |
| nn.Linear(512, num_type_classes) | |
| ) | |
| # Shape classifier: simpler architecture with first dropout rate | |
| self.fc_shape = nn.Sequential( | |
| nn.Linear(feature_dim, 256), | |
| nn.ReLU(), | |
| nn.Dropout(0.5), | |
| nn.Linear(256, num_shape_classes) | |
| ) | |
| # Color classifier: simpler architecture with first dropout rate | |
| self.fc_color = nn.Sequential( | |
| nn.Linear(feature_dim, 128), | |
| nn.ReLU(), | |
| nn.Dropout(0.5), | |
| nn.Linear(128, num_color_classes) | |
| ) | |
| def forward(self, x, return_attention=False): | |
| # Backbone forward pass | |
| x = self.conv1(x) | |
| x = self.bn1(x) | |
| x = self.relu(x) | |
| x = self.maxpool(x) | |
| x = self.layer1(x) | |
| x = self.layer2(x) | |
| x = self.layer3(x) | |
| x = self.layer4(x) | |
| # Triple attention | |
| x, type_attention_weights = self.triple_attention(x) | |
| # Global average pooling | |
| x = self.avgpool(x) | |
| x = torch.flatten(x, 1) | |
| # Classification heads | |
| type_out = self.fc_type(x) | |
| shape_out = self.fc_shape(x) | |
| color_out = self.fc_color(x) | |
| if return_attention: | |
| return type_out, shape_out, color_out, type_attention_weights | |
| else: | |
| return type_out, shape_out, color_out |