File size: 5,331 Bytes
ecc16d3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 | """
Quick CNN Training Script
========================
Simple script to quickly train the CNN model for the Image Deblurring application.
"""
import os
import sys
# Add current directory to path for imports
current_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.append(current_dir)
def main():
print("π― AI Image Deblurring - CNN Model Training")
print("=" * 50)
print()
# Import after path setup
from modules.cnn_deblurring import CNNDeblurModel
# Check if model already exists
model_path = "models/cnn_deblur_model.h5"
if os.path.exists(model_path):
print("β οΈ A trained model already exists!")
print(f" Location: {model_path}")
choice = input("\nDo you want to:\n (1) Keep existing model\n (2) Train new model (overwrites existing)\n\nChoice (1/2): ").strip()
if choice == "1":
print("β
Keeping existing model. You can start using the application!")
return
elif choice != "2":
print("β Invalid choice. Exiting.")
return
print("π Starting CNN Model Training...")
print()
# Choose training mode
print("Training Options:")
print(" 1. Quick Training (Recommended for testing)")
print(" β’ 500 samples, 10 epochs")
print(" β’ Training time: ~10-15 minutes")
print(" β’ Good for initial testing")
print()
print(" 2. Standard Training")
print(" β’ 1000 samples, 20 epochs")
print(" β’ Training time: ~20-30 minutes")
print(" β’ Balanced quality and time")
print()
print(" 3. Full Training")
print(" β’ 2000 samples, 30 epochs")
print(" β’ Training time: ~45-60 minutes")
print(" β’ Best quality results")
while True:
choice = input("\nSelect training mode (1/2/3): ").strip()
if choice == "1":
samples, epochs = 500, 10
break
elif choice == "2":
samples, epochs = 1000, 20
break
elif choice == "3":
samples, epochs = 2000, 30
break
else:
print("β Invalid choice. Please enter 1, 2, or 3.")
print(f"\nπ― Training Configuration:")
print(f" Samples: {samples}")
print(f" Epochs: {epochs}")
print(f" Model will be saved to: {model_path}")
print()
# Confirm training
confirm = input("Start training? (y/N): ").strip().lower()
if confirm != 'y':
print("β Training cancelled.")
return
try:
# Create model and train
print("\nποΈ Initializing CNN model...")
model = CNNDeblurModel()
print("π Starting training process...")
print(" This will:")
print(" 1. Create synthetic blur dataset")
print(" 2. Build U-Net CNN architecture")
print(" 3. Train the model with early stopping")
print(" 4. Save the trained model")
print()
success = model.train_model(
epochs=epochs,
batch_size=16,
validation_split=0.2,
use_existing_dataset=True,
num_training_samples=samples
)
if success:
print("\nπ Training Completed Successfully!")
print("=" * 40)
print(f"β
Model saved to: {model_path}")
print("β
Training dataset created and saved")
# Test the model
print("\nπ§ͺ Testing trained model...")
metrics = model.evaluate_model()
if metrics:
print("π Model Performance:")
print(f" Loss: {metrics['loss']:.4f}")
print(f" Mean Absolute Error: {metrics['mae']:.4f}")
print(f" Mean Squared Error: {metrics['mse']:.4f}")
if metrics['loss'] < 0.05:
print("π Excellent! Your model is ready for high-quality deblurring!")
elif metrics['loss'] < 0.1:
print("π Good! Your model will provide decent deblurring results.")
else:
print("β οΈ Model trained but may need more training for optimal results.")
print("\nπ Next Steps:")
print(" 1. Run the main application: streamlit run streamlit_app.py")
print(" 2. Upload a blurry image")
print(" 3. Select 'CNN Enhancement' method")
print(" 4. Enjoy high-quality AI deblurring!")
else:
print("\nβ Training Failed!")
print(" Check the error messages above for details.")
print(" You can still use other enhancement methods in the application.")
except KeyboardInterrupt:
print("\nβ οΈ Training interrupted by user.")
print(" Partial progress may be saved.")
except Exception as e:
print(f"\nβ Training error: {e}")
print(" You can still use traditional enhancement methods.")
if __name__ == "__main__":
main() |