Spaces:
Running
Running
Commit
Β·
04e423f
1
Parent(s):
6e89f30
Enhance training process with improved early stopping and metrics tracking. Update README with training results and insights. Modify .gitignore to allow Metrics plots. Add plotting functionality for inference results in plotting.py. Update configuration parameters for CAPTCHA length limits.
Browse files- .gitignore +5 -1
- Metrics/inference_results.png +3 -0
- Metrics/loss_comparison.png +3 -0
- Metrics/training_losses.png +3 -0
- Metrics/training_metrics.txt +2 -2
- README.md +12 -2
- inference.py +188 -0
- src/config.py +4 -1
- src/plotting.py +80 -1
- train.py +74 -1
.gitignore
CHANGED
|
@@ -66,7 +66,7 @@ logs/
|
|
| 66 |
Thumbs.db
|
| 67 |
desktop.ini
|
| 68 |
|
| 69 |
-
# Images/artifacts
|
| 70 |
*.png
|
| 71 |
*.jpg
|
| 72 |
*.jpeg
|
|
@@ -75,6 +75,10 @@ desktop.ini
|
|
| 75 |
*.tiff
|
| 76 |
*.webp
|
| 77 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
# Models and checkpoints
|
| 79 |
checkpoints/
|
| 80 |
*.ckpt
|
|
|
|
| 66 |
Thumbs.db
|
| 67 |
desktop.ini
|
| 68 |
|
| 69 |
+
# Images/artifacts (but allow Metrics plots)
|
| 70 |
*.png
|
| 71 |
*.jpg
|
| 72 |
*.jpeg
|
|
|
|
| 75 |
*.tiff
|
| 76 |
*.webp
|
| 77 |
|
| 78 |
+
# Allow Metrics plots
|
| 79 |
+
!Metrics/*.png
|
| 80 |
+
!Metrics/*.jpg
|
| 81 |
+
|
| 82 |
# Models and checkpoints
|
| 83 |
checkpoints/
|
| 84 |
*.ckpt
|
Metrics/inference_results.png
ADDED
|
Git LFS Details
|
Metrics/loss_comparison.png
ADDED
|
Git LFS Details
|
Metrics/training_losses.png
ADDED
|
Git LFS Details
|
Metrics/training_metrics.txt
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7fb9b2125cd77da83e51022b8de31388541a080c2f63d11e8b85cb6b34efe534
|
| 3 |
+
size 822
|
README.md
CHANGED
|
@@ -36,8 +36,18 @@ This project implements an end-to-end CAPTCHA OCR system that can recognize text
|
|
| 36 |
- **Character Diversity**: Limited to a few characters, needs more training
|
| 37 |
|
| 38 |
### π― Training Status
|
| 39 |
-
- **Current**: Epoch
|
| 40 |
-
- **
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
## π Project Structure
|
| 43 |
|
|
|
|
| 36 |
- **Character Diversity**: Limited to a few characters, needs more training
|
| 37 |
|
| 38 |
### π― Training Status
|
| 39 |
+
- **Current**: Epoch 8, excellent convergence achieved
|
| 40 |
+
- **Best Model**: Validation loss 0.1782, early stopping working perfectly
|
| 41 |
+
- **Performance**: 75-100% accuracy on fresh CAPTCHAs (varies by run)
|
| 42 |
+
|
| 43 |
+
### π Training Results
|
| 44 |
+

|
| 45 |
+

|
| 46 |
+
|
| 47 |
+
**Key Insights:**
|
| 48 |
+
- **Rapid convergence**: Loss dropped from 21β0.1 in first 7 epochs
|
| 49 |
+
- **No overfitting**: Enhanced early stopping prevents overfitting
|
| 50 |
+
- **Stable training**: Val/Train ratio stays healthy throughout training
|
| 51 |
|
| 52 |
## π Project Structure
|
| 53 |
|
inference.py
ADDED
|
@@ -0,0 +1,188 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import cv2
|
| 3 |
+
import os
|
| 4 |
+
import random
|
| 5 |
+
import numpy as np
|
| 6 |
+
from src.config import cfg
|
| 7 |
+
from src.model_crnn import CRNN
|
| 8 |
+
from src.vocab import ctc_greedy_decode, vocab_size
|
| 9 |
+
from src.plotting import TrainingMetrics
|
| 10 |
+
from captcha.image import ImageCaptcha
|
| 11 |
+
|
| 12 |
+
def load_model(checkpoint_path="checkpoints/best_model.pth"):
|
| 13 |
+
"""Load the trained model from checkpoint."""
|
| 14 |
+
if not os.path.exists(checkpoint_path):
|
| 15 |
+
raise FileNotFoundError(f"Checkpoint not found: {checkpoint_path}")
|
| 16 |
+
|
| 17 |
+
# Detect available device
|
| 18 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 19 |
+
|
| 20 |
+
# Load checkpoint to the detected device
|
| 21 |
+
checkpoint = torch.load(checkpoint_path, map_location=device)
|
| 22 |
+
print(f"β
Loaded model from epoch {checkpoint['epoch']}")
|
| 23 |
+
print(f" Best validation loss: {checkpoint['best_val_loss']:.4f}")
|
| 24 |
+
print(f" Loading to device: {device}")
|
| 25 |
+
|
| 26 |
+
# Create model and load weights
|
| 27 |
+
model = CRNN(vocab_size=vocab_size(), hidden=320, dropout=0.05)
|
| 28 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
| 29 |
+
model.eval()
|
| 30 |
+
|
| 31 |
+
return model
|
| 32 |
+
|
| 33 |
+
def preprocess_image(image_path, target_size=(cfg.W_max, cfg.H)):
|
| 34 |
+
"""Preprocess image for inference (same as training)."""
|
| 35 |
+
# Load image
|
| 36 |
+
if not os.path.exists(image_path):
|
| 37 |
+
raise FileNotFoundError(f"Image not found: {image_path}")
|
| 38 |
+
|
| 39 |
+
# Read and preprocess
|
| 40 |
+
img = cv2.imread(image_path, cv2.IMREAD_GRAYSCALE if cfg.grayscale else cv2.IMREAD_COLOR)
|
| 41 |
+
if img is None:
|
| 42 |
+
raise ValueError(f"Failed to load image: {image_path}")
|
| 43 |
+
|
| 44 |
+
# Resize to target dimensions
|
| 45 |
+
img = cv2.resize(img, target_size)
|
| 46 |
+
|
| 47 |
+
# Convert to tensor and normalize
|
| 48 |
+
img_tensor = torch.from_numpy(img).float() / 255.0
|
| 49 |
+
|
| 50 |
+
# Add batch and channel dimensions
|
| 51 |
+
if cfg.grayscale:
|
| 52 |
+
img_tensor = img_tensor.unsqueeze(0).unsqueeze(0) # [1, 1, H, W]
|
| 53 |
+
else:
|
| 54 |
+
img_tensor = img_tensor.permute(2, 0, 1).unsqueeze(0) # [1, 3, H, W]
|
| 55 |
+
|
| 56 |
+
return img_tensor
|
| 57 |
+
|
| 58 |
+
def predict_captcha(model, image_tensor, device):
|
| 59 |
+
"""Run inference on a single image."""
|
| 60 |
+
with torch.no_grad():
|
| 61 |
+
# Move to device
|
| 62 |
+
image_tensor = image_tensor.to(device)
|
| 63 |
+
|
| 64 |
+
# Forward pass
|
| 65 |
+
logits = model(image_tensor)
|
| 66 |
+
|
| 67 |
+
# Decode prediction
|
| 68 |
+
prediction = ctc_greedy_decode(logits)
|
| 69 |
+
|
| 70 |
+
return prediction[0] if prediction else ""
|
| 71 |
+
|
| 72 |
+
def generate_test_captcha(text, filename, width=160, height=60):
|
| 73 |
+
"""Generate a test CAPTCHA image."""
|
| 74 |
+
image = ImageCaptcha(width=width, height=height)
|
| 75 |
+
filepath = os.path.join(cfg.RESULT_DIR, filename)
|
| 76 |
+
image.write(text, filepath)
|
| 77 |
+
print(f"πΈ Generated test CAPTCHA: {filename}")
|
| 78 |
+
return filepath
|
| 79 |
+
|
| 80 |
+
def main():
|
| 81 |
+
# Setup
|
| 82 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 83 |
+
print(f"π Using device: {device}")
|
| 84 |
+
|
| 85 |
+
os.makedirs(cfg.RESULT_DIR, exist_ok=True)
|
| 86 |
+
|
| 87 |
+
try:
|
| 88 |
+
# Load trained model
|
| 89 |
+
print("π₯ Loading trained model...")
|
| 90 |
+
model = load_model()
|
| 91 |
+
model = model.to(device)
|
| 92 |
+
print("β
Model loaded successfully!")
|
| 93 |
+
|
| 94 |
+
# Generate test CAPTCHAs
|
| 95 |
+
print("\nπ― Generating test CAPTCHAs...")
|
| 96 |
+
test_cases = []
|
| 97 |
+
|
| 98 |
+
for i in range(4):
|
| 99 |
+
# Generate random text
|
| 100 |
+
text = ''.join(random.choices(cfg.chars, k=random.randint(cfg.CAPTCHA_LEN_LOWER_LIMIT, cfg.CAPTCHA_LEN_UPPER_LIMIT)))
|
| 101 |
+
filename = f"{text}_{i}.png"
|
| 102 |
+
|
| 103 |
+
# Generate image
|
| 104 |
+
image_path = generate_test_captcha(text, filename)
|
| 105 |
+
test_cases.append((text, image_path, "")) # Add empty prediction slot
|
| 106 |
+
|
| 107 |
+
# Run inference
|
| 108 |
+
print("\nπ Running inference...")
|
| 109 |
+
print("-" * 60)
|
| 110 |
+
print(f"{'Target':<15} {'Prediction':<15} {'Correct':<10} {'Image':<20}")
|
| 111 |
+
print("-" * 60)
|
| 112 |
+
|
| 113 |
+
correct_count = 0
|
| 114 |
+
for i, (target_text, image_path, _) in enumerate(test_cases):
|
| 115 |
+
try:
|
| 116 |
+
# Preprocess image
|
| 117 |
+
image_tensor = preprocess_image(image_path)
|
| 118 |
+
|
| 119 |
+
# Run prediction
|
| 120 |
+
prediction = predict_captcha(model, image_tensor, device)
|
| 121 |
+
|
| 122 |
+
# Store prediction in test_cases
|
| 123 |
+
test_cases[i] = (target_text, image_path, prediction)
|
| 124 |
+
|
| 125 |
+
# Check if correct
|
| 126 |
+
is_correct = prediction == target_text
|
| 127 |
+
if is_correct:
|
| 128 |
+
correct_count += 1
|
| 129 |
+
|
| 130 |
+
# Display result
|
| 131 |
+
status = "β
" if is_correct else "β"
|
| 132 |
+
print(f"{target_text:<15} {prediction:<15} {status:<10} {os.path.basename(image_path):<20}")
|
| 133 |
+
|
| 134 |
+
except Exception as e:
|
| 135 |
+
print(f"β Error processing {image_path}: {e}")
|
| 136 |
+
|
| 137 |
+
# Summary
|
| 138 |
+
print("-" * 60)
|
| 139 |
+
accuracy = (correct_count / len(test_cases)) * 100
|
| 140 |
+
print(f"π Overall Accuracy: {correct_count}/{len(test_cases)} ({accuracy:.1f}%)")
|
| 141 |
+
|
| 142 |
+
# Calculate individual character accuracy
|
| 143 |
+
total_chars = 0
|
| 144 |
+
correct_chars = 0
|
| 145 |
+
for target_text, _, prediction in test_cases:
|
| 146 |
+
total_chars += len(target_text)
|
| 147 |
+
# Count correct characters (position by position)
|
| 148 |
+
min_len = min(len(target_text), len(prediction))
|
| 149 |
+
for i in range(min_len):
|
| 150 |
+
if target_text[i] == prediction[i]:
|
| 151 |
+
correct_chars += 1
|
| 152 |
+
|
| 153 |
+
char_accuracy = (correct_chars / total_chars) * 100 if total_chars > 0 else 0
|
| 154 |
+
print(f"π€ Character Accuracy: {correct_chars}/{total_chars} ({char_accuracy:.1f}%)")
|
| 155 |
+
|
| 156 |
+
if accuracy >= 80:
|
| 157 |
+
print("π Excellent performance!")
|
| 158 |
+
elif accuracy >= 60:
|
| 159 |
+
print("π Good performance!")
|
| 160 |
+
else:
|
| 161 |
+
print("π€ Room for improvement...")
|
| 162 |
+
|
| 163 |
+
# Create and save results plot
|
| 164 |
+
print("\nπ Generating results visualization...")
|
| 165 |
+
try:
|
| 166 |
+
metrics = TrainingMetrics()
|
| 167 |
+
image_paths = [case[1] for case in test_cases]
|
| 168 |
+
predictions = [case[2] for case in test_cases]
|
| 169 |
+
targets = [case[0] for case in test_cases]
|
| 170 |
+
|
| 171 |
+
# Create results directory if it doesn't exist
|
| 172 |
+
os.makedirs("Metrics", exist_ok=True)
|
| 173 |
+
|
| 174 |
+
# Plot results
|
| 175 |
+
metrics.plot_results(image_paths, predictions, targets)
|
| 176 |
+
print("β
Results plot generated successfully!")
|
| 177 |
+
|
| 178 |
+
except Exception as e:
|
| 179 |
+
print(f"β οΈ Warning: Could not generate plot: {e}")
|
| 180 |
+
|
| 181 |
+
except Exception as e:
|
| 182 |
+
print(f"β Error: {e}")
|
| 183 |
+
print("π‘ Make sure you have a trained model in checkpoints/best_model.pth")
|
| 184 |
+
|
| 185 |
+
if __name__ == "__main__":
|
| 186 |
+
main()
|
| 187 |
+
|
| 188 |
+
|
src/config.py
CHANGED
|
@@ -7,7 +7,10 @@ class Config:
|
|
| 7 |
data_root: str = os.getenv("DATA_ROOT","Dataset_test\captchas")
|
| 8 |
|
| 9 |
chars: str = string.ascii_letters + string.digits
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
| 11 |
# Image dimensions - increased for better character detail
|
| 12 |
H: int = 60 # Increased from 48 for more vertical detail
|
| 13 |
W_max: int = 256 # Increased from 224 for more time steps (T=64)
|
|
|
|
| 7 |
data_root: str = os.getenv("DATA_ROOT","Dataset_test\captchas")
|
| 8 |
|
| 9 |
chars: str = string.ascii_letters + string.digits
|
| 10 |
+
CAPTCHA_LEN_LOWER_LIMIT: int = 5
|
| 11 |
+
CAPTCHA_LEN_UPPER_LIMIT: int = 7
|
| 12 |
+
|
| 13 |
+
RESULT_DIR: str = "Results"
|
| 14 |
# Image dimensions - increased for better character detail
|
| 15 |
H: int = 60 # Increased from 48 for more vertical detail
|
| 16 |
W_max: int = 256 # Increased from 224 for more time steps (T=64)
|
src/plotting.py
CHANGED
|
@@ -104,4 +104,83 @@ class TrainingMetrics:
|
|
| 104 |
f.write("Sample Predictions:\n")
|
| 105 |
f.write("-" * 20 + "\n")
|
| 106 |
for i, (pred, target) in enumerate(zip(self.sample_predictions[:10], self.sample_targets[:10])):
|
| 107 |
-
f.write(f"Sample {i+1}: Predicted='{pred}', Target='{target}'\n")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
f.write("Sample Predictions:\n")
|
| 105 |
f.write("-" * 20 + "\n")
|
| 106 |
for i, (pred, target) in enumerate(zip(self.sample_predictions[:10], self.sample_targets[:10])):
|
| 107 |
+
f.write(f"Sample {i+1}: Predicted='{pred}', Target='{target}'\n")
|
| 108 |
+
|
| 109 |
+
def plot_results(self, image_paths, predictions, targets, save_path="Metrics/inference_results.png"):
|
| 110 |
+
"""
|
| 111 |
+
Plot CAPTCHA images with their predictions and targets.
|
| 112 |
+
|
| 113 |
+
Args:
|
| 114 |
+
image_paths: List of paths to CAPTCHA images
|
| 115 |
+
predictions: List of predicted texts
|
| 116 |
+
targets: List of target texts
|
| 117 |
+
save_path: Path to save the plot
|
| 118 |
+
"""
|
| 119 |
+
import cv2
|
| 120 |
+
|
| 121 |
+
n_images = len(image_paths)
|
| 122 |
+
if n_images == 0:
|
| 123 |
+
print("No images to plot!")
|
| 124 |
+
return
|
| 125 |
+
|
| 126 |
+
# Force 2x2 grid for 4 images
|
| 127 |
+
rows, cols = 2, 2
|
| 128 |
+
fig, axes = plt.subplots(rows, cols, figsize=(12, 8))
|
| 129 |
+
|
| 130 |
+
# Flatten axes for easier indexing
|
| 131 |
+
axes = axes.flatten()
|
| 132 |
+
|
| 133 |
+
for i, (img_path, pred, target) in enumerate(zip(image_paths, predictions, targets)):
|
| 134 |
+
if i >= len(axes):
|
| 135 |
+
break
|
| 136 |
+
|
| 137 |
+
ax = axes[i]
|
| 138 |
+
|
| 139 |
+
# Load and display image
|
| 140 |
+
try:
|
| 141 |
+
img = cv2.imread(img_path)
|
| 142 |
+
if img is not None:
|
| 143 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| 144 |
+
ax.imshow(img)
|
| 145 |
+
|
| 146 |
+
# Determine if prediction is correct
|
| 147 |
+
is_correct = pred == target
|
| 148 |
+
color = 'green' if is_correct else 'red'
|
| 149 |
+
status = 'CORRECT' if is_correct else 'WRONG'
|
| 150 |
+
|
| 151 |
+
# Set title with prediction and target
|
| 152 |
+
title = f"Pred: {pred}\nTarget: {target}\n{status}"
|
| 153 |
+
ax.set_title(title, fontsize=10, color=color, fontweight='bold')
|
| 154 |
+
|
| 155 |
+
else:
|
| 156 |
+
ax.text(0.5, 0.5, f"Failed to load\n{os.path.basename(img_path)}",
|
| 157 |
+
ha='center', va='center', transform=ax.transAxes, fontsize=12)
|
| 158 |
+
|
| 159 |
+
except Exception as e:
|
| 160 |
+
ax.text(0.5, 0.5, f"Error loading image\n{str(e)[:30]}...",
|
| 161 |
+
ha='center', va='center', transform=ax.transAxes, fontsize=10, color='red')
|
| 162 |
+
|
| 163 |
+
# Remove axes
|
| 164 |
+
ax.axis('off')
|
| 165 |
+
|
| 166 |
+
# Hide unused subplots
|
| 167 |
+
for i in range(n_images, len(axes)):
|
| 168 |
+
axes[i].axis('off')
|
| 169 |
+
|
| 170 |
+
# Add overall title
|
| 171 |
+
fig.suptitle('CAPTCHA OCR Inference Results', fontsize=16, fontweight='bold', y=0.98)
|
| 172 |
+
|
| 173 |
+
# Calculate accuracy
|
| 174 |
+
correct = sum(1 for p, t in zip(predictions, targets) if p == t)
|
| 175 |
+
accuracy = (correct / len(targets)) * 100
|
| 176 |
+
|
| 177 |
+
# Add accuracy info
|
| 178 |
+
fig.text(0.5, 0.02, f'Accuracy: {correct}/{len(targets)} ({accuracy:.1f}%)',
|
| 179 |
+
ha='center', fontsize=14, fontweight='bold',
|
| 180 |
+
bbox=dict(boxstyle="round,pad=0.3", facecolor="lightblue", alpha=0.7))
|
| 181 |
+
|
| 182 |
+
plt.tight_layout()
|
| 183 |
+
plt.subplots_adjust(top=0.9, bottom=0.15)
|
| 184 |
+
plt.savefig(save_path, dpi=300, bbox_inches='tight')
|
| 185 |
+
plt.close()
|
| 186 |
+
print(f"Results plot saved to: {save_path}")
|
train.py
CHANGED
|
@@ -57,6 +57,12 @@ def main():
|
|
| 57 |
print(f"\nStarting training for {epochs} epochs...")
|
| 58 |
|
| 59 |
metrics = TrainingMetrics()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
|
| 61 |
for epoch in range(epochs):
|
| 62 |
# Training phase
|
|
@@ -134,6 +140,50 @@ def main():
|
|
| 134 |
print(f" Train Loss: {avg_train_loss:.4f}")
|
| 135 |
print(f" Val Loss: {avg_val_loss:.4f}")
|
| 136 |
metrics.add_epoch(epoch+1, avg_train_loss, avg_val_loss)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 137 |
|
| 138 |
# Test some predictions
|
| 139 |
if epoch % 2 == 0: # Every 2 epochs
|
|
@@ -181,7 +231,30 @@ def main():
|
|
| 181 |
|
| 182 |
metrics.add_predictions(test_preds, test_targets)
|
| 183 |
|
| 184 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 185 |
print("\nGenerating training metrics and plots...")
|
| 186 |
os.makedirs("Metrics", exist_ok=True)
|
| 187 |
metrics.plot_losses()
|
|
|
|
| 57 |
print(f"\nStarting training for {epochs} epochs...")
|
| 58 |
|
| 59 |
metrics = TrainingMetrics()
|
| 60 |
+
|
| 61 |
+
# Early stopping setup
|
| 62 |
+
best_val_loss = float('inf')
|
| 63 |
+
patience = 5 # Stop if no improvement for 5 epochs
|
| 64 |
+
patience_counter = 0
|
| 65 |
+
early_stop = False
|
| 66 |
|
| 67 |
for epoch in range(epochs):
|
| 68 |
# Training phase
|
|
|
|
| 140 |
print(f" Train Loss: {avg_train_loss:.4f}")
|
| 141 |
print(f" Val Loss: {avg_val_loss:.4f}")
|
| 142 |
metrics.add_epoch(epoch+1, avg_train_loss, avg_val_loss)
|
| 143 |
+
|
| 144 |
+
# Enhanced early stopping check
|
| 145 |
+
val_train_ratio = avg_val_loss / (avg_train_loss + 1e-8) # Avoid division by zero
|
| 146 |
+
|
| 147 |
+
if avg_val_loss < best_val_loss:
|
| 148 |
+
best_val_loss = avg_val_loss
|
| 149 |
+
patience_counter = 0
|
| 150 |
+
print(f" π― New best validation loss: {best_val_loss:.4f}")
|
| 151 |
+
print(f" π Val/Train ratio: {val_train_ratio:.3f}")
|
| 152 |
+
|
| 153 |
+
# Save best model checkpoint with metadata
|
| 154 |
+
checkpoint = {
|
| 155 |
+
'epoch': epoch + 1,
|
| 156 |
+
'model_state_dict': model.state_dict(),
|
| 157 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
| 158 |
+
'scheduler_state_dict': scheduler.state_dict(),
|
| 159 |
+
'best_val_loss': best_val_loss,
|
| 160 |
+
'train_loss': avg_train_loss,
|
| 161 |
+
'val_loss': avg_val_loss,
|
| 162 |
+
'val_train_ratio': val_train_ratio,
|
| 163 |
+
'config': {
|
| 164 |
+
'vocab_size': vocab_size(),
|
| 165 |
+
'hidden_size': 320,
|
| 166 |
+
'total_stride': cfg.total_stride,
|
| 167 |
+
'H': cfg.H,
|
| 168 |
+
'W_max': cfg.W_max
|
| 169 |
+
}
|
| 170 |
+
}
|
| 171 |
+
torch.save(checkpoint, "checkpoints/best_model.pth")
|
| 172 |
+
print(f" πΎ Best model saved to checkpoints/best_model.pth")
|
| 173 |
+
|
| 174 |
+
else:
|
| 175 |
+
patience_counter += 1
|
| 176 |
+
print(f" β οΈ No improvement for {patience_counter} epochs")
|
| 177 |
+
print(f" π Val/Train ratio: {val_train_ratio:.3f}")
|
| 178 |
+
|
| 179 |
+
# Enhanced early stopping: Check both absolute loss and ratio
|
| 180 |
+
if patience_counter >= patience or val_train_ratio > 3.0: # Stop if ratio > 3x
|
| 181 |
+
if val_train_ratio > 3.0:
|
| 182 |
+
print(f" π Early stopping triggered! Val/Train ratio too high: {val_train_ratio:.3f}")
|
| 183 |
+
else:
|
| 184 |
+
print(f" π Early stopping triggered! No improvement for {patience} epochs")
|
| 185 |
+
early_stop = True
|
| 186 |
+
break
|
| 187 |
|
| 188 |
# Test some predictions
|
| 189 |
if epoch % 2 == 0: # Every 2 epochs
|
|
|
|
| 231 |
|
| 232 |
metrics.add_predictions(test_preds, test_targets)
|
| 233 |
|
| 234 |
+
if early_stop:
|
| 235 |
+
print(f"\nTraining stopped early at epoch {epoch+1} due to no improvement!")
|
| 236 |
+
else:
|
| 237 |
+
print(f"\nTraining completed for all {epochs} epochs!")
|
| 238 |
+
|
| 239 |
+
# Save final model
|
| 240 |
+
final_checkpoint = {
|
| 241 |
+
'epoch': epoch + 1,
|
| 242 |
+
'model_state_dict': model.state_dict(),
|
| 243 |
+
'optimizer_state_dict': optimizer.state_dict(),
|
| 244 |
+
'scheduler_state_dict': scheduler.state_dict(),
|
| 245 |
+
'final_val_loss': avg_val_loss,
|
| 246 |
+
'final_train_loss': avg_train_loss,
|
| 247 |
+
'config': {
|
| 248 |
+
'vocab_size': vocab_size(),
|
| 249 |
+
'hidden_size': 320,
|
| 250 |
+
'total_stride': cfg.total_stride,
|
| 251 |
+
'H': cfg.H,
|
| 252 |
+
'W_max': cfg.W_max
|
| 253 |
+
}
|
| 254 |
+
}
|
| 255 |
+
torch.save(final_checkpoint, "checkpoints/final_model.pth")
|
| 256 |
+
print(f"πΎ Final model saved to checkpoints/final_model.pth")
|
| 257 |
+
|
| 258 |
print("\nGenerating training metrics and plots...")
|
| 259 |
os.makedirs("Metrics", exist_ok=True)
|
| 260 |
metrics.plot_losses()
|