Pavement-Evaluation / src /models /components.py
Blessing988's picture
deploy: Context-CrackNet Pavement Analyzer
a81ff3f verified
"""
Model Components for Context-CrackNet
This module contains the building blocks used in Context-CrackNet:
- ConvBlock: Basic convolutional block with BatchNorm and ReLU
- ResNet50Encoder: Pretrained ResNet50 backbone for feature extraction
- AttentionGate: Attention mechanism for skip connections (RFEM)
- LinformerSelfAttention: Efficient self-attention with linear complexity
- LinformerBlock: Transformer block using Linformer attention (CAGM)
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import math
try:
from torchvision.models import ResNet50_Weights
except ImportError:
ResNet50_Weights = None
class ConvBlock(nn.Module):
"""
Basic convolutional block: Conv -> BatchNorm -> ReLU (two times)
Args:
in_channels (int): Number of input channels
out_channels (int): Number of output channels
"""
def __init__(self, in_channels, out_channels):
super(ConvBlock, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.ReLU(),
)
def forward(self, x):
return self.conv(x)
class ResNet50Encoder(nn.Module):
"""
ResNet50 encoder for hierarchical feature extraction.
Extracts features at multiple scales:
- x0: (64, H/2, W/2) - After initial conv
- x1: (256, H/4, W/4) - After layer1
- x2: (512, H/8, W/8) - After layer2
- x3: (1024, H/16, W/16) - After layer3
- x4: (2048, H/32, W/32) - After layer4
Args:
pretrained (bool): Whether to use ImageNet pretrained weights
"""
def __init__(self, pretrained=True):
super(ResNet50Encoder, self).__init__()
if ResNet50_Weights is not None:
weights = ResNet50_Weights.DEFAULT if pretrained else None
resnet = models.resnet50(weights=weights)
else:
resnet = models.resnet50(pretrained=pretrained)
# Initial layers
self.conv1 = resnet.conv1 # (64, H/2, W/2)
self.bn1 = resnet.bn1
self.relu = resnet.relu
self.maxpool = resnet.maxpool # (64, H/4, W/4)
# ResNet layers
self.layer1 = resnet.layer1 # (256, H/4, W/4)
self.layer2 = resnet.layer2 # (512, H/8, W/8)
self.layer3 = resnet.layer3 # (1024, H/16, W/16)
self.layer4 = resnet.layer4 # (2048, H/32, W/32)
def forward(self, x):
x = self.conv1(x) # (B, 64, H/2, W/2)
x = self.bn1(x)
x = self.relu(x)
x0 = x # (B, 64, H/2, W/2)
x = self.maxpool(x) # (B, 64, H/4, W/4)
x1 = self.layer1(x) # (B, 256, H/4, W/4)
x2 = self.layer2(x1) # (B, 512, H/8, W/8)
x3 = self.layer3(x2) # (B, 1024, H/16, W/16)
x4 = self.layer4(x3) # (B, 2048, H/32, W/32)
return x0, x1, x2, x3, x4
class AttentionGate(nn.Module):
"""
Attention Gate for skip connections (Region Focused Enhancement Module - RFEM).
Implements attention mechanism that learns to focus on relevant regions
in the encoder features based on the decoder gating signal.
Args:
F_g (int): Number of channels in gating signal (from decoder)
F_l (int): Number of channels in encoder feature map
F_int (int): Number of intermediate channels
"""
def __init__(self, F_g, F_l, F_int):
super(AttentionGate, self).__init__()
# W_g: gating signal (from decoder)
self.W_g = nn.Sequential(
nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
nn.BatchNorm2d(F_int)
)
# W_x: encoder feature map
self.W_x = nn.Sequential(
nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
nn.BatchNorm2d(F_int)
)
# Psi: attention coefficient
self.psi = nn.Sequential(
nn.ReLU(),
nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
nn.Sigmoid()
)
def forward(self, x, g, return_attention=True):
"""
Args:
x: Encoder feature map
g: Gating signal from decoder
return_attention: Whether to return attention weights
Returns:
Attended features, optionally with attention weights
"""
g1 = self.W_g(g)
x1 = self.W_x(x)
psi = self.psi(F.relu(g1 + x1))
if return_attention:
return x * psi, psi
return x * psi
class LinformerSelfAttention(nn.Module):
"""
Linformer Self-Attention with linear complexity O(n*k).
Projects keys and values to a lower dimension k to reduce the
quadratic complexity of standard self-attention.
Args:
embed_dim (int): Embedding dimension
num_heads (int): Number of attention heads
seq_len (int): Sequence length for projection matrices
k (int): Projection dimension (default: 256)
dropout (float): Dropout rate
"""
def __init__(self, embed_dim, num_heads, seq_len=None, k=256, dropout=0.1):
super(LinformerSelfAttention, self).__init__()
self.embed_dim = embed_dim
self.num_heads = num_heads
self.k = k
self.seq_len = seq_len
self.head_dim = embed_dim // num_heads
# Query, Key, Value linear layers
self.q_proj = nn.Linear(embed_dim, embed_dim)
self.k_proj = nn.Linear(embed_dim, embed_dim)
self.v_proj = nn.Linear(embed_dim, embed_dim)
# Projection matrices for linear attention
self.proj_E = nn.Parameter(torch.randn(seq_len, k))
self.proj_F = nn.Parameter(torch.randn(seq_len, k))
self.out_proj = nn.Linear(embed_dim, embed_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x, return_attention=True):
batch_size, seq_len, embed_dim = x.size()
# Linear projections
Q = self.q_proj(x) # (batch_size, seq_len, embed_dim)
K = self.k_proj(x)
V = self.v_proj(x)
# Reshape for multi-head attention
Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
K = K.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
V = V.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
# Project K and V using proj_E and proj_F
proj_E = self.proj_E.unsqueeze(0).unsqueeze(0) # (1, 1, seq_len, k)
proj_F = self.proj_F.unsqueeze(0).unsqueeze(0)
K_proj = torch.matmul(K.transpose(-2, -1), proj_E).transpose(-2, -1)
V_proj = torch.matmul(V.transpose(-2, -1), proj_F).transpose(-2, -1)
# Scaled dot-product attention
attn_scores = torch.matmul(Q, K_proj.transpose(-2, -1)) / math.sqrt(self.head_dim)
attn_probs = F.softmax(attn_scores, dim=-1)
attn_probs = self.dropout(attn_probs)
attn_output = torch.matmul(attn_probs, V_proj)
# Concatenate heads
attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, embed_dim)
# Final linear layer
attn_output = self.out_proj(attn_output)
if return_attention:
return attn_output, attn_probs
return attn_output
class LinformerBlock(nn.Module):
"""
Transformer block using Linformer attention (Context-Aware Global Module - CAGM).
Combines Linformer self-attention with feed-forward network and
residual connections.
Args:
embed_dim (int): Embedding dimension
num_heads (int): Number of attention heads
seq_len (int): Sequence length
k (int): Linformer projection dimension
ff_dim (int): Feed-forward network hidden dimension
dropout (float): Dropout rate
"""
def __init__(self, embed_dim, num_heads, seq_len, k=256, ff_dim=512, dropout=0.1):
super(LinformerBlock, self).__init__()
self.self_attn = LinformerSelfAttention(embed_dim, num_heads, seq_len, k, dropout)
self.norm1 = nn.LayerNorm(embed_dim)
self.dropout1 = nn.Dropout(dropout)
self.feed_forward = nn.Sequential(
nn.Linear(embed_dim, ff_dim),
nn.ReLU(),
nn.Linear(ff_dim, embed_dim),
nn.Dropout(dropout)
)
self.norm2 = nn.LayerNorm(embed_dim)
self.dropout2 = nn.Dropout(dropout)
def forward(self, x, return_attention=True):
x2 = self.norm1(x)
if return_attention:
attn_output, attn_probs = self.self_attn(x2, return_attention=True)
x = x + self.dropout1(attn_output)
x2 = self.norm2(x)
ff_output = self.feed_forward(x2)
x = x + self.dropout2(ff_output)
return x, attn_probs
else:
attn_output = self.self_attn(x2, return_attention=False)
x = x + self.dropout1(attn_output)
x2 = self.norm2(x)
ff_output = self.feed_forward(x2)
x = x + self.dropout2(ff_output)
return x