ganeshkumar383's picture
Upload 27 files (#2)
ecc16d3 verified
#!/usr/bin/env python3
"""
Quick test script to verify CNN training functionality
"""
import sys
import os
sys.path.append('.')
def test_cnn_import():
"""Test if CNN module imports correctly"""
print("πŸ§ͺ Testing CNN module import...")
try:
from modules.cnn_deblurring import CNNDeblurModel
print("βœ… CNN module imported successfully")
return True
except Exception as e:
print(f"❌ CNN import failed: {e}")
return False
def test_model_creation():
"""Test model creation"""
print("πŸ§ͺ Testing model creation...")
try:
from modules.cnn_deblurring import CNNDeblurModel
model = CNNDeblurModel()
model.build_model()
print("βœ… Model created successfully")
print(f" Model input shape: {model.input_shape}")
print(f" Model built: {model.model is not None}")
return True
except Exception as e:
print(f"❌ Model creation failed: {e}")
return False
def test_user_images():
"""Test user images detection"""
print("πŸ§ͺ Testing user images detection...")
try:
dataset_path = "data/training_dataset"
if os.path.exists(dataset_path):
valid_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif'}
user_images = [f for f in os.listdir(dataset_path)
if any(f.lower().endswith(ext) for ext in valid_extensions)]
print(f"βœ… Found {len(user_images)} user training images")
for img in user_images[:5]: # Show first 5
print(f" - {img}")
if len(user_images) > 5:
print(f" ... and {len(user_images) - 5} more")
return True
else:
print("⚠️ Training dataset directory not found")
return False
except Exception as e:
print(f"❌ User images test failed: {e}")
return False
def test_quick_dataset_creation():
"""Test creating a small dataset"""
print("πŸ§ͺ Testing quick dataset creation...")
try:
from modules.cnn_deblurring import CNNDeblurModel
model = CNNDeblurModel()
trainer = model # For accessing trainer methods
# Create small dataset for testing
print(" Creating 10 sample dataset...")
blurred, clean = trainer.create_training_dataset(num_samples=10, save_dataset=False)
print(f"βœ… Dataset created successfully")
print(f" Blurred images shape: {blurred.shape}")
print(f" Clean images shape: {clean.shape}")
return True
except Exception as e:
print(f"❌ Dataset creation failed: {e}")
return False
def main():
"""Run all tests"""
print("πŸš€ CNN Training Test Suite")
print("=" * 40)
tests = [
("CNN Import", test_cnn_import),
("Model Creation", test_model_creation),
("User Images", test_user_images),
("Dataset Creation", test_quick_dataset_creation)
]
passed = 0
total = len(tests)
for name, test_func in tests:
print(f"\nπŸ“‹ {name}")
print("-" * 20)
if test_func():
passed += 1
print()
print("=" * 40)
print(f"πŸ“Š Test Results: {passed}/{total} tests passed")
if passed == total:
print("πŸŽ‰ All tests passed! Training should work correctly.")
print("\nπŸ’‘ Next steps:")
print(" 1. Go to your Streamlit app: http://localhost:8503")
print(" 2. Look for 'πŸ€– CNN Model Management' in sidebar")
print(" 3. Click '⚑ Quick Train' to start training")
else:
print("❌ Some tests failed. Please check the errors above.")
if __name__ == "__main__":
main()