Spaces:
Running
Running
Commit
·
6e89f30
1
Parent(s):
ada63c0
checkpoint
Browse files- .gitignore +2 -12
- Metrics/.gitkeep +3 -0
- Metrics/training_metrics.txt +3 -0
- README.md +59 -28
- src/captcha_dataset.py +49 -0
- src/config.py +21 -8
- src/data.py +2 -2
- src/model_crnn.py +80 -0
- src/plotting.py +107 -0
- src/test.py +49 -0
- train.py +206 -0
- train_sanity.py +96 -0
.gitignore
CHANGED
|
@@ -1,12 +1,4 @@
|
|
| 1 |
-
|
| 2 |
-
#!/usr/bin/env bash
|
| 3 |
-
# Create a .gitignore that keeps the Dataset folder but ignores its contents,
|
| 4 |
-
# plus common Python/ML ignores. Run this from your repo root.
|
| 5 |
-
|
| 6 |
-
set -e
|
| 7 |
-
|
| 8 |
-
cat > .gitignore << 'EOF'
|
| 9 |
-
# Keep the Dataset folder but ignore its contents
|
| 10 |
Dataset/
|
| 11 |
!Dataset/.gitkeep
|
| 12 |
!Dataset/**/
|
|
@@ -32,7 +24,6 @@ pip-wheel-metadata/
|
|
| 32 |
wheels/
|
| 33 |
.pytest_cache/
|
| 34 |
.coverage
|
| 35 |
-
#.coverage.* # uncomment if you create multiple coverage files
|
| 36 |
htmlcov/
|
| 37 |
.cache/
|
| 38 |
.mypy_cache/
|
|
@@ -75,7 +66,7 @@ logs/
|
|
| 75 |
Thumbs.db
|
| 76 |
desktop.ini
|
| 77 |
|
| 78 |
-
# Images/artifacts
|
| 79 |
*.png
|
| 80 |
*.jpg
|
| 81 |
*.jpeg
|
|
@@ -143,5 +134,4 @@ cmake-build-*/
|
|
| 143 |
*.class
|
| 144 |
.gradle/
|
| 145 |
build/
|
| 146 |
-
EOF
|
| 147 |
|
|
|
|
| 1 |
+
# Keep Dataset folders but ignore their contents
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
Dataset/
|
| 3 |
!Dataset/.gitkeep
|
| 4 |
!Dataset/**/
|
|
|
|
| 24 |
wheels/
|
| 25 |
.pytest_cache/
|
| 26 |
.coverage
|
|
|
|
| 27 |
htmlcov/
|
| 28 |
.cache/
|
| 29 |
.mypy_cache/
|
|
|
|
| 66 |
Thumbs.db
|
| 67 |
desktop.ini
|
| 68 |
|
| 69 |
+
# Images/artifacts
|
| 70 |
*.png
|
| 71 |
*.jpg
|
| 72 |
*.jpeg
|
|
|
|
| 134 |
*.class
|
| 135 |
.gradle/
|
| 136 |
build/
|
|
|
|
| 137 |
|
Metrics/.gitkeep
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f67f029777a688ea90615d9a21b1935347c102ecee39cb5d50f740a4e95095eb
|
| 3 |
+
size 122
|
Metrics/training_metrics.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:edef5b371d7b2c75153063f41c43f0e3dff8d58d5fda50e7a0db52d230e04f3c
|
| 3 |
+
size 807
|
README.md
CHANGED
|
@@ -13,18 +13,31 @@ This project implements an end-to-end CAPTCHA OCR system that can recognize text
|
|
| 13 |
## 🏗️ Current Status
|
| 14 |
|
| 15 |
### ✅ Completed Components
|
| 16 |
-
- **Dataset Generation**: Synthetic CAPTCHA creation with train/val/test splits
|
| 17 |
- **Configuration**: Centralized config with image dimensions and training parameters
|
| 18 |
-
- **Vocabulary System**: Character encoding/decoding with CTC blank token support
|
| 19 |
- **CTC Collate Function**: Proper batching for variable-length sequences
|
| 20 |
- **CTC Decoding**: Greedy decode for inference
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
- **
|
| 24 |
-
- **
|
| 25 |
-
- **
|
| 26 |
-
|
| 27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
## 📁 Project Structure
|
| 30 |
|
|
@@ -38,10 +51,14 @@ CaptchaDetect/
|
|
| 38 |
│ └── test/ # 10% of data
|
| 39 |
├── src/
|
| 40 |
│ ├── config.py # Configuration and hyperparameters
|
| 41 |
-
│ ├── vocab.py # Character vocabulary and CTC encoding
|
| 42 |
│ ├── data.py # Dataset generation script
|
| 43 |
│ ├── collate.py # CTC batching function
|
| 44 |
-
│
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
├── .gitignore # Ignores dataset contents, keeps structure
|
| 46 |
└── README.md # This file
|
| 47 |
```
|
|
@@ -57,18 +74,24 @@ pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cu12
|
|
| 57 |
pip install captcha pandas pillow
|
| 58 |
```
|
| 59 |
|
| 60 |
-
### 2. Generate
|
| 61 |
```bash
|
| 62 |
cd src
|
| 63 |
python data.py
|
| 64 |
```
|
| 65 |
-
This creates
|
| 66 |
|
| 67 |
-
### 3.
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 72 |
|
| 73 |
## 🎮 Usage
|
| 74 |
|
|
@@ -84,21 +107,29 @@ Edit `src/config.py` to adjust:
|
|
| 84 |
|
| 85 |
## 🔬 Technical Details
|
| 86 |
|
| 87 |
-
### Model Architecture
|
| 88 |
-
- **CNN Encoder**:
|
| 89 |
-
- **BiLSTM**:
|
| 90 |
-
- **
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
|
| 92 |
### CTC Training
|
| 93 |
-
- **Input**: Images resized to 48×224
|
| 94 |
- **Output**: Character sequences (a-z, A-Z, 0-9)
|
| 95 |
-
- **Loss**: CTCLoss with blank=0
|
| 96 |
-
- **Decoding**: Greedy CTC decode
|
| 97 |
|
| 98 |
-
### Data
|
| 99 |
-
- **Images**: Grayscale, normalized
|
| 100 |
- **Labels**: CSV with filename and text label
|
| 101 |
-
- **Batching**: Variable-length sequences
|
|
|
|
| 102 |
|
| 103 |
## 📊 Performance Expectations
|
| 104 |
|
|
|
|
| 13 |
## 🏗️ Current Status
|
| 14 |
|
| 15 |
### ✅ Completed Components
|
| 16 |
+
- **Dataset Generation**: Synthetic CAPTCHA creation with train/val/test splits (8k train, 1k val)
|
| 17 |
- **Configuration**: Centralized config with image dimensions and training parameters
|
| 18 |
+
- **Vocabulary System**: Character encoding/decoding with CTC blank token support (63 classes)
|
| 19 |
- **CTC Collate Function**: Proper batching for variable-length sequences
|
| 20 |
- **CTC Decoding**: Greedy decode for inference
|
| 21 |
+
- **PyTorch Dataset Class**: Image loading and preprocessing with proper cv2 resizing
|
| 22 |
+
- **CRNN Model**: CNN encoder + BiLSTM + LayerNorm + linear output (working!)
|
| 23 |
+
- **Training Loop**: Complete epoch-based training pipeline with validation
|
| 24 |
+
- **Metrics & Plotting**: Training/validation loss tracking with beautiful visualizations
|
| 25 |
+
- **Debugging Tools**: Comprehensive logging of logits, predictions, and model health
|
| 26 |
+
|
| 27 |
+
### ✅ What's Working
|
| 28 |
+
- **Training Pipeline**: Stable training loop with proper loss convergence
|
| 29 |
+
- **Model Architecture**: CRNN produces correct output shapes (56×batch×63)
|
| 30 |
+
- **Data Loading**: Proper image preprocessing and CTC batching
|
| 31 |
+
- **Early Learning**: Model outputs first characters after 3 epochs (blank prob: 1.0→0.975)
|
| 32 |
+
|
| 33 |
+
### ❌ What's Not Working Yet
|
| 34 |
+
- **Accuracy**: Still very low, mostly single characters (`'t', 'tu'`)
|
| 35 |
+
- **Sequence Length**: Not yet producing full CAPTCHA sequences
|
| 36 |
+
- **Character Diversity**: Limited to a few characters, needs more training
|
| 37 |
+
|
| 38 |
+
### 🎯 Training Status
|
| 39 |
+
- **Current**: Epoch 3, basic character recognition starting
|
| 40 |
+
- **Estimated**: 20-40 epochs needed for decent CAPTCHA accuracy
|
| 41 |
|
| 42 |
## 📁 Project Structure
|
| 43 |
|
|
|
|
| 51 |
│ └── test/ # 10% of data
|
| 52 |
├── src/
|
| 53 |
│ ├── config.py # Configuration and hyperparameters
|
| 54 |
+
│ ├── vocab.py # Character vocabulary and CTC encoding/decoding
|
| 55 |
│ ├── data.py # Dataset generation script
|
| 56 |
│ ├── collate.py # CTC batching function
|
| 57 |
+
│ ├── captcha_dataset.py # PyTorch Dataset class
|
| 58 |
+
│ ├── model_crnn.py # CRNN model architecture
|
| 59 |
+
│ └── plotting.py # Training metrics and visualization
|
| 60 |
+
├── train.py # Main training script (✅ WORKING!)
|
| 61 |
+
├── Metrics/ # Training plots and logs (auto-generated)
|
| 62 |
├── .gitignore # Ignores dataset contents, keeps structure
|
| 63 |
└── README.md # This file
|
| 64 |
```
|
|
|
|
| 74 |
pip install captcha pandas pillow
|
| 75 |
```
|
| 76 |
|
| 77 |
+
### 2. Generate Training Dataset
|
| 78 |
```bash
|
| 79 |
cd src
|
| 80 |
python data.py
|
| 81 |
```
|
| 82 |
+
This creates 10,000 synthetic CAPTCHAs in `Dataset_test/captchas/` with proper train/val/test splits.
|
| 83 |
|
| 84 |
+
### 3. Start Training
|
| 85 |
+
```bash
|
| 86 |
+
python train.py
|
| 87 |
+
```
|
| 88 |
+
This starts the full training pipeline with automatic metrics generation.
|
| 89 |
+
|
| 90 |
+
### 4. Monitor Progress
|
| 91 |
+
Training will show:
|
| 92 |
+
- Real-time loss and prediction samples
|
| 93 |
+
- Automatic plot generation in `Metrics/` folder
|
| 94 |
+
- Comprehensive training logs and summaries
|
| 95 |
|
| 96 |
## 🎮 Usage
|
| 97 |
|
|
|
|
| 107 |
|
| 108 |
## 🔬 Technical Details
|
| 109 |
|
| 110 |
+
### Model Architecture (CRNN)
|
| 111 |
+
- **CNN Encoder**: SmallCNN with stride=4, reduces W=224→56 timesteps
|
| 112 |
+
- **BiLSTM**: 2-layer bidirectional LSTM (256 hidden, dropout=0.1)
|
| 113 |
+
- **LayerNorm**: Stabilizes training before output layer
|
| 114 |
+
- **Linear Output**: Maps to 63 classes (62 chars + 1 blank token)
|
| 115 |
+
|
| 116 |
+
### Training Optimizations
|
| 117 |
+
- **AdamW Optimizer**: lr=3e-4, weight_decay=1e-4
|
| 118 |
+
- **Gradient Clipping**: max_norm=1.0 prevents exploding gradients
|
| 119 |
+
- **Weight Initialization**: Small uniform weights (-1e-3, 1e-3) for stability
|
| 120 |
+
- **Numeric Stability**: AMP disabled during initial training for stability
|
| 121 |
|
| 122 |
### CTC Training
|
| 123 |
+
- **Input**: Images resized to 48×224 (height×width)
|
| 124 |
- **Output**: Character sequences (a-z, A-Z, 0-9)
|
| 125 |
+
- **Loss**: CTCLoss with blank=0, zero_infinity=True
|
| 126 |
+
- **Decoding**: Greedy CTC decode with duplicate removal
|
| 127 |
|
| 128 |
+
### Data Pipeline
|
| 129 |
+
- **Images**: Grayscale, normalized to [0,1], proper cv2 resizing
|
| 130 |
- **Labels**: CSV with filename and text label
|
| 131 |
+
- **Batching**: Variable-length sequences with custom CTC collate function
|
| 132 |
+
- **Debugging**: Real-time monitoring of logits, blank probability, predictions
|
| 133 |
|
| 134 |
## 📊 Performance Expectations
|
| 135 |
|
src/captcha_dataset.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import glob
|
| 2 |
+
import cv2
|
| 3 |
+
import pandas as pd
|
| 4 |
+
import torch
|
| 5 |
+
import os
|
| 6 |
+
from src.config import cfg
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
|
| 9 |
+
@dataclass
|
| 10 |
+
class CaptchaDataset(torch.utils.data.Dataset):
|
| 11 |
+
def __init__(self,folder:str):
|
| 12 |
+
self.data_root = cfg.data_root
|
| 13 |
+
df = pd.read_csv(f"{self.data_root}/{folder}/labels.csv")
|
| 14 |
+
self.data = []
|
| 15 |
+
for _,row in df.iterrows():
|
| 16 |
+
filename = row['filename']
|
| 17 |
+
label = row['label']
|
| 18 |
+
img_path = f"{self.data_root}/{folder}/{row['filename']}"
|
| 19 |
+
|
| 20 |
+
# Check if file actually exists
|
| 21 |
+
if os.path.exists(img_path):
|
| 22 |
+
self.data.append((img_path,label,folder))
|
| 23 |
+
else:
|
| 24 |
+
print(f"Warning: Image file not found: {img_path}")
|
| 25 |
+
|
| 26 |
+
print(f"Loaded {len(self.data)} valid images from {folder}")
|
| 27 |
+
self.img_dim = (cfg.W_max, cfg.H) # cv2.resize expects (width, height)
|
| 28 |
+
|
| 29 |
+
def __len__(self):
|
| 30 |
+
return len(self.data)
|
| 31 |
+
|
| 32 |
+
def __getitem__(self,idx):
|
| 33 |
+
img_path, label_string,folder = self.data[idx]
|
| 34 |
+
|
| 35 |
+
# Load image with error checking
|
| 36 |
+
img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE if cfg.grayscale else cv2.IMREAD_COLOR)
|
| 37 |
+
|
| 38 |
+
if img is None:
|
| 39 |
+
raise ValueError(f"Failed to load image: {img_path}")
|
| 40 |
+
|
| 41 |
+
img = cv2.resize(img, self.img_dim)
|
| 42 |
+
img_tensor = torch.from_numpy(img).float()/255.0 # Normalize to [0,1]
|
| 43 |
+
img_tensor = img_tensor.unsqueeze(0) # Add channel dimension
|
| 44 |
+
return img_tensor, label_string, img_path
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
|
src/config.py
CHANGED
|
@@ -7,14 +7,27 @@ class Config:
|
|
| 7 |
data_root: str = os.getenv("DATA_ROOT","Dataset_test\captchas")
|
| 8 |
|
| 9 |
chars: str = string.ascii_letters + string.digits
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
num_workers: int = 4
|
| 18 |
-
amp: bool = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
cfg = Config()
|
|
|
|
| 7 |
data_root: str = os.getenv("DATA_ROOT","Dataset_test\captchas")
|
| 8 |
|
| 9 |
chars: str = string.ascii_letters + string.digits
|
| 10 |
+
|
| 11 |
+
# Image dimensions - increased for better character detail
|
| 12 |
+
H: int = 60 # Increased from 48 for more vertical detail
|
| 13 |
+
W_max: int = 256 # Increased from 224 for more time steps (T=64)
|
| 14 |
+
grayscale: bool = True
|
| 15 |
+
|
| 16 |
+
# Model architecture
|
| 17 |
+
total_stride: int = 4 # CNN width downsampling factor
|
| 18 |
+
|
| 19 |
+
# Training hyperparameters
|
| 20 |
+
batch_size: int = 32 # Local testing
|
| 21 |
+
batch_size_t4: int = 128 # Colab T4 recommendation
|
| 22 |
num_workers: int = 4
|
| 23 |
+
amp: bool = True
|
| 24 |
+
|
| 25 |
+
# Learning rate and optimization
|
| 26 |
+
lr: float = 3e-4
|
| 27 |
+
weight_decay: float = 1e-4
|
| 28 |
+
|
| 29 |
+
# Training duration
|
| 30 |
+
epochs: int = 40 # For 100k dataset
|
| 31 |
+
epochs_test: int = 10 # For 1k test dataset
|
| 32 |
|
| 33 |
cfg = Config()
|
src/data.py
CHANGED
|
@@ -8,11 +8,11 @@ import pandas as pd
|
|
| 8 |
# config
|
| 9 |
DATASET_DIR = "Dataset_test/captchas"
|
| 10 |
LABELS = "Dataset_test/labels.csv"
|
| 11 |
-
NUM_IMAGES =
|
| 12 |
CHARS = string.ascii_letters + string.digits
|
| 13 |
CAPTCHA_LEN_LOWER_LIMIT = 5
|
| 14 |
CAPTCHA_LEN_UPPER_LIMIT = 7
|
| 15 |
-
directories = [["train",0.8],["
|
| 16 |
|
| 17 |
os.makedirs(DATASET_DIR, exist_ok=True)
|
| 18 |
image = ImageCaptcha(width=160, height=60)
|
|
|
|
| 8 |
# config
|
| 9 |
DATASET_DIR = "Dataset_test/captchas"
|
| 10 |
LABELS = "Dataset_test/labels.csv"
|
| 11 |
+
NUM_IMAGES = 10000
|
| 12 |
CHARS = string.ascii_letters + string.digits
|
| 13 |
CAPTCHA_LEN_LOWER_LIMIT = 5
|
| 14 |
CAPTCHA_LEN_UPPER_LIMIT = 7
|
| 15 |
+
directories = [["train",0.8],["val",0.1],["test",0.1]]
|
| 16 |
|
| 17 |
os.makedirs(DATASET_DIR, exist_ok=True)
|
| 18 |
image = ImageCaptcha(width=160, height=60)
|
src/model_crnn.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from src.config import cfg
|
| 4 |
+
|
| 5 |
+
class SmallCNN(nn.Module):
|
| 6 |
+
"""
|
| 7 |
+
Improved CNN with BatchNorm and residual connections.
|
| 8 |
+
Produces feature map with total stride 4 along width,
|
| 9 |
+
and compresses height to ~1 via pooling.
|
| 10 |
+
"""
|
| 11 |
+
def __init__(self, in_ch=1) -> None:
|
| 12 |
+
super().__init__()
|
| 13 |
+
# First conv block: H,W -> H/2, W/2
|
| 14 |
+
self.conv1 = nn.Sequential(
|
| 15 |
+
nn.Conv2d(in_ch, 64, 3, padding=1),
|
| 16 |
+
nn.BatchNorm2d(64),
|
| 17 |
+
nn.ReLU(inplace=True),
|
| 18 |
+
nn.MaxPool2d(kernel_size=(2,2), stride=(2,2)) # stride 2x2
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
# Second conv block: maintain H/2, W/2 -> W/4
|
| 22 |
+
self.conv2 = nn.Sequential(
|
| 23 |
+
nn.Conv2d(64, 128, 3, padding=1),
|
| 24 |
+
nn.BatchNorm2d(128),
|
| 25 |
+
nn.ReLU(inplace=True),
|
| 26 |
+
nn.MaxPool2d(kernel_size=(1,2), stride=(1,2)) # height stride 1, width stride 2
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
# Residual block at 128 channels
|
| 30 |
+
self.residual = nn.Sequential(
|
| 31 |
+
nn.Conv2d(128, 128, 3, padding=1),
|
| 32 |
+
nn.BatchNorm2d(128),
|
| 33 |
+
nn.ReLU(inplace=True),
|
| 34 |
+
nn.Conv2d(128, 128, 3, padding=1),
|
| 35 |
+
nn.BatchNorm2d(128)
|
| 36 |
+
)
|
| 37 |
+
self.residual_relu = nn.ReLU(inplace=True)
|
| 38 |
+
|
| 39 |
+
self.height_pool = nn.AdaptiveAvgPool2d((1, None)) # squeeze height to 1
|
| 40 |
+
|
| 41 |
+
def forward(self, x):
|
| 42 |
+
# First two conv blocks
|
| 43 |
+
f = self.conv1(x) # [B, 64, H/2, W/2]
|
| 44 |
+
f = self.conv2(f) # [B, 128, H/2, W/4]
|
| 45 |
+
|
| 46 |
+
# Residual connection
|
| 47 |
+
residual = f
|
| 48 |
+
f = self.residual(f) # [B, 128, H/2, W/4]
|
| 49 |
+
f = f + residual # Skip connection
|
| 50 |
+
f = self.residual_relu(f) # [B, 128, H/2, W/4]
|
| 51 |
+
|
| 52 |
+
# Height pooling
|
| 53 |
+
f = self.height_pool(f) # [B, 128, 1, W/4]
|
| 54 |
+
f = f.squeeze(2) # [B, 128, W/4]
|
| 55 |
+
f = f.permute(2, 0, 1) # [T(=W/4), B, 128]
|
| 56 |
+
return f
|
| 57 |
+
|
| 58 |
+
class CRNN(nn.Module):
|
| 59 |
+
def __init__(self, vocab_size: int, in_ch: int = 1, hidden: int = 320, layers: int = 2, dropout: float = 0.05):
|
| 60 |
+
super().__init__()
|
| 61 |
+
self.cnn = SmallCNN(in_ch=in_ch)
|
| 62 |
+
self.rnn = nn.LSTM(input_size=128, hidden_size=hidden, num_layers=layers,
|
| 63 |
+
bidirectional=True, dropout=dropout, batch_first=False)
|
| 64 |
+
self.norm = nn.LayerNorm(2*hidden) # Add LayerNorm for stability
|
| 65 |
+
self.fc = nn.Linear(2*hidden, vocab_size)
|
| 66 |
+
|
| 67 |
+
# Initialize weights properly
|
| 68 |
+
self._init_weights()
|
| 69 |
+
|
| 70 |
+
def _init_weights(self):
|
| 71 |
+
# Initialize final linear layer with small weights
|
| 72 |
+
nn.init.xavier_uniform_(self.fc.weight)
|
| 73 |
+
nn.init.zeros_(self.fc.bias)
|
| 74 |
+
|
| 75 |
+
def forward(self, x):
|
| 76 |
+
seq = self.cnn(x) # [T,B,C=128]
|
| 77 |
+
y, _ = self.rnn(seq) # [T,B,2H]
|
| 78 |
+
y = self.norm(y) # [T,B,2H] - Apply LayerNorm
|
| 79 |
+
logits = self.fc(y) # [T,B,V]
|
| 80 |
+
return logits
|
src/plotting.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import matplotlib.pyplot as plt
|
| 2 |
+
import numpy as np
|
| 3 |
+
from datetime import datetime
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class TrainingMetrics:
|
| 8 |
+
def __init__(self):
|
| 9 |
+
self.train_losses = []
|
| 10 |
+
self.val_losses = []
|
| 11 |
+
self.epochs = []
|
| 12 |
+
self.sample_predictions = []
|
| 13 |
+
self.sample_targets = []
|
| 14 |
+
|
| 15 |
+
def add_epoch(self, epoch, train_loss, val_loss):
|
| 16 |
+
self.epochs.append(epoch)
|
| 17 |
+
self.train_losses.append(train_loss)
|
| 18 |
+
self.val_losses.append(val_loss)
|
| 19 |
+
|
| 20 |
+
def add_predictions(self, predictions, targets):
|
| 21 |
+
self.sample_predictions.extend(predictions)
|
| 22 |
+
self.sample_targets.extend(targets)
|
| 23 |
+
|
| 24 |
+
def plot_losses(self, save_path="Metrics/training_losses.png"):
|
| 25 |
+
plt.figure(figsize=(10, 6))
|
| 26 |
+
plt.plot(self.epochs, self.train_losses, 'b-', label='Training Loss', linewidth=2)
|
| 27 |
+
plt.plot(self.epochs, self.val_losses, 'r-', label='Validation Loss', linewidth=2)
|
| 28 |
+
plt.xlabel('Epoch')
|
| 29 |
+
plt.ylabel('Loss')
|
| 30 |
+
plt.title('Training and Validation Loss Over Time')
|
| 31 |
+
plt.legend()
|
| 32 |
+
plt.grid(True, alpha=0.3)
|
| 33 |
+
plt.tight_layout()
|
| 34 |
+
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
| 35 |
+
plt.close()
|
| 36 |
+
print(f"Loss plot saved to: {save_path}")
|
| 37 |
+
|
| 38 |
+
def plot_loss_comparison(self, save_path="Metrics/loss_comparison.png"):
|
| 39 |
+
plt.figure(figsize=(12, 8))
|
| 40 |
+
|
| 41 |
+
# Main loss plot
|
| 42 |
+
plt.subplot(2, 2, 1)
|
| 43 |
+
plt.plot(self.epochs, self.train_losses, 'b-', label='Training Loss')
|
| 44 |
+
plt.plot(self.epochs, self.val_losses, 'r-', label='Validation Loss')
|
| 45 |
+
plt.xlabel('Epoch')
|
| 46 |
+
plt.ylabel('Loss')
|
| 47 |
+
plt.title('Training vs Validation Loss')
|
| 48 |
+
plt.legend()
|
| 49 |
+
plt.grid(True, alpha=0.3)
|
| 50 |
+
|
| 51 |
+
# Loss difference plot
|
| 52 |
+
plt.subplot(2, 2, 2)
|
| 53 |
+
loss_diff = [t - v for t, v in zip(self.train_losses, self.val_losses)]
|
| 54 |
+
plt.plot(self.epochs, loss_diff, 'g-', label='Train - Val Loss')
|
| 55 |
+
plt.xlabel('Epoch')
|
| 56 |
+
plt.ylabel('Loss Difference')
|
| 57 |
+
plt.title('Overfitting Indicator')
|
| 58 |
+
plt.legend()
|
| 59 |
+
plt.grid(True, alpha=0.3)
|
| 60 |
+
|
| 61 |
+
# Loss ratio plot
|
| 62 |
+
plt.subplot(2, 2, 3)
|
| 63 |
+
loss_ratio = [v/t if t > 0 else 0 for t, v in zip(self.train_losses, self.val_losses)]
|
| 64 |
+
plt.plot(self.epochs, loss_ratio, 'm-', label='Val/Train Loss Ratio')
|
| 65 |
+
plt.xlabel('Epoch')
|
| 66 |
+
plt.ylabel('Ratio')
|
| 67 |
+
plt.title('Validation/Training Loss Ratio')
|
| 68 |
+
plt.legend()
|
| 69 |
+
plt.grid(True, alpha=0.3)
|
| 70 |
+
|
| 71 |
+
# Loss improvement plot
|
| 72 |
+
plt.subplot(2, 2, 4)
|
| 73 |
+
train_improvement = [self.train_losses[0] - t for t in self.train_losses]
|
| 74 |
+
val_improvement = [self.val_losses[0] - v for v in self.val_losses]
|
| 75 |
+
plt.plot(self.epochs, train_improvement, 'b-', label='Training Improvement')
|
| 76 |
+
plt.plot(self.epochs, val_improvement, 'r-', label='Validation Improvement')
|
| 77 |
+
plt.xlabel('Epoch')
|
| 78 |
+
plt.ylabel('Loss Improvement')
|
| 79 |
+
plt.title('Loss Improvement from Start')
|
| 80 |
+
plt.legend()
|
| 81 |
+
plt.grid(True, alpha=0.3)
|
| 82 |
+
|
| 83 |
+
plt.tight_layout()
|
| 84 |
+
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
| 85 |
+
plt.close()
|
| 86 |
+
print(f"Loss comparison plot saved to: {save_path}")
|
| 87 |
+
|
| 88 |
+
def save_metrics(self, save_path="Metrics/training_metrics.txt"):
|
| 89 |
+
with open(save_path, 'w') as f:
|
| 90 |
+
f.write("CAPTCHA OCR Training Metrics\n")
|
| 91 |
+
f.write("=" * 50 + "\n\n")
|
| 92 |
+
f.write(f"Training completed at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
|
| 93 |
+
f.write(f"Total epochs: {len(self.epochs)}\n\n")
|
| 94 |
+
|
| 95 |
+
f.write("Loss Summary:\n")
|
| 96 |
+
f.write("-" * 20 + "\n")
|
| 97 |
+
f.write(f"Final training loss: {self.train_losses[-1]:.4f}\n")
|
| 98 |
+
f.write(f"Final validation loss: {self.val_losses[-1]:.4f}\n")
|
| 99 |
+
f.write(f"Best training loss: {min(self.train_losses):.4f}\n")
|
| 100 |
+
f.write(f"Best validation loss: {min(self.val_losses):.4f}\n")
|
| 101 |
+
f.write(f"Training loss improvement: {self.train_losses[0] - self.train_losses[-1]:.4f}\n")
|
| 102 |
+
f.write(f"Validation loss improvement: {self.val_losses[0] - self.val_losses[-1]:.4f}\n\n")
|
| 103 |
+
|
| 104 |
+
f.write("Sample Predictions:\n")
|
| 105 |
+
f.write("-" * 20 + "\n")
|
| 106 |
+
for i, (pred, target) in enumerate(zip(self.sample_predictions[:10], self.sample_targets[:10])):
|
| 107 |
+
f.write(f"Sample {i+1}: Predicted='{pred}', Target='{target}'\n")
|
src/test.py
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
print(f"PyTorch version: {torch.__version__}")
|
| 4 |
+
print(f"CUDA available: {torch.cuda.is_available()}")
|
| 5 |
+
|
| 6 |
+
if torch.cuda.is_available():
|
| 7 |
+
print(f"CUDA version: {torch.version.cuda}")
|
| 8 |
+
|
| 9 |
+
# GPU detection and info
|
| 10 |
+
gpu_count = torch.cuda.device_count()
|
| 11 |
+
print(f"Number of GPUs: {gpu_count}")
|
| 12 |
+
|
| 13 |
+
for i in range(gpu_count):
|
| 14 |
+
gpu_name = torch.cuda.get_device_name(i)
|
| 15 |
+
gpu_memory = torch.cuda.get_device_properties(i).total_memory / 1024**3 # Convert to GB
|
| 16 |
+
print(f"GPU {i}: {gpu_name}")
|
| 17 |
+
print(f"GPU {i} Memory: {gpu_memory:.1f} GB")
|
| 18 |
+
|
| 19 |
+
# Current GPU
|
| 20 |
+
current_gpu = torch.cuda.current_device()
|
| 21 |
+
print(f"Current GPU: {current_gpu}")
|
| 22 |
+
|
| 23 |
+
# Test GPU tensor operations
|
| 24 |
+
print("\nTesting GPU operations...")
|
| 25 |
+
try:
|
| 26 |
+
# Create a test tensor on GPU
|
| 27 |
+
test_tensor = torch.randn(1000, 1000).cuda()
|
| 28 |
+
print(f"✓ Successfully created tensor on GPU: {test_tensor.shape}")
|
| 29 |
+
print(f"✓ Tensor device: {test_tensor.device}")
|
| 30 |
+
|
| 31 |
+
# Test basic operations
|
| 32 |
+
result = torch.mm(test_tensor, test_tensor.T)
|
| 33 |
+
print(f"✓ Matrix multiplication successful: {result.shape}")
|
| 34 |
+
|
| 35 |
+
# Memory usage
|
| 36 |
+
allocated = torch.cuda.memory_allocated() / 1024**2 # MB
|
| 37 |
+
cached = torch.cuda.memory_reserved() / 1024**2 # MB
|
| 38 |
+
print(f"✓ GPU Memory allocated: {allocated:.1f} MB")
|
| 39 |
+
print(f"✓ GPU Memory cached: {cached:.1f} MB")
|
| 40 |
+
|
| 41 |
+
# Clean up
|
| 42 |
+
del test_tensor, result
|
| 43 |
+
torch.cuda.empty_cache()
|
| 44 |
+
print("✓ GPU memory cleaned up successfully")
|
| 45 |
+
|
| 46 |
+
except Exception as e:
|
| 47 |
+
print(f"✗ GPU test failed: {e}")
|
| 48 |
+
else:
|
| 49 |
+
print("CUDA not available - PyTorch will use CPU only")
|
train.py
ADDED
|
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from torch.utils.data import DataLoader
|
| 5 |
+
from src.config import cfg
|
| 6 |
+
from src.collate import ctc_collate
|
| 7 |
+
from src.captcha_dataset import CaptchaDataset
|
| 8 |
+
from src.vocab import vocab_size, ctc_greedy_decode, decode_indices, itos
|
| 9 |
+
from src.plotting import TrainingMetrics
|
| 10 |
+
from src.model_crnn import CRNN
|
| 11 |
+
import difflib
|
| 12 |
+
|
| 13 |
+
def cer(pred: str, tgt: str) -> float:
|
| 14 |
+
"""Approximate Character Error Rate using difflib."""
|
| 15 |
+
sm = difflib.SequenceMatcher(a=pred, b=tgt)
|
| 16 |
+
return 1 - sm.ratio()
|
| 17 |
+
|
| 18 |
+
def main():
|
| 19 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 20 |
+
in_ch = 1 if cfg.grayscale else 3
|
| 21 |
+
|
| 22 |
+
print("Creating datasets...")
|
| 23 |
+
train_ds = CaptchaDataset("train")
|
| 24 |
+
val_ds = CaptchaDataset("val")
|
| 25 |
+
|
| 26 |
+
# Debug: Check vocabulary
|
| 27 |
+
print(f"Vocabulary size: {vocab_size()}")
|
| 28 |
+
print(f"First 10 characters: {list(cfg.chars)[:10]}")
|
| 29 |
+
print(f"First 10 itos: {itos[:10]}")
|
| 30 |
+
|
| 31 |
+
print(f"Training dataset size: {len(train_ds)}")
|
| 32 |
+
print(f"Validation dataset size: {len(val_ds)}")
|
| 33 |
+
|
| 34 |
+
train_dl = DataLoader(train_ds, batch_size=cfg.batch_size, shuffle=True,
|
| 35 |
+
num_workers=cfg.num_workers, pin_memory=True,
|
| 36 |
+
drop_last=True, collate_fn=ctc_collate)
|
| 37 |
+
val_dl = DataLoader(val_ds, batch_size=cfg.batch_size, shuffle=False,
|
| 38 |
+
num_workers=cfg.num_workers, pin_memory=True,
|
| 39 |
+
drop_last=True, collate_fn=ctc_collate)
|
| 40 |
+
|
| 41 |
+
model = CRNN(vocab_size=vocab_size()).to(device)
|
| 42 |
+
|
| 43 |
+
# Initialize final layer with small weights for stability
|
| 44 |
+
with torch.no_grad():
|
| 45 |
+
torch.nn.init.uniform_(model.fc.weight, -1e-3, 1e-3)
|
| 46 |
+
torch.nn.init.zeros_(model.fc.bias)
|
| 47 |
+
|
| 48 |
+
criterion = nn.CTCLoss(blank=0, zero_infinity=True)
|
| 49 |
+
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-4)
|
| 50 |
+
scaler = torch.amp.GradScaler('cuda', enabled=False) # Disable AMP for stability
|
| 51 |
+
|
| 52 |
+
# Epoch-based training with scheduler
|
| 53 |
+
epochs = 20 # Increased for OneCycleLR
|
| 54 |
+
scheduler = torch.optim.lr_scheduler.OneCycleLR(
|
| 55 |
+
optimizer, max_lr=3e-4, steps_per_epoch=len(train_dl), epochs=epochs
|
| 56 |
+
)
|
| 57 |
+
print(f"\nStarting training for {epochs} epochs...")
|
| 58 |
+
|
| 59 |
+
metrics = TrainingMetrics()
|
| 60 |
+
|
| 61 |
+
for epoch in range(epochs):
|
| 62 |
+
# Training phase
|
| 63 |
+
model.train()
|
| 64 |
+
total_train_loss = 0
|
| 65 |
+
num_batches = 0
|
| 66 |
+
|
| 67 |
+
print(f"\nEpoch {epoch+1}/{epochs}")
|
| 68 |
+
print("Training...")
|
| 69 |
+
|
| 70 |
+
for batch_idx, batch in enumerate(train_dl):
|
| 71 |
+
images, targets_flat, target_lengths, input_lengths, paths = batch
|
| 72 |
+
|
| 73 |
+
# CTC sanity checks (first batch of each epoch)
|
| 74 |
+
if batch_idx == 0:
|
| 75 |
+
assert targets_flat.numel() == target_lengths.sum().item(), "Target lengths mismatch"
|
| 76 |
+
assert torch.all(target_lengths <= input_lengths), "Target longer than input"
|
| 77 |
+
print(f" Batch 0 sanity: input_lens={input_lengths[:5].tolist()}, target_lens={target_lengths[:5].tolist()}")
|
| 78 |
+
print(f" Image stats: min={images.min():.3f}, max={images.max():.3f}, mean={images.mean():.3f}")
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
images = images.to(device)
|
| 82 |
+
targets_flat = targets_flat.to(device)
|
| 83 |
+
target_lengths = target_lengths.to(device)
|
| 84 |
+
input_lengths = input_lengths.to(device)
|
| 85 |
+
|
| 86 |
+
optimizer.zero_grad(set_to_none=True)
|
| 87 |
+
|
| 88 |
+
with torch.amp.autocast('cuda', enabled=False):
|
| 89 |
+
logits = model(images)
|
| 90 |
+
log_probs = logits.log_softmax(dim=-1)
|
| 91 |
+
loss = criterion(log_probs, targets_flat, input_lengths, target_lengths)
|
| 92 |
+
|
| 93 |
+
loss.backward()
|
| 94 |
+
|
| 95 |
+
# Gradient clipping to prevent exploding gradients
|
| 96 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
| 97 |
+
|
| 98 |
+
optimizer.step()
|
| 99 |
+
scheduler.step() # OneCycleLR step per batch
|
| 100 |
+
|
| 101 |
+
total_train_loss += loss.item()
|
| 102 |
+
num_batches += 1
|
| 103 |
+
|
| 104 |
+
# Progress update every 50 batches
|
| 105 |
+
if batch_idx % 50 == 0:
|
| 106 |
+
print(f" Batch {batch_idx}/{len(train_dl)} - Loss: {loss.item():.4f}")
|
| 107 |
+
|
| 108 |
+
avg_train_loss = total_train_loss / num_batches
|
| 109 |
+
|
| 110 |
+
# Validation phase
|
| 111 |
+
model.eval()
|
| 112 |
+
total_val_loss = 0
|
| 113 |
+
num_val_batches = 0
|
| 114 |
+
|
| 115 |
+
print("Validating...")
|
| 116 |
+
with torch.no_grad():
|
| 117 |
+
for batch in val_dl:
|
| 118 |
+
images, targets_flat, target_lengths, input_lengths, paths = batch
|
| 119 |
+
images = images.to(device)
|
| 120 |
+
targets_flat = targets_flat.to(device)
|
| 121 |
+
target_lengths = target_lengths.to(device)
|
| 122 |
+
input_lengths = input_lengths.to(device)
|
| 123 |
+
|
| 124 |
+
logits = model(images)
|
| 125 |
+
log_probs = logits.log_softmax(dim=-1)
|
| 126 |
+
loss = criterion(log_probs, targets_flat, input_lengths, target_lengths)
|
| 127 |
+
|
| 128 |
+
total_val_loss += loss.item()
|
| 129 |
+
num_val_batches += 1
|
| 130 |
+
|
| 131 |
+
avg_val_loss = total_val_loss / num_val_batches
|
| 132 |
+
|
| 133 |
+
print(f"Epoch {epoch+1}/{epochs} Summary:")
|
| 134 |
+
print(f" Train Loss: {avg_train_loss:.4f}")
|
| 135 |
+
print(f" Val Loss: {avg_val_loss:.4f}")
|
| 136 |
+
metrics.add_epoch(epoch+1, avg_train_loss, avg_val_loss)
|
| 137 |
+
|
| 138 |
+
# Test some predictions
|
| 139 |
+
if epoch % 2 == 0: # Every 2 epochs
|
| 140 |
+
print("Sample predictions:")
|
| 141 |
+
with torch.no_grad():
|
| 142 |
+
test_batch = next(iter(val_dl))
|
| 143 |
+
test_images = test_batch[0][:5].to(device) # First 5 images
|
| 144 |
+
print(f" Input image shape: {test_images.shape}")
|
| 145 |
+
print(f" Input image min/max: {test_images.min():.4f}/{test_images.max():.4f}")
|
| 146 |
+
test_logits = model(test_images)
|
| 147 |
+
|
| 148 |
+
# Debug: Check logits shape and values
|
| 149 |
+
print(f" Logits shape: {test_logits.shape}")
|
| 150 |
+
print(f" Expected logits shape: [W//stride, B, V] = [{cfg.W_max}//{cfg.total_stride}, 5, 63] = [{cfg.W_max//cfg.total_stride}, 5, 63]")
|
| 151 |
+
print(f" Logits min/max: {test_logits.min():.4f}/{test_logits.max():.4f}")
|
| 152 |
+
|
| 153 |
+
# Check raw predictions and blank probability (from softmax)
|
| 154 |
+
raw_preds = test_logits.argmax(dim=-1)
|
| 155 |
+
probs = test_logits.log_softmax(-1).exp()
|
| 156 |
+
avg_blank_prob = probs[..., 0].mean().item()
|
| 157 |
+
print(f" Raw predictions shape: {raw_preds.shape}")
|
| 158 |
+
print(f" Raw predictions sample: {raw_preds[:10, 0].tolist()}")
|
| 159 |
+
print(f" Avg blank prob (softmax): {avg_blank_prob:.4f}")
|
| 160 |
+
print(f" Blank probability (argmax): {(raw_preds == 0).float().mean():.4f}")
|
| 161 |
+
|
| 162 |
+
test_preds = ctc_greedy_decode(test_logits)
|
| 163 |
+
|
| 164 |
+
# Decode the target integers back to text strings with proper offsets
|
| 165 |
+
targets_flat, target_lengths = test_batch[1], test_batch[2]
|
| 166 |
+
offsets = torch.zeros(len(target_lengths), dtype=torch.long)
|
| 167 |
+
offsets[1:] = torch.cumsum(target_lengths[:-1], dim=0)
|
| 168 |
+
test_targets = []
|
| 169 |
+
for i in range(min(5, len(target_lengths))):
|
| 170 |
+
s = offsets[i].item()
|
| 171 |
+
e = s + target_lengths[i].item()
|
| 172 |
+
indices = targets_flat[s:e].tolist()
|
| 173 |
+
test_targets.append(decode_indices(indices))
|
| 174 |
+
|
| 175 |
+
# Calculate CER for this batch
|
| 176 |
+
batch_cer = sum(cer(p, t) for p, t in zip(test_preds, test_targets)) / len(test_targets)
|
| 177 |
+
print(f" Val CER (approx): {batch_cer:.3f}")
|
| 178 |
+
|
| 179 |
+
for i, (pred, target) in enumerate(zip(test_preds, test_targets)):
|
| 180 |
+
print(f" {i}: Predicted='{pred}', Target='{target}'")
|
| 181 |
+
|
| 182 |
+
metrics.add_predictions(test_preds, test_targets)
|
| 183 |
+
|
| 184 |
+
print("\nTraining complete!")
|
| 185 |
+
print("\nGenerating training metrics and plots...")
|
| 186 |
+
os.makedirs("Metrics", exist_ok=True)
|
| 187 |
+
metrics.plot_losses()
|
| 188 |
+
metrics.plot_loss_comparison()
|
| 189 |
+
metrics.save_metrics()
|
| 190 |
+
|
| 191 |
+
# Final validation test
|
| 192 |
+
model.eval()
|
| 193 |
+
with torch.no_grad():
|
| 194 |
+
images, targets_flat, target_lengths, input_lengths, paths = next(iter(val_dl))
|
| 195 |
+
images = images.to(device)
|
| 196 |
+
logits = model(images)
|
| 197 |
+
preds = ctc_greedy_decode(logits)
|
| 198 |
+
|
| 199 |
+
print("\nFinal validation predictions:")
|
| 200 |
+
for i, pred in enumerate(preds[:10]):
|
| 201 |
+
print(f" {i}: {pred}")
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
if __name__ == "__main__":
|
| 205 |
+
os.makedirs("checkpoints", exist_ok=True)
|
| 206 |
+
main()
|
train_sanity.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from torch.utils.data import DataLoader
|
| 5 |
+
from src.config import cfg
|
| 6 |
+
from src.collate import ctc_collate
|
| 7 |
+
from src.captcha_dataset import CaptchaDataset
|
| 8 |
+
from src.vocab import vocab_size, ctc_greedy_decode
|
| 9 |
+
from src.model_crnn import CRNN
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def main():
|
| 13 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 14 |
+
in_ch = 1 if cfg.grayscale else 3
|
| 15 |
+
|
| 16 |
+
print("Creating datasets...")
|
| 17 |
+
train_ds = CaptchaDataset("train")
|
| 18 |
+
val_ds = CaptchaDataset("val")
|
| 19 |
+
|
| 20 |
+
print(f"Training dataset size: {len(train_ds)}")
|
| 21 |
+
print(f"Validation dataset size: {len(val_ds)}")
|
| 22 |
+
|
| 23 |
+
train_dl = DataLoader(train_ds, batch_size=cfg.batch_size, shuffle=True,
|
| 24 |
+
num_workers=cfg.num_workers, pin_memory=True,
|
| 25 |
+
drop_last=True, collate_fn=ctc_collate)
|
| 26 |
+
val_dl = DataLoader(val_ds, batch_size=cfg.batch_size, shuffle=False,
|
| 27 |
+
num_workers=cfg.num_workers, pin_memory=True,
|
| 28 |
+
drop_last=True, collate_fn=ctc_collate)
|
| 29 |
+
|
| 30 |
+
# # Test training data
|
| 31 |
+
# print("\nTesting training data...")
|
| 32 |
+
# for batch in train_dl:
|
| 33 |
+
# images, targets_flat, target_lengths, input_lengths, paths = batch
|
| 34 |
+
# print(f"Training batch shape: {images.shape}")
|
| 35 |
+
# print(f"Sample labels: {targets_flat[:10]}")
|
| 36 |
+
# break
|
| 37 |
+
|
| 38 |
+
# # Test validation data
|
| 39 |
+
# print("\nTesting validation data...")
|
| 40 |
+
# try:
|
| 41 |
+
# for batch in val_dl:
|
| 42 |
+
# images, targets_flat, target_lengths, input_lengths, paths = batch
|
| 43 |
+
# print(f"Validation batch shape: {images.shape}")
|
| 44 |
+
# print(f"Sample labels: {targets_flat[:10]}")
|
| 45 |
+
# break
|
| 46 |
+
# except Exception as e:
|
| 47 |
+
# print(f"Error in validation data: {e}")
|
| 48 |
+
# print("This suggests there are issues with some validation images")
|
| 49 |
+
|
| 50 |
+
model = CRNN(vocab_size=vocab_size()).to(device)
|
| 51 |
+
criterion = nn.CTCLoss(blank=0, zero_infinity=True)
|
| 52 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
|
| 53 |
+
scaler = torch.amp.GradScaler('cuda', enabled=cfg.amp and device.type == "cuda")
|
| 54 |
+
|
| 55 |
+
model.train()
|
| 56 |
+
steps = 200
|
| 57 |
+
it = iter(train_dl)
|
| 58 |
+
for step in range(1,steps+1):
|
| 59 |
+
try:
|
| 60 |
+
images, targets_flat, target_lengths, input_lengths, paths = next(it)
|
| 61 |
+
except StopIteration:
|
| 62 |
+
it = iter(train_dl)
|
| 63 |
+
images, targets_flat, target_lengths, input_lengths, paths = next(it)
|
| 64 |
+
|
| 65 |
+
images = images.to(device)
|
| 66 |
+
targets_flat = targets_flat.to(device)
|
| 67 |
+
target_lengths = target_lengths.to(device)
|
| 68 |
+
input_lengths = input_lengths.to(device)
|
| 69 |
+
|
| 70 |
+
optimizer.zero_grad(set_to_none=True)
|
| 71 |
+
|
| 72 |
+
with torch.amp.autocast('cuda', enabled=scaler.is_enabled()):
|
| 73 |
+
logits = model(images)
|
| 74 |
+
log_probs = logits.log_softmax(dim=-1)
|
| 75 |
+
loss = criterion(log_probs,targets_flat,input_lengths,target_lengths)
|
| 76 |
+
scaler.scale(loss).backward()
|
| 77 |
+
scaler.step(optimizer)
|
| 78 |
+
scaler.update()
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
if step % 20 == 0:
|
| 82 |
+
print(f"step {step}/{steps} - loss {loss.item():.4f}")
|
| 83 |
+
|
| 84 |
+
model.eval()
|
| 85 |
+
with torch.no_grad():
|
| 86 |
+
images, targets_flat, target_lengths, input_lengths, paths = next(iter(val_dl))
|
| 87 |
+
images = images.to(device)
|
| 88 |
+
logits = model(images)
|
| 89 |
+
preds = ctc_greedy_decode(logits)
|
| 90 |
+
|
| 91 |
+
print("Sanity check complete")
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
if __name__ == "__main__":
|
| 95 |
+
os.makedirs("checkpoints", exist_ok=True)
|
| 96 |
+
main()
|