Spaces:
Sleeping
Sleeping
| #!/usr/bin/env python3 | |
| """ | |
| Setup script to help configure model paths for the Gradio app | |
| """ | |
| import os | |
| import glob | |
| def find_models(): | |
| """Find available model files""" | |
| model_dir = os.path.join(os.path.dirname(__file__), '..', 'models') | |
| model_dir = os.path.abspath(model_dir) | |
| print("="*70) | |
| print("π Searching for model files") | |
| print("="*70) | |
| print(f"π Model directory: {model_dir}\n") | |
| if not os.path.exists(model_dir): | |
| print(f"β Model directory not found: {model_dir}") | |
| print("Please create the directory and add your trained models.\n") | |
| return | |
| # Find all .pth files | |
| pth_files = glob.glob(os.path.join(model_dir, "*.pth")) | |
| if not pth_files: | |
| print("β No .pth model files found in the models directory") | |
| print("\nπ‘ To train models, run:") | |
| print(" cd ../src/model/shifted_CNN") | |
| print(" python main.py --mode train --model_type cnn --epochs 5") | |
| print(" python main.py --mode train --model_type tinycnn --epochs 5") | |
| print(" python main.py --mode train --model_type minicnn --epochs 5") | |
| return | |
| print(f"β Found {len(pth_files)} model file(s):\n") | |
| # Categorize models | |
| cnn_models = [] | |
| tinycnn_models = [] | |
| minicnn_models = [] | |
| other_models = [] | |
| for file in pth_files: | |
| basename = os.path.basename(file) | |
| if 'CNN_model' in basename and 'Tiny' not in basename and 'Mini' not in basename: | |
| cnn_models.append(file) | |
| elif 'TinyCNN' in basename: | |
| tinycnn_models.append(file) | |
| elif 'MiniCNN' in basename: | |
| minicnn_models.append(file) | |
| else: | |
| other_models.append(file) | |
| # Display findings | |
| if cnn_models: | |
| print("π¦ CNNModel files:") | |
| for model in cnn_models: | |
| print(f" β {os.path.basename(model)}") | |
| else: | |
| print("β οΈ No CNNModel files found") | |
| print() | |
| if tinycnn_models: | |
| print("π¦ TinyCNN files:") | |
| for model in tinycnn_models: | |
| print(f" β {os.path.basename(model)}") | |
| else: | |
| print("β οΈ No TinyCNN files found") | |
| print() | |
| if minicnn_models: | |
| print("π¦ MiniCNN files:") | |
| for model in minicnn_models: | |
| print(f" β {os.path.basename(model)}") | |
| else: | |
| print("β οΈ No MiniCNN files found") | |
| if other_models: | |
| print("\nπ¦ Other model files:") | |
| for model in other_models: | |
| print(f" β {os.path.basename(model)}") | |
| print("\n" + "="*70) | |
| print("π Summary") | |
| print("="*70) | |
| print(f"Total models found: {len(pth_files)}") | |
| print(f"CNNModel: {len(cnn_models)}") | |
| print(f"TinyCNN: {len(tinycnn_models)}") | |
| print(f"MiniCNN: {len(minicnn_models)}") | |
| print(f"Other: {len(other_models)}") | |
| print("\nπ‘ Tips:") | |
| print("1. The Gradio app will automatically detect these models") | |
| print("2. Models should be named with pattern: best_[ModelType]_model_acc_XX.XX.pth") | |
| print("3. If models are not loading, check the file paths in app.py") | |
| print("\nπ Ready to launch!") | |
| print("Run: python app.py") | |
| print("="*70) | |
| if __name__ == "__main__": | |
| find_models() | |