CaptchaOCR / train.py
mohakapoor's picture
checkpoint
6e89f30
raw
history blame
8.83 kB
import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from src.config import cfg
from src.collate import ctc_collate
from src.captcha_dataset import CaptchaDataset
from src.vocab import vocab_size, ctc_greedy_decode, decode_indices, itos
from src.plotting import TrainingMetrics
from src.model_crnn import CRNN
import difflib
def cer(pred: str, tgt: str) -> float:
"""Approximate Character Error Rate using difflib."""
sm = difflib.SequenceMatcher(a=pred, b=tgt)
return 1 - sm.ratio()
def main():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
in_ch = 1 if cfg.grayscale else 3
print("Creating datasets...")
train_ds = CaptchaDataset("train")
val_ds = CaptchaDataset("val")
# Debug: Check vocabulary
print(f"Vocabulary size: {vocab_size()}")
print(f"First 10 characters: {list(cfg.chars)[:10]}")
print(f"First 10 itos: {itos[:10]}")
print(f"Training dataset size: {len(train_ds)}")
print(f"Validation dataset size: {len(val_ds)}")
train_dl = DataLoader(train_ds, batch_size=cfg.batch_size, shuffle=True,
num_workers=cfg.num_workers, pin_memory=True,
drop_last=True, collate_fn=ctc_collate)
val_dl = DataLoader(val_ds, batch_size=cfg.batch_size, shuffle=False,
num_workers=cfg.num_workers, pin_memory=True,
drop_last=True, collate_fn=ctc_collate)
model = CRNN(vocab_size=vocab_size()).to(device)
# Initialize final layer with small weights for stability
with torch.no_grad():
torch.nn.init.uniform_(model.fc.weight, -1e-3, 1e-3)
torch.nn.init.zeros_(model.fc.bias)
criterion = nn.CTCLoss(blank=0, zero_infinity=True)
optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-4)
scaler = torch.amp.GradScaler('cuda', enabled=False) # Disable AMP for stability
# Epoch-based training with scheduler
epochs = 20 # Increased for OneCycleLR
scheduler = torch.optim.lr_scheduler.OneCycleLR(
optimizer, max_lr=3e-4, steps_per_epoch=len(train_dl), epochs=epochs
)
print(f"\nStarting training for {epochs} epochs...")
metrics = TrainingMetrics()
for epoch in range(epochs):
# Training phase
model.train()
total_train_loss = 0
num_batches = 0
print(f"\nEpoch {epoch+1}/{epochs}")
print("Training...")
for batch_idx, batch in enumerate(train_dl):
images, targets_flat, target_lengths, input_lengths, paths = batch
# CTC sanity checks (first batch of each epoch)
if batch_idx == 0:
assert targets_flat.numel() == target_lengths.sum().item(), "Target lengths mismatch"
assert torch.all(target_lengths <= input_lengths), "Target longer than input"
print(f" Batch 0 sanity: input_lens={input_lengths[:5].tolist()}, target_lens={target_lengths[:5].tolist()}")
print(f" Image stats: min={images.min():.3f}, max={images.max():.3f}, mean={images.mean():.3f}")
images = images.to(device)
targets_flat = targets_flat.to(device)
target_lengths = target_lengths.to(device)
input_lengths = input_lengths.to(device)
optimizer.zero_grad(set_to_none=True)
with torch.amp.autocast('cuda', enabled=False):
logits = model(images)
log_probs = logits.log_softmax(dim=-1)
loss = criterion(log_probs, targets_flat, input_lengths, target_lengths)
loss.backward()
# Gradient clipping to prevent exploding gradients
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optimizer.step()
scheduler.step() # OneCycleLR step per batch
total_train_loss += loss.item()
num_batches += 1
# Progress update every 50 batches
if batch_idx % 50 == 0:
print(f" Batch {batch_idx}/{len(train_dl)} - Loss: {loss.item():.4f}")
avg_train_loss = total_train_loss / num_batches
# Validation phase
model.eval()
total_val_loss = 0
num_val_batches = 0
print("Validating...")
with torch.no_grad():
for batch in val_dl:
images, targets_flat, target_lengths, input_lengths, paths = batch
images = images.to(device)
targets_flat = targets_flat.to(device)
target_lengths = target_lengths.to(device)
input_lengths = input_lengths.to(device)
logits = model(images)
log_probs = logits.log_softmax(dim=-1)
loss = criterion(log_probs, targets_flat, input_lengths, target_lengths)
total_val_loss += loss.item()
num_val_batches += 1
avg_val_loss = total_val_loss / num_val_batches
print(f"Epoch {epoch+1}/{epochs} Summary:")
print(f" Train Loss: {avg_train_loss:.4f}")
print(f" Val Loss: {avg_val_loss:.4f}")
metrics.add_epoch(epoch+1, avg_train_loss, avg_val_loss)
# Test some predictions
if epoch % 2 == 0: # Every 2 epochs
print("Sample predictions:")
with torch.no_grad():
test_batch = next(iter(val_dl))
test_images = test_batch[0][:5].to(device) # First 5 images
print(f" Input image shape: {test_images.shape}")
print(f" Input image min/max: {test_images.min():.4f}/{test_images.max():.4f}")
test_logits = model(test_images)
# Debug: Check logits shape and values
print(f" Logits shape: {test_logits.shape}")
print(f" Expected logits shape: [W//stride, B, V] = [{cfg.W_max}//{cfg.total_stride}, 5, 63] = [{cfg.W_max//cfg.total_stride}, 5, 63]")
print(f" Logits min/max: {test_logits.min():.4f}/{test_logits.max():.4f}")
# Check raw predictions and blank probability (from softmax)
raw_preds = test_logits.argmax(dim=-1)
probs = test_logits.log_softmax(-1).exp()
avg_blank_prob = probs[..., 0].mean().item()
print(f" Raw predictions shape: {raw_preds.shape}")
print(f" Raw predictions sample: {raw_preds[:10, 0].tolist()}")
print(f" Avg blank prob (softmax): {avg_blank_prob:.4f}")
print(f" Blank probability (argmax): {(raw_preds == 0).float().mean():.4f}")
test_preds = ctc_greedy_decode(test_logits)
# Decode the target integers back to text strings with proper offsets
targets_flat, target_lengths = test_batch[1], test_batch[2]
offsets = torch.zeros(len(target_lengths), dtype=torch.long)
offsets[1:] = torch.cumsum(target_lengths[:-1], dim=0)
test_targets = []
for i in range(min(5, len(target_lengths))):
s = offsets[i].item()
e = s + target_lengths[i].item()
indices = targets_flat[s:e].tolist()
test_targets.append(decode_indices(indices))
# Calculate CER for this batch
batch_cer = sum(cer(p, t) for p, t in zip(test_preds, test_targets)) / len(test_targets)
print(f" Val CER (approx): {batch_cer:.3f}")
for i, (pred, target) in enumerate(zip(test_preds, test_targets)):
print(f" {i}: Predicted='{pred}', Target='{target}'")
metrics.add_predictions(test_preds, test_targets)
print("\nTraining complete!")
print("\nGenerating training metrics and plots...")
os.makedirs("Metrics", exist_ok=True)
metrics.plot_losses()
metrics.plot_loss_comparison()
metrics.save_metrics()
# Final validation test
model.eval()
with torch.no_grad():
images, targets_flat, target_lengths, input_lengths, paths = next(iter(val_dl))
images = images.to(device)
logits = model(images)
preds = ctc_greedy_decode(logits)
print("\nFinal validation predictions:")
for i, pred in enumerate(preds[:10]):
print(f" {i}: {pred}")
if __name__ == "__main__":
os.makedirs("checkpoints", exist_ok=True)
main()