| |
|
| | """
|
| | Quick test script to verify CNN training functionality
|
| | """
|
| |
|
| | import sys
|
| | import os
|
| | sys.path.append('.')
|
| |
|
| | def test_cnn_import():
|
| | """Test if CNN module imports correctly"""
|
| | print("π§ͺ Testing CNN module import...")
|
| | try:
|
| | from modules.cnn_deblurring import CNNDeblurModel
|
| | print("β
CNN module imported successfully")
|
| | return True
|
| | except Exception as e:
|
| | print(f"β CNN import failed: {e}")
|
| | return False
|
| |
|
| | def test_model_creation():
|
| | """Test model creation"""
|
| | print("π§ͺ Testing model creation...")
|
| | try:
|
| | from modules.cnn_deblurring import CNNDeblurModel
|
| | model = CNNDeblurModel()
|
| | model.build_model()
|
| | print("β
Model created successfully")
|
| | print(f" Model input shape: {model.input_shape}")
|
| | print(f" Model built: {model.model is not None}")
|
| | return True
|
| | except Exception as e:
|
| | print(f"β Model creation failed: {e}")
|
| | return False
|
| |
|
| | def test_user_images():
|
| | """Test user images detection"""
|
| | print("π§ͺ Testing user images detection...")
|
| | try:
|
| | dataset_path = "data/training_dataset"
|
| | if os.path.exists(dataset_path):
|
| | valid_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif'}
|
| | user_images = [f for f in os.listdir(dataset_path)
|
| | if any(f.lower().endswith(ext) for ext in valid_extensions)]
|
| | print(f"β
Found {len(user_images)} user training images")
|
| | for img in user_images[:5]:
|
| | print(f" - {img}")
|
| | if len(user_images) > 5:
|
| | print(f" ... and {len(user_images) - 5} more")
|
| | return True
|
| | else:
|
| | print("β οΈ Training dataset directory not found")
|
| | return False
|
| | except Exception as e:
|
| | print(f"β User images test failed: {e}")
|
| | return False
|
| |
|
| | def test_quick_dataset_creation():
|
| | """Test creating a small dataset"""
|
| | print("π§ͺ Testing quick dataset creation...")
|
| | try:
|
| | from modules.cnn_deblurring import CNNDeblurModel
|
| |
|
| | model = CNNDeblurModel()
|
| | trainer = model
|
| |
|
| |
|
| | print(" Creating 10 sample dataset...")
|
| | blurred, clean = trainer.create_training_dataset(num_samples=10, save_dataset=False)
|
| |
|
| | print(f"β
Dataset created successfully")
|
| | print(f" Blurred images shape: {blurred.shape}")
|
| | print(f" Clean images shape: {clean.shape}")
|
| | return True
|
| | except Exception as e:
|
| | print(f"β Dataset creation failed: {e}")
|
| | return False
|
| |
|
| | def main():
|
| | """Run all tests"""
|
| | print("π CNN Training Test Suite")
|
| | print("=" * 40)
|
| |
|
| | tests = [
|
| | ("CNN Import", test_cnn_import),
|
| | ("Model Creation", test_model_creation),
|
| | ("User Images", test_user_images),
|
| | ("Dataset Creation", test_quick_dataset_creation)
|
| | ]
|
| |
|
| | passed = 0
|
| | total = len(tests)
|
| |
|
| | for name, test_func in tests:
|
| | print(f"\nπ {name}")
|
| | print("-" * 20)
|
| | if test_func():
|
| | passed += 1
|
| | print()
|
| |
|
| | print("=" * 40)
|
| | print(f"π Test Results: {passed}/{total} tests passed")
|
| |
|
| | if passed == total:
|
| | print("π All tests passed! Training should work correctly.")
|
| | print("\nπ‘ Next steps:")
|
| | print(" 1. Go to your Streamlit app: http://localhost:8503")
|
| | print(" 2. Look for 'π€ CNN Model Management' in sidebar")
|
| | print(" 3. Click 'β‘ Quick Train' to start training")
|
| | else:
|
| | print("β Some tests failed. Please check the errors above.")
|
| |
|
| | if __name__ == "__main__":
|
| | main() |