""" Context-CrackNet Model Architecture This module contains the main Context-CrackNet model and its ablation variant for crack segmentation in pavement images. The architecture combines: - ResNet50 encoder for hierarchical feature extraction - Linformer-based Context-Aware Global Module (CAGM) for global context - Attention-gated skip connections (Region Focused Enhancement Module - RFEM) """ import torch import torch.nn as nn import torch.nn.functional as F from src.models.components import ( ConvBlock, ResNet50Encoder, AttentionGate, LinformerBlock, ) class Context_CrackNet(nn.Module): """ Context-CrackNet: A novel architecture for crack segmentation. Combines a ResNet50 encoder with Linformer-based global attention (CAGM) and attention-gated skip connections (RFEM) for accurate crack detection. Args: in_channels (int): Number of input image channels (default: 3) out_channels (int): Number of output segmentation classes (default: 1) img_size (int): Input image size, must be divisible by 16 (default: 448) num_heads (int): Number of attention heads in Linformer (default: 8) ff_dim (int): Feed-forward dimension in Linformer (default: 2048) linformer_k (int): Projection dimension for Linformer (default: 256) pretrained (bool): Whether to use pretrained ResNet50 weights (default: True) Example: >>> model = Context_CrackNet(in_channels=3, out_channels=1, img_size=448) >>> x = torch.randn(1, 3, 448, 448) >>> output, attention_maps = model(x) >>> output.shape torch.Size([1, 1, 448, 448]) """ def __init__(self, in_channels=3, out_channels=1, img_size=448, num_heads=8, ff_dim=2048, linformer_k=256, pretrained=True): super(Context_CrackNet, self).__init__() self.in_channels = in_channels self.out_channels = out_channels # Encoder self.encoder = ResNet50Encoder(pretrained=pretrained) # Linformer Block (CAGM - Context-Aware Global Module) seq_len = (img_size // 16) ** 2 # Sequence length at layer3 self.linformer = LinformerBlock( embed_dim=1024, num_heads=num_heads, seq_len=seq_len, k=linformer_k, ff_dim=ff_dim ) # Decoder with attention gates (RFEM - Region Focused Enhancement Module) self.up4 = nn.ConvTranspose2d(2048, 1024, kernel_size=2, stride=2) self.attention3 = AttentionGate(F_g=1024, F_l=1024, F_int=512) self.conv3 = ConvBlock(2048, 1024) self.up3 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2) self.attention2 = AttentionGate(F_g=512, F_l=512, F_int=256) self.conv2 = ConvBlock(1024, 512) self.up2 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2) self.attention1 = AttentionGate(F_g=256, F_l=256, F_int=128) self.conv1 = ConvBlock(512, 256) self.up1 = nn.ConvTranspose2d(256, 64, kernel_size=2, stride=2) self.attention0 = AttentionGate(F_g=64, F_l=64, F_int=32) self.conv0 = ConvBlock(128, 64) self.final = nn.Conv2d(64, out_channels, kernel_size=1) def forward(self, x, return_attention=True): """ Forward pass through Context-CrackNet. Args: x: Input tensor of shape (B, C, H, W) return_attention: Whether to return attention maps Returns: output: Segmentation mask of shape (B, out_channels, H, W) attention_maps: List of attention maps if return_attention=True """ # Encoder x0, x1, x2, x3, x4 = self.encoder(x) # x0: (B, 64, H/2, W/2) # x1: (B, 256, H/4, W/4) # x2: (B, 512, H/8, W/8) # x3: (B, 1024, H/16, W/16) # x4: (B, 2048, H/32, W/32) # Apply Linformer block to x3 (CAGM) b, c, h, w = x3.shape x3_flat = x3.view(b, c, -1).permute(0, 2, 1) # (B, seq_len, 1024) x3_transformed, linformer_attention = self.linformer(x3_flat, return_attention=True) x3 = x3_transformed.permute(0, 2, 1).view(b, c, h, w) # Decoder Level 4 d4 = self.up4(x4) x3_att, attn3 = self.attention3(x3, d4, return_attention=True) d4 = torch.cat([x3_att, d4], dim=1) d4 = self.conv3(d4) # Decoder Level 3 d3 = self.up3(d4) x2_att, attn2 = self.attention2(x2, d3, return_attention=True) d3 = torch.cat([x2_att, d3], dim=1) d3 = self.conv2(d3) # Decoder Level 2 d2 = self.up2(d3) x1_att, attn1 = self.attention1(x1, d2, return_attention=True) d2 = torch.cat([x1_att, d2], dim=1) d2 = self.conv1(d2) # Decoder Level 1 d1 = self.up1(d2) x0_att, attn0 = self.attention0(x0, d1, return_attention=True) d1 = torch.cat([x0_att, d1], dim=1) d1 = self.conv0(d1) # Final output output = self.final(d1) # Upsample to original input size output = F.interpolate(output, size=(x.size(2), x.size(3)), mode='bilinear', align_corners=False) if return_attention: return output, [attn0, attn1, attn2, attn3, linformer_attention] return output class Context_CrackNet_ablation(nn.Module): """ Context-CrackNet with optional modules for ablation studies. Allows selective enabling/disabling of RFEM and CAGM modules to study their individual contributions. Args: in_channels (int): Number of input image channels out_channels (int): Number of output segmentation classes img_size (int): Input image size num_heads (int): Number of attention heads in Linformer ff_dim (int): Feed-forward dimension in Linformer linformer_k (int): Projection dimension for Linformer use_rfem (bool): Whether to include the RFEM module (attention gates) use_cagm (bool): Whether to include the CAGM module (Linformer) pretrained (bool): Whether to use pretrained ResNet50 weights """ def __init__(self, in_channels=3, out_channels=1, img_size=448, num_heads=8, ff_dim=2048, linformer_k=256, use_rfem=True, use_cagm=True, pretrained=True): super(Context_CrackNet_ablation, self).__init__() self.use_rfem = use_rfem self.use_cagm = use_cagm self.encoder = ResNet50Encoder(pretrained=pretrained) # Linformer Block (optional based on use_cagm) seq_len = (img_size // 16) ** 2 self.linformer = LinformerBlock( embed_dim=1024, num_heads=num_heads, seq_len=seq_len, k=linformer_k, ff_dim=ff_dim ) if use_cagm else None # Decoder with optional attention gates self.up4 = nn.ConvTranspose2d(2048, 1024, kernel_size=2, stride=2) self.attention3 = AttentionGate(F_g=1024, F_l=1024, F_int=512) if use_rfem else None self.conv3 = ConvBlock(2048 if use_rfem else 1024, 1024) self.up3 = nn.ConvTranspose2d(1024, 512, kernel_size=2, stride=2) self.attention2 = AttentionGate(F_g=512, F_l=512, F_int=256) if use_rfem else None self.conv2 = ConvBlock(1024 if use_rfem else 512, 512) self.up2 = nn.ConvTranspose2d(512, 256, kernel_size=2, stride=2) self.attention1 = AttentionGate(F_g=256, F_l=256, F_int=128) if use_rfem else None self.conv1 = ConvBlock(512 if use_rfem else 256, 256) self.up1 = nn.ConvTranspose2d(256, 64, kernel_size=2, stride=2) self.attention0 = AttentionGate(F_g=64, F_l=64, F_int=32) if use_rfem else None self.conv0 = ConvBlock(128 if use_rfem else 64, 64) self.final = nn.Conv2d(64, out_channels, kernel_size=1) def forward(self, x): # Encoder x0, x1, x2, x3, x4 = self.encoder(x) # Apply Linformer block if enabled if self.use_cagm: b, c, h, w = x3.shape x3_flat = x3.view(b, c, -1).permute(0, 2, 1) x3_transformed, _ = self.linformer(x3_flat) x3 = x3_transformed.permute(0, 2, 1).view(b, c, h, w) # Decoder Level 4 d4 = self.up4(x4) if self.use_rfem: x3_att, _ = self.attention3(x3, d4) d4 = torch.cat([x3_att, d4], dim=1) d4 = self.conv3(d4) # Decoder Level 3 d3 = self.up3(d4) if self.use_rfem: x2_att, _ = self.attention2(x2, d3) d3 = torch.cat([x2_att, d3], dim=1) d3 = self.conv2(d3) # Decoder Level 2 d2 = self.up2(d3) if self.use_rfem: x1_att, _ = self.attention1(x1, d2) d2 = torch.cat([x1_att, d2], dim=1) d2 = self.conv1(d2) # Decoder Level 1 d1 = self.up1(d2) if self.use_rfem: x0_att, _ = self.attention0(x0, d1) d1 = torch.cat([x0_att, d1], dim=1) d1 = self.conv0(d1) # Final output output = self.final(d1) output = F.interpolate(output, size=(x.size(2), x.size(3)), mode='bilinear', align_corners=False) return output