""" Quick CNN Training Script ======================== Simple script to quickly train the CNN model for the Image Deblurring application. """ import os import sys # Add current directory to path for imports current_dir = os.path.dirname(os.path.abspath(__file__)) sys.path.append(current_dir) def main(): print("๐ŸŽฏ AI Image Deblurring - CNN Model Training") print("=" * 50) print() # Import after path setup from modules.cnn_deblurring import CNNDeblurModel # Check if model already exists model_path = "models/cnn_deblur_model.h5" if os.path.exists(model_path): print("โš ๏ธ A trained model already exists!") print(f" Location: {model_path}") choice = input("\nDo you want to:\n (1) Keep existing model\n (2) Train new model (overwrites existing)\n\nChoice (1/2): ").strip() if choice == "1": print("โœ… Keeping existing model. You can start using the application!") return elif choice != "2": print("โŒ Invalid choice. Exiting.") return print("๐Ÿš€ Starting CNN Model Training...") print() # Choose training mode print("Training Options:") print(" 1. Quick Training (Recommended for testing)") print(" โ€ข 500 samples, 10 epochs") print(" โ€ข Training time: ~10-15 minutes") print(" โ€ข Good for initial testing") print() print(" 2. Standard Training") print(" โ€ข 1000 samples, 20 epochs") print(" โ€ข Training time: ~20-30 minutes") print(" โ€ข Balanced quality and time") print() print(" 3. Full Training") print(" โ€ข 2000 samples, 30 epochs") print(" โ€ข Training time: ~45-60 minutes") print(" โ€ข Best quality results") while True: choice = input("\nSelect training mode (1/2/3): ").strip() if choice == "1": samples, epochs = 500, 10 break elif choice == "2": samples, epochs = 1000, 20 break elif choice == "3": samples, epochs = 2000, 30 break else: print("โŒ Invalid choice. Please enter 1, 2, or 3.") print(f"\n๐ŸŽฏ Training Configuration:") print(f" Samples: {samples}") print(f" Epochs: {epochs}") print(f" Model will be saved to: {model_path}") print() # Confirm training confirm = input("Start training? (y/N): ").strip().lower() if confirm != 'y': print("โŒ Training cancelled.") return try: # Create model and train print("\n๐Ÿ—๏ธ Initializing CNN model...") model = CNNDeblurModel() print("๐Ÿ“Š Starting training process...") print(" This will:") print(" 1. Create synthetic blur dataset") print(" 2. Build U-Net CNN architecture") print(" 3. Train the model with early stopping") print(" 4. Save the trained model") print() success = model.train_model( epochs=epochs, batch_size=16, validation_split=0.2, use_existing_dataset=True, num_training_samples=samples ) if success: print("\n๐ŸŽ‰ Training Completed Successfully!") print("=" * 40) print(f"โœ… Model saved to: {model_path}") print("โœ… Training dataset created and saved") # Test the model print("\n๐Ÿงช Testing trained model...") metrics = model.evaluate_model() if metrics: print("๐Ÿ“Š Model Performance:") print(f" Loss: {metrics['loss']:.4f}") print(f" Mean Absolute Error: {metrics['mae']:.4f}") print(f" Mean Squared Error: {metrics['mse']:.4f}") if metrics['loss'] < 0.05: print("๐ŸŒŸ Excellent! Your model is ready for high-quality deblurring!") elif metrics['loss'] < 0.1: print("๐Ÿ‘ Good! Your model will provide decent deblurring results.") else: print("โš ๏ธ Model trained but may need more training for optimal results.") print("\n๐Ÿš€ Next Steps:") print(" 1. Run the main application: streamlit run streamlit_app.py") print(" 2. Upload a blurry image") print(" 3. Select 'CNN Enhancement' method") print(" 4. Enjoy high-quality AI deblurring!") else: print("\nโŒ Training Failed!") print(" Check the error messages above for details.") print(" You can still use other enhancement methods in the application.") except KeyboardInterrupt: print("\nโš ๏ธ Training interrupted by user.") print(" Partial progress may be saved.") except Exception as e: print(f"\nโŒ Training error: {e}") print(" You can still use traditional enhancement methods.") if __name__ == "__main__": main()