| | """
|
| | Quick CNN Training Script
|
| | ========================
|
| |
|
| | Simple script to quickly train the CNN model for the Image Deblurring application.
|
| | """
|
| |
|
| | import os
|
| | import sys
|
| |
|
| |
|
| | current_dir = os.path.dirname(os.path.abspath(__file__))
|
| | sys.path.append(current_dir)
|
| |
|
| | def main():
|
| | print("π― AI Image Deblurring - CNN Model Training")
|
| | print("=" * 50)
|
| | print()
|
| |
|
| |
|
| | from modules.cnn_deblurring import CNNDeblurModel
|
| |
|
| |
|
| | model_path = "models/cnn_deblur_model.h5"
|
| | if os.path.exists(model_path):
|
| | print("β οΈ A trained model already exists!")
|
| | print(f" Location: {model_path}")
|
| |
|
| | choice = input("\nDo you want to:\n (1) Keep existing model\n (2) Train new model (overwrites existing)\n\nChoice (1/2): ").strip()
|
| |
|
| | if choice == "1":
|
| | print("β
Keeping existing model. You can start using the application!")
|
| | return
|
| | elif choice != "2":
|
| | print("β Invalid choice. Exiting.")
|
| | return
|
| |
|
| | print("π Starting CNN Model Training...")
|
| | print()
|
| |
|
| |
|
| | print("Training Options:")
|
| | print(" 1. Quick Training (Recommended for testing)")
|
| | print(" β’ 500 samples, 10 epochs")
|
| | print(" β’ Training time: ~10-15 minutes")
|
| | print(" β’ Good for initial testing")
|
| | print()
|
| | print(" 2. Standard Training")
|
| | print(" β’ 1000 samples, 20 epochs")
|
| | print(" β’ Training time: ~20-30 minutes")
|
| | print(" β’ Balanced quality and time")
|
| | print()
|
| | print(" 3. Full Training")
|
| | print(" β’ 2000 samples, 30 epochs")
|
| | print(" β’ Training time: ~45-60 minutes")
|
| | print(" β’ Best quality results")
|
| |
|
| | while True:
|
| | choice = input("\nSelect training mode (1/2/3): ").strip()
|
| |
|
| | if choice == "1":
|
| | samples, epochs = 500, 10
|
| | break
|
| | elif choice == "2":
|
| | samples, epochs = 1000, 20
|
| | break
|
| | elif choice == "3":
|
| | samples, epochs = 2000, 30
|
| | break
|
| | else:
|
| | print("β Invalid choice. Please enter 1, 2, or 3.")
|
| |
|
| | print(f"\nπ― Training Configuration:")
|
| | print(f" Samples: {samples}")
|
| | print(f" Epochs: {epochs}")
|
| | print(f" Model will be saved to: {model_path}")
|
| | print()
|
| |
|
| |
|
| | confirm = input("Start training? (y/N): ").strip().lower()
|
| | if confirm != 'y':
|
| | print("β Training cancelled.")
|
| | return
|
| |
|
| | try:
|
| |
|
| | print("\nποΈ Initializing CNN model...")
|
| | model = CNNDeblurModel()
|
| |
|
| | print("π Starting training process...")
|
| | print(" This will:")
|
| | print(" 1. Create synthetic blur dataset")
|
| | print(" 2. Build U-Net CNN architecture")
|
| | print(" 3. Train the model with early stopping")
|
| | print(" 4. Save the trained model")
|
| | print()
|
| |
|
| | success = model.train_model(
|
| | epochs=epochs,
|
| | batch_size=16,
|
| | validation_split=0.2,
|
| | use_existing_dataset=True,
|
| | num_training_samples=samples
|
| | )
|
| |
|
| | if success:
|
| | print("\nπ Training Completed Successfully!")
|
| | print("=" * 40)
|
| | print(f"β
Model saved to: {model_path}")
|
| | print("β
Training dataset created and saved")
|
| |
|
| |
|
| | print("\nπ§ͺ Testing trained model...")
|
| | metrics = model.evaluate_model()
|
| | if metrics:
|
| | print("π Model Performance:")
|
| | print(f" Loss: {metrics['loss']:.4f}")
|
| | print(f" Mean Absolute Error: {metrics['mae']:.4f}")
|
| | print(f" Mean Squared Error: {metrics['mse']:.4f}")
|
| |
|
| | if metrics['loss'] < 0.05:
|
| | print("π Excellent! Your model is ready for high-quality deblurring!")
|
| | elif metrics['loss'] < 0.1:
|
| | print("π Good! Your model will provide decent deblurring results.")
|
| | else:
|
| | print("β οΈ Model trained but may need more training for optimal results.")
|
| |
|
| | print("\nπ Next Steps:")
|
| | print(" 1. Run the main application: streamlit run streamlit_app.py")
|
| | print(" 2. Upload a blurry image")
|
| | print(" 3. Select 'CNN Enhancement' method")
|
| | print(" 4. Enjoy high-quality AI deblurring!")
|
| |
|
| | else:
|
| | print("\nβ Training Failed!")
|
| | print(" Check the error messages above for details.")
|
| | print(" You can still use other enhancement methods in the application.")
|
| |
|
| | except KeyboardInterrupt:
|
| | print("\nβ οΈ Training interrupted by user.")
|
| | print(" Partial progress may be saved.")
|
| |
|
| | except Exception as e:
|
| | print(f"\nβ Training error: {e}")
|
| | print(" You can still use traditional enhancement methods.")
|
| |
|
| | if __name__ == "__main__":
|
| | main() |