Spaces:
Sleeping
Sleeping
Commit ·
3a3f6c6
1
Parent(s): fd44722
Add CAPTCHA breaker app
Browse files- app.py +273 -0
- models/captcha_model_v3.pth +3 -0
- requirements.txt +9 -0
- src/__init__.py +0 -0
- src/__pycache__/__init__.cpython-314.pyc +0 -0
- src/__pycache__/model.cpython-314.pyc +0 -0
- src/model.py +294 -0
app.py
ADDED
|
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Gradio app for testing CAPTCHA model.
|
| 3 |
+
Allows uploading CAPTCHA images and getting predictions with preprocessing.
|
| 4 |
+
"""
|
| 5 |
+
import gradio as gr
|
| 6 |
+
import torch
|
| 7 |
+
from torchvision import transforms
|
| 8 |
+
from PIL import Image
|
| 9 |
+
import string
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
import numpy as np
|
| 12 |
+
import cv2
|
| 13 |
+
|
| 14 |
+
from src.model import CTCCaptchaModel
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
# Setup
|
| 18 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 19 |
+
CHARACTERS = string.digits + string.ascii_lowercase + string.ascii_uppercase
|
| 20 |
+
MODEL_PATH = Path("models/captcha_model_v3.pth")
|
| 21 |
+
|
| 22 |
+
# Load model
|
| 23 |
+
model = CTCCaptchaModel(num_classes=len(CHARACTERS), use_attention=True)
|
| 24 |
+
|
| 25 |
+
# Load checkpoint
|
| 26 |
+
checkpoint = torch.load(MODEL_PATH, map_location=DEVICE)
|
| 27 |
+
if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint:
|
| 28 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
| 29 |
+
else:
|
| 30 |
+
model.load_state_dict(checkpoint)
|
| 31 |
+
|
| 32 |
+
model.to(DEVICE)
|
| 33 |
+
model.eval()
|
| 34 |
+
|
| 35 |
+
# Image preprocessing transforms
|
| 36 |
+
transform = transforms.Compose([
|
| 37 |
+
transforms.Resize((60, 160)),
|
| 38 |
+
transforms.ToTensor(),
|
| 39 |
+
transforms.Normalize(mean=[0.5], std=[0.5])
|
| 40 |
+
])
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def preprocess_image(image):
|
| 44 |
+
"""
|
| 45 |
+
Preprocess image: grayscale, denoising, and thresholding.
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
image: PIL Image
|
| 49 |
+
|
| 50 |
+
Returns:
|
| 51 |
+
Preprocessed PIL Image
|
| 52 |
+
"""
|
| 53 |
+
# Convert to numpy array
|
| 54 |
+
img_array = np.array(image.convert('L'))
|
| 55 |
+
|
| 56 |
+
# Apply Otsu's thresholding
|
| 57 |
+
_, binary = cv2.threshold(img_array, 0, 255, cv2.THRESH_BINARY + cv2.THRESH_OTSU)
|
| 58 |
+
|
| 59 |
+
# Morphological closing to remove noise
|
| 60 |
+
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
|
| 61 |
+
processed = cv2.morphologyEx(binary, cv2.MORPH_CLOSE, kernel)
|
| 62 |
+
|
| 63 |
+
# Convert back to PIL Image
|
| 64 |
+
return Image.fromarray(processed)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def predict_captcha(image, ground_truth=""):
|
| 68 |
+
"""
|
| 69 |
+
Predict CAPTCHA text from image with preprocessing.
|
| 70 |
+
|
| 71 |
+
Args:
|
| 72 |
+
image: PIL Image or numpy array
|
| 73 |
+
ground_truth: Optional ground truth text for comparison
|
| 74 |
+
|
| 75 |
+
Returns:
|
| 76 |
+
Tuple of (prediction result, preprocessed image)
|
| 77 |
+
"""
|
| 78 |
+
try:
|
| 79 |
+
# Convert to PIL Image if numpy array
|
| 80 |
+
if isinstance(image, np.ndarray):
|
| 81 |
+
image = Image.fromarray(image)
|
| 82 |
+
|
| 83 |
+
# Resize image if not standard dimensions (60x160)
|
| 84 |
+
if image.size != (160, 60):
|
| 85 |
+
image = image.resize((160, 60), Image.LANCZOS)
|
| 86 |
+
|
| 87 |
+
# Preprocess image
|
| 88 |
+
processed_image = preprocess_image(image)
|
| 89 |
+
|
| 90 |
+
# Convert to tensor and predict
|
| 91 |
+
image_tensor = transform(processed_image).unsqueeze(0).to(DEVICE)
|
| 92 |
+
|
| 93 |
+
# Predict
|
| 94 |
+
with torch.no_grad():
|
| 95 |
+
pred_indices = model.predict(image_tensor)[0]
|
| 96 |
+
|
| 97 |
+
# Decode
|
| 98 |
+
predicted_text = ''.join([
|
| 99 |
+
CHARACTERS[idx.item()] for idx in pred_indices
|
| 100 |
+
if idx.item() < len(CHARACTERS)
|
| 101 |
+
])
|
| 102 |
+
|
| 103 |
+
# Format output with styling
|
| 104 |
+
result = f"### 🎯 Prediction Result\n\n"
|
| 105 |
+
result += f"# **{predicted_text}**\n\n"
|
| 106 |
+
result += f"*Length: {len(predicted_text)} characters*\n\n"
|
| 107 |
+
|
| 108 |
+
if ground_truth.strip():
|
| 109 |
+
ground_truth = ground_truth # Keep case sensitive
|
| 110 |
+
is_correct = predicted_text == ground_truth
|
| 111 |
+
result += f"**Expected:** {ground_truth}\n\n"
|
| 112 |
+
if is_correct:
|
| 113 |
+
result += "## ✅ **CORRECT!**"
|
| 114 |
+
else:
|
| 115 |
+
result += f"## ❌ **INCORRECT**"
|
| 116 |
+
|
| 117 |
+
return result, processed_image
|
| 118 |
+
|
| 119 |
+
except Exception as e:
|
| 120 |
+
return f"❌ **Error:** {str(e)}", None
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def extract_from_filename(filename):
|
| 124 |
+
"""Extract text from CAPTCHA filename (format: TEXT_INDEX.png)."""
|
| 125 |
+
if filename and hasattr(filename, 'name'):
|
| 126 |
+
stem = Path(filename.name).stem
|
| 127 |
+
text = stem.split('_')[0]
|
| 128 |
+
return text
|
| 129 |
+
return ""
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
# Create Gradio interface
|
| 133 |
+
with gr.Blocks(title="🔐 CAPTCHA Breaker", theme=gr.themes.Soft()) as demo:
|
| 134 |
+
gr.Markdown("""
|
| 135 |
+
<div style="text-align: center; padding: 20px;">
|
| 136 |
+
|
| 137 |
+
# 🔐 CAPTCHA Breaker
|
| 138 |
+
|
| 139 |
+
### Advanced AI-Powered CAPTCHA Recognition
|
| 140 |
+
|
| 141 |
+
Powered by **CNN + LSTM + Self-Attention** neural network
|
| 142 |
+
|
| 143 |
+
</div>
|
| 144 |
+
""")
|
| 145 |
+
|
| 146 |
+
with gr.Row():
|
| 147 |
+
with gr.Column(scale=2):
|
| 148 |
+
gr.Markdown("#### 📸 Upload Your CAPTCHA")
|
| 149 |
+
image_input = gr.Image(
|
| 150 |
+
type="pil",
|
| 151 |
+
label="Drop CAPTCHA image here",
|
| 152 |
+
image_mode="L"
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
with gr.Row():
|
| 156 |
+
ground_truth_input = gr.Textbox(
|
| 157 |
+
label="Expected Answer (optional)",
|
| 158 |
+
placeholder="Type here to verify accuracy",
|
| 159 |
+
lines=1,
|
| 160 |
+
scale=3
|
| 161 |
+
)
|
| 162 |
+
predict_button = gr.Button(
|
| 163 |
+
"🔍 Decode",
|
| 164 |
+
variant="primary",
|
| 165 |
+
scale=1
|
| 166 |
+
)
|
| 167 |
+
|
| 168 |
+
with gr.Column(scale=2):
|
| 169 |
+
gr.Markdown("#### 🎯 Results")
|
| 170 |
+
output = gr.Markdown(
|
| 171 |
+
"<div style='text-align: center; padding: 40px; color: #888;'>Upload an image to get started</div>"
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
with gr.Row():
|
| 175 |
+
with gr.Column():
|
| 176 |
+
gr.Markdown("#### 🔬 Preprocessing Steps Applied:")
|
| 177 |
+
gr.Markdown("""
|
| 178 |
+
- ✓ Auto-resize to 60×160 (if needed)
|
| 179 |
+
- ✓ Grayscale conversion
|
| 180 |
+
- ✓ Otsu's thresholding
|
| 181 |
+
- ✓ Morphological closing (denoising)
|
| 182 |
+
- ✓ Tensor normalization
|
| 183 |
+
- ✓ Variable length support (3-7 chars)
|
| 184 |
+
- ✓ Lowercase + Uppercase + Digits
|
| 185 |
+
""")
|
| 186 |
+
|
| 187 |
+
with gr.Column():
|
| 188 |
+
gr.Markdown("#### 📊 Character Set:")
|
| 189 |
+
gr.Markdown("""
|
| 190 |
+
- **Digits:** 0-9
|
| 191 |
+
- **Lowercase:** a-z
|
| 192 |
+
- **Uppercase:** A-Z
|
| 193 |
+
- **Total:** 62 characters
|
| 194 |
+
""")
|
| 195 |
+
|
| 196 |
+
with gr.Column():
|
| 197 |
+
gr.Markdown("#### 🖼️ Processed Image:")
|
| 198 |
+
preprocessed_image = gr.Image(
|
| 199 |
+
label="Input After Preprocessing",
|
| 200 |
+
type="pil"
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
# Info section
|
| 204 |
+
with gr.Accordion("ℹ️ Model Architecture & Performance", open=False):
|
| 205 |
+
gr.Markdown("""
|
| 206 |
+
### 🏗️ Architecture
|
| 207 |
+
|
| 208 |
+
```
|
| 209 |
+
Input Image (1, 60, 160) [Auto-resized if needed]
|
| 210 |
+
↓
|
| 211 |
+
CNN: 4 Convolutional Blocks
|
| 212 |
+
• Progressive feature extraction
|
| 213 |
+
• 1→32→64→128→256 channels
|
| 214 |
+
↓
|
| 215 |
+
Bidirectional LSTM: 2 layers
|
| 216 |
+
• 256 hidden units each direction
|
| 217 |
+
• Learns sequential dependencies
|
| 218 |
+
↓
|
| 219 |
+
Self-Attention: 4 heads
|
| 220 |
+
• Refines character representations
|
| 221 |
+
• Improves focus on important features
|
| 222 |
+
↓
|
| 223 |
+
CTC Loss: Automatic Alignment
|
| 224 |
+
• No bounding boxes needed!
|
| 225 |
+
• Learns character positions automatically
|
| 226 |
+
↓
|
| 227 |
+
Output: Variable-length prediction (3-7 characters)
|
| 228 |
+
```
|
| 229 |
+
|
| 230 |
+
### 📈 Model Capabilities (v3)
|
| 231 |
+
|
| 232 |
+
| Feature | Details |
|
| 233 |
+
|---------|---------|
|
| 234 |
+
| **Model Version** | v3 (Latest) |
|
| 235 |
+
| **Text Length** | 3-7 characters (variable) |
|
| 236 |
+
| **Character Set** | 0-9, a-z, A-Z (62 total) |
|
| 237 |
+
| **Architecture** | CNN + LSTM + Attention |
|
| 238 |
+
| **Training Data** | 10,000 synthetic CAPTCHAs |
|
| 239 |
+
| **Image Resize** | Automatic (any size → 60×160) |
|
| 240 |
+
|
| 241 |
+
### ⚠️ Known Limitations
|
| 242 |
+
|
| 243 |
+
- 0 vs O confusion (visual similarity)
|
| 244 |
+
- i vs l vs 1 confusion (very similar shapes)
|
| 245 |
+
- Limited performance on decorative/stylized fonts
|
| 246 |
+
- Sensitive to extreme image distortions
|
| 247 |
+
""")
|
| 248 |
+
|
| 249 |
+
# Connect buttons to prediction function
|
| 250 |
+
predict_button.click(
|
| 251 |
+
fn=predict_captcha,
|
| 252 |
+
inputs=[image_input, ground_truth_input],
|
| 253 |
+
outputs=[output, preprocessed_image]
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
# Auto-predict on image upload
|
| 257 |
+
image_input.change(
|
| 258 |
+
fn=lambda img: predict_captcha(img, ""),
|
| 259 |
+
inputs=image_input,
|
| 260 |
+
outputs=[output, preprocessed_image]
|
| 261 |
+
)
|
| 262 |
+
|
| 263 |
+
# Footer
|
| 264 |
+
gr.Markdown("""
|
| 265 |
+
---
|
| 266 |
+
<div style="text-align: center; color: #999; padding: 20px;">
|
| 267 |
+
Built with PyTorch | Device: {device} | GitHub: vedchamp07/captcha-breaker
|
| 268 |
+
</div>
|
| 269 |
+
""".format(device=DEVICE))
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
if __name__ == "__main__":
|
| 273 |
+
demo.launch(share=True)
|
models/captcha_model_v3.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9e724f2d10b44f23f6794de5aa316b809388006f66eb39059851b6cd750e6de4
|
| 3 |
+
size 20361923
|
requirements.txt
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
torchvision
|
| 3 |
+
captcha
|
| 4 |
+
Pillow
|
| 5 |
+
numpy
|
| 6 |
+
matplotlib
|
| 7 |
+
tqdm
|
| 8 |
+
opencv-python # For preprocessing (grayscale, noise removal)
|
| 9 |
+
gradio # For interactive web app
|
src/__init__.py
ADDED
|
File without changes
|
src/__pycache__/__init__.cpython-314.pyc
ADDED
|
Binary file (145 Bytes). View file
|
|
|
src/__pycache__/model.cpython-314.pyc
ADDED
|
Binary file (10.6 kB). View file
|
|
|
src/model.py
ADDED
|
@@ -0,0 +1,294 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
CTC-based CAPTCHA recognition model.
|
| 3 |
+
Uses CNN + LSTM + CTC loss - no bounding boxes needed!
|
| 4 |
+
|
| 5 |
+
This approach is standard for sequence recognition tasks where
|
| 6 |
+
character positions are unknown or variable.
|
| 7 |
+
"""
|
| 8 |
+
import torch
|
| 9 |
+
import torch.nn as nn
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class CTCCaptchaModel(nn.Module):
|
| 13 |
+
"""
|
| 14 |
+
CAPTCHA recognition using CTC (Connectionist Temporal Classification).
|
| 15 |
+
|
| 16 |
+
Architecture:
|
| 17 |
+
1. CNN backbone extracts visual features
|
| 18 |
+
2. Reshape to sequence (treating width as time steps)
|
| 19 |
+
3. Bidirectional LSTM processes sequence
|
| 20 |
+
4. Linear layer outputs character probabilities for each time step
|
| 21 |
+
5. CTC loss handles alignment between predictions and ground truth
|
| 22 |
+
|
| 23 |
+
No need for bounding boxes - CTC figures out alignment automatically!
|
| 24 |
+
"""
|
| 25 |
+
|
| 26 |
+
def __init__(self, num_classes=36, hidden_size=256, num_lstm_layers=2, use_attention=False):
|
| 27 |
+
"""
|
| 28 |
+
Args:
|
| 29 |
+
num_classes: Number of character classes (36 for A-Z, 0-9)
|
| 30 |
+
hidden_size: Hidden size for LSTM layers
|
| 31 |
+
num_lstm_layers: Number of LSTM layers
|
| 32 |
+
"""
|
| 33 |
+
super(CTCCaptchaModel, self).__init__()
|
| 34 |
+
|
| 35 |
+
self.num_classes = num_classes
|
| 36 |
+
# CTC needs blank token for alignment (class index = num_classes)
|
| 37 |
+
self.blank_idx = num_classes
|
| 38 |
+
|
| 39 |
+
# CNN backbone for feature extraction
|
| 40 |
+
# Input: (batch, 1, 60, 160) - grayscale image
|
| 41 |
+
self.cnn = nn.Sequential(
|
| 42 |
+
# Block 1
|
| 43 |
+
nn.Conv2d(1, 32, kernel_size=3, padding=1),
|
| 44 |
+
nn.BatchNorm2d(32),
|
| 45 |
+
nn.ReLU(),
|
| 46 |
+
nn.MaxPool2d(2, 2), # -> (32, 30, 80)
|
| 47 |
+
|
| 48 |
+
# Block 2
|
| 49 |
+
nn.Conv2d(32, 64, kernel_size=3, padding=1),
|
| 50 |
+
nn.BatchNorm2d(64),
|
| 51 |
+
nn.ReLU(),
|
| 52 |
+
nn.MaxPool2d(2, 2), # -> (64, 15, 40)
|
| 53 |
+
|
| 54 |
+
# Block 3
|
| 55 |
+
nn.Conv2d(64, 128, kernel_size=3, padding=1),
|
| 56 |
+
nn.BatchNorm2d(128),
|
| 57 |
+
nn.ReLU(),
|
| 58 |
+
nn.MaxPool2d((1, 2)), # Pool only width -> (128, 15, 20)
|
| 59 |
+
|
| 60 |
+
# Block 4
|
| 61 |
+
nn.Conv2d(128, 256, kernel_size=3, padding=1),
|
| 62 |
+
nn.BatchNorm2d(256),
|
| 63 |
+
nn.ReLU(),
|
| 64 |
+
nn.MaxPool2d((1, 2)), # Pool only width -> (256, 15, 10)
|
| 65 |
+
)
|
| 66 |
+
|
| 67 |
+
# After CNN: (batch, 256, 15, 10)
|
| 68 |
+
# We'll reshape to: (batch, 10, 256*15) treating width as sequence
|
| 69 |
+
# So sequence length = 10, feature dim = 256*15 = 3840
|
| 70 |
+
self.feature_size = 256 * 15 # channels * height
|
| 71 |
+
self.sequence_length = 10 # width after pooling
|
| 72 |
+
|
| 73 |
+
# Map CNN features to LSTM input size
|
| 74 |
+
self.map_to_seq = nn.Linear(self.feature_size, hidden_size)
|
| 75 |
+
|
| 76 |
+
# Bidirectional LSTM to process sequence
|
| 77 |
+
self.lstm = nn.LSTM(
|
| 78 |
+
hidden_size,
|
| 79 |
+
hidden_size,
|
| 80 |
+
num_layers=num_lstm_layers,
|
| 81 |
+
bidirectional=True,
|
| 82 |
+
dropout=0.3 if num_lstm_layers > 1 else 0,
|
| 83 |
+
batch_first=True
|
| 84 |
+
)
|
| 85 |
+
|
| 86 |
+
# Optional self-attention on top of LSTM outputs
|
| 87 |
+
self.use_attention = use_attention
|
| 88 |
+
if self.use_attention:
|
| 89 |
+
self.attn = nn.MultiheadAttention(hidden_size * 2, num_heads=4, dropout=0.1, batch_first=True)
|
| 90 |
+
self.attn_norm = nn.LayerNorm(hidden_size * 2)
|
| 91 |
+
self.attn_dropout = nn.Dropout(0.1)
|
| 92 |
+
else:
|
| 93 |
+
self.attn = None
|
| 94 |
+
|
| 95 |
+
# Output layer: map LSTM outputs to character probabilities
|
| 96 |
+
# +1 for CTC blank token
|
| 97 |
+
self.fc = nn.Linear(hidden_size * 2, num_classes + 1) # *2 for bidirectional
|
| 98 |
+
|
| 99 |
+
def forward(self, x):
|
| 100 |
+
"""
|
| 101 |
+
Args:
|
| 102 |
+
x: Input images (batch_size, 1, 60, 160)
|
| 103 |
+
|
| 104 |
+
Returns:
|
| 105 |
+
Log probabilities for CTC loss (sequence_length, batch_size, num_classes+1)
|
| 106 |
+
"""
|
| 107 |
+
batch_size = x.size(0)
|
| 108 |
+
|
| 109 |
+
# Extract CNN features
|
| 110 |
+
features = self.cnn(x) # (batch, 256, 15, 10)
|
| 111 |
+
|
| 112 |
+
# Reshape to sequence: (batch, width, channels*height)
|
| 113 |
+
# Transpose to treat width as sequence dimension
|
| 114 |
+
features = features.permute(0, 3, 1, 2) # (batch, 10, 256, 15)
|
| 115 |
+
features = features.reshape(batch_size, self.sequence_length, self.feature_size)
|
| 116 |
+
|
| 117 |
+
# Map to LSTM input size
|
| 118 |
+
features = self.map_to_seq(features) # (batch, 10, hidden_size)
|
| 119 |
+
|
| 120 |
+
# Process with LSTM
|
| 121 |
+
lstm_out, _ = self.lstm(features) # (batch, 10, hidden_size*2)
|
| 122 |
+
|
| 123 |
+
# Optional attention
|
| 124 |
+
if self.attn is not None:
|
| 125 |
+
attn_out, _ = self.attn(lstm_out, lstm_out, lstm_out)
|
| 126 |
+
lstm_out = self.attn_norm(lstm_out + self.attn_dropout(attn_out))
|
| 127 |
+
|
| 128 |
+
# Get character predictions for each time step
|
| 129 |
+
logits = self.fc(lstm_out) # (batch, 10, num_classes+1)
|
| 130 |
+
|
| 131 |
+
# CTC expects: (sequence_length, batch, num_classes)
|
| 132 |
+
logits = logits.permute(1, 0, 2) # (10, batch, num_classes+1)
|
| 133 |
+
|
| 134 |
+
# Apply log_softmax for CTC loss
|
| 135 |
+
log_probs = torch.nn.functional.log_softmax(logits, dim=2)
|
| 136 |
+
|
| 137 |
+
return log_probs
|
| 138 |
+
|
| 139 |
+
def predict(self, x):
|
| 140 |
+
"""
|
| 141 |
+
Decode predictions using greedy decoding.
|
| 142 |
+
|
| 143 |
+
Args:
|
| 144 |
+
x: Input images (batch_size, 1, 60, 160)
|
| 145 |
+
|
| 146 |
+
Returns:
|
| 147 |
+
Predicted character indices (batch_size, max_length)
|
| 148 |
+
"""
|
| 149 |
+
self.eval()
|
| 150 |
+
with torch.no_grad():
|
| 151 |
+
log_probs = self.forward(x) # (seq_len, batch, num_classes+1)
|
| 152 |
+
|
| 153 |
+
# Greedy decoding: take argmax at each time step
|
| 154 |
+
_, preds = log_probs.max(2) # (seq_len, batch)
|
| 155 |
+
preds = preds.transpose(0, 1) # (batch, seq_len)
|
| 156 |
+
|
| 157 |
+
# Decode: remove blanks and repeated characters
|
| 158 |
+
decoded = []
|
| 159 |
+
for pred_seq in preds:
|
| 160 |
+
decoded_seq = []
|
| 161 |
+
prev_char = None
|
| 162 |
+
|
| 163 |
+
for char_idx in pred_seq:
|
| 164 |
+
char_idx = char_idx.item()
|
| 165 |
+
|
| 166 |
+
# Skip blank tokens
|
| 167 |
+
if char_idx == self.blank_idx:
|
| 168 |
+
prev_char = None
|
| 169 |
+
continue
|
| 170 |
+
|
| 171 |
+
# Skip repeated characters (CTC rule)
|
| 172 |
+
if char_idx != prev_char:
|
| 173 |
+
decoded_seq.append(char_idx)
|
| 174 |
+
prev_char = char_idx
|
| 175 |
+
|
| 176 |
+
decoded.append(decoded_seq)
|
| 177 |
+
|
| 178 |
+
# Pad sequences to same length (max 5 for CAPTCHA)
|
| 179 |
+
max_len = 5
|
| 180 |
+
padded = []
|
| 181 |
+
for seq in decoded:
|
| 182 |
+
if len(seq) < max_len:
|
| 183 |
+
seq = seq + [0] * (max_len - len(seq)) # Pad with 0
|
| 184 |
+
else:
|
| 185 |
+
seq = seq[:max_len] # Truncate if too long
|
| 186 |
+
padded.append(seq)
|
| 187 |
+
|
| 188 |
+
# Return tensor on same device as input
|
| 189 |
+
return torch.tensor(padded, dtype=torch.long, device=x.device)
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
class CTCCaptchaModelSimple(nn.Module):
|
| 193 |
+
"""
|
| 194 |
+
Simpler CTC model without LSTM (faster training, less memory).
|
| 195 |
+
Good baseline to start with.
|
| 196 |
+
"""
|
| 197 |
+
|
| 198 |
+
def __init__(self, num_classes=36):
|
| 199 |
+
super(CTCCaptchaModelSimple, self).__init__()
|
| 200 |
+
|
| 201 |
+
self.num_classes = num_classes
|
| 202 |
+
self.blank_idx = num_classes
|
| 203 |
+
|
| 204 |
+
# CNN backbone
|
| 205 |
+
self.features = nn.Sequential(
|
| 206 |
+
nn.Conv2d(1, 64, kernel_size=3, padding=1),
|
| 207 |
+
nn.BatchNorm2d(64),
|
| 208 |
+
nn.ReLU(),
|
| 209 |
+
nn.MaxPool2d((2, 2)), # -> (64, 30, 80)
|
| 210 |
+
|
| 211 |
+
nn.Conv2d(64, 128, kernel_size=3, padding=1),
|
| 212 |
+
nn.BatchNorm2d(128),
|
| 213 |
+
nn.ReLU(),
|
| 214 |
+
nn.MaxPool2d((2, 2)), # -> (128, 15, 40)
|
| 215 |
+
|
| 216 |
+
nn.Conv2d(128, 256, kernel_size=3, padding=1),
|
| 217 |
+
nn.BatchNorm2d(256),
|
| 218 |
+
nn.ReLU(),
|
| 219 |
+
nn.MaxPool2d((1, 2)), # -> (256, 15, 20)
|
| 220 |
+
|
| 221 |
+
nn.Conv2d(256, 512, kernel_size=3, padding=1),
|
| 222 |
+
nn.BatchNorm2d(512),
|
| 223 |
+
nn.ReLU(),
|
| 224 |
+
nn.MaxPool2d((1, 2)), # -> (512, 15, 10)
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
# Direct mapping to character predictions
|
| 228 |
+
# Treat width dimension as sequence
|
| 229 |
+
self.classifier = nn.Sequential(
|
| 230 |
+
nn.Linear(512 * 15, 256),
|
| 231 |
+
nn.ReLU(),
|
| 232 |
+
nn.Dropout(0.3),
|
| 233 |
+
nn.Linear(256, num_classes + 1)
|
| 234 |
+
)
|
| 235 |
+
|
| 236 |
+
self.sequence_length = 10
|
| 237 |
+
|
| 238 |
+
def forward(self, x):
|
| 239 |
+
"""Forward pass for CTC."""
|
| 240 |
+
batch_size = x.size(0)
|
| 241 |
+
|
| 242 |
+
# Extract features
|
| 243 |
+
features = self.features(x) # (batch, 512, 15, 10)
|
| 244 |
+
|
| 245 |
+
# Reshape: treat width as sequence
|
| 246 |
+
features = features.permute(0, 3, 1, 2) # (batch, 10, 512, 15)
|
| 247 |
+
features = features.reshape(batch_size, self.sequence_length, -1)
|
| 248 |
+
|
| 249 |
+
# Classify each time step
|
| 250 |
+
logits = self.classifier(features) # (batch, 10, num_classes+1)
|
| 251 |
+
|
| 252 |
+
# CTC format
|
| 253 |
+
logits = logits.permute(1, 0, 2) # (10, batch, num_classes+1)
|
| 254 |
+
log_probs = torch.nn.functional.log_softmax(logits, dim=2)
|
| 255 |
+
|
| 256 |
+
return log_probs
|
| 257 |
+
|
| 258 |
+
def predict(self, x):
|
| 259 |
+
"""Greedy decoding."""
|
| 260 |
+
self.eval()
|
| 261 |
+
with torch.no_grad():
|
| 262 |
+
log_probs = self.forward(x)
|
| 263 |
+
_, preds = log_probs.max(2)
|
| 264 |
+
preds = preds.transpose(0, 1)
|
| 265 |
+
|
| 266 |
+
# Decode
|
| 267 |
+
decoded = []
|
| 268 |
+
for pred_seq in preds:
|
| 269 |
+
decoded_seq = []
|
| 270 |
+
prev_char = None
|
| 271 |
+
|
| 272 |
+
for char_idx in pred_seq:
|
| 273 |
+
char_idx = char_idx.item()
|
| 274 |
+
if char_idx == self.blank_idx:
|
| 275 |
+
prev_char = None
|
| 276 |
+
continue
|
| 277 |
+
if char_idx != prev_char:
|
| 278 |
+
decoded_seq.append(char_idx)
|
| 279 |
+
prev_char = char_idx
|
| 280 |
+
|
| 281 |
+
decoded.append(decoded_seq)
|
| 282 |
+
|
| 283 |
+
# Pad to length 5
|
| 284 |
+
max_len = 5
|
| 285 |
+
padded = []
|
| 286 |
+
for seq in decoded:
|
| 287 |
+
if len(seq) < max_len:
|
| 288 |
+
seq = seq + [0] * (max_len - len(seq))
|
| 289 |
+
else:
|
| 290 |
+
seq = seq[:max_len]
|
| 291 |
+
padded.append(seq)
|
| 292 |
+
|
| 293 |
+
# Return tensor on same device as input
|
| 294 |
+
return torch.tensor(padded, dtype=torch.long, device=x.device)
|