#!/usr/bin/env python3 """ Quick test script to verify CNN training functionality """ import sys import os sys.path.append('.') def test_cnn_import(): """Test if CNN module imports correctly""" print("๐Ÿงช Testing CNN module import...") try: from modules.cnn_deblurring import CNNDeblurModel print("โœ… CNN module imported successfully") return True except Exception as e: print(f"โŒ CNN import failed: {e}") return False def test_model_creation(): """Test model creation""" print("๐Ÿงช Testing model creation...") try: from modules.cnn_deblurring import CNNDeblurModel model = CNNDeblurModel() model.build_model() print("โœ… Model created successfully") print(f" Model input shape: {model.input_shape}") print(f" Model built: {model.model is not None}") return True except Exception as e: print(f"โŒ Model creation failed: {e}") return False def test_user_images(): """Test user images detection""" print("๐Ÿงช Testing user images detection...") try: dataset_path = "data/training_dataset" if os.path.exists(dataset_path): valid_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif'} user_images = [f for f in os.listdir(dataset_path) if any(f.lower().endswith(ext) for ext in valid_extensions)] print(f"โœ… Found {len(user_images)} user training images") for img in user_images[:5]: # Show first 5 print(f" - {img}") if len(user_images) > 5: print(f" ... and {len(user_images) - 5} more") return True else: print("โš ๏ธ Training dataset directory not found") return False except Exception as e: print(f"โŒ User images test failed: {e}") return False def test_quick_dataset_creation(): """Test creating a small dataset""" print("๐Ÿงช Testing quick dataset creation...") try: from modules.cnn_deblurring import CNNDeblurModel model = CNNDeblurModel() trainer = model # For accessing trainer methods # Create small dataset for testing print(" Creating 10 sample dataset...") blurred, clean = trainer.create_training_dataset(num_samples=10, save_dataset=False) print(f"โœ… Dataset created successfully") print(f" Blurred images shape: {blurred.shape}") print(f" Clean images shape: {clean.shape}") return True except Exception as e: print(f"โŒ Dataset creation failed: {e}") return False def main(): """Run all tests""" print("๐Ÿš€ CNN Training Test Suite") print("=" * 40) tests = [ ("CNN Import", test_cnn_import), ("Model Creation", test_model_creation), ("User Images", test_user_images), ("Dataset Creation", test_quick_dataset_creation) ] passed = 0 total = len(tests) for name, test_func in tests: print(f"\n๐Ÿ“‹ {name}") print("-" * 20) if test_func(): passed += 1 print() print("=" * 40) print(f"๐Ÿ“Š Test Results: {passed}/{total} tests passed") if passed == total: print("๐ŸŽ‰ All tests passed! Training should work correctly.") print("\n๐Ÿ’ก Next steps:") print(" 1. Go to your Streamlit app: http://localhost:8503") print(" 2. Look for '๐Ÿค– CNN Model Management' in sidebar") print(" 3. Click 'โšก Quick Train' to start training") else: print("โŒ Some tests failed. Please check the errors above.") if __name__ == "__main__": main()