Spaces:
Sleeping
Sleeping
File size: 9,393 Bytes
3a3f6c6 c8c67ab 3a3f6c6 c8c67ab 3a3f6c6 c8c67ab 3a3f6c6 c8c67ab 3a3f6c6 c8c67ab 3a3f6c6 c8c67ab 3a3f6c6 c8c67ab 3a3f6c6 c8c67ab 3a3f6c6 c8c67ab 3a3f6c6 c8c67ab | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 | """
CTC-based CAPTCHA recognition model.
Uses CNN + LSTM + CTC loss - no bounding boxes needed!
This approach is standard for sequence recognition tasks where
character positions are unknown or variable.
"""
import torch
import torch.nn as nn
class CTCCaptchaModel(nn.Module):
"""
CAPTCHA recognition using CTC (Connectionist Temporal Classification).
Architecture:
1. CNN backbone extracts visual features
2. Reshape to sequence (treating width as time steps)
3. Bidirectional LSTM processes sequence
4. Linear layer outputs character probabilities for each time step
5. CTC loss handles alignment between predictions and ground truth
No need for bounding boxes - CTC figures out alignment automatically!
"""
def __init__(self, num_classes=36, hidden_size=256, num_lstm_layers=2, use_attention=False):
"""
Args:
num_classes: Number of character classes (36 for A-Z, 0-9)
hidden_size: Hidden size for LSTM layers
num_lstm_layers: Number of LSTM layers
"""
super(CTCCaptchaModel, self).__init__()
self.num_classes = num_classes
# CTC needs blank token for alignment (class index = num_classes)
self.blank_idx = num_classes
# CNN backbone for feature extraction
# Input: (batch, 1, 60, 160) - grayscale image
self.cnn = nn.Sequential(
# Block 1
nn.Conv2d(1, 32, kernel_size=3, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.MaxPool2d(2, 2), # -> (32, 30, 80)
# Block 2
nn.Conv2d(32, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(2, 2), # -> (64, 15, 40)
# Block 3
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.MaxPool2d((1, 2)), # Pool only width -> (128, 15, 20)
# Block 4
nn.Conv2d(128, 256, kernel_size=3, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(),
nn.MaxPool2d((1, 2)), # Pool only width -> (256, 15, 10)
)
# After CNN: (batch, 256, 15, 10)
# We'll reshape to: (batch, 10, 256*15) treating width as sequence
# So sequence length = 10, feature dim = 256*15 = 3840
self.feature_size = 256 * 15 # channels * height
self.sequence_length = 10 # width after pooling
# Map CNN features to LSTM input size
self.map_to_seq = nn.Linear(self.feature_size, hidden_size)
# Bidirectional LSTM to process sequence
self.lstm = nn.LSTM(
hidden_size,
hidden_size,
num_layers=num_lstm_layers,
bidirectional=True,
dropout=0.3 if num_lstm_layers > 1 else 0,
batch_first=True
)
# Optional self-attention on top of LSTM outputs
self.use_attention = use_attention
if self.use_attention:
self.attn = nn.MultiheadAttention(hidden_size * 2, num_heads=4, dropout=0.1, batch_first=True)
self.attn_norm = nn.LayerNorm(hidden_size * 2)
self.attn_dropout = nn.Dropout(0.1)
else:
self.attn = None
# Output layer: map LSTM outputs to character probabilities
# +1 for CTC blank token
self.fc = nn.Linear(hidden_size * 2, num_classes + 1) # *2 for bidirectional
def forward(self, x):
"""
Args:
x: Input images (batch_size, 1, 60, 160)
Returns:
Log probabilities for CTC loss (sequence_length, batch_size, num_classes+1)
"""
batch_size = x.size(0)
# Extract CNN features
features = self.cnn(x) # (batch, 256, 15, 10)
# Reshape to sequence: (batch, width, channels*height)
# Transpose to treat width as sequence dimension
features = features.permute(0, 3, 1, 2) # (batch, 10, 256, 15)
features = features.reshape(batch_size, self.sequence_length, self.feature_size)
# Map to LSTM input size
features = self.map_to_seq(features) # (batch, 10, hidden_size)
# Process with LSTM
lstm_out, _ = self.lstm(features) # (batch, 10, hidden_size*2)
# Optional attention
if self.attn is not None:
attn_out, _ = self.attn(lstm_out, lstm_out, lstm_out)
lstm_out = self.attn_norm(lstm_out + self.attn_dropout(attn_out))
# Get character predictions for each time step
logits = self.fc(lstm_out) # (batch, 10, num_classes+1)
# CTC expects: (sequence_length, batch, num_classes)
logits = logits.permute(1, 0, 2) # (10, batch, num_classes+1)
# Apply log_softmax for CTC loss
log_probs = torch.nn.functional.log_softmax(logits, dim=2)
return log_probs
def predict(self, x):
"""
Decode predictions using greedy decoding (variable length).
Returns a list of index lists with blanks and repeats removed.
"""
self.eval()
with torch.no_grad():
log_probs = self.forward(x) # (seq_len, batch, num_classes+1)
# Greedy decoding: take argmax at each time step
_, preds = log_probs.max(2) # (seq_len, batch)
preds = preds.transpose(0, 1) # (batch, seq_len)
decoded = []
for pred_seq in preds:
decoded_seq = []
prev_char = None
for char_idx in pred_seq:
char_idx = char_idx.item()
# Skip blank tokens
if char_idx == self.blank_idx:
prev_char = None
continue
# Skip repeated characters (CTC rule)
if char_idx != prev_char:
decoded_seq.append(char_idx)
prev_char = char_idx
decoded.append(decoded_seq)
# Return Python lists (variable length) for downstream decoding
return decoded
class CTCCaptchaModelSimple(nn.Module):
"""
Simpler CTC model without LSTM (faster training, less memory).
Good baseline to start with.
"""
def __init__(self, num_classes=36):
super(CTCCaptchaModelSimple, self).__init__()
self.num_classes = num_classes
self.blank_idx = num_classes
# CNN backbone
self.features = nn.Sequential(
nn.Conv2d(1, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d((2, 2)), # -> (64, 30, 80)
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.MaxPool2d((2, 2)), # -> (128, 15, 40)
nn.Conv2d(128, 256, kernel_size=3, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(),
nn.MaxPool2d((1, 2)), # -> (256, 15, 20)
nn.Conv2d(256, 512, kernel_size=3, padding=1),
nn.BatchNorm2d(512),
nn.ReLU(),
nn.MaxPool2d((1, 2)), # -> (512, 15, 10)
)
# Direct mapping to character predictions
# Treat width dimension as sequence
self.classifier = nn.Sequential(
nn.Linear(512 * 15, 256),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(256, num_classes + 1)
)
self.sequence_length = 10
def forward(self, x):
"""Forward pass for CTC."""
batch_size = x.size(0)
# Extract features
features = self.features(x) # (batch, 512, 15, 10)
# Reshape: treat width as sequence
features = features.permute(0, 3, 1, 2) # (batch, 10, 512, 15)
features = features.reshape(batch_size, self.sequence_length, -1)
# Classify each time step
logits = self.classifier(features) # (batch, 10, num_classes+1)
# CTC format
logits = logits.permute(1, 0, 2) # (10, batch, num_classes+1)
log_probs = torch.nn.functional.log_softmax(logits, dim=2)
return log_probs
def predict(self, x):
"""Greedy decoding with variable-length output (list of lists)."""
self.eval()
with torch.no_grad():
log_probs = self.forward(x)
_, preds = log_probs.max(2)
preds = preds.transpose(0, 1)
decoded = []
for pred_seq in preds:
decoded_seq = []
prev_char = None
for char_idx in pred_seq:
char_idx = char_idx.item()
if char_idx == self.blank_idx:
prev_char = None
continue
if char_idx != prev_char:
decoded_seq.append(char_idx)
prev_char = char_idx
decoded.append(decoded_seq)
return decoded
|