File size: 3,295 Bytes
d120b6d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
#!/usr/bin/env python3
"""
Setup script to help configure model paths for the Gradio app
"""

import os
import glob

def find_models():
    """Find available model files"""
    model_dir = os.path.join(os.path.dirname(__file__), '..', 'models')
    model_dir = os.path.abspath(model_dir)
    
    print("="*70)
    print("πŸ” Searching for model files")
    print("="*70)
    print(f"πŸ“‚ Model directory: {model_dir}\n")
    
    if not os.path.exists(model_dir):
        print(f"❌ Model directory not found: {model_dir}")
        print("Please create the directory and add your trained models.\n")
        return
    
    # Find all .pth files
    pth_files = glob.glob(os.path.join(model_dir, "*.pth"))
    
    if not pth_files:
        print("❌ No .pth model files found in the models directory")
        print("\nπŸ’‘ To train models, run:")
        print("   cd ../src/model/shifted_CNN")
        print("   python main.py --mode train --model_type cnn --epochs 5")
        print("   python main.py --mode train --model_type tinycnn --epochs 5")
        print("   python main.py --mode train --model_type minicnn --epochs 5")
        return
    
    print(f"βœ… Found {len(pth_files)} model file(s):\n")
    
    # Categorize models
    cnn_models = []
    tinycnn_models = []
    minicnn_models = []
    other_models = []
    
    for file in pth_files:
        basename = os.path.basename(file)
        if 'CNN_model' in basename and 'Tiny' not in basename and 'Mini' not in basename:
            cnn_models.append(file)
        elif 'TinyCNN' in basename:
            tinycnn_models.append(file)
        elif 'MiniCNN' in basename:
            minicnn_models.append(file)
        else:
            other_models.append(file)
    
    # Display findings
    if cnn_models:
        print("πŸ“¦ CNNModel files:")
        for model in cnn_models:
            print(f"   βœ“ {os.path.basename(model)}")
    else:
        print("⚠️  No CNNModel files found")
    
    print()
    
    if tinycnn_models:
        print("πŸ“¦ TinyCNN files:")
        for model in tinycnn_models:
            print(f"   βœ“ {os.path.basename(model)}")
    else:
        print("⚠️  No TinyCNN files found")
    
    print()
    
    if minicnn_models:
        print("πŸ“¦ MiniCNN files:")
        for model in minicnn_models:
            print(f"   βœ“ {os.path.basename(model)}")
    else:
        print("⚠️  No MiniCNN files found")
    
    if other_models:
        print("\nπŸ“¦ Other model files:")
        for model in other_models:
            print(f"   βœ“ {os.path.basename(model)}")
    
    print("\n" + "="*70)
    print("πŸ“‹ Summary")
    print("="*70)
    print(f"Total models found: {len(pth_files)}")
    print(f"CNNModel: {len(cnn_models)}")
    print(f"TinyCNN: {len(tinycnn_models)}")
    print(f"MiniCNN: {len(minicnn_models)}")
    print(f"Other: {len(other_models)}")
    
    print("\nπŸ’‘ Tips:")
    print("1. The Gradio app will automatically detect these models")
    print("2. Models should be named with pattern: best_[ModelType]_model_acc_XX.XX.pth")
    print("3. If models are not loading, check the file paths in app.py")
    
    print("\nπŸš€ Ready to launch!")
    print("Run: python app.py")
    print("="*70)


if __name__ == "__main__":
    find_models()