Sid3503's picture
Upload model.py
e80850b verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
# ==========================================
# POSITION ATTENTION MODULE (OPTIMIZED)
# ==========================================
class PositionAttention(nn.Module):
"""
Position Attention Module with best configuration
- reduction_ratio: 16 (from best config)
"""
def __init__(self, in_channels, reduction_ratio=16):
super(PositionAttention, self).__init__()
self.query_conv = nn.Conv2d(in_channels, in_channels // reduction_ratio, 1)
self.key_conv = nn.Conv2d(in_channels, in_channels // reduction_ratio, 1)
self.value_conv = nn.Conv2d(in_channels, in_channels, 1)
self.gamma = nn.Parameter(torch.zeros(1))
self.softmax = nn.Softmax(dim=-1)
def forward(self, x):
batch_size, C, H, W = x.size()
proj_query = self.query_conv(x).view(batch_size, -1, H * W).permute(0, 2, 1)
proj_key = self.key_conv(x).view(batch_size, -1, H * W)
attention = self.softmax(torch.bmm(proj_query, proj_key))
proj_value = self.value_conv(x).view(batch_size, C, -1)
out = torch.bmm(proj_value, attention.permute(0, 2, 1))
out = out.view(batch_size, C, H, W)
out = self.gamma * out + x
return out
# ==========================================
# CHANNEL ATTENTION MODULE (OPTIMIZED)
# ==========================================
class ChannelAttention(nn.Module):
"""
Channel Attention Module with best configuration
"""
def __init__(self, in_channels):
super(ChannelAttention, self).__init__()
self.gamma = nn.Parameter(torch.zeros(1))
self.softmax = nn.Softmax(dim=-1)
def forward(self, x):
batch_size, C, H, W = x.size()
proj_query = x.view(batch_size, C, -1)
proj_key = x.view(batch_size, C, -1).permute(0, 2, 1)
attention = self.softmax(torch.bmm(proj_query, proj_key))
proj_value = x.view(batch_size, C, -1)
out = torch.bmm(attention, proj_value)
out = out.view(batch_size, C, H, W)
out = self.gamma * out + x
return out
# ==========================================
# TYPE ATTENTION MODULE (OPTIMIZED)
# ==========================================
class TypeAttention(nn.Module):
"""
Type Attention Module with best configuration:
- scale_channels_ratio: 2
- num_scales: 3
- continuity_channels_ratio: 8
- temperature_scaling: 5.0
- use_texture_analyzer: True
- use_continuity_detector: False
- use_type_prototypes: False
"""
def __init__(self, in_channels, num_types=7):
super(TypeAttention, self).__init__()
self.in_channels = in_channels
self.num_types = num_types
# Best config parameters
scale_channels_ratio = 2
num_scales = 3
continuity_channels_ratio = 8
self.temperature_scaling = 5.0
kernel_sizes = [3, 5, 7]
scale_ch = in_channels // scale_channels_ratio
# Multi-scale feature extraction (3 scales)
self.scale_convs = nn.ModuleList()
for k_size in kernel_sizes[:num_scales]:
self.scale_convs.append(nn.Sequential(
nn.Conv2d(in_channels, scale_ch, kernel_size=k_size, padding=k_size//2),
nn.BatchNorm2d(scale_ch),
nn.ReLU(inplace=True)
))
# Texture analyzer (ENABLED in best config)
self.texture_analyzer = nn.Sequential(
nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1,
groups=in_channels),
nn.BatchNorm2d(in_channels),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels, in_channels, kernel_size=1),
nn.BatchNorm2d(in_channels),
nn.ReLU(inplace=True)
)
# Continuity detector (DISABLED in best config)
# Not initialized since use_continuity_detector = False
# Scale attention
total_scale_ch = scale_ch * num_scales
hidden_ch = in_channels // 4 # scale_attention_hidden_ratio default
self.scale_attention = nn.Sequential(
nn.AdaptiveAvgPool2d(1),
nn.Conv2d(total_scale_ch, hidden_ch, kernel_size=1),
nn.ReLU(inplace=True),
nn.Conv2d(hidden_ch, num_scales, kernel_size=1),
nn.Softmax(dim=1)
)
# Channel projection
self.channel_projection = nn.Conv2d(scale_ch, in_channels, kernel_size=1)
# Type enhancer
self.type_enhancer = nn.Sequential(
nn.Conv2d(in_channels * 2, in_channels, kernel_size=1),
nn.BatchNorm2d(in_channels),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(in_channels)
)
# Type prototypes (DISABLED in best config)
# Not initialized since use_type_prototypes = False
self.gamma = nn.Parameter(torch.zeros(1))
def forward(self, x):
batch_size, C, H, W = x.size()
# Multi-scale extraction
scale_feats = [conv(x) for conv in self.scale_convs]
multiscale_feat = torch.cat(scale_feats, dim=1)
# Scale attention
scale_weights = self.scale_attention(multiscale_feat)
# Weighted aggregation
aggregated = sum(feat * scale_weights[:, i:i+1, :, :]
for i, feat in enumerate(scale_feats))
aggregated = self.channel_projection(aggregated)
# Texture analysis (ENABLED)
texture_features = self.texture_analyzer(x)
# Continuity detection (DISABLED - skip this step)
# texture_features remains unchanged
# Fusion
combined = torch.cat([aggregated, texture_features], dim=1)
enhanced = self.type_enhancer(combined)
# Type prototype matching (DISABLED - skip this step)
type_attended = enhanced
type_attention_weights = None
out = self.gamma * type_attended + x
return out, type_attention_weights
# ==========================================
# TRIPLE ATTENTION MODULE (OPTIMIZED)
# ==========================================
class TripleAttention(nn.Module):
"""
Triple Attention Module with best configuration:
- use_position_attention: True
- use_channel_attention: True
- position_reduction_ratio: 16
- fusion_method: 'equal'
"""
def __init__(self, in_channels, num_types=7):
super(TripleAttention, self).__init__()
# Best config parameters
self.use_position_attention = True
self.use_channel_attention = True
self.use_type_attention = True
self.fusion_method = 'equal'
# Initialize attention modules
self.position_attention = PositionAttention(in_channels, reduction_ratio=16)
self.channel_attention = ChannelAttention(in_channels)
self.type_attention = TypeAttention(in_channels, num_types)
def forward(self, x):
outputs = []
type_weights = None
# Position attention
outputs.append(self.position_attention(x))
# Channel attention
outputs.append(self.channel_attention(x))
# Type attention
tam_out, type_weights = self.type_attention(x)
outputs.append(tam_out)
# Fusion using equal weighting
out = sum(outputs) / len(outputs)
return out, type_weights
# ==========================================
# STOOLNET MODEL (OPTIMIZED)
# ==========================================
class StoolNetTriple(nn.Module):
"""
StoolNet with best configuration:
- backbone: resnet34
- fc_type_hidden: [512]
- dropout_rates: [0.5, 0.3]
- learning_rate: 0.001
- batch_size: 16
- weight_decay: 0.0001
"""
def __init__(self, num_type_classes=7, num_shape_classes=4, num_color_classes=2):
super(StoolNetTriple, self).__init__()
# Best config: ResNet34 backbone
weights = models.ResNet34_Weights.IMAGENET1K_V1
resnet = models.resnet34(weights=weights)
feature_dim = 512
# Backbone layers
self.conv1 = resnet.conv1
self.bn1 = resnet.bn1
self.relu = resnet.relu
self.maxpool = resnet.maxpool
self.layer1 = resnet.layer1
self.layer2 = resnet.layer2
self.layer3 = resnet.layer3
self.layer4 = resnet.layer4
# Triple Attention with best config
self.triple_attention = TripleAttention(feature_dim, num_types=num_type_classes)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
# FC layers with best config
# Type classifier: [512] hidden layer with dropout [0.5, 0.3]
self.fc_type = nn.Sequential(
nn.Linear(feature_dim, 512),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(512, num_type_classes)
)
# Shape classifier: simpler architecture with first dropout rate
self.fc_shape = nn.Sequential(
nn.Linear(feature_dim, 256),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(256, num_shape_classes)
)
# Color classifier: simpler architecture with first dropout rate
self.fc_color = nn.Sequential(
nn.Linear(feature_dim, 128),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(128, num_color_classes)
)
def forward(self, x, return_attention=False):
# Backbone forward pass
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool(x)
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
# Triple attention
x, type_attention_weights = self.triple_attention(x)
# Global average pooling
x = self.avgpool(x)
x = torch.flatten(x, 1)
# Classification heads
type_out = self.fc_type(x)
shape_out = self.fc_shape(x)
color_out = self.fc_color(x)
if return_attention:
return type_out, shape_out, color_out, type_attention_weights
else:
return type_out, shape_out, color_out