#!/usr/bin/env python3 """ CNN Model Training Script ======================== Standalone script to train the CNN deblurring model with comprehensive options. """ import os import sys import argparse import logging from datetime import datetime # Add modules to path sys.path.append(os.path.dirname(os.path.abspath(__file__))) from modules.cnn_deblurring import CNNDeblurModel, train_new_model, quick_train, full_train # Configure logging logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s', handlers=[ logging.FileHandler(f'training_log_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log'), logging.StreamHandler() ] ) logger = logging.getLogger(__name__) def main(): """Main training function with comprehensive options""" parser = argparse.ArgumentParser( description='Train CNN Deblurring Model', formatter_class=argparse.RawDescriptionHelpFormatter, epilog=''' Examples: python train_cnn_model.py --quick # Quick training (500 samples, 10 epochs) python train_cnn_model.py --full # Full training (2000 samples, 30 epochs) python train_cnn_model.py --samples 1500 # Custom samples with default epochs python train_cnn_model.py --samples 1000 --epochs 25 # Custom training python train_cnn_model.py --test # Test existing model ''' ) # Training modes mode_group = parser.add_mutually_exclusive_group(required=True) mode_group.add_argument('--quick', action='store_true', help='Quick training (500 samples, 10 epochs)') mode_group.add_argument('--full', action='store_true', help='Full training (2000 samples, 30 epochs)') mode_group.add_argument('--custom', action='store_true', help='Custom training (specify --samples and --epochs)') mode_group.add_argument('--test', action='store_true', help='Test existing model performance') # Training parameters parser.add_argument('--samples', type=int, default=1000, help='Number of training samples (default: 1000)') parser.add_argument('--epochs', type=int, default=20, help='Number of training epochs (default: 20)') parser.add_argument('--batch-size', type=int, default=16, help='Training batch size (default: 16)') parser.add_argument('--validation-split', type=float, default=0.2, help='Validation data split (default: 0.2)') # Model parameters parser.add_argument('--image-size', type=int, default=256, help='Input image size (default: 256x256)') # Data options parser.add_argument('--use-existing-dataset', action='store_true', default=True, help='Use existing dataset if available (default: True)') parser.add_argument('--force-new-dataset', action='store_true', help='Force creation of new dataset') args = parser.parse_args() # Print banner print("๐ŸŽฏ CNN Deblurring Model Training") print("=" * 40) print(f"Timestamp: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}") print() # Ensure directories exist os.makedirs("models", exist_ok=True) os.makedirs("data/training_dataset", exist_ok=True) try: if args.test: # Test existing model print("๐Ÿงช Testing Existing Model") print("-" * 30) model = CNNDeblurModel() if model.load_model(model.model_path): print("โœ… Successfully loaded trained model") # Evaluate model print("๐Ÿ“Š Evaluating model performance...") metrics = model.evaluate_model() if metrics: print("\n๐Ÿ“ˆ Model Performance Metrics:") print(f" Loss: {metrics['loss']:.4f}") print(f" Mean Absolute Error: {metrics['mae']:.4f}") print(f" Mean Squared Error: {metrics['mse']:.4f}") # Performance interpretation if metrics['loss'] < 0.01: print("๐ŸŒŸ Excellent performance!") elif metrics['loss'] < 0.05: print("๐Ÿ‘ Good performance") elif metrics['loss'] < 0.1: print("โš ๏ธ Fair performance - consider more training") else: print("๐Ÿ”„ Poor performance - retrain recommended") else: print("โŒ Failed to evaluate model") else: print("โŒ No trained model found. Train a model first:") print(" python train_cnn_model.py --quick") return False elif args.quick: # Quick training print("๐Ÿš€ Quick Training Mode") print("-" * 30) print("Configuration:") print(f" Samples: 500") print(f" Epochs: 10") print(f" Expected time: ~10-15 minutes") print() model = quick_train() elif args.full: # Full training print("๐Ÿš€ Full Training Mode") print("-" * 30) print("Configuration:") print(f" Samples: 2000") print(f" Epochs: 30") print(f" Expected time: ~45-60 minutes") print() model = full_train() elif args.custom: # Custom training print("๐Ÿš€ Custom Training Mode") print("-" * 30) print("Configuration:") print(f" Samples: {args.samples}") print(f" Epochs: {args.epochs}") print(f" Batch Size: {args.batch_size}") print(f" Validation Split: {args.validation_split}") print(f" Image Size: {args.image_size}x{args.image_size}") print(f" Use Existing Dataset: {not args.force_new_dataset}") # Estimate training time estimated_minutes = (args.samples * args.epochs) / 1000 print(f" Estimated time: ~{estimated_minutes:.1f} minutes") print() # Initialize model with custom parameters input_shape = (args.image_size, args.image_size, 3) model = CNNDeblurModel(input_shape=input_shape) # Train with custom parameters success = model.train_model( epochs=args.epochs, batch_size=args.batch_size, validation_split=args.validation_split, use_existing_dataset=not args.force_new_dataset, num_training_samples=args.samples ) if success: print("โœ… Custom training completed successfully!") # Evaluate model metrics = model.evaluate_model() if metrics: print(f"๐Ÿ“Š Final Model Performance:") print(f" Loss: {metrics['loss']:.4f}") print(f" MAE: {metrics['mae']:.4f}") print(f" MSE: {metrics['mse']:.4f}") else: print("โŒ Custom training failed!") return False # Final message if not args.test: print("\n๐ŸŽ‰ Training Process Completed!") print(f"๐Ÿ“ Model saved to: models/cnn_deblur_model.h5") print(f"๐Ÿ“ Dataset saved to: data/training_dataset/") print(f"๐Ÿ“ Training log: training_log_*.log") print("\n๐Ÿš€ You can now use the trained model in the main application!") return True except KeyboardInterrupt: print("\nโš ๏ธ Training interrupted by user") return False except Exception as e: logger.error(f"Training failed with error: {e}") print(f"\nโŒ Training failed: {e}") return False if __name__ == "__main__": success = main() sys.exit(0 if success else 1)