File size: 9,393 Bytes
3a3f6c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c8c67ab
 
3a3f6c6
 
 
 
c8c67ab
3a3f6c6
 
 
c8c67ab
3a3f6c6
 
 
 
c8c67ab
3a3f6c6
 
c8c67ab
3a3f6c6
 
 
 
c8c67ab
3a3f6c6
 
 
 
c8c67ab
3a3f6c6
c8c67ab
 
 
3a3f6c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c8c67ab
3a3f6c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c8c67ab
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
"""
CTC-based CAPTCHA recognition model.
Uses CNN + LSTM + CTC loss - no bounding boxes needed!

This approach is standard for sequence recognition tasks where
character positions are unknown or variable.
"""
import torch
import torch.nn as nn

 
class CTCCaptchaModel(nn.Module):
    """
    CAPTCHA recognition using CTC (Connectionist Temporal Classification).
    
    Architecture:
    1. CNN backbone extracts visual features
    2. Reshape to sequence (treating width as time steps)
    3. Bidirectional LSTM processes sequence
    4. Linear layer outputs character probabilities for each time step
    5. CTC loss handles alignment between predictions and ground truth
    
    No need for bounding boxes - CTC figures out alignment automatically!
    """
    
    def __init__(self, num_classes=36, hidden_size=256, num_lstm_layers=2, use_attention=False):
        """
        Args:
            num_classes: Number of character classes (36 for A-Z, 0-9)
            hidden_size: Hidden size for LSTM layers
            num_lstm_layers: Number of LSTM layers
        """
        super(CTCCaptchaModel, self).__init__()
        
        self.num_classes = num_classes
        # CTC needs blank token for alignment (class index = num_classes)
        self.blank_idx = num_classes
        
        # CNN backbone for feature extraction
        # Input: (batch, 1, 60, 160) - grayscale image
        self.cnn = nn.Sequential(
            # Block 1
            nn.Conv2d(1, 32, kernel_size=3, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),  # -> (32, 30, 80)
            
            # Block 2
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),  # -> (64, 15, 40)
            
            # Block 3
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d((1, 2)),  # Pool only width -> (128, 15, 20)
            
            # Block 4
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d((1, 2)),  # Pool only width -> (256, 15, 10)
        )
        
        # After CNN: (batch, 256, 15, 10)
        # We'll reshape to: (batch, 10, 256*15) treating width as sequence
        # So sequence length = 10, feature dim = 256*15 = 3840
        self.feature_size = 256 * 15  # channels * height
        self.sequence_length = 10  # width after pooling
        
        # Map CNN features to LSTM input size
        self.map_to_seq = nn.Linear(self.feature_size, hidden_size)
        
        # Bidirectional LSTM to process sequence
        self.lstm = nn.LSTM(
            hidden_size,
            hidden_size,
            num_layers=num_lstm_layers,
            bidirectional=True,
            dropout=0.3 if num_lstm_layers > 1 else 0,
            batch_first=True
        )

        # Optional self-attention on top of LSTM outputs
        self.use_attention = use_attention
        if self.use_attention:
            self.attn = nn.MultiheadAttention(hidden_size * 2, num_heads=4, dropout=0.1, batch_first=True)
            self.attn_norm = nn.LayerNorm(hidden_size * 2)
            self.attn_dropout = nn.Dropout(0.1)
        else:
            self.attn = None
        
        # Output layer: map LSTM outputs to character probabilities
        # +1 for CTC blank token
        self.fc = nn.Linear(hidden_size * 2, num_classes + 1)  # *2 for bidirectional
        
    def forward(self, x):
        """
        Args:
            x: Input images (batch_size, 1, 60, 160)
        
        Returns:
            Log probabilities for CTC loss (sequence_length, batch_size, num_classes+1)
        """
        batch_size = x.size(0)
        
        # Extract CNN features
        features = self.cnn(x)  # (batch, 256, 15, 10)
        
        # Reshape to sequence: (batch, width, channels*height)
        # Transpose to treat width as sequence dimension
        features = features.permute(0, 3, 1, 2)  # (batch, 10, 256, 15)
        features = features.reshape(batch_size, self.sequence_length, self.feature_size)
        
        # Map to LSTM input size
        features = self.map_to_seq(features)  # (batch, 10, hidden_size)
        
        # Process with LSTM
        lstm_out, _ = self.lstm(features)  # (batch, 10, hidden_size*2)
        
        # Optional attention
        if self.attn is not None:
            attn_out, _ = self.attn(lstm_out, lstm_out, lstm_out)
            lstm_out = self.attn_norm(lstm_out + self.attn_dropout(attn_out))

        # Get character predictions for each time step
        logits = self.fc(lstm_out)  # (batch, 10, num_classes+1)
        
        # CTC expects: (sequence_length, batch, num_classes)
        logits = logits.permute(1, 0, 2)  # (10, batch, num_classes+1)
        
        # Apply log_softmax for CTC loss
        log_probs = torch.nn.functional.log_softmax(logits, dim=2)
        
        return log_probs
    
    def predict(self, x):
        """
        Decode predictions using greedy decoding (variable length).
        Returns a list of index lists with blanks and repeats removed.
        """
        self.eval()
        with torch.no_grad():
            log_probs = self.forward(x)  # (seq_len, batch, num_classes+1)

            # Greedy decoding: take argmax at each time step
            _, preds = log_probs.max(2)  # (seq_len, batch)
            preds = preds.transpose(0, 1)  # (batch, seq_len)

            decoded = []
            for pred_seq in preds:
                decoded_seq = []
                prev_char = None

                for char_idx in pred_seq:
                    char_idx = char_idx.item()

                    # Skip blank tokens
                    if char_idx == self.blank_idx:
                        prev_char = None
                        continue

                    # Skip repeated characters (CTC rule)
                    if char_idx != prev_char:
                        decoded_seq.append(char_idx)
                        prev_char = char_idx

                decoded.append(decoded_seq)

            # Return Python lists (variable length) for downstream decoding
            return decoded


class CTCCaptchaModelSimple(nn.Module):
    """
    Simpler CTC model without LSTM (faster training, less memory).
    Good baseline to start with.
    """
    
    def __init__(self, num_classes=36):
        super(CTCCaptchaModelSimple, self).__init__()
        
        self.num_classes = num_classes
        self.blank_idx = num_classes
        
        # CNN backbone
        self.features = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.MaxPool2d((2, 2)),  # -> (64, 30, 80)
            
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            nn.MaxPool2d((2, 2)),  # -> (128, 15, 40)
            
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            nn.MaxPool2d((1, 2)),  # -> (256, 15, 20)
            
            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.ReLU(),
            nn.MaxPool2d((1, 2)),  # -> (512, 15, 10)
        )
        
        # Direct mapping to character predictions
        # Treat width dimension as sequence
        self.classifier = nn.Sequential(
            nn.Linear(512 * 15, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, num_classes + 1)
        )
        
        self.sequence_length = 10
        
    def forward(self, x):
        """Forward pass for CTC."""
        batch_size = x.size(0)
        
        # Extract features
        features = self.features(x)  # (batch, 512, 15, 10)
        
        # Reshape: treat width as sequence
        features = features.permute(0, 3, 1, 2)  # (batch, 10, 512, 15)
        features = features.reshape(batch_size, self.sequence_length, -1)
        
        # Classify each time step
        logits = self.classifier(features)  # (batch, 10, num_classes+1)
        
        # CTC format
        logits = logits.permute(1, 0, 2)  # (10, batch, num_classes+1)
        log_probs = torch.nn.functional.log_softmax(logits, dim=2)
        
        return log_probs
    
    def predict(self, x):
        """Greedy decoding with variable-length output (list of lists)."""
        self.eval()
        with torch.no_grad():
            log_probs = self.forward(x)
            _, preds = log_probs.max(2)
            preds = preds.transpose(0, 1)
            
            decoded = []
            for pred_seq in preds:
                decoded_seq = []
                prev_char = None
                
                for char_idx in pred_seq:
                    char_idx = char_idx.item()
                    if char_idx == self.blank_idx:
                        prev_char = None
                        continue
                    if char_idx != prev_char:
                        decoded_seq.append(char_idx)
                        prev_char = char_idx
                
                decoded.append(decoded_seq)
            
            return decoded