mohakapoor commited on
Commit
04e423f
Β·
1 Parent(s): 6e89f30

Enhance training process with improved early stopping and metrics tracking. Update README with training results and insights. Modify .gitignore to allow Metrics plots. Add plotting functionality for inference results in plotting.py. Update configuration parameters for CAPTCHA length limits.

Browse files
.gitignore CHANGED
@@ -66,7 +66,7 @@ logs/
66
  Thumbs.db
67
  desktop.ini
68
 
69
- # Images/artifacts
70
  *.png
71
  *.jpg
72
  *.jpeg
@@ -75,6 +75,10 @@ desktop.ini
75
  *.tiff
76
  *.webp
77
 
 
 
 
 
78
  # Models and checkpoints
79
  checkpoints/
80
  *.ckpt
 
66
  Thumbs.db
67
  desktop.ini
68
 
69
+ # Images/artifacts (but allow Metrics plots)
70
  *.png
71
  *.jpg
72
  *.jpeg
 
75
  *.tiff
76
  *.webp
77
 
78
+ # Allow Metrics plots
79
+ !Metrics/*.png
80
+ !Metrics/*.jpg
81
+
82
  # Models and checkpoints
83
  checkpoints/
84
  *.ckpt
Metrics/inference_results.png ADDED

Git LFS Details

  • SHA256: 3f354ee931ae653ed9821adbfb33c715ad310aad312064770238e879579ef078
  • Pointer size: 131 Bytes
  • Size of remote file: 209 kB
Metrics/loss_comparison.png ADDED

Git LFS Details

  • SHA256: 4e3a5a131f815aeff76e358f45fd9af95bef77d001ef8ba538451d0b3779e005
  • Pointer size: 131 Bytes
  • Size of remote file: 322 kB
Metrics/training_losses.png ADDED

Git LFS Details

  • SHA256: c74acd1702091eee23712df3b801b9d4c310959a389d2d567b11567c19280db9
  • Pointer size: 131 Bytes
  • Size of remote file: 112 kB
Metrics/training_metrics.txt CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:edef5b371d7b2c75153063f41c43f0e3dff8d58d5fda50e7a0db52d230e04f3c
3
- size 807
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7fb9b2125cd77da83e51022b8de31388541a080c2f63d11e8b85cb6b34efe534
3
+ size 822
README.md CHANGED
@@ -36,8 +36,18 @@ This project implements an end-to-end CAPTCHA OCR system that can recognize text
36
  - **Character Diversity**: Limited to a few characters, needs more training
37
 
38
  ### 🎯 Training Status
39
- - **Current**: Epoch 3, basic character recognition starting
40
- - **Estimated**: 20-40 epochs needed for decent CAPTCHA accuracy
 
 
 
 
 
 
 
 
 
 
41
 
42
  ## πŸ“ Project Structure
43
 
 
36
  - **Character Diversity**: Limited to a few characters, needs more training
37
 
38
  ### 🎯 Training Status
39
+ - **Current**: Epoch 8, excellent convergence achieved
40
+ - **Best Model**: Validation loss 0.1782, early stopping working perfectly
41
+ - **Performance**: 75-100% accuracy on fresh CAPTCHAs (varies by run)
42
+
43
+ ### πŸ“Š Training Results
44
+ ![Training Losses](Metrics/training_losses.png)
45
+ ![Loss Comparison](Metrics/loss_comparison.png)
46
+
47
+ **Key Insights:**
48
+ - **Rapid convergence**: Loss dropped from 21β†’0.1 in first 7 epochs
49
+ - **No overfitting**: Enhanced early stopping prevents overfitting
50
+ - **Stable training**: Val/Train ratio stays healthy throughout training
51
 
52
  ## πŸ“ Project Structure
53
 
inference.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import cv2
3
+ import os
4
+ import random
5
+ import numpy as np
6
+ from src.config import cfg
7
+ from src.model_crnn import CRNN
8
+ from src.vocab import ctc_greedy_decode, vocab_size
9
+ from src.plotting import TrainingMetrics
10
+ from captcha.image import ImageCaptcha
11
+
12
+ def load_model(checkpoint_path="checkpoints/best_model.pth"):
13
+ """Load the trained model from checkpoint."""
14
+ if not os.path.exists(checkpoint_path):
15
+ raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
16
+
17
+ # Detect available device
18
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
+
20
+ # Load checkpoint to the detected device
21
+ checkpoint = torch.load(checkpoint_path, map_location=device)
22
+ print(f"βœ… Loaded model from epoch {checkpoint['epoch']}")
23
+ print(f" Best validation loss: {checkpoint['best_val_loss']:.4f}")
24
+ print(f" Loading to device: {device}")
25
+
26
+ # Create model and load weights
27
+ model = CRNN(vocab_size=vocab_size(), hidden=320, dropout=0.05)
28
+ model.load_state_dict(checkpoint['model_state_dict'])
29
+ model.eval()
30
+
31
+ return model
32
+
33
+ def preprocess_image(image_path, target_size=(cfg.W_max, cfg.H)):
34
+ """Preprocess image for inference (same as training)."""
35
+ # Load image
36
+ if not os.path.exists(image_path):
37
+ raise FileNotFoundError(f"Image not found: {image_path}")
38
+
39
+ # Read and preprocess
40
+ img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE if cfg.grayscale else cv2.IMREAD_COLOR)
41
+ if img is None:
42
+ raise ValueError(f"Failed to load image: {image_path}")
43
+
44
+ # Resize to target dimensions
45
+ img = cv2.resize(img, target_size)
46
+
47
+ # Convert to tensor and normalize
48
+ img_tensor = torch.from_numpy(img).float() / 255.0
49
+
50
+ # Add batch and channel dimensions
51
+ if cfg.grayscale:
52
+ img_tensor = img_tensor.unsqueeze(0).unsqueeze(0) # [1, 1, H, W]
53
+ else:
54
+ img_tensor = img_tensor.permute(2, 0, 1).unsqueeze(0) # [1, 3, H, W]
55
+
56
+ return img_tensor
57
+
58
+ def predict_captcha(model, image_tensor, device):
59
+ """Run inference on a single image."""
60
+ with torch.no_grad():
61
+ # Move to device
62
+ image_tensor = image_tensor.to(device)
63
+
64
+ # Forward pass
65
+ logits = model(image_tensor)
66
+
67
+ # Decode prediction
68
+ prediction = ctc_greedy_decode(logits)
69
+
70
+ return prediction[0] if prediction else ""
71
+
72
+ def generate_test_captcha(text, filename, width=160, height=60):
73
+ """Generate a test CAPTCHA image."""
74
+ image = ImageCaptcha(width=width, height=height)
75
+ filepath = os.path.join(cfg.RESULT_DIR, filename)
76
+ image.write(text, filepath)
77
+ print(f"πŸ“Έ Generated test CAPTCHA: {filename}")
78
+ return filepath
79
+
80
+ def main():
81
+ # Setup
82
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
83
+ print(f"πŸš€ Using device: {device}")
84
+
85
+ os.makedirs(cfg.RESULT_DIR, exist_ok=True)
86
+
87
+ try:
88
+ # Load trained model
89
+ print("πŸ“₯ Loading trained model...")
90
+ model = load_model()
91
+ model = model.to(device)
92
+ print("βœ… Model loaded successfully!")
93
+
94
+ # Generate test CAPTCHAs
95
+ print("\n🎯 Generating test CAPTCHAs...")
96
+ test_cases = []
97
+
98
+ for i in range(4):
99
+ # Generate random text
100
+ text = ''.join(random.choices(cfg.chars, k=random.randint(cfg.CAPTCHA_LEN_LOWER_LIMIT, cfg.CAPTCHA_LEN_UPPER_LIMIT)))
101
+ filename = f"{text}_{i}.png"
102
+
103
+ # Generate image
104
+ image_path = generate_test_captcha(text, filename)
105
+ test_cases.append((text, image_path, "")) # Add empty prediction slot
106
+
107
+ # Run inference
108
+ print("\nπŸ” Running inference...")
109
+ print("-" * 60)
110
+ print(f"{'Target':<15} {'Prediction':<15} {'Correct':<10} {'Image':<20}")
111
+ print("-" * 60)
112
+
113
+ correct_count = 0
114
+ for i, (target_text, image_path, _) in enumerate(test_cases):
115
+ try:
116
+ # Preprocess image
117
+ image_tensor = preprocess_image(image_path)
118
+
119
+ # Run prediction
120
+ prediction = predict_captcha(model, image_tensor, device)
121
+
122
+ # Store prediction in test_cases
123
+ test_cases[i] = (target_text, image_path, prediction)
124
+
125
+ # Check if correct
126
+ is_correct = prediction == target_text
127
+ if is_correct:
128
+ correct_count += 1
129
+
130
+ # Display result
131
+ status = "βœ…" if is_correct else "❌"
132
+ print(f"{target_text:<15} {prediction:<15} {status:<10} {os.path.basename(image_path):<20}")
133
+
134
+ except Exception as e:
135
+ print(f"❌ Error processing {image_path}: {e}")
136
+
137
+ # Summary
138
+ print("-" * 60)
139
+ accuracy = (correct_count / len(test_cases)) * 100
140
+ print(f"πŸ“Š Overall Accuracy: {correct_count}/{len(test_cases)} ({accuracy:.1f}%)")
141
+
142
+ # Calculate individual character accuracy
143
+ total_chars = 0
144
+ correct_chars = 0
145
+ for target_text, _, prediction in test_cases:
146
+ total_chars += len(target_text)
147
+ # Count correct characters (position by position)
148
+ min_len = min(len(target_text), len(prediction))
149
+ for i in range(min_len):
150
+ if target_text[i] == prediction[i]:
151
+ correct_chars += 1
152
+
153
+ char_accuracy = (correct_chars / total_chars) * 100 if total_chars > 0 else 0
154
+ print(f"πŸ”€ Character Accuracy: {correct_chars}/{total_chars} ({char_accuracy:.1f}%)")
155
+
156
+ if accuracy >= 80:
157
+ print("πŸŽ‰ Excellent performance!")
158
+ elif accuracy >= 60:
159
+ print("πŸ‘ Good performance!")
160
+ else:
161
+ print("πŸ€” Room for improvement...")
162
+
163
+ # Create and save results plot
164
+ print("\nπŸ“Š Generating results visualization...")
165
+ try:
166
+ metrics = TrainingMetrics()
167
+ image_paths = [case[1] for case in test_cases]
168
+ predictions = [case[2] for case in test_cases]
169
+ targets = [case[0] for case in test_cases]
170
+
171
+ # Create results directory if it doesn't exist
172
+ os.makedirs("Metrics", exist_ok=True)
173
+
174
+ # Plot results
175
+ metrics.plot_results(image_paths, predictions, targets)
176
+ print("βœ… Results plot generated successfully!")
177
+
178
+ except Exception as e:
179
+ print(f"⚠️ Warning: Could not generate plot: {e}")
180
+
181
+ except Exception as e:
182
+ print(f"❌ Error: {e}")
183
+ print("πŸ’‘ Make sure you have a trained model in checkpoints/best_model.pth")
184
+
185
+ if __name__ == "__main__":
186
+ main()
187
+
188
+
src/config.py CHANGED
@@ -7,7 +7,10 @@ class Config:
7
  data_root: str = os.getenv("DATA_ROOT","Dataset_test\captchas")
8
 
9
  chars: str = string.ascii_letters + string.digits
10
-
 
 
 
11
  # Image dimensions - increased for better character detail
12
  H: int = 60 # Increased from 48 for more vertical detail
13
  W_max: int = 256 # Increased from 224 for more time steps (T=64)
 
7
  data_root: str = os.getenv("DATA_ROOT","Dataset_test\captchas")
8
 
9
  chars: str = string.ascii_letters + string.digits
10
+ CAPTCHA_LEN_LOWER_LIMIT: int = 5
11
+ CAPTCHA_LEN_UPPER_LIMIT: int = 7
12
+
13
+ RESULT_DIR: str = "Results"
14
  # Image dimensions - increased for better character detail
15
  H: int = 60 # Increased from 48 for more vertical detail
16
  W_max: int = 256 # Increased from 224 for more time steps (T=64)
src/plotting.py CHANGED
@@ -104,4 +104,83 @@ class TrainingMetrics:
104
  f.write("Sample Predictions:\n")
105
  f.write("-" * 20 + "\n")
106
  for i, (pred, target) in enumerate(zip(self.sample_predictions[:10], self.sample_targets[:10])):
107
- f.write(f"Sample {i+1}: Predicted='{pred}', Target='{target}'\n")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
  f.write("Sample Predictions:\n")
105
  f.write("-" * 20 + "\n")
106
  for i, (pred, target) in enumerate(zip(self.sample_predictions[:10], self.sample_targets[:10])):
107
+ f.write(f"Sample {i+1}: Predicted='{pred}', Target='{target}'\n")
108
+
109
+ def plot_results(self, image_paths, predictions, targets, save_path="Metrics/inference_results.png"):
110
+ """
111
+ Plot CAPTCHA images with their predictions and targets.
112
+
113
+ Args:
114
+ image_paths: List of paths to CAPTCHA images
115
+ predictions: List of predicted texts
116
+ targets: List of target texts
117
+ save_path: Path to save the plot
118
+ """
119
+ import cv2
120
+
121
+ n_images = len(image_paths)
122
+ if n_images == 0:
123
+ print("No images to plot!")
124
+ return
125
+
126
+ # Force 2x2 grid for 4 images
127
+ rows, cols = 2, 2
128
+ fig, axes = plt.subplots(rows, cols, figsize=(12, 8))
129
+
130
+ # Flatten axes for easier indexing
131
+ axes = axes.flatten()
132
+
133
+ for i, (img_path, pred, target) in enumerate(zip(image_paths, predictions, targets)):
134
+ if i >= len(axes):
135
+ break
136
+
137
+ ax = axes[i]
138
+
139
+ # Load and display image
140
+ try:
141
+ img = cv2.imread(img_path)
142
+ if img is not None:
143
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
144
+ ax.imshow(img)
145
+
146
+ # Determine if prediction is correct
147
+ is_correct = pred == target
148
+ color = 'green' if is_correct else 'red'
149
+ status = 'CORRECT' if is_correct else 'WRONG'
150
+
151
+ # Set title with prediction and target
152
+ title = f"Pred: {pred}\nTarget: {target}\n{status}"
153
+ ax.set_title(title, fontsize=10, color=color, fontweight='bold')
154
+
155
+ else:
156
+ ax.text(0.5, 0.5, f"Failed to load\n{os.path.basename(img_path)}",
157
+ ha='center', va='center', transform=ax.transAxes, fontsize=12)
158
+
159
+ except Exception as e:
160
+ ax.text(0.5, 0.5, f"Error loading image\n{str(e)[:30]}...",
161
+ ha='center', va='center', transform=ax.transAxes, fontsize=10, color='red')
162
+
163
+ # Remove axes
164
+ ax.axis('off')
165
+
166
+ # Hide unused subplots
167
+ for i in range(n_images, len(axes)):
168
+ axes[i].axis('off')
169
+
170
+ # Add overall title
171
+ fig.suptitle('CAPTCHA OCR Inference Results', fontsize=16, fontweight='bold', y=0.98)
172
+
173
+ # Calculate accuracy
174
+ correct = sum(1 for p, t in zip(predictions, targets) if p == t)
175
+ accuracy = (correct / len(targets)) * 100
176
+
177
+ # Add accuracy info
178
+ fig.text(0.5, 0.02, f'Accuracy: {correct}/{len(targets)} ({accuracy:.1f}%)',
179
+ ha='center', fontsize=14, fontweight='bold',
180
+ bbox=dict(boxstyle="round,pad=0.3", facecolor="lightblue", alpha=0.7))
181
+
182
+ plt.tight_layout()
183
+ plt.subplots_adjust(top=0.9, bottom=0.15)
184
+ plt.savefig(save_path, dpi=300, bbox_inches='tight')
185
+ plt.close()
186
+ print(f"Results plot saved to: {save_path}")
train.py CHANGED
@@ -57,6 +57,12 @@ def main():
57
  print(f"\nStarting training for {epochs} epochs...")
58
 
59
  metrics = TrainingMetrics()
 
 
 
 
 
 
60
 
61
  for epoch in range(epochs):
62
  # Training phase
@@ -134,6 +140,50 @@ def main():
134
  print(f" Train Loss: {avg_train_loss:.4f}")
135
  print(f" Val Loss: {avg_val_loss:.4f}")
136
  metrics.add_epoch(epoch+1, avg_train_loss, avg_val_loss)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
  # Test some predictions
139
  if epoch % 2 == 0: # Every 2 epochs
@@ -181,7 +231,30 @@ def main():
181
 
182
  metrics.add_predictions(test_preds, test_targets)
183
 
184
- print("\nTraining complete!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
185
  print("\nGenerating training metrics and plots...")
186
  os.makedirs("Metrics", exist_ok=True)
187
  metrics.plot_losses()
 
57
  print(f"\nStarting training for {epochs} epochs...")
58
 
59
  metrics = TrainingMetrics()
60
+
61
+ # Early stopping setup
62
+ best_val_loss = float('inf')
63
+ patience = 5 # Stop if no improvement for 5 epochs
64
+ patience_counter = 0
65
+ early_stop = False
66
 
67
  for epoch in range(epochs):
68
  # Training phase
 
140
  print(f" Train Loss: {avg_train_loss:.4f}")
141
  print(f" Val Loss: {avg_val_loss:.4f}")
142
  metrics.add_epoch(epoch+1, avg_train_loss, avg_val_loss)
143
+
144
+ # Enhanced early stopping check
145
+ val_train_ratio = avg_val_loss / (avg_train_loss + 1e-8) # Avoid division by zero
146
+
147
+ if avg_val_loss < best_val_loss:
148
+ best_val_loss = avg_val_loss
149
+ patience_counter = 0
150
+ print(f" 🎯 New best validation loss: {best_val_loss:.4f}")
151
+ print(f" πŸ“Š Val/Train ratio: {val_train_ratio:.3f}")
152
+
153
+ # Save best model checkpoint with metadata
154
+ checkpoint = {
155
+ 'epoch': epoch + 1,
156
+ 'model_state_dict': model.state_dict(),
157
+ 'optimizer_state_dict': optimizer.state_dict(),
158
+ 'scheduler_state_dict': scheduler.state_dict(),
159
+ 'best_val_loss': best_val_loss,
160
+ 'train_loss': avg_train_loss,
161
+ 'val_loss': avg_val_loss,
162
+ 'val_train_ratio': val_train_ratio,
163
+ 'config': {
164
+ 'vocab_size': vocab_size(),
165
+ 'hidden_size': 320,
166
+ 'total_stride': cfg.total_stride,
167
+ 'H': cfg.H,
168
+ 'W_max': cfg.W_max
169
+ }
170
+ }
171
+ torch.save(checkpoint, "checkpoints/best_model.pth")
172
+ print(f" πŸ’Ύ Best model saved to checkpoints/best_model.pth")
173
+
174
+ else:
175
+ patience_counter += 1
176
+ print(f" ⚠️ No improvement for {patience_counter} epochs")
177
+ print(f" πŸ“Š Val/Train ratio: {val_train_ratio:.3f}")
178
+
179
+ # Enhanced early stopping: Check both absolute loss and ratio
180
+ if patience_counter >= patience or val_train_ratio > 3.0: # Stop if ratio > 3x
181
+ if val_train_ratio > 3.0:
182
+ print(f" πŸ›‘ Early stopping triggered! Val/Train ratio too high: {val_train_ratio:.3f}")
183
+ else:
184
+ print(f" πŸ›‘ Early stopping triggered! No improvement for {patience} epochs")
185
+ early_stop = True
186
+ break
187
 
188
  # Test some predictions
189
  if epoch % 2 == 0: # Every 2 epochs
 
231
 
232
  metrics.add_predictions(test_preds, test_targets)
233
 
234
+ if early_stop:
235
+ print(f"\nTraining stopped early at epoch {epoch+1} due to no improvement!")
236
+ else:
237
+ print(f"\nTraining completed for all {epochs} epochs!")
238
+
239
+ # Save final model
240
+ final_checkpoint = {
241
+ 'epoch': epoch + 1,
242
+ 'model_state_dict': model.state_dict(),
243
+ 'optimizer_state_dict': optimizer.state_dict(),
244
+ 'scheduler_state_dict': scheduler.state_dict(),
245
+ 'final_val_loss': avg_val_loss,
246
+ 'final_train_loss': avg_train_loss,
247
+ 'config': {
248
+ 'vocab_size': vocab_size(),
249
+ 'hidden_size': 320,
250
+ 'total_stride': cfg.total_stride,
251
+ 'H': cfg.H,
252
+ 'W_max': cfg.W_max
253
+ }
254
+ }
255
+ torch.save(final_checkpoint, "checkpoints/final_model.pth")
256
+ print(f"πŸ’Ύ Final model saved to checkpoints/final_model.pth")
257
+
258
  print("\nGenerating training metrics and plots...")
259
  os.makedirs("Metrics", exist_ok=True)
260
  metrics.plot_losses()