felix2703 commited on
Commit
d120b6d
Β·
1 Parent(s): c45f38a

Add Gradio demo with 6 CNN models (using Git LFS for checkpoints)

Browse files
README copy.md ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # πŸ”’ Shifted MNIST CNN Classifier - Gradio Demo
2
+
3
+ A Gradio web application for comparing three CNN architectures (CNNModel, TinyCNN, MiniCNN) trained on shifted MNIST labels with real-time inference timing.
4
+
5
+ ## 🌟 Features
6
+
7
+ - **Three Model Architectures**:
8
+ - **CNNModel**: 817,354 parameters - High accuracy, slower inference
9
+ - **TinyCNN**: 94,410 parameters - Balanced speed and accuracy
10
+ - **MiniCNN**: ~19,000 parameters - Fast inference, lightweight
11
+
12
+ - **Real-time Inference Timing**: See inference time in milliseconds for each prediction
13
+ - **Probability Distribution**: Visualize prediction confidence across all digits
14
+ - **Side-by-Side Comparison**: Compare all models simultaneously
15
+ - **Interactive Interface**: Upload images via file, webcam, or clipboard
16
+
17
+ ## πŸš€ Quick Start
18
+
19
+ ### 1. Install Dependencies
20
+
21
+ ```bash
22
+ pip install -r requirements.txt
23
+ ```
24
+
25
+ ### 2. Prepare Model Checkpoints
26
+
27
+ Place your trained model checkpoints in the `../models/` directory with these names:
28
+ - `best_CNN_model_acc_XX.XX.pth`
29
+ - `best_TinyCNN_model_acc_XX.XX.pth`
30
+ - `best_MiniCNN_model_acc_XX.XX.pth`
31
+
32
+ Or update the model paths in `app.py`:
33
+ ```python
34
+ cnn_model_path = 'path/to/your/cnn_model.pth'
35
+ tinycnn_model_path = 'path/to/your/tinycnn_model.pth'
36
+ minicnn_model_path = 'path/to/your/minicnn_model.pth'
37
+ ```
38
+
39
+ ### 3. Launch the App
40
+
41
+ ```bash
42
+ python app.py
43
+ ```
44
+
45
+ The app will be available at `http://localhost:7860`
46
+
47
+ ## πŸ“ Usage
48
+
49
+ 1. **Upload an Image**: Draw or upload a handwritten digit (0-9)
50
+ 2. **Choose Mode**:
51
+ - **Individual Models**: Test each model separately
52
+ - **Compare All Models**: See predictions from all models side-by-side
53
+ 3. **View Results**:
54
+ - Predicted digit
55
+ - Shifted label (internal representation)
56
+ - Confidence score
57
+ - Inference time in milliseconds
58
+ - Top 3 predictions with probabilities
59
+
60
+ ## 🎯 About Shifted MNIST
61
+
62
+ The models are trained on **shifted MNIST labels** where:
63
+ - Original Digit 0 β†’ Shifted Label 9
64
+ - Original Digit 1 β†’ Shifted Label 8
65
+ - Original Digit 2 β†’ Shifted Label 7
66
+ - ... (reversed mapping)
67
+
68
+ The app automatically unmaps the predictions to show the **original digit**.
69
+
70
+ ## πŸ“Š Model Comparison
71
+
72
+ | Model | Parameters | Architecture | Best For |
73
+ |-------|------------|--------------|----------|
74
+ | CNNModel | 817,354 | 3 Conv blocks + 3 FC layers with dropout | High accuracy tasks |
75
+ | TinyCNN | 94,410 | 3 Conv blocks + Global Avg Pool + FC | Balanced performance |
76
+ | MiniCNN | ~19,000 | 2 Conv blocks + Global Avg Pool + FC | Edge devices, fast inference |
77
+
78
+ ## ⏱️ Performance Tips
79
+
80
+ - **GPU Acceleration**: Models automatically use CUDA if available
81
+ - **First Prediction**: May be slower due to model warm-up
82
+ - **Batch Processing**: Consider using the comparison mode for efficiency
83
+ - **Image Quality**: Clear, centered digits work best
84
+
85
+ ## πŸ› οΈ Customization
86
+
87
+ ### Change Port
88
+ ```python
89
+ demo.launch(server_port=8080) # Change from default 7860
90
+ ```
91
+
92
+ ### Enable Public URL
93
+ ```python
94
+ demo.launch(share=True) # Creates a public Gradio link
95
+ ```
96
+
97
+ ### Modify Model Paths
98
+ Edit the `MODEL_DIR` and model path variables in `app.py`
99
+
100
+ ## πŸ“¦ Project Structure
101
+
102
+ ```
103
+ HF_demo/
104
+ β”œβ”€β”€ app.py # Main Gradio application
105
+ β”œβ”€β”€ requirements.txt # Python dependencies
106
+ └── README.md # This file
107
+
108
+ ../models/ # Model checkpoints directory
109
+ β”œβ”€β”€ best_CNN_model_acc_99.22.pth
110
+ β”œβ”€β”€ best_TinyCNN_model_acc_98.95.pth
111
+ └── best_MiniCNN_model_acc_95.00.pth
112
+ ```
113
+
114
+ ## πŸ› Troubleshooting
115
+
116
+ ### HuggingFace Upload Issues
117
+
118
+ #### Authentication Problems
119
+ ```bash
120
+ # Login manually first
121
+ python -c "
122
+ from huggingface_hub import login
123
+ login('your_token_here')
124
+ "
125
+ ```
126
+
127
+ #### Repository Not Found
128
+ The setup command creates the repository automatically. If it fails:
129
+ 1. Create repository manually on HuggingFace Hub
130
+ 2. Make sure your token has write access
131
+ 3. Check username spelling
132
+
133
+ #### Upload Failures
134
+ - Check internet connection
135
+ - Verify model file exists and is readable
136
+ - Check available disk space on HuggingFace Hub (5GB free tier limit)
137
+ - Ensure model file isn't corrupted
138
+
139
+ ### Gradio App Issues
140
+
141
+ #### Model Not Found Error
142
+ - Ensure model checkpoint files exist in the correct directory
143
+ - Check file paths in `app.py`
144
+ - Verify model naming convention
145
+
146
+ #### Import Error
147
+ - Make sure the parent directory structure is correct
148
+ - Check that `src/model/shifted_CNN/model.py` exists
149
+
150
+ #### Slow Inference
151
+ - Check if CUDA is available: `torch.cuda.is_available()`
152
+ - Reduce image size if needed
153
+ - Use MiniCNN for fastest inference
154
+
155
+ ### Version Conflicts
156
+ If you get version conflicts, delete `upload_state.json` and run setup again.
157
+
158
+ ## πŸ“„ License
159
+
160
+ This project is part of the Beat Nobita Challenge.
161
+
162
+ ## 🀝 Contributing
163
+
164
+ Contributions are welcome! Feel free to:
165
+ - Add new model architectures
166
+ - Improve the UI/UX
167
+ - Add batch prediction support
168
+ - Optimize inference speed
169
+
170
+ ## πŸ“š API Reference
171
+
172
+ ### HuggingFaceUploader Class
173
+
174
+ ```python
175
+ from hf_uploader import HuggingFaceUploader
176
+
177
+ uploader = HuggingFaceUploader(username, repo_name="shifted-mnist-cnn")
178
+
179
+ # Methods:
180
+ uploader.login(token=None) # Login to HF Hub
181
+ uploader.setup_repository() # Create/setup repository
182
+ uploader.upload_model(model_path) # Upload single model
183
+ uploader.upload_all_models(models_dir) # Upload all models in directory
184
+ uploader.get_model_info(model_path) # Get model metadata
185
+ ```
186
+
187
+ ### ModelMonitor Class
188
+
189
+ ```python
190
+ from auto_upload import ModelMonitor
191
+
192
+ monitor = ModelMonitor(username, repo_name, models_dir)
193
+
194
+ # Methods:
195
+ monitor.scan_for_new_models() # Scan for new models
196
+ monitor.upload_new_models(model_list) # Upload list of models
197
+ monitor.run_once() # Single scan + upload cycle
198
+ monitor.run_monitor(interval=300) # Continuous monitoring
199
+ ```
200
+
201
+ ## βš™οΈ Configuration
202
+
203
+ ### Environment Variables
204
+ - `HUGGINGFACE_TOKEN`: Your HuggingFace token
205
+ - `HF_REPO_NAME`: Repository name (default: "shifted-mnist-cnn")
206
+ - `HF_USERNAME`: Your HuggingFace username
207
+
208
+ ### Command Line Options
209
+ ```bash
210
+ python auto_upload.py --help
211
+
212
+ # Common options:
213
+ --username USER # Your HuggingFace username (required)
214
+ --repo-name NAME # Repository name (default: shifted-mnist-cnn)
215
+ --token TOKEN # HF token (or use env var)
216
+ --models-dir DIR # Models directory (default: ../models)
217
+ --interval SECONDS # Monitor interval (default: 300)
218
+ ```
219
+
220
+ ## πŸ”„ Complete Workflow Example
221
+
222
+ 1. **Setup once**:
223
+ ```bash
224
+ python auto_upload.py setup --username myusername
225
+ ```
226
+
227
+ 2. **Start monitoring** (in background):
228
+ ```bash
229
+ python auto_upload.py monitor --username myusername &
230
+ ```
231
+
232
+ 3. **Launch Gradio app** (in another terminal):
233
+ ```bash
234
+ python app.py
235
+ ```
236
+
237
+ 4. **Train new models** (they automatically upload!):
238
+ ```bash
239
+ cd ../src/model/shifted_CNN
240
+ python main.py --model CNN --train --epochs 20
241
+ ```
242
+
243
+ 5. **Check your HuggingFace repository**:
244
+ ```
245
+ https://huggingface.co/myusername/shifted-mnist-cnn
246
+ ```
247
+
248
+ ## 🎯 Model Naming Convention
249
+
250
+ The system expects models named like:
251
+ - `best_CNN_model_acc_99.33.pth`
252
+ - `best_TinyCNN_model_acc_99.17.pth`
253
+ - `best_MiniCNN_model_acc_97.57.pth`
254
+
255
+ Where:
256
+ - **Architecture**: CNN, TinyCNN, or MiniCNN
257
+ - **Accuracy**: Extracted from filename (optional)
258
+
259
+ ---
260
+
261
+ **Built with ❀️ using Gradio, PyTorch, and HuggingFace Hub**
262
+
263
+ πŸŽ‰ **Your models are now automatically backed up, versioned, and shareable on HuggingFace Hub!**
app.py ADDED
@@ -0,0 +1,669 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio Demo for Shifted MNIST CNN Models
3
+ Supports 6 models:
4
+ - Shifted MNIST: CNNModel, TinyCNN, MiniCNN
5
+ - Attack CNN: Standard, Lighter, Depthwise
6
+ """
7
+
8
+ import gradio as gr
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ from PIL import Image
13
+ import numpy as np
14
+ import time
15
+ import sys
16
+ import os
17
+
18
+ # Add parent directories to path to import models
19
+ shifted_cnn_path = os.path.join(os.path.dirname(__file__), '..', 'src', 'model', 'shifted_CNN')
20
+ attack_cnn_path = os.path.join(os.path.dirname(__file__), '..', 'src', 'model', 'attack_cnn')
21
+
22
+ sys.path.append(shifted_cnn_path)
23
+ sys.path.append(attack_cnn_path)
24
+
25
+ # Import shifted MNIST models
26
+ import importlib.util
27
+ spec_shifted = importlib.util.spec_from_file_location("shifted_model", os.path.join(shifted_cnn_path, "model.py"))
28
+ shifted_model = importlib.util.module_from_spec(spec_shifted)
29
+ spec_shifted.loader.exec_module(shifted_model)
30
+
31
+ # Import attack CNN models
32
+ spec_attack = importlib.util.spec_from_file_location("attack_model", os.path.join(attack_cnn_path, "model.py"))
33
+ attack_model = importlib.util.module_from_spec(spec_attack)
34
+ spec_attack.loader.exec_module(attack_model)
35
+
36
+ # Get classes from modules
37
+ CNNModel = shifted_model.CNNModel
38
+ TinyCNN = shifted_model.TinyCNN
39
+ MiniCNN = shifted_model.MiniCNN
40
+ StandardCNN = attack_model.StandardCNN
41
+ LighterCNN = attack_model.LighterCNN
42
+ DepthwiseCNN = attack_model.DepthwiseCNN
43
+
44
+
45
+ # Label mapping for shifted MNIST
46
+ LABEL_MAPPING = {0: 9, 1: 8, 2: 7, 3: 6, 4: 5, 5: 4, 6: 3, 7: 2, 8: 1, 9: 0}
47
+ REVERSE_MAPPING = {v: k for k, v in LABEL_MAPPING.items()}
48
+
49
+
50
+ def get_device():
51
+ """Get the best available device"""
52
+ if torch.cuda.is_available():
53
+ return torch.device('cuda')
54
+ elif torch.backends.mps.is_available():
55
+ return torch.device('mps')
56
+ else:
57
+ return torch.device('cpu')
58
+
59
+
60
+ def load_model(model_path, model_type, device):
61
+ """Load a trained model from checkpoint"""
62
+ # Create model instance
63
+ if model_type == 'CNN':
64
+ model = CNNModel(num_classes=10, dropout_rate=0.5)
65
+ elif model_type == 'TinyCNN':
66
+ model = TinyCNN(num_classes=10)
67
+ elif model_type == 'MiniCNN':
68
+ model = MiniCNN(num_classes=10)
69
+ elif model_type == 'StandardAttack':
70
+ model = StandardCNN(num_classes=10, dropout_rate=0.5)
71
+ elif model_type == 'LighterAttack':
72
+ model = LighterCNN(num_classes=10, dropout_rate=0.5)
73
+ elif model_type == 'DepthwiseAttack':
74
+ model = DepthwiseCNN(num_classes=10, dropout_rate=0.5)
75
+ else:
76
+ raise ValueError(f"Unknown model type: {model_type}")
77
+
78
+ # Load checkpoint
79
+ checkpoint = torch.load(model_path, map_location=device)
80
+
81
+ # Handle different checkpoint formats
82
+ if isinstance(checkpoint, dict):
83
+ if 'model_state_dict' in checkpoint:
84
+ # Shifted MNIST format: {'model_state_dict': ..., 'model_info': ...}
85
+ model.load_state_dict(checkpoint['model_state_dict'])
86
+ model_info = checkpoint.get('model_info', {})
87
+ else:
88
+ # Direct state dict format
89
+ model.load_state_dict(checkpoint)
90
+ model_info = {}
91
+ else:
92
+ # Fallback: assume it's a state dict
93
+ model.load_state_dict(checkpoint)
94
+ model_info = {}
95
+
96
+ # If model_info is empty, calculate parameters
97
+ if not model_info.get('total_parameters'):
98
+ total_params = sum(p.numel() for p in model.parameters())
99
+ model_info['total_parameters'] = total_params
100
+ model_info['architecture'] = model_type
101
+
102
+ model.to(device)
103
+ model.eval()
104
+
105
+ return model, model_info
106
+
107
+
108
+ def preprocess_image(image):
109
+ """Preprocess image for model input"""
110
+ # Convert to grayscale if needed
111
+ if image.mode != 'L':
112
+ image = image.convert('L')
113
+
114
+ # Resize to 28x28
115
+ image = image.resize((28, 28), Image.Resampling.LANCZOS)
116
+
117
+ # Convert to numpy array and normalize
118
+ img_array = np.array(image).astype(np.float32) / 255.0
119
+
120
+ # Apply MNIST normalization
121
+ mean = 0.1307
122
+ std = 0.3081
123
+ img_array = (img_array - mean) / std
124
+
125
+ # Convert to tensor and add batch and channel dimensions
126
+ img_tensor = torch.from_numpy(img_array).unsqueeze(0).unsqueeze(0)
127
+
128
+ return img_tensor
129
+
130
+
131
+ def logit_attack_lowest(logits, margin=5.0):
132
+ """
133
+ Attack by boosting lowest logit
134
+
135
+ Args:
136
+ logits: Model logits (batch_size, num_classes)
137
+ margin: How much to boost the lowest logit above highest
138
+
139
+ Returns:
140
+ attacked_logits
141
+ """
142
+ attacked_logits = logits.clone()
143
+ batch_size = logits.size(0)
144
+
145
+ for i in range(batch_size):
146
+ highest_val = torch.max(logits[i]).item()
147
+ lowest_idx = torch.argmin(logits[i]).item()
148
+ lowest_val = logits[i, lowest_idx].item()
149
+
150
+ delta_needed = (highest_val - lowest_val) + margin
151
+ attacked_logits[i, lowest_idx] += delta_needed
152
+
153
+ return attacked_logits
154
+
155
+
156
+ def predict_with_timing(model, image, device, apply_attack=False, margin=5.0):
157
+ """Make prediction with timing"""
158
+ # Preprocess image
159
+ img_tensor = preprocess_image(image).to(device)
160
+
161
+ # Check if model supports return_logits parameter (Attack CNN models)
162
+ # by checking if it has the parameter in forward signature
163
+ supports_return_logits = apply_attack # Only attack models need logits
164
+
165
+ # Warm-up run (for accurate timing on GPU)
166
+ with torch.no_grad():
167
+ if supports_return_logits:
168
+ _ = model(img_tensor, return_logits=True)
169
+ else:
170
+ _ = model(img_tensor)
171
+
172
+ # Actual prediction with timing
173
+ start_time = time.time()
174
+ with torch.no_grad():
175
+ if supports_return_logits:
176
+ # Attack CNN models - get logits
177
+ logits = model(img_tensor, return_logits=True)
178
+
179
+ # Apply attack if requested
180
+ if apply_attack:
181
+ logits = logit_attack_lowest(logits, margin=margin)
182
+
183
+ probabilities = F.softmax(logits, dim=1)
184
+ else:
185
+ # Shifted MNIST models - already return softmax probabilities
186
+ outputs = model(img_tensor)
187
+ # If outputs are logits, apply softmax; if already probabilities, use as-is
188
+ if outputs.max() > 1.0 or outputs.min() < 0.0:
189
+ # Likely logits
190
+ probabilities = F.softmax(outputs, dim=1)
191
+ else:
192
+ # Already probabilities
193
+ probabilities = outputs
194
+ end_time = time.time()
195
+
196
+ inference_time = (end_time - start_time) * 1000 # Convert to milliseconds
197
+
198
+ # Get predictions
199
+ probs = probabilities.cpu().numpy()[0]
200
+ predicted_label = np.argmax(probs)
201
+ confidence = probs[predicted_label] * 100
202
+
203
+ return predicted_label, confidence, probs, inference_time
204
+
205
+
206
+ def create_prediction_output(predicted_label, confidence, probs, inference_time, model_name, model_info):
207
+ """Create formatted prediction output"""
208
+ # Main prediction
209
+ result_text = f"### 🎯 Prediction Results ({model_name})\n\n"
210
+ result_text += f"**Predicted Label:** {predicted_label}\n\n"
211
+ result_text += f"**Confidence:** {confidence:.2f}%\n\n"
212
+ result_text += f"**⏱️ Inference Time:** {inference_time:.3f} ms\n\n"
213
+
214
+ # Model info
215
+ if model_info:
216
+ result_text += f"**πŸ“Š Model Info:**\n"
217
+ result_text += f"- Parameters: {model_info.get('total_parameters', 'N/A'):,}\n"
218
+ result_text += f"- Architecture: {model_info.get('architecture', 'N/A')}\n\n"
219
+
220
+ # Create probability distribution dictionary for plot - showing predicted labels
221
+ prob_dict = {}
222
+ for i in range(10):
223
+ prob_dict[f"Label {i}"] = float(probs[i])
224
+
225
+ return result_text, prob_dict
226
+
227
+
228
+ def predict_cnn(image):
229
+ """Predict using CNNModel"""
230
+ if image is None:
231
+ return "Please upload an image", {}
232
+
233
+ if cnn_model is None:
234
+ return "❌ CNNModel not loaded. Please check the model path.", {}
235
+
236
+ try:
237
+ predicted_label, conf, probs, inf_time = predict_with_timing(
238
+ cnn_model, image, device
239
+ )
240
+ text_output, prob_dict = create_prediction_output(
241
+ predicted_label, conf, probs, inf_time, "CNNModel", cnn_info
242
+ )
243
+ return text_output, prob_dict
244
+ except Exception as e:
245
+ import traceback
246
+ error_msg = f"❌ **Error occurred:**\n\n```\n{str(e)}\n{traceback.format_exc()}\n```"
247
+ return error_msg, {}
248
+
249
+
250
+ def predict_tinycnn(image):
251
+ """Predict using TinyCNN"""
252
+ if image is None:
253
+ return "Please upload an image", {}
254
+
255
+ if tinycnn_model is None:
256
+ return "❌ TinyCNN not loaded. Please check the model path.", {}
257
+
258
+ try:
259
+ predicted_label, conf, probs, inf_time = predict_with_timing(
260
+ tinycnn_model, image, device
261
+ )
262
+ text_output, prob_dict = create_prediction_output(
263
+ predicted_label, conf, probs, inf_time, "TinyCNN", tinycnn_info
264
+ )
265
+ return text_output, prob_dict
266
+ except Exception as e:
267
+ import traceback
268
+ error_msg = f"❌ **Error occurred:**\n\n```\n{str(e)}\n{traceback.format_exc()}\n```"
269
+ return error_msg, {}
270
+
271
+
272
+ def predict_minicnn(image):
273
+ """Predict using MiniCNN"""
274
+ if image is None:
275
+ return "Please upload an image", {}
276
+
277
+ if minicnn_model is None:
278
+ return "❌ MiniCNN not loaded. Please check the model path.", {}
279
+
280
+ try:
281
+ predicted_label, conf, probs, inf_time = predict_with_timing(
282
+ minicnn_model, image, device
283
+ )
284
+ text_output, prob_dict = create_prediction_output(
285
+ predicted_label, conf, probs, inf_time, "MiniCNN", minicnn_info
286
+ )
287
+ return text_output, prob_dict
288
+ except Exception as e:
289
+ import traceback
290
+ error_msg = f"❌ **Error occurred:**\n\n```\n{str(e)}\n{traceback.format_exc()}\n```"
291
+ return error_msg, {}
292
+
293
+
294
+ def predict_standard_attack(image):
295
+ """Predict using Standard Attack CNN with attack enabled (margin=5)"""
296
+ if image is None:
297
+ return "Please upload an image", {}
298
+
299
+ if standard_attack_model is None:
300
+ return "❌ Standard Attack CNN not loaded. Please check the model path.", {}
301
+
302
+ try:
303
+ predicted_label, conf, probs, inf_time = predict_with_timing(
304
+ standard_attack_model, image, device, apply_attack=True, margin=5.0
305
+ )
306
+ text_output, prob_dict = create_prediction_output(
307
+ predicted_label, conf, probs, inf_time, "Standard Attack CNN (margin=5)", standard_attack_info
308
+ )
309
+ return text_output, prob_dict
310
+ except Exception as e:
311
+ import traceback
312
+ error_msg = f"❌ **Error occurred:**\n\n```\n{str(e)}\n{traceback.format_exc()}\n```"
313
+ return error_msg, {}
314
+
315
+
316
+ def predict_lighter_attack(image):
317
+ """Predict using Lighter Attack CNN with attack enabled (margin=5)"""
318
+ if image is None:
319
+ return "Please upload an image", {}
320
+
321
+ if lighter_attack_model is None:
322
+ return "❌ Lighter Attack CNN not loaded. Please check the model path.", {}
323
+
324
+ try:
325
+ predicted_label, conf, probs, inf_time = predict_with_timing(
326
+ lighter_attack_model, image, device, apply_attack=True, margin=5.0
327
+ )
328
+ text_output, prob_dict = create_prediction_output(
329
+ predicted_label, conf, probs, inf_time, "Lighter Attack CNN (margin=5)", lighter_attack_info
330
+ )
331
+ return text_output, prob_dict
332
+ except Exception as e:
333
+ import traceback
334
+ error_msg = f"❌ **Error occurred:**\n\n```\n{str(e)}\n{traceback.format_exc()}\n```"
335
+ return error_msg, {}
336
+
337
+
338
+ def predict_depthwise_attack(image):
339
+ """Predict using Depthwise Attack CNN with attack enabled (margin=5)"""
340
+ if image is None:
341
+ return "Please upload an image", {}
342
+
343
+ if depthwise_attack_model is None:
344
+ return "❌ Depthwise Attack CNN not loaded. Please check the model path.", {}
345
+
346
+ try:
347
+ predicted_label, conf, probs, inf_time = predict_with_timing(
348
+ depthwise_attack_model, image, device, apply_attack=True, margin=5.0
349
+ )
350
+ text_output, prob_dict = create_prediction_output(
351
+ predicted_label, conf, probs, inf_time, "Depthwise Attack CNN (margin=5)", depthwise_attack_info
352
+ )
353
+ return text_output, prob_dict
354
+ except Exception as e:
355
+ import traceback
356
+ error_msg = f"❌ **Error occurred:**\n\n```\n{str(e)}\n{traceback.format_exc()}\n```"
357
+ return error_msg, {}
358
+
359
+
360
+ def predict_all_models(image):
361
+ """Predict using all models and compare"""
362
+ if image is None:
363
+ empty_msg = "Please upload an image"
364
+ return empty_msg, {}, empty_msg, {}, empty_msg, {}, empty_msg, {}, empty_msg, {}, empty_msg, {}
365
+
366
+ try:
367
+ # Shifted MNIST models
368
+ cnn_text, cnn_probs = predict_cnn(image)
369
+ tiny_text, tiny_probs = predict_tinycnn(image)
370
+ mini_text, mini_probs = predict_minicnn(image)
371
+
372
+ # Attack CNN models
373
+ standard_text, standard_probs = predict_standard_attack(image)
374
+ lighter_text, lighter_probs = predict_lighter_attack(image)
375
+ depthwise_text, depthwise_probs = predict_depthwise_attack(image)
376
+
377
+ return (cnn_text, cnn_probs,
378
+ tiny_text, tiny_probs,
379
+ mini_text, mini_probs,
380
+ standard_text, standard_probs,
381
+ lighter_text, lighter_probs,
382
+ depthwise_text, depthwise_probs)
383
+ except Exception as e:
384
+ import traceback
385
+ error_msg = f"❌ **Error occurred:**\n\n```\n{str(e)}\n{traceback.format_exc()}\n```"
386
+ return error_msg, {}, error_msg, {}, error_msg, {}, error_msg, {}, error_msg, {}, error_msg, {}
387
+
388
+
389
+ # Initialize device
390
+ device = get_device()
391
+ print(f"πŸ–₯️ Using device: {device}")
392
+
393
+ # Load models
394
+ print("πŸ“₯ Loading models...")
395
+
396
+ # Define model paths - use checkpoints in HF_demo directory
397
+ MODEL_DIR = os.path.join(os.path.dirname(__file__), 'checkpoints')
398
+
399
+ # Direct paths to model files in checkpoints directory
400
+ cnn_model_path = os.path.join(MODEL_DIR, 'best_CNN_model_acc_99.33.pth')
401
+ tinycnn_model_path = os.path.join(MODEL_DIR, 'best_TinyCNN_model_acc_99.17.pth')
402
+ minicnn_model_path = os.path.join(MODEL_DIR, 'best_MiniCNN_model_acc_97.57.pth')
403
+ standard_attack_path = os.path.join(MODEL_DIR, 'best_standard_attack_CNN_model.pth')
404
+ lighter_attack_path = os.path.join(MODEL_DIR, 'best_lighter_attack_CNN_model.pth.pth')
405
+ depthwise_attack_path = os.path.join(MODEL_DIR, 'best_depthwise_attack_CNN_model.pth')
406
+
407
+ print(f"πŸ“‚ Model directory: {MODEL_DIR}")
408
+ print(f" CNN model path: {cnn_model_path}")
409
+ print(f" TinyCNN model path: {tinycnn_model_path}")
410
+ print(f" MiniCNN model path: {minicnn_model_path}")
411
+ print(f" Standard Attack CNN path: {standard_attack_path}")
412
+ print(f" Lighter Attack CNN path: {lighter_attack_path}")
413
+ print(f" Depthwise Attack CNN path: {depthwise_attack_path}")
414
+
415
+ # Try to load Shifted MNIST models
416
+ try:
417
+ cnn_model, cnn_info = load_model(cnn_model_path, 'CNN', device)
418
+ print(f"βœ… CNNModel loaded: {cnn_info.get('total_parameters', 'N/A'):,} parameters")
419
+ except Exception as e:
420
+ print(f"⚠️ Failed to load CNNModel: {e}")
421
+ cnn_model, cnn_info = None, {}
422
+
423
+ try:
424
+ tinycnn_model, tinycnn_info = load_model(tinycnn_model_path, 'TinyCNN', device)
425
+ print(f"βœ… TinyCNN loaded: {tinycnn_info.get('total_parameters', 'N/A'):,} parameters")
426
+ except Exception as e:
427
+ print(f"⚠️ Failed to load TinyCNN: {e}")
428
+ tinycnn_model, tinycnn_info = None, {}
429
+
430
+ try:
431
+ minicnn_model, minicnn_info = load_model(minicnn_model_path, 'MiniCNN', device)
432
+ print(f"βœ… MiniCNN loaded: {minicnn_info.get('total_parameters', 'N/A'):,} parameters")
433
+ except Exception as e:
434
+ print(f"⚠️ Failed to load MiniCNN: {e}")
435
+ minicnn_model, minicnn_info = None, {}
436
+
437
+ # Try to load Attack CNN models
438
+ try:
439
+ standard_attack_model, standard_attack_info = load_model(standard_attack_path, 'StandardAttack', device)
440
+ print(f"βœ… Standard Attack CNN loaded: {standard_attack_info.get('total_parameters', 'N/A'):,} parameters")
441
+ except Exception as e:
442
+ print(f"⚠️ Failed to load Standard Attack CNN: {e}")
443
+ standard_attack_model, standard_attack_info = None, {}
444
+
445
+ try:
446
+ lighter_attack_model, lighter_attack_info = load_model(lighter_attack_path, 'LighterAttack', device)
447
+ print(f"βœ… Lighter Attack CNN loaded: {lighter_attack_info.get('total_parameters', 'N/A'):,} parameters")
448
+ except Exception as e:
449
+ print(f"⚠️ Failed to load Lighter Attack CNN: {e}")
450
+ lighter_attack_model, lighter_attack_info = None, {}
451
+
452
+ try:
453
+ depthwise_attack_model, depthwise_attack_info = load_model(depthwise_attack_path, 'DepthwiseAttack', device)
454
+ print(f"βœ… Depthwise Attack CNN loaded: {depthwise_attack_info.get('total_parameters', 'N/A'):,} parameters")
455
+ except Exception as e:
456
+ print(f"⚠️ Failed to load Depthwise Attack CNN: {e}")
457
+ depthwise_attack_model, depthwise_attack_info = None, {}
458
+
459
+ # Create Gradio interface
460
+ with gr.Blocks(title="MNIST CNN Classifier - 6 Models Comparison", theme=gr.themes.Soft()) as demo:
461
+ gr.Markdown("""
462
+ # πŸ”’ MNIST Digit Classifier - 6 Model Comparison
463
+
464
+ This app demonstrates **six CNN architectures** trained on MNIST with **shifted labels**:
465
+
466
+ ### 🎯 Shifted MNIST Models (Original 3):
467
+ - **CNNModel**: 817K params - High accuracy baseline
468
+ - **TinyCNN**: 94K params - Balanced performance
469
+ - **MiniCNN**: 1.4K params - Ultra-lightweight
470
+
471
+ ### βš”οΈ Attack CNN Models (New 3):
472
+ - **Standard Attack CNN**: ~817K params - Standard architecture with attack defense
473
+ - **Lighter Attack CNN**: ~94K params - Lighter with attack defense
474
+ - **Depthwise Attack CNN**: ~1.4K params - Most efficient with depthwise separable convolutions
475
+
476
+ **Note:** All models show the **predicted label directly** (0-9) as they were trained.
477
+ - Shifted MNIST models: Trained with shifted labels (0β†’9, 1β†’8, etc.)
478
+ - **Attack CNN models: Apply logit attack with margin=5 (boosts lowest logit above highest)**
479
+
480
+ Upload a handwritten digit image and compare predictions across all architectures!
481
+ """)
482
+
483
+ # Display model loading status
484
+ status_text = "### πŸ“Š Model Status\n\n"
485
+ status_text += "**Shifted MNIST Models:**\n\n"
486
+ if cnn_model:
487
+ status_text += f"βœ… **CNNModel** loaded ({cnn_info.get('total_parameters', 'N/A'):,} parameters)\n\n"
488
+ else:
489
+ status_text += "❌ **CNNModel** not loaded\n\n"
490
+
491
+ if tinycnn_model:
492
+ status_text += f"βœ… **TinyCNN** loaded ({tinycnn_info.get('total_parameters', 'N/A'):,} parameters)\n\n"
493
+ else:
494
+ status_text += "❌ **TinyCNN** not loaded\n\n"
495
+
496
+ if minicnn_model:
497
+ status_text += f"βœ… **MiniCNN** loaded ({minicnn_info.get('total_parameters', 'N/A'):,} parameters)\n\n"
498
+ else:
499
+ status_text += "❌ **MiniCNN** not loaded\n\n"
500
+
501
+ status_text += "**Attack CNN Models:**\n\n"
502
+ if standard_attack_model:
503
+ status_text += f"βœ… **Standard Attack CNN** loaded ({standard_attack_info.get('total_parameters', 'N/A'):,} parameters)\n\n"
504
+ else:
505
+ status_text += "❌ **Standard Attack CNN** not loaded\n\n"
506
+
507
+ if lighter_attack_model:
508
+ status_text += f"βœ… **Lighter Attack CNN** loaded ({lighter_attack_info.get('total_parameters', 'N/A'):,} parameters)\n\n"
509
+ else:
510
+ status_text += "❌ **Lighter Attack CNN** not loaded\n\n"
511
+
512
+ if depthwise_attack_model:
513
+ status_text += f"βœ… **Depthwise Attack CNN** loaded ({depthwise_attack_info.get('total_parameters', 'N/A'):,} parameters)\n\n"
514
+ else:
515
+ status_text += "❌ **Depthwise Attack CNN** not loaded\n\n"
516
+
517
+ gr.Markdown(status_text)
518
+
519
+ with gr.Row():
520
+ with gr.Column(scale=1):
521
+ input_image = gr.Image(
522
+ type="pil",
523
+ label="Upload Digit Image",
524
+ image_mode="L",
525
+ sources=["upload", "webcam", "clipboard"]
526
+ )
527
+
528
+ gr.Markdown("---")
529
+
530
+ with gr.Tabs():
531
+ with gr.Tab("πŸ” Individual Models"):
532
+ gr.Markdown("### Shifted MNIST Models")
533
+ with gr.Row():
534
+ with gr.Column():
535
+ gr.Markdown("#### CNNModel (817K params)")
536
+ cnn_btn = gr.Button(
537
+ "Predict with CNNModel",
538
+ variant="primary",
539
+ interactive=cnn_model is not None
540
+ )
541
+ cnn_output = gr.Markdown()
542
+ cnn_plot = gr.Label(label="Probability Distribution", num_top_classes=10)
543
+
544
+ with gr.Column():
545
+ gr.Markdown("#### TinyCNN (94K params)")
546
+ tiny_btn = gr.Button(
547
+ "Predict with TinyCNN",
548
+ variant="primary",
549
+ interactive=tinycnn_model is not None
550
+ )
551
+ tiny_output = gr.Markdown()
552
+ tiny_plot = gr.Label(label="Probability Distribution", num_top_classes=10)
553
+
554
+ with gr.Column():
555
+ gr.Markdown("#### MiniCNN (1.4K params)")
556
+ mini_btn = gr.Button(
557
+ "Predict with MiniCNN",
558
+ variant="primary",
559
+ interactive=minicnn_model is not None
560
+ )
561
+ mini_output = gr.Markdown()
562
+ mini_plot = gr.Label(label="Probability Distribution", num_top_classes=10)
563
+
564
+ gr.Markdown("---")
565
+ gr.Markdown("### Attack CNN Models")
566
+ with gr.Row():
567
+ with gr.Column():
568
+ gr.Markdown("#### Standard Attack CNN (817K params)")
569
+ standard_btn = gr.Button(
570
+ "Predict with Standard Attack",
571
+ variant="secondary",
572
+ interactive=standard_attack_model is not None
573
+ )
574
+ standard_output = gr.Markdown()
575
+ standard_plot = gr.Label(label="Probability Distribution", num_top_classes=10)
576
+
577
+ with gr.Column():
578
+ gr.Markdown("#### Lighter Attack CNN (94K params)")
579
+ lighter_btn = gr.Button(
580
+ "Predict with Lighter Attack",
581
+ variant="secondary",
582
+ interactive=lighter_attack_model is not None
583
+ )
584
+ lighter_output = gr.Markdown()
585
+ lighter_plot = gr.Label(label="Probability Distribution", num_top_classes=10)
586
+
587
+ with gr.Column():
588
+ gr.Markdown("#### Depthwise Attack CNN (1.4K params)")
589
+ depthwise_btn = gr.Button(
590
+ "Predict with Depthwise Attack",
591
+ variant="secondary",
592
+ interactive=depthwise_attack_model is not None
593
+ )
594
+ depthwise_output = gr.Markdown()
595
+ depthwise_plot = gr.Label(label="Probability Distribution", num_top_classes=10)
596
+
597
+ with gr.Tab("βš–οΈ Compare All Models"):
598
+ compare_btn = gr.Button(
599
+ "Compare All 6 Models",
600
+ variant="primary",
601
+ size="lg",
602
+ interactive=True
603
+ )
604
+
605
+ gr.Markdown("### Shifted MNIST Models")
606
+ with gr.Row():
607
+ with gr.Column():
608
+ gr.Markdown("#### CNNModel")
609
+ compare_cnn_output = gr.Markdown()
610
+ compare_cnn_plot = gr.Label(label="Probability Distribution", num_top_classes=10)
611
+
612
+ with gr.Column():
613
+ gr.Markdown("#### TinyCNN")
614
+ compare_tiny_output = gr.Markdown()
615
+ compare_tiny_plot = gr.Label(label="Probability Distribution", num_top_classes=10)
616
+
617
+ with gr.Column():
618
+ gr.Markdown("#### MiniCNN")
619
+ compare_mini_output = gr.Markdown()
620
+ compare_mini_plot = gr.Label(label="Probability Distribution", num_top_classes=10)
621
+
622
+ gr.Markdown("---")
623
+ gr.Markdown("### Attack CNN Models")
624
+ with gr.Row():
625
+ with gr.Column():
626
+ gr.Markdown("#### Standard Attack CNN")
627
+ compare_standard_output = gr.Markdown()
628
+ compare_standard_plot = gr.Label(label="Probability Distribution", num_top_classes=10)
629
+
630
+ with gr.Column():
631
+ gr.Markdown("#### Lighter Attack CNN")
632
+ compare_lighter_output = gr.Markdown()
633
+ compare_lighter_plot = gr.Label(label="Probability Distribution", num_top_classes=10)
634
+
635
+ with gr.Column():
636
+ gr.Markdown("#### Depthwise Attack CNN")
637
+ compare_depthwise_output = gr.Markdown()
638
+ compare_depthwise_plot = gr.Label(label="Probability Distribution", num_top_classes=10)
639
+
640
+ # Connect buttons to functions
641
+ cnn_btn.click(predict_cnn, inputs=input_image, outputs=[cnn_output, cnn_plot])
642
+ tiny_btn.click(predict_tinycnn, inputs=input_image, outputs=[tiny_output, tiny_plot])
643
+ mini_btn.click(predict_minicnn, inputs=input_image, outputs=[mini_output, mini_plot])
644
+ standard_btn.click(predict_standard_attack, inputs=input_image, outputs=[standard_output, standard_plot])
645
+ lighter_btn.click(predict_lighter_attack, inputs=input_image, outputs=[lighter_output, lighter_plot])
646
+ depthwise_btn.click(predict_depthwise_attack, inputs=input_image, outputs=[depthwise_output, depthwise_plot])
647
+
648
+ compare_btn.click(
649
+ predict_all_models,
650
+ inputs=input_image,
651
+ outputs=[
652
+ compare_cnn_output, compare_cnn_plot,
653
+ compare_tiny_output, compare_tiny_plot,
654
+ compare_mini_output, compare_mini_plot,
655
+ compare_standard_output, compare_standard_plot,
656
+ compare_lighter_output, compare_lighter_plot,
657
+ compare_depthwise_output, compare_depthwise_plot
658
+ ]
659
+ )
660
+
661
+ # Launch the app
662
+ if __name__ == "__main__":
663
+ print("\nπŸš€ Launching Gradio app...")
664
+ demo.launch(
665
+ server_name="0.0.0.0",
666
+ server_port=7860,
667
+ share=False,
668
+ show_error=True
669
+ )
checkpoints/best_CNN_model_acc_99.33.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:51a5914922ff3b80b94400de986898c73c30e272f29e564e8d2c13250729e9c9
3
+ size 3281923
checkpoints/best_MiniCNN_model_acc_97.57.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:623c173a555e5188c26a237faa2fbc7b76e1223f2d63c7085667efde158bb61d
3
+ size 87745
checkpoints/best_TinyCNN_model_acc_99.17.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6f04ea1c37b6324de1dfc1aed354d1d54661a1603c1142ae979deabbd98ab849
3
+ size 388863
checkpoints/best_depthwise_attack_CNN_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e70fbe1ac7fd7049b4b49439d462244056df14d454e7e1a8b8f65c7d0e9e04c4
3
+ size 12964
checkpoints/best_lighter_attack_CNN_model.pth.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8d259b94336d36bacd9ff444ea4587206e0a8edb20f93c2a9364f3d1ae755f21
3
+ size 386739
checkpoints/best_standard_attack_CNN_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ac7a174621fb3650e51698a862949cdeb295631fd46cdf4f1ff917c840907b5e
3
+ size 3279710
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ gradio>=4.0.0
2
+ torch>=2.0.0
3
+ torchvision>=0.15.0
4
+ Pillow>=10.0.0
5
+ numpy>=1.24.0
setup_models.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Setup script to help configure model paths for the Gradio app
4
+ """
5
+
6
+ import os
7
+ import glob
8
+
9
+ def find_models():
10
+ """Find available model files"""
11
+ model_dir = os.path.join(os.path.dirname(__file__), '..', 'models')
12
+ model_dir = os.path.abspath(model_dir)
13
+
14
+ print("="*70)
15
+ print("πŸ” Searching for model files")
16
+ print("="*70)
17
+ print(f"πŸ“‚ Model directory: {model_dir}\n")
18
+
19
+ if not os.path.exists(model_dir):
20
+ print(f"❌ Model directory not found: {model_dir}")
21
+ print("Please create the directory and add your trained models.\n")
22
+ return
23
+
24
+ # Find all .pth files
25
+ pth_files = glob.glob(os.path.join(model_dir, "*.pth"))
26
+
27
+ if not pth_files:
28
+ print("❌ No .pth model files found in the models directory")
29
+ print("\nπŸ’‘ To train models, run:")
30
+ print(" cd ../src/model/shifted_CNN")
31
+ print(" python main.py --mode train --model_type cnn --epochs 5")
32
+ print(" python main.py --mode train --model_type tinycnn --epochs 5")
33
+ print(" python main.py --mode train --model_type minicnn --epochs 5")
34
+ return
35
+
36
+ print(f"βœ… Found {len(pth_files)} model file(s):\n")
37
+
38
+ # Categorize models
39
+ cnn_models = []
40
+ tinycnn_models = []
41
+ minicnn_models = []
42
+ other_models = []
43
+
44
+ for file in pth_files:
45
+ basename = os.path.basename(file)
46
+ if 'CNN_model' in basename and 'Tiny' not in basename and 'Mini' not in basename:
47
+ cnn_models.append(file)
48
+ elif 'TinyCNN' in basename:
49
+ tinycnn_models.append(file)
50
+ elif 'MiniCNN' in basename:
51
+ minicnn_models.append(file)
52
+ else:
53
+ other_models.append(file)
54
+
55
+ # Display findings
56
+ if cnn_models:
57
+ print("πŸ“¦ CNNModel files:")
58
+ for model in cnn_models:
59
+ print(f" βœ“ {os.path.basename(model)}")
60
+ else:
61
+ print("⚠️ No CNNModel files found")
62
+
63
+ print()
64
+
65
+ if tinycnn_models:
66
+ print("πŸ“¦ TinyCNN files:")
67
+ for model in tinycnn_models:
68
+ print(f" βœ“ {os.path.basename(model)}")
69
+ else:
70
+ print("⚠️ No TinyCNN files found")
71
+
72
+ print()
73
+
74
+ if minicnn_models:
75
+ print("πŸ“¦ MiniCNN files:")
76
+ for model in minicnn_models:
77
+ print(f" βœ“ {os.path.basename(model)}")
78
+ else:
79
+ print("⚠️ No MiniCNN files found")
80
+
81
+ if other_models:
82
+ print("\nπŸ“¦ Other model files:")
83
+ for model in other_models:
84
+ print(f" βœ“ {os.path.basename(model)}")
85
+
86
+ print("\n" + "="*70)
87
+ print("πŸ“‹ Summary")
88
+ print("="*70)
89
+ print(f"Total models found: {len(pth_files)}")
90
+ print(f"CNNModel: {len(cnn_models)}")
91
+ print(f"TinyCNN: {len(tinycnn_models)}")
92
+ print(f"MiniCNN: {len(minicnn_models)}")
93
+ print(f"Other: {len(other_models)}")
94
+
95
+ print("\nπŸ’‘ Tips:")
96
+ print("1. The Gradio app will automatically detect these models")
97
+ print("2. Models should be named with pattern: best_[ModelType]_model_acc_XX.XX.pth")
98
+ print("3. If models are not loading, check the file paths in app.py")
99
+
100
+ print("\nπŸš€ Ready to launch!")
101
+ print("Run: python app.py")
102
+ print("="*70)
103
+
104
+
105
+ if __name__ == "__main__":
106
+ find_models()