|
|
| """
|
| CNN Model Training Script
|
| ========================
|
|
|
| Standalone script to train the CNN deblurring model with comprehensive options.
|
| """
|
|
|
| import os
|
| import sys
|
| import argparse
|
| import logging
|
| from datetime import datetime
|
|
|
|
|
| sys.path.append(os.path.dirname(os.path.abspath(__file__)))
|
|
|
| from modules.cnn_deblurring import CNNDeblurModel, train_new_model, quick_train, full_train
|
|
|
|
|
| logging.basicConfig(
|
| level=logging.INFO,
|
| format='%(asctime)s - %(levelname)s - %(message)s',
|
| handlers=[
|
| logging.FileHandler(f'training_log_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log'),
|
| logging.StreamHandler()
|
| ]
|
| )
|
|
|
| logger = logging.getLogger(__name__)
|
|
|
| def main():
|
| """Main training function with comprehensive options"""
|
|
|
| parser = argparse.ArgumentParser(
|
| description='Train CNN Deblurring Model',
|
| formatter_class=argparse.RawDescriptionHelpFormatter,
|
| epilog='''
|
| Examples:
|
| python train_cnn_model.py --quick # Quick training (500 samples, 10 epochs)
|
| python train_cnn_model.py --full # Full training (2000 samples, 30 epochs)
|
| python train_cnn_model.py --samples 1500 # Custom samples with default epochs
|
| python train_cnn_model.py --samples 1000 --epochs 25 # Custom training
|
| python train_cnn_model.py --test # Test existing model
|
| '''
|
| )
|
|
|
|
|
| mode_group = parser.add_mutually_exclusive_group(required=True)
|
| mode_group.add_argument('--quick', action='store_true',
|
| help='Quick training (500 samples, 10 epochs)')
|
| mode_group.add_argument('--full', action='store_true',
|
| help='Full training (2000 samples, 30 epochs)')
|
| mode_group.add_argument('--custom', action='store_true',
|
| help='Custom training (specify --samples and --epochs)')
|
| mode_group.add_argument('--test', action='store_true',
|
| help='Test existing model performance')
|
|
|
|
|
| parser.add_argument('--samples', type=int, default=1000,
|
| help='Number of training samples (default: 1000)')
|
| parser.add_argument('--epochs', type=int, default=20,
|
| help='Number of training epochs (default: 20)')
|
| parser.add_argument('--batch-size', type=int, default=16,
|
| help='Training batch size (default: 16)')
|
| parser.add_argument('--validation-split', type=float, default=0.2,
|
| help='Validation data split (default: 0.2)')
|
|
|
|
|
| parser.add_argument('--image-size', type=int, default=256,
|
| help='Input image size (default: 256x256)')
|
|
|
|
|
| parser.add_argument('--use-existing-dataset', action='store_true', default=True,
|
| help='Use existing dataset if available (default: True)')
|
| parser.add_argument('--force-new-dataset', action='store_true',
|
| help='Force creation of new dataset')
|
|
|
| args = parser.parse_args()
|
|
|
|
|
| print("π― CNN Deblurring Model Training")
|
| print("=" * 40)
|
| print(f"Timestamp: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
|
| print()
|
|
|
|
|
| os.makedirs("models", exist_ok=True)
|
| os.makedirs("data/training_dataset", exist_ok=True)
|
|
|
| try:
|
| if args.test:
|
|
|
| print("π§ͺ Testing Existing Model")
|
| print("-" * 30)
|
|
|
| model = CNNDeblurModel()
|
|
|
| if model.load_model(model.model_path):
|
| print("β
Successfully loaded trained model")
|
|
|
|
|
| print("π Evaluating model performance...")
|
| metrics = model.evaluate_model()
|
|
|
| if metrics:
|
| print("\nπ Model Performance Metrics:")
|
| print(f" Loss: {metrics['loss']:.4f}")
|
| print(f" Mean Absolute Error: {metrics['mae']:.4f}")
|
| print(f" Mean Squared Error: {metrics['mse']:.4f}")
|
|
|
|
|
| if metrics['loss'] < 0.01:
|
| print("π Excellent performance!")
|
| elif metrics['loss'] < 0.05:
|
| print("π Good performance")
|
| elif metrics['loss'] < 0.1:
|
| print("β οΈ Fair performance - consider more training")
|
| else:
|
| print("π Poor performance - retrain recommended")
|
| else:
|
| print("β Failed to evaluate model")
|
| else:
|
| print("β No trained model found. Train a model first:")
|
| print(" python train_cnn_model.py --quick")
|
| return False
|
|
|
| elif args.quick:
|
|
|
| print("π Quick Training Mode")
|
| print("-" * 30)
|
| print("Configuration:")
|
| print(f" Samples: 500")
|
| print(f" Epochs: 10")
|
| print(f" Expected time: ~10-15 minutes")
|
| print()
|
|
|
| model = quick_train()
|
|
|
| elif args.full:
|
|
|
| print("π Full Training Mode")
|
| print("-" * 30)
|
| print("Configuration:")
|
| print(f" Samples: 2000")
|
| print(f" Epochs: 30")
|
| print(f" Expected time: ~45-60 minutes")
|
| print()
|
|
|
| model = full_train()
|
|
|
| elif args.custom:
|
|
|
| print("π Custom Training Mode")
|
| print("-" * 30)
|
| print("Configuration:")
|
| print(f" Samples: {args.samples}")
|
| print(f" Epochs: {args.epochs}")
|
| print(f" Batch Size: {args.batch_size}")
|
| print(f" Validation Split: {args.validation_split}")
|
| print(f" Image Size: {args.image_size}x{args.image_size}")
|
| print(f" Use Existing Dataset: {not args.force_new_dataset}")
|
|
|
|
|
| estimated_minutes = (args.samples * args.epochs) / 1000
|
| print(f" Estimated time: ~{estimated_minutes:.1f} minutes")
|
| print()
|
|
|
|
|
| input_shape = (args.image_size, args.image_size, 3)
|
| model = CNNDeblurModel(input_shape=input_shape)
|
|
|
|
|
| success = model.train_model(
|
| epochs=args.epochs,
|
| batch_size=args.batch_size,
|
| validation_split=args.validation_split,
|
| use_existing_dataset=not args.force_new_dataset,
|
| num_training_samples=args.samples
|
| )
|
|
|
| if success:
|
| print("β
Custom training completed successfully!")
|
|
|
|
|
| metrics = model.evaluate_model()
|
| if metrics:
|
| print(f"π Final Model Performance:")
|
| print(f" Loss: {metrics['loss']:.4f}")
|
| print(f" MAE: {metrics['mae']:.4f}")
|
| print(f" MSE: {metrics['mse']:.4f}")
|
| else:
|
| print("β Custom training failed!")
|
| return False
|
|
|
|
|
| if not args.test:
|
| print("\nπ Training Process Completed!")
|
| print(f"π Model saved to: models/cnn_deblur_model.h5")
|
| print(f"π Dataset saved to: data/training_dataset/")
|
| print(f"π Training log: training_log_*.log")
|
| print("\nπ You can now use the trained model in the main application!")
|
|
|
| return True
|
|
|
| except KeyboardInterrupt:
|
| print("\nβ οΈ Training interrupted by user")
|
| return False
|
| except Exception as e:
|
| logger.error(f"Training failed with error: {e}")
|
| print(f"\nβ Training failed: {e}")
|
| return False
|
|
|
| if __name__ == "__main__":
|
| success = main()
|
| sys.exit(0 if success else 1) |