HF-Demo / setup_models.py
felix2703's picture
Add Gradio demo with 6 CNN models (using Git LFS for checkpoints)
d120b6d
#!/usr/bin/env python3
"""
Setup script to help configure model paths for the Gradio app
"""
import os
import glob
def find_models():
"""Find available model files"""
model_dir = os.path.join(os.path.dirname(__file__), '..', 'models')
model_dir = os.path.abspath(model_dir)
print("="*70)
print("πŸ” Searching for model files")
print("="*70)
print(f"πŸ“‚ Model directory: {model_dir}\n")
if not os.path.exists(model_dir):
print(f"❌ Model directory not found: {model_dir}")
print("Please create the directory and add your trained models.\n")
return
# Find all .pth files
pth_files = glob.glob(os.path.join(model_dir, "*.pth"))
if not pth_files:
print("❌ No .pth model files found in the models directory")
print("\nπŸ’‘ To train models, run:")
print(" cd ../src/model/shifted_CNN")
print(" python main.py --mode train --model_type cnn --epochs 5")
print(" python main.py --mode train --model_type tinycnn --epochs 5")
print(" python main.py --mode train --model_type minicnn --epochs 5")
return
print(f"βœ… Found {len(pth_files)} model file(s):\n")
# Categorize models
cnn_models = []
tinycnn_models = []
minicnn_models = []
other_models = []
for file in pth_files:
basename = os.path.basename(file)
if 'CNN_model' in basename and 'Tiny' not in basename and 'Mini' not in basename:
cnn_models.append(file)
elif 'TinyCNN' in basename:
tinycnn_models.append(file)
elif 'MiniCNN' in basename:
minicnn_models.append(file)
else:
other_models.append(file)
# Display findings
if cnn_models:
print("πŸ“¦ CNNModel files:")
for model in cnn_models:
print(f" βœ“ {os.path.basename(model)}")
else:
print("⚠️ No CNNModel files found")
print()
if tinycnn_models:
print("πŸ“¦ TinyCNN files:")
for model in tinycnn_models:
print(f" βœ“ {os.path.basename(model)}")
else:
print("⚠️ No TinyCNN files found")
print()
if minicnn_models:
print("πŸ“¦ MiniCNN files:")
for model in minicnn_models:
print(f" βœ“ {os.path.basename(model)}")
else:
print("⚠️ No MiniCNN files found")
if other_models:
print("\nπŸ“¦ Other model files:")
for model in other_models:
print(f" βœ“ {os.path.basename(model)}")
print("\n" + "="*70)
print("πŸ“‹ Summary")
print("="*70)
print(f"Total models found: {len(pth_files)}")
print(f"CNNModel: {len(cnn_models)}")
print(f"TinyCNN: {len(tinycnn_models)}")
print(f"MiniCNN: {len(minicnn_models)}")
print(f"Other: {len(other_models)}")
print("\nπŸ’‘ Tips:")
print("1. The Gradio app will automatically detect these models")
print("2. Models should be named with pattern: best_[ModelType]_model_acc_XX.XX.pth")
print("3. If models are not loading, check the file paths in app.py")
print("\nπŸš€ Ready to launch!")
print("Run: python app.py")
print("="*70)
if __name__ == "__main__":
find_models()