File size: 3,896 Bytes
ecc16d3 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 | #!/usr/bin/env python3
"""
Quick test script to verify CNN training functionality
"""
import sys
import os
sys.path.append('.')
def test_cnn_import():
"""Test if CNN module imports correctly"""
print("π§ͺ Testing CNN module import...")
try:
from modules.cnn_deblurring import CNNDeblurModel
print("β
CNN module imported successfully")
return True
except Exception as e:
print(f"β CNN import failed: {e}")
return False
def test_model_creation():
"""Test model creation"""
print("π§ͺ Testing model creation...")
try:
from modules.cnn_deblurring import CNNDeblurModel
model = CNNDeblurModel()
model.build_model()
print("β
Model created successfully")
print(f" Model input shape: {model.input_shape}")
print(f" Model built: {model.model is not None}")
return True
except Exception as e:
print(f"β Model creation failed: {e}")
return False
def test_user_images():
"""Test user images detection"""
print("π§ͺ Testing user images detection...")
try:
dataset_path = "data/training_dataset"
if os.path.exists(dataset_path):
valid_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif'}
user_images = [f for f in os.listdir(dataset_path)
if any(f.lower().endswith(ext) for ext in valid_extensions)]
print(f"β
Found {len(user_images)} user training images")
for img in user_images[:5]: # Show first 5
print(f" - {img}")
if len(user_images) > 5:
print(f" ... and {len(user_images) - 5} more")
return True
else:
print("β οΈ Training dataset directory not found")
return False
except Exception as e:
print(f"β User images test failed: {e}")
return False
def test_quick_dataset_creation():
"""Test creating a small dataset"""
print("π§ͺ Testing quick dataset creation...")
try:
from modules.cnn_deblurring import CNNDeblurModel
model = CNNDeblurModel()
trainer = model # For accessing trainer methods
# Create small dataset for testing
print(" Creating 10 sample dataset...")
blurred, clean = trainer.create_training_dataset(num_samples=10, save_dataset=False)
print(f"β
Dataset created successfully")
print(f" Blurred images shape: {blurred.shape}")
print(f" Clean images shape: {clean.shape}")
return True
except Exception as e:
print(f"β Dataset creation failed: {e}")
return False
def main():
"""Run all tests"""
print("π CNN Training Test Suite")
print("=" * 40)
tests = [
("CNN Import", test_cnn_import),
("Model Creation", test_model_creation),
("User Images", test_user_images),
("Dataset Creation", test_quick_dataset_creation)
]
passed = 0
total = len(tests)
for name, test_func in tests:
print(f"\nπ {name}")
print("-" * 20)
if test_func():
passed += 1
print()
print("=" * 40)
print(f"π Test Results: {passed}/{total} tests passed")
if passed == total:
print("π All tests passed! Training should work correctly.")
print("\nπ‘ Next steps:")
print(" 1. Go to your Streamlit app: http://localhost:8503")
print(" 2. Look for 'π€ CNN Model Management' in sidebar")
print(" 3. Click 'β‘ Quick Train' to start training")
else:
print("β Some tests failed. Please check the errors above.")
if __name__ == "__main__":
main() |