File size: 5,401 Bytes
dc7e610 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 | #!/usr/bin/env python3
"""Debug script to analyze the entire training pipeline"""
import os
import json
import numpy as np
from PIL import Image
import pillow_heif
import torch
# Register HEIF
pillow_heif.register_heif_opener()
print("="*80)
print("COMPREHENSIVE PIPELINE ANALYSIS")
print("="*80)
# 1. Check label lengths
print("\n1. LABEL LENGTH ANALYSIS")
print("-"*80)
labels_dir = "data/labels"
label_files = sorted([f for f in os.listdir(labels_dir) if f.endswith('.txt')])
label_lengths = []
for label_file in label_files:
with open(os.path.join(labels_dir, label_file), 'r', encoding='utf-8') as f:
text = ' '.join(line.strip() for line in f.readlines() if line.strip())
label_lengths.append(len(text))
if len(label_lengths) <= 3:
print(f" {label_file}: {len(text)} chars - '{text[:60]}...'")
print(f"\nLabel length statistics:")
print(f" Min: {min(label_lengths)} chars")
print(f" Max: {max(label_lengths)} chars")
print(f" Mean: {np.mean(label_lengths):.1f} chars")
print(f" ⚠️ CTC requires: input_length > label_length")
# 2. Check image dimensions
print("\n2. IMAGE DIMENSIONS ANALYSIS")
print("-"*80)
images_dir = "data/images"
image_files = sorted([f for f in os.listdir(images_dir) if f.endswith('.HEIC')])[:3]
IMG_HEIGHT = 256
for img_file in image_files:
img = Image.open(os.path.join(images_dir, img_file))
orig_w, orig_h = img.size
new_w = int(IMG_HEIGHT * orig_w / orig_h)
print(f" {img_file}:")
print(f" Original: {orig_w}×{orig_h}")
print(f" After resize: {new_w}×{IMG_HEIGHT}")
# Simulate CNN output
# 3 stride-2 layers reduce height: 256 → 128 → 64 → 32
# Adaptive pool to (1, None): 32 → 1
# Width stays same (stride is (2,1) not (2,2))
final_sequence_length = new_w
print(f" Model output sequence: ~{final_sequence_length} time steps")
# 3. Check one validation sample
print("\n3. VALIDATION SET ANALYSIS")
print("-"*80)
print("⚠️ Validation set has only 1 sample!")
print("⚠️ Batch size is 2, but only 1 val sample exists!")
print("⚠️ This likely causes the validation loss = 0.0 bug!")
print("\nPossible issues:")
print(" 1. DataLoader creates incomplete batch")
print(" 2. Label length might be 0 after preprocessing")
print(" 3. Sequence length might be <= label length")
# 4. Load vocabulary and check
print("\n4. VOCABULARY ANALYSIS")
print("-"*80)
with open('outputs/vocabulary.json', 'r') as f:
vocab = json.load(f)
print(f" Vocab size: {vocab['vocab_size']} characters")
print(f" Blank index: {vocab['blank_idx']}")
print(f" Num classes: {vocab['num_classes']}")
print(f" ✓ Correctly configured for CTC")
# 5. Image quality analysis
print("\n5. IMAGE QUALITY ANALYSIS")
print("-"*80)
print("Preprocessing steps:")
print(" 1. RGB → Grayscale (loses color information)")
print(" 2. CLAHE enhancement (clipLimit=2.0)")
print(" 3. Resize to height=256 (maintains aspect ratio)")
print(" 4. Normalize to [0, 1]")
print("\n⚠️ Potential issues:")
print(" - HEIC images are photos (~900KB) with lots of detail")
print(" - Grayscale might lose colored ink vs background")
print(" - CLAHE might over-enhance and create artifacts")
print(" - Height 256 might be too small for detailed handwriting")
# 6. Architecture analysis
print("\n6. MODEL ARCHITECTURE ANALYSIS")
print("-"*80)
print("CNN Architecture:")
print(" Input: [B, 1, 256, W]")
print(" Conv1: [B, 64, 256, W]")
print(" Conv2: [B, 128, 128, W] (stride 2 on height)")
print(" Conv3: [B, 256, 128, W]")
print(" Conv4: [B, 256, 64, W] (stride 2 on height)")
print(" Conv5: [B, 512, 64, W]")
print(" Conv6: [B, 512, 32, W] (stride 2 on height)")
print(" AdaptivePool: [B, 512, 1, W]")
print(" Squeeze: [B, 512, W]")
print(" Permute: [B, W, 512]")
print("\n⚠️ Potential issue:")
print(" - Height reduced from 256 → 32 before pooling")
print(" - For tall handwriting, this might lose vertical information")
print(" - Adaptive pool to height=1 might be too aggressive")
# 7. CTC Requirements Check
print("\n7. CTC REQUIREMENTS")
print("-"*80)
print("For CTC to work:")
print(" ✓ blank_idx must be valid: {blank_idx} < {num_classes}".format(**vocab))
print(f" ✓ Model outputs {vocab['num_classes']} classes")
print(" ⚠️ input_length MUST be > label_length for each sample")
print(" ⚠️ With labels averaging ~100 chars, need >100 time steps")
print(" ⚠️ Sequence length = image_width (after resize)")
print(" ⚠️ If image width < label length → CTC FAILS")
print("\n" + "="*80)
print("RECOMMENDATIONS")
print("="*80)
print("1. FIX VALIDATION BUG:")
print(" - Increase validation set to at least 2 samples")
print(" - OR set batch_size=1 for validation")
print("")
print("2. FIX IMAGE QUALITY:")
print(" - Increase IMG_HEIGHT to 512 or higher")
print(" - Consider keeping color (RGB) instead of grayscale")
print(" - Reduce CLAHE clipLimit or remove it")
print("")
print("3. FIX ARCHITECTURE:")
print(" - Reduce number of stride-2 layers (use max pooling)")
print(" - OR increase initial image height")
print(" - Ensure sequence_length >> max_label_length")
print("")
print("4. GET MORE DATA:")
print(" - 12 samples is way too small")
print(" - Need at least 100+ labeled samples")
print(" - Use data augmentation (rotation, perspective, noise)")
print("="*80)
|