""" CTC-based CAPTCHA recognition model. Uses CNN + LSTM + CTC loss - no bounding boxes needed! This approach is standard for sequence recognition tasks where character positions are unknown or variable. """ import torch import torch.nn as nn class CTCCaptchaModel(nn.Module): """ CAPTCHA recognition using CTC (Connectionist Temporal Classification). Architecture: 1. CNN backbone extracts visual features 2. Reshape to sequence (treating width as time steps) 3. Bidirectional LSTM processes sequence 4. Linear layer outputs character probabilities for each time step 5. CTC loss handles alignment between predictions and ground truth No need for bounding boxes - CTC figures out alignment automatically! """ def __init__(self, num_classes=36, hidden_size=256, num_lstm_layers=2, use_attention=False): """ Args: num_classes: Number of character classes (36 for A-Z, 0-9) hidden_size: Hidden size for LSTM layers num_lstm_layers: Number of LSTM layers """ super(CTCCaptchaModel, self).__init__() self.num_classes = num_classes # CTC needs blank token for alignment (class index = num_classes) self.blank_idx = num_classes # CNN backbone for feature extraction # Input: (batch, 1, 60, 160) - grayscale image self.cnn = nn.Sequential( # Block 1 nn.Conv2d(1, 32, kernel_size=3, padding=1), nn.BatchNorm2d(32), nn.ReLU(), nn.MaxPool2d(2, 2), # -> (32, 30, 80) # Block 2 nn.Conv2d(32, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(2, 2), # -> (64, 15, 40) # Block 3 nn.Conv2d(64, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.ReLU(), nn.MaxPool2d((1, 2)), # Pool only width -> (128, 15, 20) # Block 4 nn.Conv2d(128, 256, kernel_size=3, padding=1), nn.BatchNorm2d(256), nn.ReLU(), nn.MaxPool2d((1, 2)), # Pool only width -> (256, 15, 10) ) # After CNN: (batch, 256, 15, 10) # We'll reshape to: (batch, 10, 256*15) treating width as sequence # So sequence length = 10, feature dim = 256*15 = 3840 self.feature_size = 256 * 15 # channels * height self.sequence_length = 10 # width after pooling # Map CNN features to LSTM input size self.map_to_seq = nn.Linear(self.feature_size, hidden_size) # Bidirectional LSTM to process sequence self.lstm = nn.LSTM( hidden_size, hidden_size, num_layers=num_lstm_layers, bidirectional=True, dropout=0.3 if num_lstm_layers > 1 else 0, batch_first=True ) # Optional self-attention on top of LSTM outputs self.use_attention = use_attention if self.use_attention: self.attn = nn.MultiheadAttention(hidden_size * 2, num_heads=4, dropout=0.1, batch_first=True) self.attn_norm = nn.LayerNorm(hidden_size * 2) self.attn_dropout = nn.Dropout(0.1) else: self.attn = None # Output layer: map LSTM outputs to character probabilities # +1 for CTC blank token self.fc = nn.Linear(hidden_size * 2, num_classes + 1) # *2 for bidirectional def forward(self, x): """ Args: x: Input images (batch_size, 1, 60, 160) Returns: Log probabilities for CTC loss (sequence_length, batch_size, num_classes+1) """ batch_size = x.size(0) # Extract CNN features features = self.cnn(x) # (batch, 256, 15, 10) # Reshape to sequence: (batch, width, channels*height) # Transpose to treat width as sequence dimension features = features.permute(0, 3, 1, 2) # (batch, 10, 256, 15) features = features.reshape(batch_size, self.sequence_length, self.feature_size) # Map to LSTM input size features = self.map_to_seq(features) # (batch, 10, hidden_size) # Process with LSTM lstm_out, _ = self.lstm(features) # (batch, 10, hidden_size*2) # Optional attention if self.attn is not None: attn_out, _ = self.attn(lstm_out, lstm_out, lstm_out) lstm_out = self.attn_norm(lstm_out + self.attn_dropout(attn_out)) # Get character predictions for each time step logits = self.fc(lstm_out) # (batch, 10, num_classes+1) # CTC expects: (sequence_length, batch, num_classes) logits = logits.permute(1, 0, 2) # (10, batch, num_classes+1) # Apply log_softmax for CTC loss log_probs = torch.nn.functional.log_softmax(logits, dim=2) return log_probs def predict(self, x): """ Decode predictions using greedy decoding (variable length). Returns a list of index lists with blanks and repeats removed. """ self.eval() with torch.no_grad(): log_probs = self.forward(x) # (seq_len, batch, num_classes+1) # Greedy decoding: take argmax at each time step _, preds = log_probs.max(2) # (seq_len, batch) preds = preds.transpose(0, 1) # (batch, seq_len) decoded = [] for pred_seq in preds: decoded_seq = [] prev_char = None for char_idx in pred_seq: char_idx = char_idx.item() # Skip blank tokens if char_idx == self.blank_idx: prev_char = None continue # Skip repeated characters (CTC rule) if char_idx != prev_char: decoded_seq.append(char_idx) prev_char = char_idx decoded.append(decoded_seq) # Return Python lists (variable length) for downstream decoding return decoded class CTCCaptchaModelSimple(nn.Module): """ Simpler CTC model without LSTM (faster training, less memory). Good baseline to start with. """ def __init__(self, num_classes=36): super(CTCCaptchaModelSimple, self).__init__() self.num_classes = num_classes self.blank_idx = num_classes # CNN backbone self.features = nn.Sequential( nn.Conv2d(1, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d((2, 2)), # -> (64, 30, 80) nn.Conv2d(64, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.ReLU(), nn.MaxPool2d((2, 2)), # -> (128, 15, 40) nn.Conv2d(128, 256, kernel_size=3, padding=1), nn.BatchNorm2d(256), nn.ReLU(), nn.MaxPool2d((1, 2)), # -> (256, 15, 20) nn.Conv2d(256, 512, kernel_size=3, padding=1), nn.BatchNorm2d(512), nn.ReLU(), nn.MaxPool2d((1, 2)), # -> (512, 15, 10) ) # Direct mapping to character predictions # Treat width dimension as sequence self.classifier = nn.Sequential( nn.Linear(512 * 15, 256), nn.ReLU(), nn.Dropout(0.3), nn.Linear(256, num_classes + 1) ) self.sequence_length = 10 def forward(self, x): """Forward pass for CTC.""" batch_size = x.size(0) # Extract features features = self.features(x) # (batch, 512, 15, 10) # Reshape: treat width as sequence features = features.permute(0, 3, 1, 2) # (batch, 10, 512, 15) features = features.reshape(batch_size, self.sequence_length, -1) # Classify each time step logits = self.classifier(features) # (batch, 10, num_classes+1) # CTC format logits = logits.permute(1, 0, 2) # (10, batch, num_classes+1) log_probs = torch.nn.functional.log_softmax(logits, dim=2) return log_probs def predict(self, x): """Greedy decoding with variable-length output (list of lists).""" self.eval() with torch.no_grad(): log_probs = self.forward(x) _, preds = log_probs.max(2) preds = preds.transpose(0, 1) decoded = [] for pred_seq in preds: decoded_seq = [] prev_char = None for char_idx in pred_seq: char_idx = char_idx.item() if char_idx == self.blank_idx: prev_char = None continue if char_idx != prev_char: decoded_seq.append(char_idx) prev_char = char_idx decoded.append(decoded_seq) return decoded