""" Model Components for Context-CrackNet This module contains the building blocks used in Context-CrackNet: - ConvBlock: Basic convolutional block with BatchNorm and ReLU - ResNet50Encoder: Pretrained ResNet50 backbone for feature extraction - AttentionGate: Attention mechanism for skip connections (RFEM) - LinformerSelfAttention: Efficient self-attention with linear complexity - LinformerBlock: Transformer block using Linformer attention (CAGM) """ import torch import torch.nn as nn import torch.nn.functional as F import torchvision.models as models import math try: from torchvision.models import ResNet50_Weights except ImportError: ResNet50_Weights = None class ConvBlock(nn.Module): """ Basic convolutional block: Conv -> BatchNorm -> ReLU (two times) Args: in_channels (int): Number of input channels out_channels (int): Number of output channels """ def __init__(self, in_channels, out_channels): super(ConvBlock, self).__init__() self.conv = nn.Sequential( nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(), nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False), nn.BatchNorm2d(out_channels), nn.ReLU(), ) def forward(self, x): return self.conv(x) class ResNet50Encoder(nn.Module): """ ResNet50 encoder for hierarchical feature extraction. Extracts features at multiple scales: - x0: (64, H/2, W/2) - After initial conv - x1: (256, H/4, W/4) - After layer1 - x2: (512, H/8, W/8) - After layer2 - x3: (1024, H/16, W/16) - After layer3 - x4: (2048, H/32, W/32) - After layer4 Args: pretrained (bool): Whether to use ImageNet pretrained weights """ def __init__(self, pretrained=True): super(ResNet50Encoder, self).__init__() if ResNet50_Weights is not None: weights = ResNet50_Weights.DEFAULT if pretrained else None resnet = models.resnet50(weights=weights) else: resnet = models.resnet50(pretrained=pretrained) # Initial layers self.conv1 = resnet.conv1 # (64, H/2, W/2) self.bn1 = resnet.bn1 self.relu = resnet.relu self.maxpool = resnet.maxpool # (64, H/4, W/4) # ResNet layers self.layer1 = resnet.layer1 # (256, H/4, W/4) self.layer2 = resnet.layer2 # (512, H/8, W/8) self.layer3 = resnet.layer3 # (1024, H/16, W/16) self.layer4 = resnet.layer4 # (2048, H/32, W/32) def forward(self, x): x = self.conv1(x) # (B, 64, H/2, W/2) x = self.bn1(x) x = self.relu(x) x0 = x # (B, 64, H/2, W/2) x = self.maxpool(x) # (B, 64, H/4, W/4) x1 = self.layer1(x) # (B, 256, H/4, W/4) x2 = self.layer2(x1) # (B, 512, H/8, W/8) x3 = self.layer3(x2) # (B, 1024, H/16, W/16) x4 = self.layer4(x3) # (B, 2048, H/32, W/32) return x0, x1, x2, x3, x4 class AttentionGate(nn.Module): """ Attention Gate for skip connections (Region Focused Enhancement Module - RFEM). Implements attention mechanism that learns to focus on relevant regions in the encoder features based on the decoder gating signal. Args: F_g (int): Number of channels in gating signal (from decoder) F_l (int): Number of channels in encoder feature map F_int (int): Number of intermediate channels """ def __init__(self, F_g, F_l, F_int): super(AttentionGate, self).__init__() # W_g: gating signal (from decoder) self.W_g = nn.Sequential( nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True), nn.BatchNorm2d(F_int) ) # W_x: encoder feature map self.W_x = nn.Sequential( nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True), nn.BatchNorm2d(F_int) ) # Psi: attention coefficient self.psi = nn.Sequential( nn.ReLU(), nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True), nn.Sigmoid() ) def forward(self, x, g, return_attention=True): """ Args: x: Encoder feature map g: Gating signal from decoder return_attention: Whether to return attention weights Returns: Attended features, optionally with attention weights """ g1 = self.W_g(g) x1 = self.W_x(x) psi = self.psi(F.relu(g1 + x1)) if return_attention: return x * psi, psi return x * psi class LinformerSelfAttention(nn.Module): """ Linformer Self-Attention with linear complexity O(n*k). Projects keys and values to a lower dimension k to reduce the quadratic complexity of standard self-attention. Args: embed_dim (int): Embedding dimension num_heads (int): Number of attention heads seq_len (int): Sequence length for projection matrices k (int): Projection dimension (default: 256) dropout (float): Dropout rate """ def __init__(self, embed_dim, num_heads, seq_len=None, k=256, dropout=0.1): super(LinformerSelfAttention, self).__init__() self.embed_dim = embed_dim self.num_heads = num_heads self.k = k self.seq_len = seq_len self.head_dim = embed_dim // num_heads # Query, Key, Value linear layers self.q_proj = nn.Linear(embed_dim, embed_dim) self.k_proj = nn.Linear(embed_dim, embed_dim) self.v_proj = nn.Linear(embed_dim, embed_dim) # Projection matrices for linear attention self.proj_E = nn.Parameter(torch.randn(seq_len, k)) self.proj_F = nn.Parameter(torch.randn(seq_len, k)) self.out_proj = nn.Linear(embed_dim, embed_dim) self.dropout = nn.Dropout(dropout) def forward(self, x, return_attention=True): batch_size, seq_len, embed_dim = x.size() # Linear projections Q = self.q_proj(x) # (batch_size, seq_len, embed_dim) K = self.k_proj(x) V = self.v_proj(x) # Reshape for multi-head attention Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) K = K.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) V = V.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2) # Project K and V using proj_E and proj_F proj_E = self.proj_E.unsqueeze(0).unsqueeze(0) # (1, 1, seq_len, k) proj_F = self.proj_F.unsqueeze(0).unsqueeze(0) K_proj = torch.matmul(K.transpose(-2, -1), proj_E).transpose(-2, -1) V_proj = torch.matmul(V.transpose(-2, -1), proj_F).transpose(-2, -1) # Scaled dot-product attention attn_scores = torch.matmul(Q, K_proj.transpose(-2, -1)) / math.sqrt(self.head_dim) attn_probs = F.softmax(attn_scores, dim=-1) attn_probs = self.dropout(attn_probs) attn_output = torch.matmul(attn_probs, V_proj) # Concatenate heads attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, embed_dim) # Final linear layer attn_output = self.out_proj(attn_output) if return_attention: return attn_output, attn_probs return attn_output class LinformerBlock(nn.Module): """ Transformer block using Linformer attention (Context-Aware Global Module - CAGM). Combines Linformer self-attention with feed-forward network and residual connections. Args: embed_dim (int): Embedding dimension num_heads (int): Number of attention heads seq_len (int): Sequence length k (int): Linformer projection dimension ff_dim (int): Feed-forward network hidden dimension dropout (float): Dropout rate """ def __init__(self, embed_dim, num_heads, seq_len, k=256, ff_dim=512, dropout=0.1): super(LinformerBlock, self).__init__() self.self_attn = LinformerSelfAttention(embed_dim, num_heads, seq_len, k, dropout) self.norm1 = nn.LayerNorm(embed_dim) self.dropout1 = nn.Dropout(dropout) self.feed_forward = nn.Sequential( nn.Linear(embed_dim, ff_dim), nn.ReLU(), nn.Linear(ff_dim, embed_dim), nn.Dropout(dropout) ) self.norm2 = nn.LayerNorm(embed_dim) self.dropout2 = nn.Dropout(dropout) def forward(self, x, return_attention=True): x2 = self.norm1(x) if return_attention: attn_output, attn_probs = self.self_attn(x2, return_attention=True) x = x + self.dropout1(attn_output) x2 = self.norm2(x) ff_output = self.feed_forward(x2) x = x + self.dropout2(ff_output) return x, attn_probs else: attn_output = self.self_attn(x2, return_attention=False) x = x + self.dropout1(attn_output) x2 = self.norm2(x) ff_output = self.feed_forward(x2) x = x + self.dropout2(ff_output) return x