break-captcha / src /model.py
vedchamp07's picture
Update to v4 model with variable-length support and color inversion
c8c67ab
"""
CTC-based CAPTCHA recognition model.
Uses CNN + LSTM + CTC loss - no bounding boxes needed!
This approach is standard for sequence recognition tasks where
character positions are unknown or variable.
"""
import torch
import torch.nn as nn
class CTCCaptchaModel(nn.Module):
"""
CAPTCHA recognition using CTC (Connectionist Temporal Classification).
Architecture:
1. CNN backbone extracts visual features
2. Reshape to sequence (treating width as time steps)
3. Bidirectional LSTM processes sequence
4. Linear layer outputs character probabilities for each time step
5. CTC loss handles alignment between predictions and ground truth
No need for bounding boxes - CTC figures out alignment automatically!
"""
def __init__(self, num_classes=36, hidden_size=256, num_lstm_layers=2, use_attention=False):
"""
Args:
num_classes: Number of character classes (36 for A-Z, 0-9)
hidden_size: Hidden size for LSTM layers
num_lstm_layers: Number of LSTM layers
"""
super(CTCCaptchaModel, self).__init__()
self.num_classes = num_classes
# CTC needs blank token for alignment (class index = num_classes)
self.blank_idx = num_classes
# CNN backbone for feature extraction
# Input: (batch, 1, 60, 160) - grayscale image
self.cnn = nn.Sequential(
# Block 1
nn.Conv2d(1, 32, kernel_size=3, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(),
nn.MaxPool2d(2, 2), # -> (32, 30, 80)
# Block 2
nn.Conv2d(32, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(2, 2), # -> (64, 15, 40)
# Block 3
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.MaxPool2d((1, 2)), # Pool only width -> (128, 15, 20)
# Block 4
nn.Conv2d(128, 256, kernel_size=3, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(),
nn.MaxPool2d((1, 2)), # Pool only width -> (256, 15, 10)
)
# After CNN: (batch, 256, 15, 10)
# We'll reshape to: (batch, 10, 256*15) treating width as sequence
# So sequence length = 10, feature dim = 256*15 = 3840
self.feature_size = 256 * 15 # channels * height
self.sequence_length = 10 # width after pooling
# Map CNN features to LSTM input size
self.map_to_seq = nn.Linear(self.feature_size, hidden_size)
# Bidirectional LSTM to process sequence
self.lstm = nn.LSTM(
hidden_size,
hidden_size,
num_layers=num_lstm_layers,
bidirectional=True,
dropout=0.3 if num_lstm_layers > 1 else 0,
batch_first=True
)
# Optional self-attention on top of LSTM outputs
self.use_attention = use_attention
if self.use_attention:
self.attn = nn.MultiheadAttention(hidden_size * 2, num_heads=4, dropout=0.1, batch_first=True)
self.attn_norm = nn.LayerNorm(hidden_size * 2)
self.attn_dropout = nn.Dropout(0.1)
else:
self.attn = None
# Output layer: map LSTM outputs to character probabilities
# +1 for CTC blank token
self.fc = nn.Linear(hidden_size * 2, num_classes + 1) # *2 for bidirectional
def forward(self, x):
"""
Args:
x: Input images (batch_size, 1, 60, 160)
Returns:
Log probabilities for CTC loss (sequence_length, batch_size, num_classes+1)
"""
batch_size = x.size(0)
# Extract CNN features
features = self.cnn(x) # (batch, 256, 15, 10)
# Reshape to sequence: (batch, width, channels*height)
# Transpose to treat width as sequence dimension
features = features.permute(0, 3, 1, 2) # (batch, 10, 256, 15)
features = features.reshape(batch_size, self.sequence_length, self.feature_size)
# Map to LSTM input size
features = self.map_to_seq(features) # (batch, 10, hidden_size)
# Process with LSTM
lstm_out, _ = self.lstm(features) # (batch, 10, hidden_size*2)
# Optional attention
if self.attn is not None:
attn_out, _ = self.attn(lstm_out, lstm_out, lstm_out)
lstm_out = self.attn_norm(lstm_out + self.attn_dropout(attn_out))
# Get character predictions for each time step
logits = self.fc(lstm_out) # (batch, 10, num_classes+1)
# CTC expects: (sequence_length, batch, num_classes)
logits = logits.permute(1, 0, 2) # (10, batch, num_classes+1)
# Apply log_softmax for CTC loss
log_probs = torch.nn.functional.log_softmax(logits, dim=2)
return log_probs
def predict(self, x):
"""
Decode predictions using greedy decoding (variable length).
Returns a list of index lists with blanks and repeats removed.
"""
self.eval()
with torch.no_grad():
log_probs = self.forward(x) # (seq_len, batch, num_classes+1)
# Greedy decoding: take argmax at each time step
_, preds = log_probs.max(2) # (seq_len, batch)
preds = preds.transpose(0, 1) # (batch, seq_len)
decoded = []
for pred_seq in preds:
decoded_seq = []
prev_char = None
for char_idx in pred_seq:
char_idx = char_idx.item()
# Skip blank tokens
if char_idx == self.blank_idx:
prev_char = None
continue
# Skip repeated characters (CTC rule)
if char_idx != prev_char:
decoded_seq.append(char_idx)
prev_char = char_idx
decoded.append(decoded_seq)
# Return Python lists (variable length) for downstream decoding
return decoded
class CTCCaptchaModelSimple(nn.Module):
"""
Simpler CTC model without LSTM (faster training, less memory).
Good baseline to start with.
"""
def __init__(self, num_classes=36):
super(CTCCaptchaModelSimple, self).__init__()
self.num_classes = num_classes
self.blank_idx = num_classes
# CNN backbone
self.features = nn.Sequential(
nn.Conv2d(1, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d((2, 2)), # -> (64, 30, 80)
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.MaxPool2d((2, 2)), # -> (128, 15, 40)
nn.Conv2d(128, 256, kernel_size=3, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(),
nn.MaxPool2d((1, 2)), # -> (256, 15, 20)
nn.Conv2d(256, 512, kernel_size=3, padding=1),
nn.BatchNorm2d(512),
nn.ReLU(),
nn.MaxPool2d((1, 2)), # -> (512, 15, 10)
)
# Direct mapping to character predictions
# Treat width dimension as sequence
self.classifier = nn.Sequential(
nn.Linear(512 * 15, 256),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(256, num_classes + 1)
)
self.sequence_length = 10
def forward(self, x):
"""Forward pass for CTC."""
batch_size = x.size(0)
# Extract features
features = self.features(x) # (batch, 512, 15, 10)
# Reshape: treat width as sequence
features = features.permute(0, 3, 1, 2) # (batch, 10, 512, 15)
features = features.reshape(batch_size, self.sequence_length, -1)
# Classify each time step
logits = self.classifier(features) # (batch, 10, num_classes+1)
# CTC format
logits = logits.permute(1, 0, 2) # (10, batch, num_classes+1)
log_probs = torch.nn.functional.log_softmax(logits, dim=2)
return log_probs
def predict(self, x):
"""Greedy decoding with variable-length output (list of lists)."""
self.eval()
with torch.no_grad():
log_probs = self.forward(x)
_, preds = log_probs.max(2)
preds = preds.transpose(0, 1)
decoded = []
for pred_seq in preds:
decoded_seq = []
prev_char = None
for char_idx in pred_seq:
char_idx = char_idx.item()
if char_idx == self.blank_idx:
prev_char = None
continue
if char_idx != prev_char:
decoded_seq.append(char_idx)
prev_char = char_idx
decoded.append(decoded_seq)
return decoded