#!/usr/bin/env python3 """ Setup script to help configure model paths for the Gradio app """ import os import glob def find_models(): """Find available model files""" model_dir = os.path.join(os.path.dirname(__file__), '..', 'models') model_dir = os.path.abspath(model_dir) print("="*70) print("šŸ” Searching for model files") print("="*70) print(f"šŸ“‚ Model directory: {model_dir}\n") if not os.path.exists(model_dir): print(f"āŒ Model directory not found: {model_dir}") print("Please create the directory and add your trained models.\n") return # Find all .pth files pth_files = glob.glob(os.path.join(model_dir, "*.pth")) if not pth_files: print("āŒ No .pth model files found in the models directory") print("\nšŸ’” To train models, run:") print(" cd ../src/model/shifted_CNN") print(" python main.py --mode train --model_type cnn --epochs 5") print(" python main.py --mode train --model_type tinycnn --epochs 5") print(" python main.py --mode train --model_type minicnn --epochs 5") return print(f"āœ… Found {len(pth_files)} model file(s):\n") # Categorize models cnn_models = [] tinycnn_models = [] minicnn_models = [] other_models = [] for file in pth_files: basename = os.path.basename(file) if 'CNN_model' in basename and 'Tiny' not in basename and 'Mini' not in basename: cnn_models.append(file) elif 'TinyCNN' in basename: tinycnn_models.append(file) elif 'MiniCNN' in basename: minicnn_models.append(file) else: other_models.append(file) # Display findings if cnn_models: print("šŸ“¦ CNNModel files:") for model in cnn_models: print(f" āœ“ {os.path.basename(model)}") else: print("āš ļø No CNNModel files found") print() if tinycnn_models: print("šŸ“¦ TinyCNN files:") for model in tinycnn_models: print(f" āœ“ {os.path.basename(model)}") else: print("āš ļø No TinyCNN files found") print() if minicnn_models: print("šŸ“¦ MiniCNN files:") for model in minicnn_models: print(f" āœ“ {os.path.basename(model)}") else: print("āš ļø No MiniCNN files found") if other_models: print("\nšŸ“¦ Other model files:") for model in other_models: print(f" āœ“ {os.path.basename(model)}") print("\n" + "="*70) print("šŸ“‹ Summary") print("="*70) print(f"Total models found: {len(pth_files)}") print(f"CNNModel: {len(cnn_models)}") print(f"TinyCNN: {len(tinycnn_models)}") print(f"MiniCNN: {len(minicnn_models)}") print(f"Other: {len(other_models)}") print("\nšŸ’” Tips:") print("1. The Gradio app will automatically detect these models") print("2. Models should be named with pattern: best_[ModelType]_model_acc_XX.XX.pth") print("3. If models are not loading, check the file paths in app.py") print("\nšŸš€ Ready to launch!") print("Run: python app.py") print("="*70) if __name__ == "__main__": find_models()