File size: 3,896 Bytes
ecc16d3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
#!/usr/bin/env python3
"""

Quick test script to verify CNN training functionality

"""

import sys
import os
sys.path.append('.')

def test_cnn_import():
    """Test if CNN module imports correctly"""
    print("πŸ§ͺ Testing CNN module import...")
    try:
        from modules.cnn_deblurring import CNNDeblurModel
        print("βœ… CNN module imported successfully")
        return True
    except Exception as e:
        print(f"❌ CNN import failed: {e}")
        return False

def test_model_creation():
    """Test model creation"""
    print("πŸ§ͺ Testing model creation...")
    try:
        from modules.cnn_deblurring import CNNDeblurModel
        model = CNNDeblurModel()
        model.build_model()
        print("βœ… Model created successfully")
        print(f"   Model input shape: {model.input_shape}")
        print(f"   Model built: {model.model is not None}")
        return True
    except Exception as e:
        print(f"❌ Model creation failed: {e}")
        return False

def test_user_images():
    """Test user images detection"""
    print("πŸ§ͺ Testing user images detection...")
    try:
        dataset_path = "data/training_dataset"
        if os.path.exists(dataset_path):
            valid_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.tif'}
            user_images = [f for f in os.listdir(dataset_path) 
                          if any(f.lower().endswith(ext) for ext in valid_extensions)]
            print(f"βœ… Found {len(user_images)} user training images")
            for img in user_images[:5]:  # Show first 5
                print(f"   - {img}")
            if len(user_images) > 5:
                print(f"   ... and {len(user_images) - 5} more")
            return True
        else:
            print("⚠️ Training dataset directory not found")
            return False
    except Exception as e:
        print(f"❌ User images test failed: {e}")
        return False

def test_quick_dataset_creation():
    """Test creating a small dataset"""
    print("πŸ§ͺ Testing quick dataset creation...")
    try:
        from modules.cnn_deblurring import CNNDeblurModel
        
        model = CNNDeblurModel()
        trainer = model  # For accessing trainer methods
        
        # Create small dataset for testing
        print("   Creating 10 sample dataset...")
        blurred, clean = trainer.create_training_dataset(num_samples=10, save_dataset=False)
        
        print(f"βœ… Dataset created successfully")
        print(f"   Blurred images shape: {blurred.shape}")
        print(f"   Clean images shape: {clean.shape}")
        return True
    except Exception as e:
        print(f"❌ Dataset creation failed: {e}")
        return False

def main():
    """Run all tests"""
    print("πŸš€ CNN Training Test Suite")
    print("=" * 40)
    
    tests = [
        ("CNN Import", test_cnn_import),
        ("Model Creation", test_model_creation),
        ("User Images", test_user_images),
        ("Dataset Creation", test_quick_dataset_creation)
    ]
    
    passed = 0
    total = len(tests)
    
    for name, test_func in tests:
        print(f"\nπŸ“‹ {name}")
        print("-" * 20)
        if test_func():
            passed += 1
        print()
    
    print("=" * 40)
    print(f"πŸ“Š Test Results: {passed}/{total} tests passed")
    
    if passed == total:
        print("πŸŽ‰ All tests passed! Training should work correctly.")
        print("\nπŸ’‘ Next steps:")
        print("   1. Go to your Streamlit app: http://localhost:8503")
        print("   2. Look for 'πŸ€– CNN Model Management' in sidebar")
        print("   3. Click '⚑ Quick Train' to start training")
    else:
        print("❌ Some tests failed. Please check the errors above.")

if __name__ == "__main__":
    main()