Spaces:
Sleeping
Sleeping
File size: 3,295 Bytes
d120b6d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
#!/usr/bin/env python3
"""
Setup script to help configure model paths for the Gradio app
"""
import os
import glob
def find_models():
"""Find available model files"""
model_dir = os.path.join(os.path.dirname(__file__), '..', 'models')
model_dir = os.path.abspath(model_dir)
print("="*70)
print("π Searching for model files")
print("="*70)
print(f"π Model directory: {model_dir}\n")
if not os.path.exists(model_dir):
print(f"β Model directory not found: {model_dir}")
print("Please create the directory and add your trained models.\n")
return
# Find all .pth files
pth_files = glob.glob(os.path.join(model_dir, "*.pth"))
if not pth_files:
print("β No .pth model files found in the models directory")
print("\nπ‘ To train models, run:")
print(" cd ../src/model/shifted_CNN")
print(" python main.py --mode train --model_type cnn --epochs 5")
print(" python main.py --mode train --model_type tinycnn --epochs 5")
print(" python main.py --mode train --model_type minicnn --epochs 5")
return
print(f"β
Found {len(pth_files)} model file(s):\n")
# Categorize models
cnn_models = []
tinycnn_models = []
minicnn_models = []
other_models = []
for file in pth_files:
basename = os.path.basename(file)
if 'CNN_model' in basename and 'Tiny' not in basename and 'Mini' not in basename:
cnn_models.append(file)
elif 'TinyCNN' in basename:
tinycnn_models.append(file)
elif 'MiniCNN' in basename:
minicnn_models.append(file)
else:
other_models.append(file)
# Display findings
if cnn_models:
print("π¦ CNNModel files:")
for model in cnn_models:
print(f" β {os.path.basename(model)}")
else:
print("β οΈ No CNNModel files found")
print()
if tinycnn_models:
print("π¦ TinyCNN files:")
for model in tinycnn_models:
print(f" β {os.path.basename(model)}")
else:
print("β οΈ No TinyCNN files found")
print()
if minicnn_models:
print("π¦ MiniCNN files:")
for model in minicnn_models:
print(f" β {os.path.basename(model)}")
else:
print("β οΈ No MiniCNN files found")
if other_models:
print("\nπ¦ Other model files:")
for model in other_models:
print(f" β {os.path.basename(model)}")
print("\n" + "="*70)
print("π Summary")
print("="*70)
print(f"Total models found: {len(pth_files)}")
print(f"CNNModel: {len(cnn_models)}")
print(f"TinyCNN: {len(tinycnn_models)}")
print(f"MiniCNN: {len(minicnn_models)}")
print(f"Other: {len(other_models)}")
print("\nπ‘ Tips:")
print("1. The Gradio app will automatically detect these models")
print("2. Models should be named with pattern: best_[ModelType]_model_acc_XX.XX.pth")
print("3. If models are not loading, check the file paths in app.py")
print("\nπ Ready to launch!")
print("Run: python app.py")
print("="*70)
if __name__ == "__main__":
find_models()
|