AI-Based-Image-Deblurring-App / src /train_cnn_model.py
ganeshkumar383's picture
Upload 27 files (#2)
ecc16d3 verified
#!/usr/bin/env python3
"""
CNN Model Training Script
========================
Standalone script to train the CNN deblurring model with comprehensive options.
"""
import os
import sys
import argparse
import logging
from datetime import datetime
# Add modules to path
sys.path.append(os.path.dirname(os.path.abspath(__file__)))
from modules.cnn_deblurring import CNNDeblurModel, train_new_model, quick_train, full_train
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler(f'training_log_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log'),
logging.StreamHandler()
]
)
logger = logging.getLogger(__name__)
def main():
"""Main training function with comprehensive options"""
parser = argparse.ArgumentParser(
description='Train CNN Deblurring Model',
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog='''
Examples:
python train_cnn_model.py --quick # Quick training (500 samples, 10 epochs)
python train_cnn_model.py --full # Full training (2000 samples, 30 epochs)
python train_cnn_model.py --samples 1500 # Custom samples with default epochs
python train_cnn_model.py --samples 1000 --epochs 25 # Custom training
python train_cnn_model.py --test # Test existing model
'''
)
# Training modes
mode_group = parser.add_mutually_exclusive_group(required=True)
mode_group.add_argument('--quick', action='store_true',
help='Quick training (500 samples, 10 epochs)')
mode_group.add_argument('--full', action='store_true',
help='Full training (2000 samples, 30 epochs)')
mode_group.add_argument('--custom', action='store_true',
help='Custom training (specify --samples and --epochs)')
mode_group.add_argument('--test', action='store_true',
help='Test existing model performance')
# Training parameters
parser.add_argument('--samples', type=int, default=1000,
help='Number of training samples (default: 1000)')
parser.add_argument('--epochs', type=int, default=20,
help='Number of training epochs (default: 20)')
parser.add_argument('--batch-size', type=int, default=16,
help='Training batch size (default: 16)')
parser.add_argument('--validation-split', type=float, default=0.2,
help='Validation data split (default: 0.2)')
# Model parameters
parser.add_argument('--image-size', type=int, default=256,
help='Input image size (default: 256x256)')
# Data options
parser.add_argument('--use-existing-dataset', action='store_true', default=True,
help='Use existing dataset if available (default: True)')
parser.add_argument('--force-new-dataset', action='store_true',
help='Force creation of new dataset')
args = parser.parse_args()
# Print banner
print("🎯 CNN Deblurring Model Training")
print("=" * 40)
print(f"Timestamp: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
print()
# Ensure directories exist
os.makedirs("models", exist_ok=True)
os.makedirs("data/training_dataset", exist_ok=True)
try:
if args.test:
# Test existing model
print("πŸ§ͺ Testing Existing Model")
print("-" * 30)
model = CNNDeblurModel()
if model.load_model(model.model_path):
print("βœ… Successfully loaded trained model")
# Evaluate model
print("πŸ“Š Evaluating model performance...")
metrics = model.evaluate_model()
if metrics:
print("\nπŸ“ˆ Model Performance Metrics:")
print(f" Loss: {metrics['loss']:.4f}")
print(f" Mean Absolute Error: {metrics['mae']:.4f}")
print(f" Mean Squared Error: {metrics['mse']:.4f}")
# Performance interpretation
if metrics['loss'] < 0.01:
print("🌟 Excellent performance!")
elif metrics['loss'] < 0.05:
print("πŸ‘ Good performance")
elif metrics['loss'] < 0.1:
print("⚠️ Fair performance - consider more training")
else:
print("πŸ”„ Poor performance - retrain recommended")
else:
print("❌ Failed to evaluate model")
else:
print("❌ No trained model found. Train a model first:")
print(" python train_cnn_model.py --quick")
return False
elif args.quick:
# Quick training
print("πŸš€ Quick Training Mode")
print("-" * 30)
print("Configuration:")
print(f" Samples: 500")
print(f" Epochs: 10")
print(f" Expected time: ~10-15 minutes")
print()
model = quick_train()
elif args.full:
# Full training
print("πŸš€ Full Training Mode")
print("-" * 30)
print("Configuration:")
print(f" Samples: 2000")
print(f" Epochs: 30")
print(f" Expected time: ~45-60 minutes")
print()
model = full_train()
elif args.custom:
# Custom training
print("πŸš€ Custom Training Mode")
print("-" * 30)
print("Configuration:")
print(f" Samples: {args.samples}")
print(f" Epochs: {args.epochs}")
print(f" Batch Size: {args.batch_size}")
print(f" Validation Split: {args.validation_split}")
print(f" Image Size: {args.image_size}x{args.image_size}")
print(f" Use Existing Dataset: {not args.force_new_dataset}")
# Estimate training time
estimated_minutes = (args.samples * args.epochs) / 1000
print(f" Estimated time: ~{estimated_minutes:.1f} minutes")
print()
# Initialize model with custom parameters
input_shape = (args.image_size, args.image_size, 3)
model = CNNDeblurModel(input_shape=input_shape)
# Train with custom parameters
success = model.train_model(
epochs=args.epochs,
batch_size=args.batch_size,
validation_split=args.validation_split,
use_existing_dataset=not args.force_new_dataset,
num_training_samples=args.samples
)
if success:
print("βœ… Custom training completed successfully!")
# Evaluate model
metrics = model.evaluate_model()
if metrics:
print(f"πŸ“Š Final Model Performance:")
print(f" Loss: {metrics['loss']:.4f}")
print(f" MAE: {metrics['mae']:.4f}")
print(f" MSE: {metrics['mse']:.4f}")
else:
print("❌ Custom training failed!")
return False
# Final message
if not args.test:
print("\nπŸŽ‰ Training Process Completed!")
print(f"πŸ“ Model saved to: models/cnn_deblur_model.h5")
print(f"πŸ“ Dataset saved to: data/training_dataset/")
print(f"πŸ“ Training log: training_log_*.log")
print("\nπŸš€ You can now use the trained model in the main application!")
return True
except KeyboardInterrupt:
print("\n⚠️ Training interrupted by user")
return False
except Exception as e:
logger.error(f"Training failed with error: {e}")
print(f"\n❌ Training failed: {e}")
return False
if __name__ == "__main__":
success = main()
sys.exit(0 if success else 1)