Spaces:
Running
Running
Enhance training process with improved early stopping and metrics tracking. Update README with training results and insights. Modify .gitignore to allow Metrics plots. Add plotting functionality for inference results in plotting.py. Update configuration parameters for CAPTCHA length limits.
04e423f
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| from datetime import datetime | |
| import os | |
| class TrainingMetrics: | |
| def __init__(self): | |
| self.train_losses = [] | |
| self.val_losses = [] | |
| self.epochs = [] | |
| self.sample_predictions = [] | |
| self.sample_targets = [] | |
| def add_epoch(self, epoch, train_loss, val_loss): | |
| self.epochs.append(epoch) | |
| self.train_losses.append(train_loss) | |
| self.val_losses.append(val_loss) | |
| def add_predictions(self, predictions, targets): | |
| self.sample_predictions.extend(predictions) | |
| self.sample_targets.extend(targets) | |
| def plot_losses(self, save_path="Metrics/training_losses.png"): | |
| plt.figure(figsize=(10, 6)) | |
| plt.plot(self.epochs, self.train_losses, 'b-', label='Training Loss', linewidth=2) | |
| plt.plot(self.epochs, self.val_losses, 'r-', label='Validation Loss', linewidth=2) | |
| plt.xlabel('Epoch') | |
| plt.ylabel('Loss') | |
| plt.title('Training and Validation Loss Over Time') | |
| plt.legend() | |
| plt.grid(True, alpha=0.3) | |
| plt.tight_layout() | |
| plt.savefig(save_path, dpi=300, bbox_inches='tight') | |
| plt.close() | |
| print(f"Loss plot saved to: {save_path}") | |
| def plot_loss_comparison(self, save_path="Metrics/loss_comparison.png"): | |
| plt.figure(figsize=(12, 8)) | |
| # Main loss plot | |
| plt.subplot(2, 2, 1) | |
| plt.plot(self.epochs, self.train_losses, 'b-', label='Training Loss') | |
| plt.plot(self.epochs, self.val_losses, 'r-', label='Validation Loss') | |
| plt.xlabel('Epoch') | |
| plt.ylabel('Loss') | |
| plt.title('Training vs Validation Loss') | |
| plt.legend() | |
| plt.grid(True, alpha=0.3) | |
| # Loss difference plot | |
| plt.subplot(2, 2, 2) | |
| loss_diff = [t - v for t, v in zip(self.train_losses, self.val_losses)] | |
| plt.plot(self.epochs, loss_diff, 'g-', label='Train - Val Loss') | |
| plt.xlabel('Epoch') | |
| plt.ylabel('Loss Difference') | |
| plt.title('Overfitting Indicator') | |
| plt.legend() | |
| plt.grid(True, alpha=0.3) | |
| # Loss ratio plot | |
| plt.subplot(2, 2, 3) | |
| loss_ratio = [v/t if t > 0 else 0 for t, v in zip(self.train_losses, self.val_losses)] | |
| plt.plot(self.epochs, loss_ratio, 'm-', label='Val/Train Loss Ratio') | |
| plt.xlabel('Epoch') | |
| plt.ylabel('Ratio') | |
| plt.title('Validation/Training Loss Ratio') | |
| plt.legend() | |
| plt.grid(True, alpha=0.3) | |
| # Loss improvement plot | |
| plt.subplot(2, 2, 4) | |
| train_improvement = [self.train_losses[0] - t for t in self.train_losses] | |
| val_improvement = [self.val_losses[0] - v for v in self.val_losses] | |
| plt.plot(self.epochs, train_improvement, 'b-', label='Training Improvement') | |
| plt.plot(self.epochs, val_improvement, 'r-', label='Validation Improvement') | |
| plt.xlabel('Epoch') | |
| plt.ylabel('Loss Improvement') | |
| plt.title('Loss Improvement from Start') | |
| plt.legend() | |
| plt.grid(True, alpha=0.3) | |
| plt.tight_layout() | |
| plt.savefig(save_path, dpi=300, bbox_inches='tight') | |
| plt.close() | |
| print(f"Loss comparison plot saved to: {save_path}") | |
| def save_metrics(self, save_path="Metrics/training_metrics.txt"): | |
| with open(save_path, 'w') as f: | |
| f.write("CAPTCHA OCR Training Metrics\n") | |
| f.write("=" * 50 + "\n\n") | |
| f.write(f"Training completed at: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n") | |
| f.write(f"Total epochs: {len(self.epochs)}\n\n") | |
| f.write("Loss Summary:\n") | |
| f.write("-" * 20 + "\n") | |
| f.write(f"Final training loss: {self.train_losses[-1]:.4f}\n") | |
| f.write(f"Final validation loss: {self.val_losses[-1]:.4f}\n") | |
| f.write(f"Best training loss: {min(self.train_losses):.4f}\n") | |
| f.write(f"Best validation loss: {min(self.val_losses):.4f}\n") | |
| f.write(f"Training loss improvement: {self.train_losses[0] - self.train_losses[-1]:.4f}\n") | |
| f.write(f"Validation loss improvement: {self.val_losses[0] - self.val_losses[-1]:.4f}\n\n") | |
| f.write("Sample Predictions:\n") | |
| f.write("-" * 20 + "\n") | |
| for i, (pred, target) in enumerate(zip(self.sample_predictions[:10], self.sample_targets[:10])): | |
| f.write(f"Sample {i+1}: Predicted='{pred}', Target='{target}'\n") | |
| def plot_results(self, image_paths, predictions, targets, save_path="Metrics/inference_results.png"): | |
| """ | |
| Plot CAPTCHA images with their predictions and targets. | |
| Args: | |
| image_paths: List of paths to CAPTCHA images | |
| predictions: List of predicted texts | |
| targets: List of target texts | |
| save_path: Path to save the plot | |
| """ | |
| import cv2 | |
| n_images = len(image_paths) | |
| if n_images == 0: | |
| print("No images to plot!") | |
| return | |
| # Force 2x2 grid for 4 images | |
| rows, cols = 2, 2 | |
| fig, axes = plt.subplots(rows, cols, figsize=(12, 8)) | |
| # Flatten axes for easier indexing | |
| axes = axes.flatten() | |
| for i, (img_path, pred, target) in enumerate(zip(image_paths, predictions, targets)): | |
| if i >= len(axes): | |
| break | |
| ax = axes[i] | |
| # Load and display image | |
| try: | |
| img = cv2.imread(img_path) | |
| if img is not None: | |
| img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
| ax.imshow(img) | |
| # Determine if prediction is correct | |
| is_correct = pred == target | |
| color = 'green' if is_correct else 'red' | |
| status = 'CORRECT' if is_correct else 'WRONG' | |
| # Set title with prediction and target | |
| title = f"Pred: {pred}\nTarget: {target}\n{status}" | |
| ax.set_title(title, fontsize=10, color=color, fontweight='bold') | |
| else: | |
| ax.text(0.5, 0.5, f"Failed to load\n{os.path.basename(img_path)}", | |
| ha='center', va='center', transform=ax.transAxes, fontsize=12) | |
| except Exception as e: | |
| ax.text(0.5, 0.5, f"Error loading image\n{str(e)[:30]}...", | |
| ha='center', va='center', transform=ax.transAxes, fontsize=10, color='red') | |
| # Remove axes | |
| ax.axis('off') | |
| # Hide unused subplots | |
| for i in range(n_images, len(axes)): | |
| axes[i].axis('off') | |
| # Add overall title | |
| fig.suptitle('CAPTCHA OCR Inference Results', fontsize=16, fontweight='bold', y=0.98) | |
| # Calculate accuracy | |
| correct = sum(1 for p, t in zip(predictions, targets) if p == t) | |
| accuracy = (correct / len(targets)) * 100 | |
| # Add accuracy info | |
| fig.text(0.5, 0.02, f'Accuracy: {correct}/{len(targets)} ({accuracy:.1f}%)', | |
| ha='center', fontsize=14, fontweight='bold', | |
| bbox=dict(boxstyle="round,pad=0.3", facecolor="lightblue", alpha=0.7)) | |
| plt.tight_layout() | |
| plt.subplots_adjust(top=0.9, bottom=0.15) | |
| plt.savefig(save_path, dpi=300, bbox_inches='tight') | |
| plt.close() | |
| print(f"Results plot saved to: {save_path}") |