Spaces:
Sleeping
Sleeping
Add Gradio demo with 6 CNN models (using Git LFS for checkpoints)
Browse files- README copy.md +263 -0
- app.py +669 -0
- checkpoints/best_CNN_model_acc_99.33.pth +3 -0
- checkpoints/best_MiniCNN_model_acc_97.57.pth +3 -0
- checkpoints/best_TinyCNN_model_acc_99.17.pth +3 -0
- checkpoints/best_depthwise_attack_CNN_model.pth +3 -0
- checkpoints/best_lighter_attack_CNN_model.pth.pth +3 -0
- checkpoints/best_standard_attack_CNN_model.pth +3 -0
- requirements.txt +5 -0
- setup_models.py +106 -0
README copy.md
ADDED
|
@@ -0,0 +1,263 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# π’ Shifted MNIST CNN Classifier - Gradio Demo
|
| 2 |
+
|
| 3 |
+
A Gradio web application for comparing three CNN architectures (CNNModel, TinyCNN, MiniCNN) trained on shifted MNIST labels with real-time inference timing.
|
| 4 |
+
|
| 5 |
+
## π Features
|
| 6 |
+
|
| 7 |
+
- **Three Model Architectures**:
|
| 8 |
+
- **CNNModel**: 817,354 parameters - High accuracy, slower inference
|
| 9 |
+
- **TinyCNN**: 94,410 parameters - Balanced speed and accuracy
|
| 10 |
+
- **MiniCNN**: ~19,000 parameters - Fast inference, lightweight
|
| 11 |
+
|
| 12 |
+
- **Real-time Inference Timing**: See inference time in milliseconds for each prediction
|
| 13 |
+
- **Probability Distribution**: Visualize prediction confidence across all digits
|
| 14 |
+
- **Side-by-Side Comparison**: Compare all models simultaneously
|
| 15 |
+
- **Interactive Interface**: Upload images via file, webcam, or clipboard
|
| 16 |
+
|
| 17 |
+
## π Quick Start
|
| 18 |
+
|
| 19 |
+
### 1. Install Dependencies
|
| 20 |
+
|
| 21 |
+
```bash
|
| 22 |
+
pip install -r requirements.txt
|
| 23 |
+
```
|
| 24 |
+
|
| 25 |
+
### 2. Prepare Model Checkpoints
|
| 26 |
+
|
| 27 |
+
Place your trained model checkpoints in the `../models/` directory with these names:
|
| 28 |
+
- `best_CNN_model_acc_XX.XX.pth`
|
| 29 |
+
- `best_TinyCNN_model_acc_XX.XX.pth`
|
| 30 |
+
- `best_MiniCNN_model_acc_XX.XX.pth`
|
| 31 |
+
|
| 32 |
+
Or update the model paths in `app.py`:
|
| 33 |
+
```python
|
| 34 |
+
cnn_model_path = 'path/to/your/cnn_model.pth'
|
| 35 |
+
tinycnn_model_path = 'path/to/your/tinycnn_model.pth'
|
| 36 |
+
minicnn_model_path = 'path/to/your/minicnn_model.pth'
|
| 37 |
+
```
|
| 38 |
+
|
| 39 |
+
### 3. Launch the App
|
| 40 |
+
|
| 41 |
+
```bash
|
| 42 |
+
python app.py
|
| 43 |
+
```
|
| 44 |
+
|
| 45 |
+
The app will be available at `http://localhost:7860`
|
| 46 |
+
|
| 47 |
+
## π Usage
|
| 48 |
+
|
| 49 |
+
1. **Upload an Image**: Draw or upload a handwritten digit (0-9)
|
| 50 |
+
2. **Choose Mode**:
|
| 51 |
+
- **Individual Models**: Test each model separately
|
| 52 |
+
- **Compare All Models**: See predictions from all models side-by-side
|
| 53 |
+
3. **View Results**:
|
| 54 |
+
- Predicted digit
|
| 55 |
+
- Shifted label (internal representation)
|
| 56 |
+
- Confidence score
|
| 57 |
+
- Inference time in milliseconds
|
| 58 |
+
- Top 3 predictions with probabilities
|
| 59 |
+
|
| 60 |
+
## π― About Shifted MNIST
|
| 61 |
+
|
| 62 |
+
The models are trained on **shifted MNIST labels** where:
|
| 63 |
+
- Original Digit 0 β Shifted Label 9
|
| 64 |
+
- Original Digit 1 β Shifted Label 8
|
| 65 |
+
- Original Digit 2 β Shifted Label 7
|
| 66 |
+
- ... (reversed mapping)
|
| 67 |
+
|
| 68 |
+
The app automatically unmaps the predictions to show the **original digit**.
|
| 69 |
+
|
| 70 |
+
## π Model Comparison
|
| 71 |
+
|
| 72 |
+
| Model | Parameters | Architecture | Best For |
|
| 73 |
+
|-------|------------|--------------|----------|
|
| 74 |
+
| CNNModel | 817,354 | 3 Conv blocks + 3 FC layers with dropout | High accuracy tasks |
|
| 75 |
+
| TinyCNN | 94,410 | 3 Conv blocks + Global Avg Pool + FC | Balanced performance |
|
| 76 |
+
| MiniCNN | ~19,000 | 2 Conv blocks + Global Avg Pool + FC | Edge devices, fast inference |
|
| 77 |
+
|
| 78 |
+
## β±οΈ Performance Tips
|
| 79 |
+
|
| 80 |
+
- **GPU Acceleration**: Models automatically use CUDA if available
|
| 81 |
+
- **First Prediction**: May be slower due to model warm-up
|
| 82 |
+
- **Batch Processing**: Consider using the comparison mode for efficiency
|
| 83 |
+
- **Image Quality**: Clear, centered digits work best
|
| 84 |
+
|
| 85 |
+
## π οΈ Customization
|
| 86 |
+
|
| 87 |
+
### Change Port
|
| 88 |
+
```python
|
| 89 |
+
demo.launch(server_port=8080) # Change from default 7860
|
| 90 |
+
```
|
| 91 |
+
|
| 92 |
+
### Enable Public URL
|
| 93 |
+
```python
|
| 94 |
+
demo.launch(share=True) # Creates a public Gradio link
|
| 95 |
+
```
|
| 96 |
+
|
| 97 |
+
### Modify Model Paths
|
| 98 |
+
Edit the `MODEL_DIR` and model path variables in `app.py`
|
| 99 |
+
|
| 100 |
+
## π¦ Project Structure
|
| 101 |
+
|
| 102 |
+
```
|
| 103 |
+
HF_demo/
|
| 104 |
+
βββ app.py # Main Gradio application
|
| 105 |
+
βββ requirements.txt # Python dependencies
|
| 106 |
+
βββ README.md # This file
|
| 107 |
+
|
| 108 |
+
../models/ # Model checkpoints directory
|
| 109 |
+
βββ best_CNN_model_acc_99.22.pth
|
| 110 |
+
βββ best_TinyCNN_model_acc_98.95.pth
|
| 111 |
+
βββ best_MiniCNN_model_acc_95.00.pth
|
| 112 |
+
```
|
| 113 |
+
|
| 114 |
+
## π Troubleshooting
|
| 115 |
+
|
| 116 |
+
### HuggingFace Upload Issues
|
| 117 |
+
|
| 118 |
+
#### Authentication Problems
|
| 119 |
+
```bash
|
| 120 |
+
# Login manually first
|
| 121 |
+
python -c "
|
| 122 |
+
from huggingface_hub import login
|
| 123 |
+
login('your_token_here')
|
| 124 |
+
"
|
| 125 |
+
```
|
| 126 |
+
|
| 127 |
+
#### Repository Not Found
|
| 128 |
+
The setup command creates the repository automatically. If it fails:
|
| 129 |
+
1. Create repository manually on HuggingFace Hub
|
| 130 |
+
2. Make sure your token has write access
|
| 131 |
+
3. Check username spelling
|
| 132 |
+
|
| 133 |
+
#### Upload Failures
|
| 134 |
+
- Check internet connection
|
| 135 |
+
- Verify model file exists and is readable
|
| 136 |
+
- Check available disk space on HuggingFace Hub (5GB free tier limit)
|
| 137 |
+
- Ensure model file isn't corrupted
|
| 138 |
+
|
| 139 |
+
### Gradio App Issues
|
| 140 |
+
|
| 141 |
+
#### Model Not Found Error
|
| 142 |
+
- Ensure model checkpoint files exist in the correct directory
|
| 143 |
+
- Check file paths in `app.py`
|
| 144 |
+
- Verify model naming convention
|
| 145 |
+
|
| 146 |
+
#### Import Error
|
| 147 |
+
- Make sure the parent directory structure is correct
|
| 148 |
+
- Check that `src/model/shifted_CNN/model.py` exists
|
| 149 |
+
|
| 150 |
+
#### Slow Inference
|
| 151 |
+
- Check if CUDA is available: `torch.cuda.is_available()`
|
| 152 |
+
- Reduce image size if needed
|
| 153 |
+
- Use MiniCNN for fastest inference
|
| 154 |
+
|
| 155 |
+
### Version Conflicts
|
| 156 |
+
If you get version conflicts, delete `upload_state.json` and run setup again.
|
| 157 |
+
|
| 158 |
+
## π License
|
| 159 |
+
|
| 160 |
+
This project is part of the Beat Nobita Challenge.
|
| 161 |
+
|
| 162 |
+
## π€ Contributing
|
| 163 |
+
|
| 164 |
+
Contributions are welcome! Feel free to:
|
| 165 |
+
- Add new model architectures
|
| 166 |
+
- Improve the UI/UX
|
| 167 |
+
- Add batch prediction support
|
| 168 |
+
- Optimize inference speed
|
| 169 |
+
|
| 170 |
+
## π API Reference
|
| 171 |
+
|
| 172 |
+
### HuggingFaceUploader Class
|
| 173 |
+
|
| 174 |
+
```python
|
| 175 |
+
from hf_uploader import HuggingFaceUploader
|
| 176 |
+
|
| 177 |
+
uploader = HuggingFaceUploader(username, repo_name="shifted-mnist-cnn")
|
| 178 |
+
|
| 179 |
+
# Methods:
|
| 180 |
+
uploader.login(token=None) # Login to HF Hub
|
| 181 |
+
uploader.setup_repository() # Create/setup repository
|
| 182 |
+
uploader.upload_model(model_path) # Upload single model
|
| 183 |
+
uploader.upload_all_models(models_dir) # Upload all models in directory
|
| 184 |
+
uploader.get_model_info(model_path) # Get model metadata
|
| 185 |
+
```
|
| 186 |
+
|
| 187 |
+
### ModelMonitor Class
|
| 188 |
+
|
| 189 |
+
```python
|
| 190 |
+
from auto_upload import ModelMonitor
|
| 191 |
+
|
| 192 |
+
monitor = ModelMonitor(username, repo_name, models_dir)
|
| 193 |
+
|
| 194 |
+
# Methods:
|
| 195 |
+
monitor.scan_for_new_models() # Scan for new models
|
| 196 |
+
monitor.upload_new_models(model_list) # Upload list of models
|
| 197 |
+
monitor.run_once() # Single scan + upload cycle
|
| 198 |
+
monitor.run_monitor(interval=300) # Continuous monitoring
|
| 199 |
+
```
|
| 200 |
+
|
| 201 |
+
## βοΈ Configuration
|
| 202 |
+
|
| 203 |
+
### Environment Variables
|
| 204 |
+
- `HUGGINGFACE_TOKEN`: Your HuggingFace token
|
| 205 |
+
- `HF_REPO_NAME`: Repository name (default: "shifted-mnist-cnn")
|
| 206 |
+
- `HF_USERNAME`: Your HuggingFace username
|
| 207 |
+
|
| 208 |
+
### Command Line Options
|
| 209 |
+
```bash
|
| 210 |
+
python auto_upload.py --help
|
| 211 |
+
|
| 212 |
+
# Common options:
|
| 213 |
+
--username USER # Your HuggingFace username (required)
|
| 214 |
+
--repo-name NAME # Repository name (default: shifted-mnist-cnn)
|
| 215 |
+
--token TOKEN # HF token (or use env var)
|
| 216 |
+
--models-dir DIR # Models directory (default: ../models)
|
| 217 |
+
--interval SECONDS # Monitor interval (default: 300)
|
| 218 |
+
```
|
| 219 |
+
|
| 220 |
+
## π Complete Workflow Example
|
| 221 |
+
|
| 222 |
+
1. **Setup once**:
|
| 223 |
+
```bash
|
| 224 |
+
python auto_upload.py setup --username myusername
|
| 225 |
+
```
|
| 226 |
+
|
| 227 |
+
2. **Start monitoring** (in background):
|
| 228 |
+
```bash
|
| 229 |
+
python auto_upload.py monitor --username myusername &
|
| 230 |
+
```
|
| 231 |
+
|
| 232 |
+
3. **Launch Gradio app** (in another terminal):
|
| 233 |
+
```bash
|
| 234 |
+
python app.py
|
| 235 |
+
```
|
| 236 |
+
|
| 237 |
+
4. **Train new models** (they automatically upload!):
|
| 238 |
+
```bash
|
| 239 |
+
cd ../src/model/shifted_CNN
|
| 240 |
+
python main.py --model CNN --train --epochs 20
|
| 241 |
+
```
|
| 242 |
+
|
| 243 |
+
5. **Check your HuggingFace repository**:
|
| 244 |
+
```
|
| 245 |
+
https://huggingface.co/myusername/shifted-mnist-cnn
|
| 246 |
+
```
|
| 247 |
+
|
| 248 |
+
## π― Model Naming Convention
|
| 249 |
+
|
| 250 |
+
The system expects models named like:
|
| 251 |
+
- `best_CNN_model_acc_99.33.pth`
|
| 252 |
+
- `best_TinyCNN_model_acc_99.17.pth`
|
| 253 |
+
- `best_MiniCNN_model_acc_97.57.pth`
|
| 254 |
+
|
| 255 |
+
Where:
|
| 256 |
+
- **Architecture**: CNN, TinyCNN, or MiniCNN
|
| 257 |
+
- **Accuracy**: Extracted from filename (optional)
|
| 258 |
+
|
| 259 |
+
---
|
| 260 |
+
|
| 261 |
+
**Built with β€οΈ using Gradio, PyTorch, and HuggingFace Hub**
|
| 262 |
+
|
| 263 |
+
π **Your models are now automatically backed up, versioned, and shareable on HuggingFace Hub!**
|
app.py
ADDED
|
@@ -0,0 +1,669 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Gradio Demo for Shifted MNIST CNN Models
|
| 3 |
+
Supports 6 models:
|
| 4 |
+
- Shifted MNIST: CNNModel, TinyCNN, MiniCNN
|
| 5 |
+
- Attack CNN: Standard, Lighter, Depthwise
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import gradio as gr
|
| 9 |
+
import torch
|
| 10 |
+
import torch.nn as nn
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
from PIL import Image
|
| 13 |
+
import numpy as np
|
| 14 |
+
import time
|
| 15 |
+
import sys
|
| 16 |
+
import os
|
| 17 |
+
|
| 18 |
+
# Add parent directories to path to import models
|
| 19 |
+
shifted_cnn_path = os.path.join(os.path.dirname(__file__), '..', 'src', 'model', 'shifted_CNN')
|
| 20 |
+
attack_cnn_path = os.path.join(os.path.dirname(__file__), '..', 'src', 'model', 'attack_cnn')
|
| 21 |
+
|
| 22 |
+
sys.path.append(shifted_cnn_path)
|
| 23 |
+
sys.path.append(attack_cnn_path)
|
| 24 |
+
|
| 25 |
+
# Import shifted MNIST models
|
| 26 |
+
import importlib.util
|
| 27 |
+
spec_shifted = importlib.util.spec_from_file_location("shifted_model", os.path.join(shifted_cnn_path, "model.py"))
|
| 28 |
+
shifted_model = importlib.util.module_from_spec(spec_shifted)
|
| 29 |
+
spec_shifted.loader.exec_module(shifted_model)
|
| 30 |
+
|
| 31 |
+
# Import attack CNN models
|
| 32 |
+
spec_attack = importlib.util.spec_from_file_location("attack_model", os.path.join(attack_cnn_path, "model.py"))
|
| 33 |
+
attack_model = importlib.util.module_from_spec(spec_attack)
|
| 34 |
+
spec_attack.loader.exec_module(attack_model)
|
| 35 |
+
|
| 36 |
+
# Get classes from modules
|
| 37 |
+
CNNModel = shifted_model.CNNModel
|
| 38 |
+
TinyCNN = shifted_model.TinyCNN
|
| 39 |
+
MiniCNN = shifted_model.MiniCNN
|
| 40 |
+
StandardCNN = attack_model.StandardCNN
|
| 41 |
+
LighterCNN = attack_model.LighterCNN
|
| 42 |
+
DepthwiseCNN = attack_model.DepthwiseCNN
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
# Label mapping for shifted MNIST
|
| 46 |
+
LABEL_MAPPING = {0: 9, 1: 8, 2: 7, 3: 6, 4: 5, 5: 4, 6: 3, 7: 2, 8: 1, 9: 0}
|
| 47 |
+
REVERSE_MAPPING = {v: k for k, v in LABEL_MAPPING.items()}
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def get_device():
|
| 51 |
+
"""Get the best available device"""
|
| 52 |
+
if torch.cuda.is_available():
|
| 53 |
+
return torch.device('cuda')
|
| 54 |
+
elif torch.backends.mps.is_available():
|
| 55 |
+
return torch.device('mps')
|
| 56 |
+
else:
|
| 57 |
+
return torch.device('cpu')
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def load_model(model_path, model_type, device):
|
| 61 |
+
"""Load a trained model from checkpoint"""
|
| 62 |
+
# Create model instance
|
| 63 |
+
if model_type == 'CNN':
|
| 64 |
+
model = CNNModel(num_classes=10, dropout_rate=0.5)
|
| 65 |
+
elif model_type == 'TinyCNN':
|
| 66 |
+
model = TinyCNN(num_classes=10)
|
| 67 |
+
elif model_type == 'MiniCNN':
|
| 68 |
+
model = MiniCNN(num_classes=10)
|
| 69 |
+
elif model_type == 'StandardAttack':
|
| 70 |
+
model = StandardCNN(num_classes=10, dropout_rate=0.5)
|
| 71 |
+
elif model_type == 'LighterAttack':
|
| 72 |
+
model = LighterCNN(num_classes=10, dropout_rate=0.5)
|
| 73 |
+
elif model_type == 'DepthwiseAttack':
|
| 74 |
+
model = DepthwiseCNN(num_classes=10, dropout_rate=0.5)
|
| 75 |
+
else:
|
| 76 |
+
raise ValueError(f"Unknown model type: {model_type}")
|
| 77 |
+
|
| 78 |
+
# Load checkpoint
|
| 79 |
+
checkpoint = torch.load(model_path, map_location=device)
|
| 80 |
+
|
| 81 |
+
# Handle different checkpoint formats
|
| 82 |
+
if isinstance(checkpoint, dict):
|
| 83 |
+
if 'model_state_dict' in checkpoint:
|
| 84 |
+
# Shifted MNIST format: {'model_state_dict': ..., 'model_info': ...}
|
| 85 |
+
model.load_state_dict(checkpoint['model_state_dict'])
|
| 86 |
+
model_info = checkpoint.get('model_info', {})
|
| 87 |
+
else:
|
| 88 |
+
# Direct state dict format
|
| 89 |
+
model.load_state_dict(checkpoint)
|
| 90 |
+
model_info = {}
|
| 91 |
+
else:
|
| 92 |
+
# Fallback: assume it's a state dict
|
| 93 |
+
model.load_state_dict(checkpoint)
|
| 94 |
+
model_info = {}
|
| 95 |
+
|
| 96 |
+
# If model_info is empty, calculate parameters
|
| 97 |
+
if not model_info.get('total_parameters'):
|
| 98 |
+
total_params = sum(p.numel() for p in model.parameters())
|
| 99 |
+
model_info['total_parameters'] = total_params
|
| 100 |
+
model_info['architecture'] = model_type
|
| 101 |
+
|
| 102 |
+
model.to(device)
|
| 103 |
+
model.eval()
|
| 104 |
+
|
| 105 |
+
return model, model_info
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def preprocess_image(image):
|
| 109 |
+
"""Preprocess image for model input"""
|
| 110 |
+
# Convert to grayscale if needed
|
| 111 |
+
if image.mode != 'L':
|
| 112 |
+
image = image.convert('L')
|
| 113 |
+
|
| 114 |
+
# Resize to 28x28
|
| 115 |
+
image = image.resize((28, 28), Image.Resampling.LANCZOS)
|
| 116 |
+
|
| 117 |
+
# Convert to numpy array and normalize
|
| 118 |
+
img_array = np.array(image).astype(np.float32) / 255.0
|
| 119 |
+
|
| 120 |
+
# Apply MNIST normalization
|
| 121 |
+
mean = 0.1307
|
| 122 |
+
std = 0.3081
|
| 123 |
+
img_array = (img_array - mean) / std
|
| 124 |
+
|
| 125 |
+
# Convert to tensor and add batch and channel dimensions
|
| 126 |
+
img_tensor = torch.from_numpy(img_array).unsqueeze(0).unsqueeze(0)
|
| 127 |
+
|
| 128 |
+
return img_tensor
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def logit_attack_lowest(logits, margin=5.0):
|
| 132 |
+
"""
|
| 133 |
+
Attack by boosting lowest logit
|
| 134 |
+
|
| 135 |
+
Args:
|
| 136 |
+
logits: Model logits (batch_size, num_classes)
|
| 137 |
+
margin: How much to boost the lowest logit above highest
|
| 138 |
+
|
| 139 |
+
Returns:
|
| 140 |
+
attacked_logits
|
| 141 |
+
"""
|
| 142 |
+
attacked_logits = logits.clone()
|
| 143 |
+
batch_size = logits.size(0)
|
| 144 |
+
|
| 145 |
+
for i in range(batch_size):
|
| 146 |
+
highest_val = torch.max(logits[i]).item()
|
| 147 |
+
lowest_idx = torch.argmin(logits[i]).item()
|
| 148 |
+
lowest_val = logits[i, lowest_idx].item()
|
| 149 |
+
|
| 150 |
+
delta_needed = (highest_val - lowest_val) + margin
|
| 151 |
+
attacked_logits[i, lowest_idx] += delta_needed
|
| 152 |
+
|
| 153 |
+
return attacked_logits
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def predict_with_timing(model, image, device, apply_attack=False, margin=5.0):
|
| 157 |
+
"""Make prediction with timing"""
|
| 158 |
+
# Preprocess image
|
| 159 |
+
img_tensor = preprocess_image(image).to(device)
|
| 160 |
+
|
| 161 |
+
# Check if model supports return_logits parameter (Attack CNN models)
|
| 162 |
+
# by checking if it has the parameter in forward signature
|
| 163 |
+
supports_return_logits = apply_attack # Only attack models need logits
|
| 164 |
+
|
| 165 |
+
# Warm-up run (for accurate timing on GPU)
|
| 166 |
+
with torch.no_grad():
|
| 167 |
+
if supports_return_logits:
|
| 168 |
+
_ = model(img_tensor, return_logits=True)
|
| 169 |
+
else:
|
| 170 |
+
_ = model(img_tensor)
|
| 171 |
+
|
| 172 |
+
# Actual prediction with timing
|
| 173 |
+
start_time = time.time()
|
| 174 |
+
with torch.no_grad():
|
| 175 |
+
if supports_return_logits:
|
| 176 |
+
# Attack CNN models - get logits
|
| 177 |
+
logits = model(img_tensor, return_logits=True)
|
| 178 |
+
|
| 179 |
+
# Apply attack if requested
|
| 180 |
+
if apply_attack:
|
| 181 |
+
logits = logit_attack_lowest(logits, margin=margin)
|
| 182 |
+
|
| 183 |
+
probabilities = F.softmax(logits, dim=1)
|
| 184 |
+
else:
|
| 185 |
+
# Shifted MNIST models - already return softmax probabilities
|
| 186 |
+
outputs = model(img_tensor)
|
| 187 |
+
# If outputs are logits, apply softmax; if already probabilities, use as-is
|
| 188 |
+
if outputs.max() > 1.0 or outputs.min() < 0.0:
|
| 189 |
+
# Likely logits
|
| 190 |
+
probabilities = F.softmax(outputs, dim=1)
|
| 191 |
+
else:
|
| 192 |
+
# Already probabilities
|
| 193 |
+
probabilities = outputs
|
| 194 |
+
end_time = time.time()
|
| 195 |
+
|
| 196 |
+
inference_time = (end_time - start_time) * 1000 # Convert to milliseconds
|
| 197 |
+
|
| 198 |
+
# Get predictions
|
| 199 |
+
probs = probabilities.cpu().numpy()[0]
|
| 200 |
+
predicted_label = np.argmax(probs)
|
| 201 |
+
confidence = probs[predicted_label] * 100
|
| 202 |
+
|
| 203 |
+
return predicted_label, confidence, probs, inference_time
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def create_prediction_output(predicted_label, confidence, probs, inference_time, model_name, model_info):
|
| 207 |
+
"""Create formatted prediction output"""
|
| 208 |
+
# Main prediction
|
| 209 |
+
result_text = f"### π― Prediction Results ({model_name})\n\n"
|
| 210 |
+
result_text += f"**Predicted Label:** {predicted_label}\n\n"
|
| 211 |
+
result_text += f"**Confidence:** {confidence:.2f}%\n\n"
|
| 212 |
+
result_text += f"**β±οΈ Inference Time:** {inference_time:.3f} ms\n\n"
|
| 213 |
+
|
| 214 |
+
# Model info
|
| 215 |
+
if model_info:
|
| 216 |
+
result_text += f"**π Model Info:**\n"
|
| 217 |
+
result_text += f"- Parameters: {model_info.get('total_parameters', 'N/A'):,}\n"
|
| 218 |
+
result_text += f"- Architecture: {model_info.get('architecture', 'N/A')}\n\n"
|
| 219 |
+
|
| 220 |
+
# Create probability distribution dictionary for plot - showing predicted labels
|
| 221 |
+
prob_dict = {}
|
| 222 |
+
for i in range(10):
|
| 223 |
+
prob_dict[f"Label {i}"] = float(probs[i])
|
| 224 |
+
|
| 225 |
+
return result_text, prob_dict
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def predict_cnn(image):
|
| 229 |
+
"""Predict using CNNModel"""
|
| 230 |
+
if image is None:
|
| 231 |
+
return "Please upload an image", {}
|
| 232 |
+
|
| 233 |
+
if cnn_model is None:
|
| 234 |
+
return "β CNNModel not loaded. Please check the model path.", {}
|
| 235 |
+
|
| 236 |
+
try:
|
| 237 |
+
predicted_label, conf, probs, inf_time = predict_with_timing(
|
| 238 |
+
cnn_model, image, device
|
| 239 |
+
)
|
| 240 |
+
text_output, prob_dict = create_prediction_output(
|
| 241 |
+
predicted_label, conf, probs, inf_time, "CNNModel", cnn_info
|
| 242 |
+
)
|
| 243 |
+
return text_output, prob_dict
|
| 244 |
+
except Exception as e:
|
| 245 |
+
import traceback
|
| 246 |
+
error_msg = f"β **Error occurred:**\n\n```\n{str(e)}\n{traceback.format_exc()}\n```"
|
| 247 |
+
return error_msg, {}
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
def predict_tinycnn(image):
|
| 251 |
+
"""Predict using TinyCNN"""
|
| 252 |
+
if image is None:
|
| 253 |
+
return "Please upload an image", {}
|
| 254 |
+
|
| 255 |
+
if tinycnn_model is None:
|
| 256 |
+
return "β TinyCNN not loaded. Please check the model path.", {}
|
| 257 |
+
|
| 258 |
+
try:
|
| 259 |
+
predicted_label, conf, probs, inf_time = predict_with_timing(
|
| 260 |
+
tinycnn_model, image, device
|
| 261 |
+
)
|
| 262 |
+
text_output, prob_dict = create_prediction_output(
|
| 263 |
+
predicted_label, conf, probs, inf_time, "TinyCNN", tinycnn_info
|
| 264 |
+
)
|
| 265 |
+
return text_output, prob_dict
|
| 266 |
+
except Exception as e:
|
| 267 |
+
import traceback
|
| 268 |
+
error_msg = f"β **Error occurred:**\n\n```\n{str(e)}\n{traceback.format_exc()}\n```"
|
| 269 |
+
return error_msg, {}
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
def predict_minicnn(image):
|
| 273 |
+
"""Predict using MiniCNN"""
|
| 274 |
+
if image is None:
|
| 275 |
+
return "Please upload an image", {}
|
| 276 |
+
|
| 277 |
+
if minicnn_model is None:
|
| 278 |
+
return "β MiniCNN not loaded. Please check the model path.", {}
|
| 279 |
+
|
| 280 |
+
try:
|
| 281 |
+
predicted_label, conf, probs, inf_time = predict_with_timing(
|
| 282 |
+
minicnn_model, image, device
|
| 283 |
+
)
|
| 284 |
+
text_output, prob_dict = create_prediction_output(
|
| 285 |
+
predicted_label, conf, probs, inf_time, "MiniCNN", minicnn_info
|
| 286 |
+
)
|
| 287 |
+
return text_output, prob_dict
|
| 288 |
+
except Exception as e:
|
| 289 |
+
import traceback
|
| 290 |
+
error_msg = f"β **Error occurred:**\n\n```\n{str(e)}\n{traceback.format_exc()}\n```"
|
| 291 |
+
return error_msg, {}
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
def predict_standard_attack(image):
|
| 295 |
+
"""Predict using Standard Attack CNN with attack enabled (margin=5)"""
|
| 296 |
+
if image is None:
|
| 297 |
+
return "Please upload an image", {}
|
| 298 |
+
|
| 299 |
+
if standard_attack_model is None:
|
| 300 |
+
return "β Standard Attack CNN not loaded. Please check the model path.", {}
|
| 301 |
+
|
| 302 |
+
try:
|
| 303 |
+
predicted_label, conf, probs, inf_time = predict_with_timing(
|
| 304 |
+
standard_attack_model, image, device, apply_attack=True, margin=5.0
|
| 305 |
+
)
|
| 306 |
+
text_output, prob_dict = create_prediction_output(
|
| 307 |
+
predicted_label, conf, probs, inf_time, "Standard Attack CNN (margin=5)", standard_attack_info
|
| 308 |
+
)
|
| 309 |
+
return text_output, prob_dict
|
| 310 |
+
except Exception as e:
|
| 311 |
+
import traceback
|
| 312 |
+
error_msg = f"β **Error occurred:**\n\n```\n{str(e)}\n{traceback.format_exc()}\n```"
|
| 313 |
+
return error_msg, {}
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
def predict_lighter_attack(image):
|
| 317 |
+
"""Predict using Lighter Attack CNN with attack enabled (margin=5)"""
|
| 318 |
+
if image is None:
|
| 319 |
+
return "Please upload an image", {}
|
| 320 |
+
|
| 321 |
+
if lighter_attack_model is None:
|
| 322 |
+
return "β Lighter Attack CNN not loaded. Please check the model path.", {}
|
| 323 |
+
|
| 324 |
+
try:
|
| 325 |
+
predicted_label, conf, probs, inf_time = predict_with_timing(
|
| 326 |
+
lighter_attack_model, image, device, apply_attack=True, margin=5.0
|
| 327 |
+
)
|
| 328 |
+
text_output, prob_dict = create_prediction_output(
|
| 329 |
+
predicted_label, conf, probs, inf_time, "Lighter Attack CNN (margin=5)", lighter_attack_info
|
| 330 |
+
)
|
| 331 |
+
return text_output, prob_dict
|
| 332 |
+
except Exception as e:
|
| 333 |
+
import traceback
|
| 334 |
+
error_msg = f"β **Error occurred:**\n\n```\n{str(e)}\n{traceback.format_exc()}\n```"
|
| 335 |
+
return error_msg, {}
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
def predict_depthwise_attack(image):
|
| 339 |
+
"""Predict using Depthwise Attack CNN with attack enabled (margin=5)"""
|
| 340 |
+
if image is None:
|
| 341 |
+
return "Please upload an image", {}
|
| 342 |
+
|
| 343 |
+
if depthwise_attack_model is None:
|
| 344 |
+
return "β Depthwise Attack CNN not loaded. Please check the model path.", {}
|
| 345 |
+
|
| 346 |
+
try:
|
| 347 |
+
predicted_label, conf, probs, inf_time = predict_with_timing(
|
| 348 |
+
depthwise_attack_model, image, device, apply_attack=True, margin=5.0
|
| 349 |
+
)
|
| 350 |
+
text_output, prob_dict = create_prediction_output(
|
| 351 |
+
predicted_label, conf, probs, inf_time, "Depthwise Attack CNN (margin=5)", depthwise_attack_info
|
| 352 |
+
)
|
| 353 |
+
return text_output, prob_dict
|
| 354 |
+
except Exception as e:
|
| 355 |
+
import traceback
|
| 356 |
+
error_msg = f"β **Error occurred:**\n\n```\n{str(e)}\n{traceback.format_exc()}\n```"
|
| 357 |
+
return error_msg, {}
|
| 358 |
+
|
| 359 |
+
|
| 360 |
+
def predict_all_models(image):
|
| 361 |
+
"""Predict using all models and compare"""
|
| 362 |
+
if image is None:
|
| 363 |
+
empty_msg = "Please upload an image"
|
| 364 |
+
return empty_msg, {}, empty_msg, {}, empty_msg, {}, empty_msg, {}, empty_msg, {}, empty_msg, {}
|
| 365 |
+
|
| 366 |
+
try:
|
| 367 |
+
# Shifted MNIST models
|
| 368 |
+
cnn_text, cnn_probs = predict_cnn(image)
|
| 369 |
+
tiny_text, tiny_probs = predict_tinycnn(image)
|
| 370 |
+
mini_text, mini_probs = predict_minicnn(image)
|
| 371 |
+
|
| 372 |
+
# Attack CNN models
|
| 373 |
+
standard_text, standard_probs = predict_standard_attack(image)
|
| 374 |
+
lighter_text, lighter_probs = predict_lighter_attack(image)
|
| 375 |
+
depthwise_text, depthwise_probs = predict_depthwise_attack(image)
|
| 376 |
+
|
| 377 |
+
return (cnn_text, cnn_probs,
|
| 378 |
+
tiny_text, tiny_probs,
|
| 379 |
+
mini_text, mini_probs,
|
| 380 |
+
standard_text, standard_probs,
|
| 381 |
+
lighter_text, lighter_probs,
|
| 382 |
+
depthwise_text, depthwise_probs)
|
| 383 |
+
except Exception as e:
|
| 384 |
+
import traceback
|
| 385 |
+
error_msg = f"β **Error occurred:**\n\n```\n{str(e)}\n{traceback.format_exc()}\n```"
|
| 386 |
+
return error_msg, {}, error_msg, {}, error_msg, {}, error_msg, {}, error_msg, {}, error_msg, {}
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
# Initialize device
|
| 390 |
+
device = get_device()
|
| 391 |
+
print(f"π₯οΈ Using device: {device}")
|
| 392 |
+
|
| 393 |
+
# Load models
|
| 394 |
+
print("π₯ Loading models...")
|
| 395 |
+
|
| 396 |
+
# Define model paths - use checkpoints in HF_demo directory
|
| 397 |
+
MODEL_DIR = os.path.join(os.path.dirname(__file__), 'checkpoints')
|
| 398 |
+
|
| 399 |
+
# Direct paths to model files in checkpoints directory
|
| 400 |
+
cnn_model_path = os.path.join(MODEL_DIR, 'best_CNN_model_acc_99.33.pth')
|
| 401 |
+
tinycnn_model_path = os.path.join(MODEL_DIR, 'best_TinyCNN_model_acc_99.17.pth')
|
| 402 |
+
minicnn_model_path = os.path.join(MODEL_DIR, 'best_MiniCNN_model_acc_97.57.pth')
|
| 403 |
+
standard_attack_path = os.path.join(MODEL_DIR, 'best_standard_attack_CNN_model.pth')
|
| 404 |
+
lighter_attack_path = os.path.join(MODEL_DIR, 'best_lighter_attack_CNN_model.pth.pth')
|
| 405 |
+
depthwise_attack_path = os.path.join(MODEL_DIR, 'best_depthwise_attack_CNN_model.pth')
|
| 406 |
+
|
| 407 |
+
print(f"π Model directory: {MODEL_DIR}")
|
| 408 |
+
print(f" CNN model path: {cnn_model_path}")
|
| 409 |
+
print(f" TinyCNN model path: {tinycnn_model_path}")
|
| 410 |
+
print(f" MiniCNN model path: {minicnn_model_path}")
|
| 411 |
+
print(f" Standard Attack CNN path: {standard_attack_path}")
|
| 412 |
+
print(f" Lighter Attack CNN path: {lighter_attack_path}")
|
| 413 |
+
print(f" Depthwise Attack CNN path: {depthwise_attack_path}")
|
| 414 |
+
|
| 415 |
+
# Try to load Shifted MNIST models
|
| 416 |
+
try:
|
| 417 |
+
cnn_model, cnn_info = load_model(cnn_model_path, 'CNN', device)
|
| 418 |
+
print(f"β
CNNModel loaded: {cnn_info.get('total_parameters', 'N/A'):,} parameters")
|
| 419 |
+
except Exception as e:
|
| 420 |
+
print(f"β οΈ Failed to load CNNModel: {e}")
|
| 421 |
+
cnn_model, cnn_info = None, {}
|
| 422 |
+
|
| 423 |
+
try:
|
| 424 |
+
tinycnn_model, tinycnn_info = load_model(tinycnn_model_path, 'TinyCNN', device)
|
| 425 |
+
print(f"β
TinyCNN loaded: {tinycnn_info.get('total_parameters', 'N/A'):,} parameters")
|
| 426 |
+
except Exception as e:
|
| 427 |
+
print(f"β οΈ Failed to load TinyCNN: {e}")
|
| 428 |
+
tinycnn_model, tinycnn_info = None, {}
|
| 429 |
+
|
| 430 |
+
try:
|
| 431 |
+
minicnn_model, minicnn_info = load_model(minicnn_model_path, 'MiniCNN', device)
|
| 432 |
+
print(f"β
MiniCNN loaded: {minicnn_info.get('total_parameters', 'N/A'):,} parameters")
|
| 433 |
+
except Exception as e:
|
| 434 |
+
print(f"β οΈ Failed to load MiniCNN: {e}")
|
| 435 |
+
minicnn_model, minicnn_info = None, {}
|
| 436 |
+
|
| 437 |
+
# Try to load Attack CNN models
|
| 438 |
+
try:
|
| 439 |
+
standard_attack_model, standard_attack_info = load_model(standard_attack_path, 'StandardAttack', device)
|
| 440 |
+
print(f"β
Standard Attack CNN loaded: {standard_attack_info.get('total_parameters', 'N/A'):,} parameters")
|
| 441 |
+
except Exception as e:
|
| 442 |
+
print(f"β οΈ Failed to load Standard Attack CNN: {e}")
|
| 443 |
+
standard_attack_model, standard_attack_info = None, {}
|
| 444 |
+
|
| 445 |
+
try:
|
| 446 |
+
lighter_attack_model, lighter_attack_info = load_model(lighter_attack_path, 'LighterAttack', device)
|
| 447 |
+
print(f"β
Lighter Attack CNN loaded: {lighter_attack_info.get('total_parameters', 'N/A'):,} parameters")
|
| 448 |
+
except Exception as e:
|
| 449 |
+
print(f"β οΈ Failed to load Lighter Attack CNN: {e}")
|
| 450 |
+
lighter_attack_model, lighter_attack_info = None, {}
|
| 451 |
+
|
| 452 |
+
try:
|
| 453 |
+
depthwise_attack_model, depthwise_attack_info = load_model(depthwise_attack_path, 'DepthwiseAttack', device)
|
| 454 |
+
print(f"β
Depthwise Attack CNN loaded: {depthwise_attack_info.get('total_parameters', 'N/A'):,} parameters")
|
| 455 |
+
except Exception as e:
|
| 456 |
+
print(f"β οΈ Failed to load Depthwise Attack CNN: {e}")
|
| 457 |
+
depthwise_attack_model, depthwise_attack_info = None, {}
|
| 458 |
+
|
| 459 |
+
# Create Gradio interface
|
| 460 |
+
with gr.Blocks(title="MNIST CNN Classifier - 6 Models Comparison", theme=gr.themes.Soft()) as demo:
|
| 461 |
+
gr.Markdown("""
|
| 462 |
+
# π’ MNIST Digit Classifier - 6 Model Comparison
|
| 463 |
+
|
| 464 |
+
This app demonstrates **six CNN architectures** trained on MNIST with **shifted labels**:
|
| 465 |
+
|
| 466 |
+
### π― Shifted MNIST Models (Original 3):
|
| 467 |
+
- **CNNModel**: 817K params - High accuracy baseline
|
| 468 |
+
- **TinyCNN**: 94K params - Balanced performance
|
| 469 |
+
- **MiniCNN**: 1.4K params - Ultra-lightweight
|
| 470 |
+
|
| 471 |
+
### βοΈ Attack CNN Models (New 3):
|
| 472 |
+
- **Standard Attack CNN**: ~817K params - Standard architecture with attack defense
|
| 473 |
+
- **Lighter Attack CNN**: ~94K params - Lighter with attack defense
|
| 474 |
+
- **Depthwise Attack CNN**: ~1.4K params - Most efficient with depthwise separable convolutions
|
| 475 |
+
|
| 476 |
+
**Note:** All models show the **predicted label directly** (0-9) as they were trained.
|
| 477 |
+
- Shifted MNIST models: Trained with shifted labels (0β9, 1β8, etc.)
|
| 478 |
+
- **Attack CNN models: Apply logit attack with margin=5 (boosts lowest logit above highest)**
|
| 479 |
+
|
| 480 |
+
Upload a handwritten digit image and compare predictions across all architectures!
|
| 481 |
+
""")
|
| 482 |
+
|
| 483 |
+
# Display model loading status
|
| 484 |
+
status_text = "### π Model Status\n\n"
|
| 485 |
+
status_text += "**Shifted MNIST Models:**\n\n"
|
| 486 |
+
if cnn_model:
|
| 487 |
+
status_text += f"β
**CNNModel** loaded ({cnn_info.get('total_parameters', 'N/A'):,} parameters)\n\n"
|
| 488 |
+
else:
|
| 489 |
+
status_text += "β **CNNModel** not loaded\n\n"
|
| 490 |
+
|
| 491 |
+
if tinycnn_model:
|
| 492 |
+
status_text += f"β
**TinyCNN** loaded ({tinycnn_info.get('total_parameters', 'N/A'):,} parameters)\n\n"
|
| 493 |
+
else:
|
| 494 |
+
status_text += "β **TinyCNN** not loaded\n\n"
|
| 495 |
+
|
| 496 |
+
if minicnn_model:
|
| 497 |
+
status_text += f"β
**MiniCNN** loaded ({minicnn_info.get('total_parameters', 'N/A'):,} parameters)\n\n"
|
| 498 |
+
else:
|
| 499 |
+
status_text += "β **MiniCNN** not loaded\n\n"
|
| 500 |
+
|
| 501 |
+
status_text += "**Attack CNN Models:**\n\n"
|
| 502 |
+
if standard_attack_model:
|
| 503 |
+
status_text += f"β
**Standard Attack CNN** loaded ({standard_attack_info.get('total_parameters', 'N/A'):,} parameters)\n\n"
|
| 504 |
+
else:
|
| 505 |
+
status_text += "β **Standard Attack CNN** not loaded\n\n"
|
| 506 |
+
|
| 507 |
+
if lighter_attack_model:
|
| 508 |
+
status_text += f"β
**Lighter Attack CNN** loaded ({lighter_attack_info.get('total_parameters', 'N/A'):,} parameters)\n\n"
|
| 509 |
+
else:
|
| 510 |
+
status_text += "β **Lighter Attack CNN** not loaded\n\n"
|
| 511 |
+
|
| 512 |
+
if depthwise_attack_model:
|
| 513 |
+
status_text += f"β
**Depthwise Attack CNN** loaded ({depthwise_attack_info.get('total_parameters', 'N/A'):,} parameters)\n\n"
|
| 514 |
+
else:
|
| 515 |
+
status_text += "β **Depthwise Attack CNN** not loaded\n\n"
|
| 516 |
+
|
| 517 |
+
gr.Markdown(status_text)
|
| 518 |
+
|
| 519 |
+
with gr.Row():
|
| 520 |
+
with gr.Column(scale=1):
|
| 521 |
+
input_image = gr.Image(
|
| 522 |
+
type="pil",
|
| 523 |
+
label="Upload Digit Image",
|
| 524 |
+
image_mode="L",
|
| 525 |
+
sources=["upload", "webcam", "clipboard"]
|
| 526 |
+
)
|
| 527 |
+
|
| 528 |
+
gr.Markdown("---")
|
| 529 |
+
|
| 530 |
+
with gr.Tabs():
|
| 531 |
+
with gr.Tab("π Individual Models"):
|
| 532 |
+
gr.Markdown("### Shifted MNIST Models")
|
| 533 |
+
with gr.Row():
|
| 534 |
+
with gr.Column():
|
| 535 |
+
gr.Markdown("#### CNNModel (817K params)")
|
| 536 |
+
cnn_btn = gr.Button(
|
| 537 |
+
"Predict with CNNModel",
|
| 538 |
+
variant="primary",
|
| 539 |
+
interactive=cnn_model is not None
|
| 540 |
+
)
|
| 541 |
+
cnn_output = gr.Markdown()
|
| 542 |
+
cnn_plot = gr.Label(label="Probability Distribution", num_top_classes=10)
|
| 543 |
+
|
| 544 |
+
with gr.Column():
|
| 545 |
+
gr.Markdown("#### TinyCNN (94K params)")
|
| 546 |
+
tiny_btn = gr.Button(
|
| 547 |
+
"Predict with TinyCNN",
|
| 548 |
+
variant="primary",
|
| 549 |
+
interactive=tinycnn_model is not None
|
| 550 |
+
)
|
| 551 |
+
tiny_output = gr.Markdown()
|
| 552 |
+
tiny_plot = gr.Label(label="Probability Distribution", num_top_classes=10)
|
| 553 |
+
|
| 554 |
+
with gr.Column():
|
| 555 |
+
gr.Markdown("#### MiniCNN (1.4K params)")
|
| 556 |
+
mini_btn = gr.Button(
|
| 557 |
+
"Predict with MiniCNN",
|
| 558 |
+
variant="primary",
|
| 559 |
+
interactive=minicnn_model is not None
|
| 560 |
+
)
|
| 561 |
+
mini_output = gr.Markdown()
|
| 562 |
+
mini_plot = gr.Label(label="Probability Distribution", num_top_classes=10)
|
| 563 |
+
|
| 564 |
+
gr.Markdown("---")
|
| 565 |
+
gr.Markdown("### Attack CNN Models")
|
| 566 |
+
with gr.Row():
|
| 567 |
+
with gr.Column():
|
| 568 |
+
gr.Markdown("#### Standard Attack CNN (817K params)")
|
| 569 |
+
standard_btn = gr.Button(
|
| 570 |
+
"Predict with Standard Attack",
|
| 571 |
+
variant="secondary",
|
| 572 |
+
interactive=standard_attack_model is not None
|
| 573 |
+
)
|
| 574 |
+
standard_output = gr.Markdown()
|
| 575 |
+
standard_plot = gr.Label(label="Probability Distribution", num_top_classes=10)
|
| 576 |
+
|
| 577 |
+
with gr.Column():
|
| 578 |
+
gr.Markdown("#### Lighter Attack CNN (94K params)")
|
| 579 |
+
lighter_btn = gr.Button(
|
| 580 |
+
"Predict with Lighter Attack",
|
| 581 |
+
variant="secondary",
|
| 582 |
+
interactive=lighter_attack_model is not None
|
| 583 |
+
)
|
| 584 |
+
lighter_output = gr.Markdown()
|
| 585 |
+
lighter_plot = gr.Label(label="Probability Distribution", num_top_classes=10)
|
| 586 |
+
|
| 587 |
+
with gr.Column():
|
| 588 |
+
gr.Markdown("#### Depthwise Attack CNN (1.4K params)")
|
| 589 |
+
depthwise_btn = gr.Button(
|
| 590 |
+
"Predict with Depthwise Attack",
|
| 591 |
+
variant="secondary",
|
| 592 |
+
interactive=depthwise_attack_model is not None
|
| 593 |
+
)
|
| 594 |
+
depthwise_output = gr.Markdown()
|
| 595 |
+
depthwise_plot = gr.Label(label="Probability Distribution", num_top_classes=10)
|
| 596 |
+
|
| 597 |
+
with gr.Tab("βοΈ Compare All Models"):
|
| 598 |
+
compare_btn = gr.Button(
|
| 599 |
+
"Compare All 6 Models",
|
| 600 |
+
variant="primary",
|
| 601 |
+
size="lg",
|
| 602 |
+
interactive=True
|
| 603 |
+
)
|
| 604 |
+
|
| 605 |
+
gr.Markdown("### Shifted MNIST Models")
|
| 606 |
+
with gr.Row():
|
| 607 |
+
with gr.Column():
|
| 608 |
+
gr.Markdown("#### CNNModel")
|
| 609 |
+
compare_cnn_output = gr.Markdown()
|
| 610 |
+
compare_cnn_plot = gr.Label(label="Probability Distribution", num_top_classes=10)
|
| 611 |
+
|
| 612 |
+
with gr.Column():
|
| 613 |
+
gr.Markdown("#### TinyCNN")
|
| 614 |
+
compare_tiny_output = gr.Markdown()
|
| 615 |
+
compare_tiny_plot = gr.Label(label="Probability Distribution", num_top_classes=10)
|
| 616 |
+
|
| 617 |
+
with gr.Column():
|
| 618 |
+
gr.Markdown("#### MiniCNN")
|
| 619 |
+
compare_mini_output = gr.Markdown()
|
| 620 |
+
compare_mini_plot = gr.Label(label="Probability Distribution", num_top_classes=10)
|
| 621 |
+
|
| 622 |
+
gr.Markdown("---")
|
| 623 |
+
gr.Markdown("### Attack CNN Models")
|
| 624 |
+
with gr.Row():
|
| 625 |
+
with gr.Column():
|
| 626 |
+
gr.Markdown("#### Standard Attack CNN")
|
| 627 |
+
compare_standard_output = gr.Markdown()
|
| 628 |
+
compare_standard_plot = gr.Label(label="Probability Distribution", num_top_classes=10)
|
| 629 |
+
|
| 630 |
+
with gr.Column():
|
| 631 |
+
gr.Markdown("#### Lighter Attack CNN")
|
| 632 |
+
compare_lighter_output = gr.Markdown()
|
| 633 |
+
compare_lighter_plot = gr.Label(label="Probability Distribution", num_top_classes=10)
|
| 634 |
+
|
| 635 |
+
with gr.Column():
|
| 636 |
+
gr.Markdown("#### Depthwise Attack CNN")
|
| 637 |
+
compare_depthwise_output = gr.Markdown()
|
| 638 |
+
compare_depthwise_plot = gr.Label(label="Probability Distribution", num_top_classes=10)
|
| 639 |
+
|
| 640 |
+
# Connect buttons to functions
|
| 641 |
+
cnn_btn.click(predict_cnn, inputs=input_image, outputs=[cnn_output, cnn_plot])
|
| 642 |
+
tiny_btn.click(predict_tinycnn, inputs=input_image, outputs=[tiny_output, tiny_plot])
|
| 643 |
+
mini_btn.click(predict_minicnn, inputs=input_image, outputs=[mini_output, mini_plot])
|
| 644 |
+
standard_btn.click(predict_standard_attack, inputs=input_image, outputs=[standard_output, standard_plot])
|
| 645 |
+
lighter_btn.click(predict_lighter_attack, inputs=input_image, outputs=[lighter_output, lighter_plot])
|
| 646 |
+
depthwise_btn.click(predict_depthwise_attack, inputs=input_image, outputs=[depthwise_output, depthwise_plot])
|
| 647 |
+
|
| 648 |
+
compare_btn.click(
|
| 649 |
+
predict_all_models,
|
| 650 |
+
inputs=input_image,
|
| 651 |
+
outputs=[
|
| 652 |
+
compare_cnn_output, compare_cnn_plot,
|
| 653 |
+
compare_tiny_output, compare_tiny_plot,
|
| 654 |
+
compare_mini_output, compare_mini_plot,
|
| 655 |
+
compare_standard_output, compare_standard_plot,
|
| 656 |
+
compare_lighter_output, compare_lighter_plot,
|
| 657 |
+
compare_depthwise_output, compare_depthwise_plot
|
| 658 |
+
]
|
| 659 |
+
)
|
| 660 |
+
|
| 661 |
+
# Launch the app
|
| 662 |
+
if __name__ == "__main__":
|
| 663 |
+
print("\nπ Launching Gradio app...")
|
| 664 |
+
demo.launch(
|
| 665 |
+
server_name="0.0.0.0",
|
| 666 |
+
server_port=7860,
|
| 667 |
+
share=False,
|
| 668 |
+
show_error=True
|
| 669 |
+
)
|
checkpoints/best_CNN_model_acc_99.33.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:51a5914922ff3b80b94400de986898c73c30e272f29e564e8d2c13250729e9c9
|
| 3 |
+
size 3281923
|
checkpoints/best_MiniCNN_model_acc_97.57.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:623c173a555e5188c26a237faa2fbc7b76e1223f2d63c7085667efde158bb61d
|
| 3 |
+
size 87745
|
checkpoints/best_TinyCNN_model_acc_99.17.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6f04ea1c37b6324de1dfc1aed354d1d54661a1603c1142ae979deabbd98ab849
|
| 3 |
+
size 388863
|
checkpoints/best_depthwise_attack_CNN_model.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e70fbe1ac7fd7049b4b49439d462244056df14d454e7e1a8b8f65c7d0e9e04c4
|
| 3 |
+
size 12964
|
checkpoints/best_lighter_attack_CNN_model.pth.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8d259b94336d36bacd9ff444ea4587206e0a8edb20f93c2a9364f3d1ae755f21
|
| 3 |
+
size 386739
|
checkpoints/best_standard_attack_CNN_model.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ac7a174621fb3650e51698a862949cdeb295631fd46cdf4f1ff917c840907b5e
|
| 3 |
+
size 3279710
|
requirements.txt
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
gradio>=4.0.0
|
| 2 |
+
torch>=2.0.0
|
| 3 |
+
torchvision>=0.15.0
|
| 4 |
+
Pillow>=10.0.0
|
| 5 |
+
numpy>=1.24.0
|
setup_models.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Setup script to help configure model paths for the Gradio app
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
import glob
|
| 8 |
+
|
| 9 |
+
def find_models():
|
| 10 |
+
"""Find available model files"""
|
| 11 |
+
model_dir = os.path.join(os.path.dirname(__file__), '..', 'models')
|
| 12 |
+
model_dir = os.path.abspath(model_dir)
|
| 13 |
+
|
| 14 |
+
print("="*70)
|
| 15 |
+
print("π Searching for model files")
|
| 16 |
+
print("="*70)
|
| 17 |
+
print(f"π Model directory: {model_dir}\n")
|
| 18 |
+
|
| 19 |
+
if not os.path.exists(model_dir):
|
| 20 |
+
print(f"β Model directory not found: {model_dir}")
|
| 21 |
+
print("Please create the directory and add your trained models.\n")
|
| 22 |
+
return
|
| 23 |
+
|
| 24 |
+
# Find all .pth files
|
| 25 |
+
pth_files = glob.glob(os.path.join(model_dir, "*.pth"))
|
| 26 |
+
|
| 27 |
+
if not pth_files:
|
| 28 |
+
print("β No .pth model files found in the models directory")
|
| 29 |
+
print("\nπ‘ To train models, run:")
|
| 30 |
+
print(" cd ../src/model/shifted_CNN")
|
| 31 |
+
print(" python main.py --mode train --model_type cnn --epochs 5")
|
| 32 |
+
print(" python main.py --mode train --model_type tinycnn --epochs 5")
|
| 33 |
+
print(" python main.py --mode train --model_type minicnn --epochs 5")
|
| 34 |
+
return
|
| 35 |
+
|
| 36 |
+
print(f"β
Found {len(pth_files)} model file(s):\n")
|
| 37 |
+
|
| 38 |
+
# Categorize models
|
| 39 |
+
cnn_models = []
|
| 40 |
+
tinycnn_models = []
|
| 41 |
+
minicnn_models = []
|
| 42 |
+
other_models = []
|
| 43 |
+
|
| 44 |
+
for file in pth_files:
|
| 45 |
+
basename = os.path.basename(file)
|
| 46 |
+
if 'CNN_model' in basename and 'Tiny' not in basename and 'Mini' not in basename:
|
| 47 |
+
cnn_models.append(file)
|
| 48 |
+
elif 'TinyCNN' in basename:
|
| 49 |
+
tinycnn_models.append(file)
|
| 50 |
+
elif 'MiniCNN' in basename:
|
| 51 |
+
minicnn_models.append(file)
|
| 52 |
+
else:
|
| 53 |
+
other_models.append(file)
|
| 54 |
+
|
| 55 |
+
# Display findings
|
| 56 |
+
if cnn_models:
|
| 57 |
+
print("π¦ CNNModel files:")
|
| 58 |
+
for model in cnn_models:
|
| 59 |
+
print(f" β {os.path.basename(model)}")
|
| 60 |
+
else:
|
| 61 |
+
print("β οΈ No CNNModel files found")
|
| 62 |
+
|
| 63 |
+
print()
|
| 64 |
+
|
| 65 |
+
if tinycnn_models:
|
| 66 |
+
print("π¦ TinyCNN files:")
|
| 67 |
+
for model in tinycnn_models:
|
| 68 |
+
print(f" β {os.path.basename(model)}")
|
| 69 |
+
else:
|
| 70 |
+
print("β οΈ No TinyCNN files found")
|
| 71 |
+
|
| 72 |
+
print()
|
| 73 |
+
|
| 74 |
+
if minicnn_models:
|
| 75 |
+
print("π¦ MiniCNN files:")
|
| 76 |
+
for model in minicnn_models:
|
| 77 |
+
print(f" β {os.path.basename(model)}")
|
| 78 |
+
else:
|
| 79 |
+
print("β οΈ No MiniCNN files found")
|
| 80 |
+
|
| 81 |
+
if other_models:
|
| 82 |
+
print("\nπ¦ Other model files:")
|
| 83 |
+
for model in other_models:
|
| 84 |
+
print(f" β {os.path.basename(model)}")
|
| 85 |
+
|
| 86 |
+
print("\n" + "="*70)
|
| 87 |
+
print("π Summary")
|
| 88 |
+
print("="*70)
|
| 89 |
+
print(f"Total models found: {len(pth_files)}")
|
| 90 |
+
print(f"CNNModel: {len(cnn_models)}")
|
| 91 |
+
print(f"TinyCNN: {len(tinycnn_models)}")
|
| 92 |
+
print(f"MiniCNN: {len(minicnn_models)}")
|
| 93 |
+
print(f"Other: {len(other_models)}")
|
| 94 |
+
|
| 95 |
+
print("\nπ‘ Tips:")
|
| 96 |
+
print("1. The Gradio app will automatically detect these models")
|
| 97 |
+
print("2. Models should be named with pattern: best_[ModelType]_model_acc_XX.XX.pth")
|
| 98 |
+
print("3. If models are not loading, check the file paths in app.py")
|
| 99 |
+
|
| 100 |
+
print("\nπ Ready to launch!")
|
| 101 |
+
print("Run: python app.py")
|
| 102 |
+
print("="*70)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
if __name__ == "__main__":
|
| 106 |
+
find_models()
|