File size: 9,590 Bytes
f8013fd
 
 
 
 
 
 
 
 
 
 
a81ff3f
 
 
 
 
 
 
 
 
 
f8013fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a81ff3f
 
 
 
 
 
 
 
 
 
f8013fd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
"""

Model Components for Context-CrackNet



This module contains the building blocks used in Context-CrackNet:

- ConvBlock: Basic convolutional block with BatchNorm and ReLU

- ResNet50Encoder: Pretrained ResNet50 backbone for feature extraction

- AttentionGate: Attention mechanism for skip connections (RFEM)

- LinformerSelfAttention: Efficient self-attention with linear complexity

- LinformerBlock: Transformer block using Linformer attention (CAGM)

"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import math

try:
    from torchvision.models import ResNet50_Weights
except ImportError:
    ResNet50_Weights = None


class ConvBlock(nn.Module):
    """

    Basic convolutional block: Conv -> BatchNorm -> ReLU (two times)

    

    Args:

        in_channels (int): Number of input channels

        out_channels (int): Number of output channels

    """
    def __init__(self, in_channels, out_channels):
        super(ConvBlock, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
            nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(),
        )

    def forward(self, x):
        return self.conv(x)


class ResNet50Encoder(nn.Module):
    """

    ResNet50 encoder for hierarchical feature extraction.

    

    Extracts features at multiple scales:

    - x0: (64, H/2, W/2) - After initial conv

    - x1: (256, H/4, W/4) - After layer1

    - x2: (512, H/8, W/8) - After layer2

    - x3: (1024, H/16, W/16) - After layer3

    - x4: (2048, H/32, W/32) - After layer4

    

    Args:
        pretrained (bool): Whether to use ImageNet pretrained weights
    """
    def __init__(self, pretrained=True):
        super(ResNet50Encoder, self).__init__()
        if ResNet50_Weights is not None:
            weights = ResNet50_Weights.DEFAULT if pretrained else None
            resnet = models.resnet50(weights=weights)
        else:
            resnet = models.resnet50(pretrained=pretrained)

        # Initial layers
        self.conv1 = resnet.conv1  # (64, H/2, W/2)
        self.bn1 = resnet.bn1
        self.relu = resnet.relu
        self.maxpool = resnet.maxpool  # (64, H/4, W/4)

        # ResNet layers
        self.layer1 = resnet.layer1  # (256, H/4, W/4)
        self.layer2 = resnet.layer2  # (512, H/8, W/8)
        self.layer3 = resnet.layer3  # (1024, H/16, W/16)
        self.layer4 = resnet.layer4  # (2048, H/32, W/32)

    def forward(self, x):
        x = self.conv1(x)  # (B, 64, H/2, W/2)
        x = self.bn1(x)
        x = self.relu(x)
        x0 = x  # (B, 64, H/2, W/2)

        x = self.maxpool(x)  # (B, 64, H/4, W/4)

        x1 = self.layer1(x)  # (B, 256, H/4, W/4)
        x2 = self.layer2(x1)  # (B, 512, H/8, W/8)
        x3 = self.layer3(x2)  # (B, 1024, H/16, W/16)
        x4 = self.layer4(x3)  # (B, 2048, H/32, W/32)

        return x0, x1, x2, x3, x4


class AttentionGate(nn.Module):
    """

    Attention Gate for skip connections (Region Focused Enhancement Module - RFEM).

    

    Implements attention mechanism that learns to focus on relevant regions

    in the encoder features based on the decoder gating signal.

    

    Args:

        F_g (int): Number of channels in gating signal (from decoder)

        F_l (int): Number of channels in encoder feature map

        F_int (int): Number of intermediate channels

    """
    def __init__(self, F_g, F_l, F_int):
        super(AttentionGate, self).__init__()
        # W_g: gating signal (from decoder)
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )
        # W_x: encoder feature map
        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1, stride=1, padding=0, bias=True),
            nn.BatchNorm2d(F_int)
        )
        # Psi: attention coefficient
        self.psi = nn.Sequential(
            nn.ReLU(),
            nn.Conv2d(F_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
            nn.Sigmoid()
        )
    
    def forward(self, x, g, return_attention=True):
        """

        Args:

            x: Encoder feature map

            g: Gating signal from decoder

            return_attention: Whether to return attention weights

            

        Returns:

            Attended features, optionally with attention weights

        """
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.psi(F.relu(g1 + x1))
        if return_attention:
            return x * psi, psi
        return x * psi


class LinformerSelfAttention(nn.Module):
    """

    Linformer Self-Attention with linear complexity O(n*k).

    

    Projects keys and values to a lower dimension k to reduce the 

    quadratic complexity of standard self-attention.

    

    Args:

        embed_dim (int): Embedding dimension

        num_heads (int): Number of attention heads

        seq_len (int): Sequence length for projection matrices

        k (int): Projection dimension (default: 256)

        dropout (float): Dropout rate

    """
    def __init__(self, embed_dim, num_heads, seq_len=None, k=256, dropout=0.1):
        super(LinformerSelfAttention, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.k = k
        self.seq_len = seq_len
        self.head_dim = embed_dim // num_heads

        # Query, Key, Value linear layers
        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)

        # Projection matrices for linear attention
        self.proj_E = nn.Parameter(torch.randn(seq_len, k))
        self.proj_F = nn.Parameter(torch.randn(seq_len, k))

        self.out_proj = nn.Linear(embed_dim, embed_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, return_attention=True):
        batch_size, seq_len, embed_dim = x.size()

        # Linear projections
        Q = self.q_proj(x)  # (batch_size, seq_len, embed_dim)
        K = self.k_proj(x)
        V = self.v_proj(x)

        # Reshape for multi-head attention
        Q = Q.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        K = K.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        V = V.view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

        # Project K and V using proj_E and proj_F
        proj_E = self.proj_E.unsqueeze(0).unsqueeze(0)  # (1, 1, seq_len, k)
        proj_F = self.proj_F.unsqueeze(0).unsqueeze(0)

        K_proj = torch.matmul(K.transpose(-2, -1), proj_E).transpose(-2, -1)
        V_proj = torch.matmul(V.transpose(-2, -1), proj_F).transpose(-2, -1)

        # Scaled dot-product attention
        attn_scores = torch.matmul(Q, K_proj.transpose(-2, -1)) / math.sqrt(self.head_dim)
        attn_probs = F.softmax(attn_scores, dim=-1)
        attn_probs = self.dropout(attn_probs)

        attn_output = torch.matmul(attn_probs, V_proj)

        # Concatenate heads
        attn_output = attn_output.transpose(1, 2).contiguous().view(batch_size, seq_len, embed_dim)

        # Final linear layer
        attn_output = self.out_proj(attn_output)

        if return_attention:
            return attn_output, attn_probs
        return attn_output
    

class LinformerBlock(nn.Module):
    """

    Transformer block using Linformer attention (Context-Aware Global Module - CAGM).

    

    Combines Linformer self-attention with feed-forward network and 

    residual connections.

    

    Args:

        embed_dim (int): Embedding dimension

        num_heads (int): Number of attention heads

        seq_len (int): Sequence length

        k (int): Linformer projection dimension

        ff_dim (int): Feed-forward network hidden dimension

        dropout (float): Dropout rate

    """
    def __init__(self, embed_dim, num_heads, seq_len, k=256, ff_dim=512, dropout=0.1):
        super(LinformerBlock, self).__init__()
        self.self_attn = LinformerSelfAttention(embed_dim, num_heads, seq_len, k, dropout)
        self.norm1 = nn.LayerNorm(embed_dim)
        self.dropout1 = nn.Dropout(dropout)

        self.feed_forward = nn.Sequential(
            nn.Linear(embed_dim, ff_dim),
            nn.ReLU(),
            nn.Linear(ff_dim, embed_dim),
            nn.Dropout(dropout)
        )
        self.norm2 = nn.LayerNorm(embed_dim)
        self.dropout2 = nn.Dropout(dropout)

    def forward(self, x, return_attention=True):
        x2 = self.norm1(x)
        if return_attention:
            attn_output, attn_probs = self.self_attn(x2, return_attention=True)
            x = x + self.dropout1(attn_output)
            x2 = self.norm2(x)
            ff_output = self.feed_forward(x2)
            x = x + self.dropout2(ff_output)
            return x, attn_probs
        else:
            attn_output = self.self_attn(x2, return_attention=False)
            x = x + self.dropout1(attn_output)
            x2 = self.norm2(x)
            ff_output = self.feed_forward(x2)
            x = x + self.dropout2(ff_output)
            return x