Spaces:
Sleeping
Sleeping
Commit
·
43124a6
0
Parent(s):
Initial commit of Food101 Classification
Browse files- .gitattributes +3 -0
- .gitignore +18 -0
- README.md +206 -0
- assets/banner.png +3 -0
- assets/confusion_matrix.png +3 -0
- assets/gradio.png +3 -0
- assets/onion_rings.jpg +3 -0
- assets/oysters.jpg +3 -0
- assets/pizza.jpg +3 -0
- assets/ramen.jpg +3 -0
- checkpoints/best-model-epoch=22-val_acc=0.8541.ckpt +3 -0
- notebooks/food101_classification.ipynb +1055 -0
- requirements.txt +0 -0
- scripts/app.py +82 -0
- scripts/class_names.py +22 -0
- scripts/main.py +229 -0
- scripts/models.py +266 -0
- scripts/prepare_data.py +240 -0
.gitattributes
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.jpg filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Python
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.pyc
|
| 4 |
+
|
| 5 |
+
# Virtual Environment
|
| 6 |
+
venv/
|
| 7 |
+
.venv/
|
| 8 |
+
|
| 9 |
+
# Data and Logs
|
| 10 |
+
data/
|
| 11 |
+
logs/
|
| 12 |
+
notebooks/data/
|
| 13 |
+
notebooks/logs/
|
| 14 |
+
|
| 15 |
+
# IDE files
|
| 16 |
+
.vscode/
|
| 17 |
+
.idea/
|
| 18 |
+
|
README.md
ADDED
|
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+

|
| 2 |
+
|
| 3 |
+
[](https://www.python.org/)[](https://pytorch.org/)[](LICENSE)
|
| 4 |
+
|
| 5 |
+
# 🍽️ Food-101 Image Classification with EfficientNetV2-S and PyTorch Lightning
|
| 6 |
+
|
| 7 |
+
This repository contains the code for an end-to-end deep learning project to classify 101 food categories from the challenging Food-101 dataset. The project demonstrates a systematic approach to model selection, fine-tuning, and hyperparameter optimization, achieving a final validation accuracy of **85.4%** on the full dataset.
|
| 8 |
+
|
| 9 |
+
The entire training and evaluation pipeline is built using modern, reproducible practices with PyTorch Lightning.
|
| 10 |
+
|
| 11 |
+
---
|
| 12 |
+
|
| 13 |
+
## 📑 Table of Contents
|
| 14 |
+
|
| 15 |
+
- [�️ Food-101 Image Classification with EfficientNetV2-S and PyTorch Lightning](#️-food-101-image-classification-with-efficientnetv2-s-and-pytorch-lightning)
|
| 16 |
+
- [📑 Table of Contents](#-table-of-contents)
|
| 17 |
+
- [🎯 Project Highlights](#-project-highlights)
|
| 18 |
+
- [💡 Real-World Applications](#-real-world-applications)
|
| 19 |
+
- [🧫 Experimental Results](#-experimental-results)
|
| 20 |
+
- [📊 Final Results](#-final-results)
|
| 21 |
+
- [🔬 Performance Analysis and Error Diagnosis](#-performance-analysis-and-error-diagnosis)
|
| 22 |
+
- [🍤 Lowest-Performing Classes](#-lowest-performing-classes)
|
| 23 |
+
- [Root Cause Analysis of Misclassifications](#root-cause-analysis-of-misclassifications)
|
| 24 |
+
- [Future Work](#future-work)
|
| 25 |
+
- [🧪 Methodology and Experimental Process](#-methodology-and-experimental-process)
|
| 26 |
+
- [📁 Repository Structure](#-repository-structure)
|
| 27 |
+
- [🚀 Getting Started](#-getting-started)
|
| 28 |
+
- [Prerequisites](#prerequisites)
|
| 29 |
+
- [Installation](#installation)
|
| 30 |
+
- [Usage](#usage)
|
| 31 |
+
- [💻 Technologies Used](#-technologies-used)
|
| 32 |
+
|
| 33 |
+
---
|
| 34 |
+
|
| 35 |
+
## 🎯 Project Highlights
|
| 36 |
+
|
| 37 |
+
- **High-Performance Model** ⚡: Utilizes a pre-trained `EfficientNetV2-S`, selected for its excellent balance of accuracy and computational efficiency suitable for potential edge deployment.
|
| 38 |
+
- **Reproducible Pipeline** 🔄: Encapsulates the entire workflow—from data loading to training and evaluation—in a clean and organized `LightningModule` and `DataModule`.
|
| 39 |
+
- **Efficient Experimentation** ⏱️: Overcame hardware limitations by implementing dataset subsetting for rapid prototyping.
|
| 40 |
+
- **Advanced Fine-Tuning** 🛠️: Implemented a robust fine-tuning strategy, unfreezing the final three blocks of the feature extractor and using the `Adam` optimizer with a `CosineAnnealingLR` scheduler for stable convergence.
|
| 41 |
+
- **In-Depth Analysis** 🔎: Went beyond simple accuracy by calculating and logging per-class F1-scores and accuracies, enabling a deep dive into the model's strengths and weaknesses.
|
| 42 |
+
- **Live Deployment** 📺: The final model is deployed and accessible as an interactive Gradio web application on Hugging Face Spaces.
|
| 43 |
+
|
| 44 |
+
---
|
| 45 |
+
|
| 46 |
+
## 💡 Real-World Applications
|
| 47 |
+
|
| 48 |
+
Beyond being a technical challenge, this food classification model serves as a foundation for numerous real-world applications in health, hospitality, and smart home technology.
|
| 49 |
+
|
| 50 |
+
- **Health and Nutrition Tracking**
|
| 51 |
+
- **Automated Calorie Counting:** Users can snap a photo of their meal, and an app can automatically identify each food item to provide an instant estimate of calories, macros, and other nutritional information.
|
| 52 |
+
- **Dietary Management:** Assists individuals with allergies or specific dietary needs (e.g., diabetes, gluten-free) by helping them identify and log their food intake accurately.
|
| 53 |
+
- **Restaurant and Hospitality Tech**
|
| 54 |
+
- **Self-Checkout Systems:** In cafeterias or quick-service restaurants, a camera-based system could identify all items on a tray to automate the billing process, reducing queues and improving efficiency.
|
| 55 |
+
- **Interactive Menus:** Allow diners to point their phone at a dish to get more information, such as ingredients, allergen warnings, or customer reviews.
|
| 56 |
+
|
| 57 |
+
- **Smart Home and Appliances**
|
| 58 |
+
- **Smart Refrigerators:** A fridge equipped with a camera could identify leftover dishes, suggest recipes based on available food, and help track food spoilage to reduce waste.
|
| 59 |
+
|
| 60 |
+
---
|
| 61 |
+
|
| 62 |
+
## 🧫 Experimental Results
|
| 63 |
+
|
| 64 |
+
This project followed an iterative approach. The table below summarizes the key experiments and their outcomes, showing the progression from the initial baseline to the final model.
|
| 65 |
+
|
| 66 |
+
| Model | Training Strategy | Data % | Key Hyperparameters | Final Val Accuracy |
|
| 67 |
+
| :--- | :--- | :--- | :--- | :--- |
|
| 68 |
+
| `EfficientNet-B2` | Simple fine-tune (last block) | 50% | `lr=1e-4` | ~64% |
|
| 69 |
+
| `EfficientNet-B2` | Unfreeze last 3 blocks | 50% | `lr=1e-3` | 82.0% |
|
| 70 |
+
| `EfficientNet-B2` | Two-Stage Fine-Tuning | 50% | `lr1=1e-3`, `lr2=1e-5` | Performance Degraded |
|
| 71 |
+
| **`EfficientNetV2-S`** | Unfreeze last 3 blocks | 50% | `lr=1e-4` (Tuned) | 82.4% |
|
| 72 |
+
| **`EfficientNetV2-S`** | Unfreeze last 3 blocks and more advanced transforms | 50% | `lr=1e-4` (Tuned) | ~82.4% Pretty much the same Performance|
|
| 73 |
+
| **`EfficientNetV2-S`** | **Unfreeze last 3 blocks** | **100%** | **`lr=1e-4` (Tuned)** | **85.4%** |
|
| 74 |
+
|
| 75 |
+
---
|
| 76 |
+
|
| 77 |
+
## 📊 Final Results
|
| 78 |
+
|
| 79 |
+
After systematically iterating on model architecture and hyperparameters, the final model achieved the following performance on the full Food-101 validation set:
|
| 80 |
+
|
| 81 |
+
| Metric | Score |
|
| 82 |
+
| :------------------ | :------ |
|
| 83 |
+
| Validation Accuracy | **85.4%** |
|
| 84 |
+
|
| 85 |
+

|
| 86 |
+
*A confusion matrix visualization helps diagnose the model's performance on a per-class basis. (Replace with your own plot)*
|
| 87 |
+
|
| 88 |
+
This model is deployed and accessible as an interactive Gradio web application on Hugging Face Spaces.
|
| 89 |
+
|
| 90 |
+

|
| 91 |
+
|
| 92 |
+
Check out my [Food101 Gradio Demo](https://huggingface.co/spaces/your-username/food101-demo).
|
| 93 |
+
|
| 94 |
+
---
|
| 95 |
+
|
| 96 |
+
## 🔬 Performance Analysis and Error Diagnosis
|
| 97 |
+
|
| 98 |
+
Beyond the aggregate accuracy, a per-class analysis was conducted to identify the model's specific limitations and diagnose the root causes of misclassifications.
|
| 99 |
+
|
| 100 |
+
The model performed exceptionally well on many classes but struggled with a distinct set of categories, primarily due to visual ambiguity and high variability in appearance.
|
| 101 |
+
|
| 102 |
+
#### 🍤 Lowest-Performing Classes
|
| 103 |
+
|
| 104 |
+
The following five classes had the lowest validation accuracy:
|
| 105 |
+
|
| 106 |
+
| Class Name | Index | Validation Accuracy |
|
| 107 |
+
| :------------------ | :---- | :------------------ |
|
| 108 |
+
| `shrimp_and_grits` | 93 | 44.0% |
|
| 109 |
+
| `ravioli` | 77 | 59.2% |
|
| 110 |
+
| `apple_pie` | 0 | 61.6% |
|
| 111 |
+
| `huevos_rancheros` | 56 | 63.2% |
|
| 112 |
+
| `falafel` | 36 | 63.6% |
|
| 113 |
+
|
| 114 |
+
#### Root Cause Analysis of Misclassifications
|
| 115 |
+
|
| 116 |
+
* **High Intra-Class Variation**: The model struggled with dishes that have no single, consistent appearance.
|
| 117 |
+
* **Fine-Grained Confusion**: Errors occurred between visually similar classes like `ravioli` vs. `dumplings`.
|
| 118 |
+
* **Ambiguous Features**: Foods like `falafel` resemble many small fried dishes, making classification tricky.
|
| 119 |
+
|
| 120 |
+
#### Future Work
|
| 121 |
+
|
| 122 |
+
Improvements could include:
|
| 123 |
+
|
| 124 |
+
- Detailed confusion matrix analysis 🔍
|
| 125 |
+
- More aggressive data augmentation 📈
|
| 126 |
+
- Larger architectures for fine-grained recognition 🏋️
|
| 127 |
+
- Training for longer 🏋️
|
| 128 |
+
|
| 129 |
+
---
|
| 130 |
+
|
| 131 |
+
## 🧪 Methodology and Experimental Process
|
| 132 |
+
|
| 133 |
+
Steps taken in the project:
|
| 134 |
+
|
| 135 |
+
1. **Baseline Establishment** 🏁 – EfficientNet-B2 achieved ~64%.
|
| 136 |
+
2. **Architecture Selection** 🏗️ – EfficientNetV2-S chosen for balance of accuracy and size.
|
| 137 |
+
3. **Transforms Selection** 🎨 – TrivialAugmentWide + RandomResizedCrop, RandAugment, etc.
|
| 138 |
+
4. **Fine-Tuning Strategy** 🔧 – Final 3 blocks unfrozen for training.
|
| 139 |
+
5. **Final Model Training** 🏆 – Full dataset, Adam, CosineAnnealingLR, EarlyStopping → 85.4%.
|
| 140 |
+
|
| 141 |
+
---
|
| 142 |
+
|
| 143 |
+
## 📁 Repository Structure
|
| 144 |
+
|
| 145 |
+
```bash
|
| 146 |
+
food-101-classification/
|
| 147 |
+
├── data/
|
| 148 |
+
├── logs/
|
| 149 |
+
├── scripts/
|
| 150 |
+
│ ├── main.py
|
| 151 |
+
│ ├── models.py
|
| 152 |
+
│ ├── class_names.py
|
| 153 |
+
│ ├── app.py
|
| 154 |
+
│ └── prepare_data.py
|
| 155 |
+
├── .gitignore
|
| 156 |
+
├── requirements.txt
|
| 157 |
+
└── README.md
|
| 158 |
+
```
|
| 159 |
+
|
| 160 |
+
---
|
| 161 |
+
|
| 162 |
+
## 🚀 Getting Started
|
| 163 |
+
|
| 164 |
+
### Prerequisites
|
| 165 |
+
|
| 166 |
+
- Python 3.10+ 🐍
|
| 167 |
+
- PyTorch 🔥
|
| 168 |
+
- CUDA-enabled GPU (recommended) 🎮
|
| 169 |
+
|
| 170 |
+
### Installation
|
| 171 |
+
|
| 172 |
+
1. **Clone the repository:**
|
| 173 |
+
|
| 174 |
+
```bash
|
| 175 |
+
git clone https://github.com/Deathshot78/Food101-Classification
|
| 176 |
+
cd Food101-Classification
|
| 177 |
+
```
|
| 178 |
+
|
| 179 |
+
2. **Install the dependencies:**
|
| 180 |
+
|
| 181 |
+
```bash
|
| 182 |
+
pip install -r requirements.txt
|
| 183 |
+
```
|
| 184 |
+
|
| 185 |
+
### Usage
|
| 186 |
+
|
| 187 |
+
Run training with a subset for quick testing:
|
| 188 |
+
|
| 189 |
+
```bash
|
| 190 |
+
python main.py
|
| 191 |
+
```
|
| 192 |
+
|
| 193 |
+
### 💻 Technologies Used
|
| 194 |
+
|
| 195 |
+
- Python
|
| 196 |
+
|
| 197 |
+
- PyTorch
|
| 198 |
+
|
| 199 |
+
- PyTorch Lightning
|
| 200 |
+
|
| 201 |
+
- TorchMetrics
|
| 202 |
+
|
| 203 |
+
- Gradio
|
| 204 |
+
|
| 205 |
+
- Matplotlib & Seaborn
|
| 206 |
+
|
assets/banner.png
ADDED
|
Git LFS Details
|
assets/confusion_matrix.png
ADDED
|
Git LFS Details
|
assets/gradio.png
ADDED
|
Git LFS Details
|
assets/onion_rings.jpg
ADDED
|
Git LFS Details
|
assets/oysters.jpg
ADDED
|
Git LFS Details
|
assets/pizza.jpg
ADDED
|
Git LFS Details
|
assets/ramen.jpg
ADDED
|
Git LFS Details
|
checkpoints/best-model-epoch=22-val_acc=0.8541.ckpt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6663c5df734fa5e9a50fab801a746ebd095732faacb1938368444518c3265615
|
| 3 |
+
size 230292623
|
notebooks/food101_classification.ipynb
ADDED
|
@@ -0,0 +1,1055 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"id": "430db510",
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"source": [
|
| 8 |
+
"# Food-101 Image Classification with EfficientNetV2-S and PyTorch Lightning\n",
|
| 9 |
+
"\n",
|
| 10 |
+
"This repository contains the code for an end-to-end deep learning project to classify 101 food categories from the challenging Food-101 dataset. The project demonstrates a systematic approach to model selection, fine-tuning, and hyperparameter optimization, achieving a final validation accuracy of **85.4%** on the full dataset.\n",
|
| 11 |
+
"\n",
|
| 12 |
+
"The entire training and evaluation pipeline is built using modern, reproducible practices with PyTorch Lightning."
|
| 13 |
+
]
|
| 14 |
+
},
|
| 15 |
+
{
|
| 16 |
+
"cell_type": "markdown",
|
| 17 |
+
"id": "1116006e",
|
| 18 |
+
"metadata": {},
|
| 19 |
+
"source": [
|
| 20 |
+
"## 1. Imports"
|
| 21 |
+
]
|
| 22 |
+
},
|
| 23 |
+
{
|
| 24 |
+
"cell_type": "code",
|
| 25 |
+
"execution_count": null,
|
| 26 |
+
"id": "531943f8",
|
| 27 |
+
"metadata": {},
|
| 28 |
+
"outputs": [],
|
| 29 |
+
"source": [
|
| 30 |
+
"import torch\n",
|
| 31 |
+
"import matplotlib.pyplot as plt\n",
|
| 32 |
+
"import pandas as pd\n",
|
| 33 |
+
"import numpy as np\n",
|
| 34 |
+
"import os\n",
|
| 35 |
+
"import pytorch_lightning as pl\n",
|
| 36 |
+
"import torch.optim.lr_scheduler as lr_scheduler\n",
|
| 37 |
+
"import torchvision\n",
|
| 38 |
+
"\n",
|
| 39 |
+
"from torchmetrics.functional import accuracy\n",
|
| 40 |
+
"from torchvision import transforms, datasets\n",
|
| 41 |
+
"from torchinfo import summary\n",
|
| 42 |
+
"from pytorch_lightning.callbacks import EarlyStopping\n",
|
| 43 |
+
"from pytorch_lightning.loggers import TensorBoardLogger\n",
|
| 44 |
+
"from torch import nn\n",
|
| 45 |
+
"from pathlib import Path\n",
|
| 46 |
+
"from torch.utils.data import DataLoader, TensorDataset, Dataset, random_split"
|
| 47 |
+
]
|
| 48 |
+
},
|
| 49 |
+
{
|
| 50 |
+
"cell_type": "markdown",
|
| 51 |
+
"id": "08a5c10b",
|
| 52 |
+
"metadata": {},
|
| 53 |
+
"source": [
|
| 54 |
+
"## 2. Quick inspection of the top model"
|
| 55 |
+
]
|
| 56 |
+
},
|
| 57 |
+
{
|
| 58 |
+
"cell_type": "code",
|
| 59 |
+
"execution_count": null,
|
| 60 |
+
"id": "e20ad559",
|
| 61 |
+
"metadata": {},
|
| 62 |
+
"outputs": [],
|
| 63 |
+
"source": [
|
| 64 |
+
"# Here we inspect the models classifier layer to match the number of classes in Food101\n",
|
| 65 |
+
"\n",
|
| 66 |
+
"weights = torchvision.models.EfficientNet_V2_S_Weights.DEFAULT\n",
|
| 67 |
+
"model = torchvision.models.efficientnet_v2_s(weights=weights)\n",
|
| 68 |
+
"effnet_v2_s_transforms = weights.transforms()\n",
|
| 69 |
+
"\n",
|
| 70 |
+
"print(model.classifier)"
|
| 71 |
+
]
|
| 72 |
+
},
|
| 73 |
+
{
|
| 74 |
+
"cell_type": "code",
|
| 75 |
+
"execution_count": null,
|
| 76 |
+
"id": "3442b5a9",
|
| 77 |
+
"metadata": {},
|
| 78 |
+
"outputs": [],
|
| 79 |
+
"source": [
|
| 80 |
+
"# Inspect the model\n",
|
| 81 |
+
"\n",
|
| 82 |
+
"summary(model=model,\n",
|
| 83 |
+
" input_size=(1, 3, 224, 224),\n",
|
| 84 |
+
" col_names=['input_size', 'output_size', 'num_params', 'trainable'],\n",
|
| 85 |
+
" col_width=20,\n",
|
| 86 |
+
" row_settings=['var_names'])"
|
| 87 |
+
]
|
| 88 |
+
},
|
| 89 |
+
{
|
| 90 |
+
"cell_type": "code",
|
| 91 |
+
"execution_count": null,
|
| 92 |
+
"id": "bb353575",
|
| 93 |
+
"metadata": {},
|
| 94 |
+
"outputs": [],
|
| 95 |
+
"source": [
|
| 96 |
+
"# This will be the base transforms for training \n",
|
| 97 |
+
"\n",
|
| 98 |
+
"effnet_v2_s_transforms = weights.transforms()\n",
|
| 99 |
+
"train_transforms = torchvision.transforms.Compose([\n",
|
| 100 |
+
" torchvision.transforms.TrivialAugmentWide(),\n",
|
| 101 |
+
" effnet_v2_s_transforms])\n",
|
| 102 |
+
"\n",
|
| 103 |
+
"train_transforms"
|
| 104 |
+
]
|
| 105 |
+
},
|
| 106 |
+
{
|
| 107 |
+
"cell_type": "markdown",
|
| 108 |
+
"id": "363ce600",
|
| 109 |
+
"metadata": {},
|
| 110 |
+
"source": [
|
| 111 |
+
"## 3. Dataset and Torch lightning Datamodule Classes"
|
| 112 |
+
]
|
| 113 |
+
},
|
| 114 |
+
{
|
| 115 |
+
"cell_type": "code",
|
| 116 |
+
"execution_count": null,
|
| 117 |
+
"id": "2a04cc09",
|
| 118 |
+
"metadata": {},
|
| 119 |
+
"outputs": [],
|
| 120 |
+
"source": [
|
| 121 |
+
"from torchvision import datasets\n",
|
| 122 |
+
"from pathlib import Path\n",
|
| 123 |
+
"import os\n",
|
| 124 |
+
"import pytorch_lightning as pl\n",
|
| 125 |
+
"from torch.utils.data import DataLoader, Subset\n",
|
| 126 |
+
"from torchvision import datasets\n",
|
| 127 |
+
"from torchvision import transforms as T\n",
|
| 128 |
+
"import numpy as np\n",
|
| 129 |
+
"import torchvision\n",
|
| 130 |
+
"from torchvision.datasets import Food101\n",
|
| 131 |
+
"from torch.utils.data import DataLoader, Dataset\n",
|
| 132 |
+
"from typing import Dict, Tuple, Any\n",
|
| 133 |
+
"import random\n",
|
| 134 |
+
"\n",
|
| 135 |
+
"\n",
|
| 136 |
+
"def get_model_components(\n",
|
| 137 |
+
" model_name: str, \n",
|
| 138 |
+
" return_classifier: bool = False, \n",
|
| 139 |
+
" augmentation_level: str = \"default\"\n",
|
| 140 |
+
") -> Dict[str, Any]:\n",
|
| 141 |
+
" \"\"\"\n",
|
| 142 |
+
" Retrieves pre-trained model components from torchvision.\n",
|
| 143 |
+
"\n",
|
| 144 |
+
" This function fetches the appropriate weights and transforms for a given\n",
|
| 145 |
+
" model. It supports different levels of training data augmentation.\n",
|
| 146 |
+
"\n",
|
| 147 |
+
" Args:\n",
|
| 148 |
+
" model_name (str): The name of the model to get components for.\n",
|
| 149 |
+
" Supported models include \"EfficientNet_V2_S\" and \"EfficientNet_B2\".\n",
|
| 150 |
+
" return_classifier (bool, optional): If True, the model's classifier\n",
|
| 151 |
+
" head is also returned. Defaults to False.\n",
|
| 152 |
+
" augmentation_level (str, optional): The level of data augmentation to use\n",
|
| 153 |
+
" for the training set. Can be \"default\" or \"strong\". \n",
|
| 154 |
+
" Defaults to \"default\".\n",
|
| 155 |
+
"\n",
|
| 156 |
+
" Returns:\n",
|
| 157 |
+
" Dict[str, Any]: A dictionary containing the requested components.\n",
|
| 158 |
+
" Always includes 'train_transforms' and 'val_transforms'.\n",
|
| 159 |
+
" Includes 'classifier' if return_classifier is True.\n",
|
| 160 |
+
" \n",
|
| 161 |
+
" Raises:\n",
|
| 162 |
+
" ValueError: If model_name or augmentation_level is not supported.\n",
|
| 163 |
+
" \"\"\"\n",
|
| 164 |
+
" model_registry = {\n",
|
| 165 |
+
" \"EfficientNet_V2_S\": (\n",
|
| 166 |
+
" torchvision.models.efficientnet_v2_s,\n",
|
| 167 |
+
" torchvision.models.EfficientNet_V2_S_Weights.DEFAULT\n",
|
| 168 |
+
" ),\n",
|
| 169 |
+
" \"EfficientNet_B2\": (\n",
|
| 170 |
+
" torchvision.models.efficientnet_b2,\n",
|
| 171 |
+
" torchvision.models.EfficientNet_B2_Weights.DEFAULT\n",
|
| 172 |
+
" )\n",
|
| 173 |
+
" }\n",
|
| 174 |
+
"\n",
|
| 175 |
+
" if model_name not in model_registry:\n",
|
| 176 |
+
" raise ValueError(f\"Model '{model_name}' is not supported. \"\n",
|
| 177 |
+
" f\"Supported models are: {list(model_registry.keys())}\")\n",
|
| 178 |
+
"\n",
|
| 179 |
+
" # 1. Look up the model and weights classes\n",
|
| 180 |
+
" model_class, weights_class = model_registry[model_name]\n",
|
| 181 |
+
" weights = weights_class\n",
|
| 182 |
+
" val_transforms = weights.transforms()\n",
|
| 183 |
+
"\n",
|
| 184 |
+
" # 2. Create the training transforms based on the desired level\n",
|
| 185 |
+
" if augmentation_level == \"default\":\n",
|
| 186 |
+
" train_transforms = T.Compose([\n",
|
| 187 |
+
" T.TrivialAugmentWide(),\n",
|
| 188 |
+
" val_transforms # val_transforms includes ToTensor and Normalize\n",
|
| 189 |
+
" ])\n",
|
| 190 |
+
" elif augmentation_level == \"strong\":\n",
|
| 191 |
+
" # Note: We don't need to add ToTensor() or Normalize() here because\n",
|
| 192 |
+
" # they are already included inside the 'val_transforms' pipeline.\n",
|
| 193 |
+
" train_transforms = T.Compose([\n",
|
| 194 |
+
" T.RandomResizedCrop(size=val_transforms.crop_size, scale=(0.7, 1.0)),\n",
|
| 195 |
+
" T.RandomHorizontalFlip(p=0.5),\n",
|
| 196 |
+
" T.RandAugment(num_ops=2, magnitude=9),\n",
|
| 197 |
+
" # RandomErasing should be applied to a tensor, so we apply it after\n",
|
| 198 |
+
" # val_transforms, which handles the PIL -> Tensor conversion.\n",
|
| 199 |
+
" val_transforms, \n",
|
| 200 |
+
" T.RandomErasing(p=0.25, scale=(0.02, 0.33), ratio=(0.3, 3.3), value='random')\n",
|
| 201 |
+
" ])\n",
|
| 202 |
+
" else:\n",
|
| 203 |
+
" raise ValueError(f\"Augmentation level '{augmentation_level}' is not supported. \"\n",
|
| 204 |
+
" f\"Choose from 'default' or 'strong'.\")\n",
|
| 205 |
+
" \n",
|
| 206 |
+
" # 3. Prepare the dictionary to be returned\n",
|
| 207 |
+
" components = {\n",
|
| 208 |
+
" \"train_transforms\": train_transforms,\n",
|
| 209 |
+
" \"val_transforms\": val_transforms\n",
|
| 210 |
+
" }\n",
|
| 211 |
+
"\n",
|
| 212 |
+
" # 4. Optionally, instantiate the model to get the classifier\n",
|
| 213 |
+
" if return_classifier:\n",
|
| 214 |
+
" model = model_class(weights=weights)\n",
|
| 215 |
+
" components[\"classifier\"] = model.classifier\n",
|
| 216 |
+
"\n",
|
| 217 |
+
" return components\n",
|
| 218 |
+
" \n",
|
| 219 |
+
"class CustomFood101(Dataset):\n",
|
| 220 |
+
" \"\"\"A PyTorch Dataset for Food101 with conditional downloading and subset support.\n",
|
| 221 |
+
"\n",
|
| 222 |
+
" This class wraps the torchvision Food101 dataset. It only downloads the data\n",
|
| 223 |
+
" if the specified directory doesn't already exist. It can also create a\n",
|
| 224 |
+
" reproducible, shuffled subset of the data for faster experimentation.\n",
|
| 225 |
+
"\n",
|
| 226 |
+
" Args:\n",
|
| 227 |
+
" split (str): The dataset split, either \"train\" or \"test\".\n",
|
| 228 |
+
" transform (callable, optional): A function/transform to apply to the images.\n",
|
| 229 |
+
" data_dir (str, optional): The directory to store the data. Defaults to \"data\".\n",
|
| 230 |
+
" subset_fraction (float, optional): The fraction of the dataset to use.\n",
|
| 231 |
+
" Defaults to 1.0 (using the full dataset).\n",
|
| 232 |
+
" \"\"\"\n",
|
| 233 |
+
"\n",
|
| 234 |
+
" def __init__(self, split, transform=None, data_dir=\"data\", subset_fraction: float = 0.1):\n",
|
| 235 |
+
" # Check if the dataset already exists before setting the download flag.\n",
|
| 236 |
+
" dataset_path = os.path.join(data_dir, \"food-101\")\n",
|
| 237 |
+
" should_download = not os.path.isdir(dataset_path)\n",
|
| 238 |
+
"\n",
|
| 239 |
+
" # 1. Load the full dataset metadata with the conditional flag\n",
|
| 240 |
+
" self.full_dataset = Food101(root=data_dir, split=split, transform=transform, download=should_download)\n",
|
| 241 |
+
" self.classes = self.full_dataset.classes\n",
|
| 242 |
+
"\n",
|
| 243 |
+
" # 2. Create a reproducible subset of indices\n",
|
| 244 |
+
" if subset_fraction < 1.0:\n",
|
| 245 |
+
" num_samples = int(len(self.full_dataset) * subset_fraction)\n",
|
| 246 |
+
" all_indices = list(range(len(self.full_dataset)))\n",
|
| 247 |
+
" # Shuffle with a fixed seed for reproducibility\n",
|
| 248 |
+
" random.Random(42).shuffle(all_indices)\n",
|
| 249 |
+
" self.indices = all_indices[:num_samples]\n",
|
| 250 |
+
" else:\n",
|
| 251 |
+
" self.indices = list(range(len(self.full_dataset)))\n",
|
| 252 |
+
"\n",
|
| 253 |
+
" def __len__(self):\n",
|
| 254 |
+
" \"\"\"Returns the total number of samples in the subset.\"\"\"\n",
|
| 255 |
+
" return len(self.indices)\n",
|
| 256 |
+
"\n",
|
| 257 |
+
" def __getitem__(self, idx):\n",
|
| 258 |
+
" \"\"\"\n",
|
| 259 |
+
" Fetches the sample for the given subset index and applies the transform.\n",
|
| 260 |
+
" \"\"\"\n",
|
| 261 |
+
" # Map the subset index to the actual index in the full dataset\n",
|
| 262 |
+
" original_idx = self.indices[idx]\n",
|
| 263 |
+
" image, label = self.full_dataset[original_idx]\n",
|
| 264 |
+
" return image, label\n",
|
| 265 |
+
"\n",
|
| 266 |
+
"class Food101DataModule(pl.LightningDataModule):\n",
|
| 267 |
+
" \"\"\"A PyTorch Lightning DataModule for the Food101 dataset.\n",
|
| 268 |
+
"\n",
|
| 269 |
+
" This module encapsulates all data-related logic, including downloading,\n",
|
| 270 |
+
" processing, and creating DataLoaders for the training, validation, and\n",
|
| 271 |
+
" test sets. It uses the CustomFood101 dataset internally and allows for\n",
|
| 272 |
+
" controlling the fraction of data used in the training and validation splits.\n",
|
| 273 |
+
"\n",
|
| 274 |
+
" Args:\n",
|
| 275 |
+
" data_dir (str, optional): Root directory for the data. Defaults to \"data\".\n",
|
| 276 |
+
" batch_size (int, optional): The batch size for DataLoaders. Defaults to 32.\n",
|
| 277 |
+
" num_workers (int, optional): Number of workers for data loading. Defaults to 2.\n",
|
| 278 |
+
" train_transforms (callable, optional): Transformations for the training set.\n",
|
| 279 |
+
" val_transforms (callable, optional): Transformations for the validation/test set.\n",
|
| 280 |
+
" subset_fraction (float, optional): The fraction of data to use for training\n",
|
| 281 |
+
" and validation. Defaults to 1.0.\n",
|
| 282 |
+
" \"\"\"\n",
|
| 283 |
+
" def __init__(self, data_dir=\"data\", batch_size=32, num_workers=2,\n",
|
| 284 |
+
" train_transforms=None, val_transforms=None, subset_fraction: float = 0.5):\n",
|
| 285 |
+
" super().__init__()\n",
|
| 286 |
+
" self.data_dir = data_dir\n",
|
| 287 |
+
" self.batch_size = batch_size\n",
|
| 288 |
+
" self.num_workers = num_workers\n",
|
| 289 |
+
" self.train_transforms = train_transforms\n",
|
| 290 |
+
" self.val_transforms = val_transforms\n",
|
| 291 |
+
" self.subset_fraction = subset_fraction\n",
|
| 292 |
+
"\n",
|
| 293 |
+
" self.classes = []\n",
|
| 294 |
+
"\n",
|
| 295 |
+
" def prepare_data(self):\n",
|
| 296 |
+
" \"\"\"Downloads data if needed.\"\"\"\n",
|
| 297 |
+
" CustomFood101(split='train', data_dir=self.data_dir)\n",
|
| 298 |
+
" CustomFood101(split='test', data_dir=self.data_dir)\n",
|
| 299 |
+
"\n",
|
| 300 |
+
" def setup(self, stage=None):\n",
|
| 301 |
+
" \"\"\"Assigns datasets, passing the subset_fraction.\"\"\"\n",
|
| 302 |
+
" if stage == 'fit' or stage is None:\n",
|
| 303 |
+
" self.train_dataset = CustomFood101(split='train', transform=self.train_transforms,\n",
|
| 304 |
+
" data_dir=self.data_dir, subset_fraction=self.subset_fraction)\n",
|
| 305 |
+
" self.val_dataset = CustomFood101(split='test', transform=self.val_transforms,\n",
|
| 306 |
+
" data_dir=self.data_dir, subset_fraction=self.subset_fraction)\n",
|
| 307 |
+
" self.classes = self.train_dataset.classes\n",
|
| 308 |
+
"\n",
|
| 309 |
+
" if stage == 'test' or stage is None:\n",
|
| 310 |
+
" self.test_dataset = CustomFood101(split='test', transform=self.val_transforms,\n",
|
| 311 |
+
" data_dir=self.data_dir, subset_fraction=1.0) # Use full test set\n",
|
| 312 |
+
" if not self.classes:\n",
|
| 313 |
+
" self.classes = self.test_dataset.classes\n",
|
| 314 |
+
"\n",
|
| 315 |
+
" def train_dataloader(self):\n",
|
| 316 |
+
" return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers)\n",
|
| 317 |
+
"\n",
|
| 318 |
+
" def val_dataloader(self):\n",
|
| 319 |
+
" return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers)\n",
|
| 320 |
+
"\n",
|
| 321 |
+
" def test_dataloader(self):\n",
|
| 322 |
+
" return DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers)\n"
|
| 323 |
+
]
|
| 324 |
+
},
|
| 325 |
+
{
|
| 326 |
+
"cell_type": "code",
|
| 327 |
+
"execution_count": null,
|
| 328 |
+
"id": "93ffd2e2",
|
| 329 |
+
"metadata": {},
|
| 330 |
+
"outputs": [],
|
| 331 |
+
"source": [
|
| 332 |
+
"# Define configuration for the script\n",
|
| 333 |
+
"DATA_DIR = \"data\"\n",
|
| 334 |
+
"MODEL_NAME = \"EfficientNet_V2_S\"\n",
|
| 335 |
+
"BATCH_SIZE = 32\n",
|
| 336 |
+
"\n",
|
| 337 |
+
"print(f\"Running data preparation script for model: {MODEL_NAME}\")\n",
|
| 338 |
+
"\n",
|
| 339 |
+
"# 1. Get model-specific transforms\n",
|
| 340 |
+
"components = get_model_components(MODEL_NAME)\n",
|
| 341 |
+
"train_transforms = components[\"train_transforms\"]\n",
|
| 342 |
+
"val_transforms = components[\"val_transforms\"]\n",
|
| 343 |
+
"\n",
|
| 344 |
+
"# 2. Instantiate the DataModule\n",
|
| 345 |
+
"datamodule = Food101DataModule(\n",
|
| 346 |
+
" data_dir=DATA_DIR,\n",
|
| 347 |
+
" batch_size=BATCH_SIZE,\n",
|
| 348 |
+
" train_transforms=train_transforms,\n",
|
| 349 |
+
" val_transforms=val_transforms,\n",
|
| 350 |
+
" subset_fraction=0.1 # Use a small subset for quick verification\n",
|
| 351 |
+
")\n",
|
| 352 |
+
"\n",
|
| 353 |
+
"# 3. Trigger download and setup\n",
|
| 354 |
+
"datamodule.prepare_data()\n",
|
| 355 |
+
"datamodule.setup(stage='fit')\n",
|
| 356 |
+
"\n",
|
| 357 |
+
"# 4. (Optional) Verification Step\n",
|
| 358 |
+
"print(\"\\n--- Verifying Dataloader ---\")\n",
|
| 359 |
+
"# Get one batch from the training dataloader\n",
|
| 360 |
+
"train_dl = datamodule.train_dataloader()\n",
|
| 361 |
+
"images, labels = next(iter(train_dl))\n",
|
| 362 |
+
"\n",
|
| 363 |
+
"print(f\"Number of classes: {len(datamodule.classes)}\")\n",
|
| 364 |
+
"print(f\"Image batch shape: {images.shape}\")\n",
|
| 365 |
+
"print(f\"Label batch shape: {labels.shape}\")\n",
|
| 366 |
+
"print(\"--- Verification Complete ---\") "
|
| 367 |
+
]
|
| 368 |
+
},
|
| 369 |
+
{
|
| 370 |
+
"cell_type": "markdown",
|
| 371 |
+
"id": "3edf64c2",
|
| 372 |
+
"metadata": {},
|
| 373 |
+
"source": [
|
| 374 |
+
"## 4. Model Classes"
|
| 375 |
+
]
|
| 376 |
+
},
|
| 377 |
+
{
|
| 378 |
+
"cell_type": "code",
|
| 379 |
+
"execution_count": null,
|
| 380 |
+
"id": "eb264fe4",
|
| 381 |
+
"metadata": {},
|
| 382 |
+
"outputs": [],
|
| 383 |
+
"source": [
|
| 384 |
+
"import torch\n",
|
| 385 |
+
"import torchvision\n",
|
| 386 |
+
"import pytorch_lightning as pl\n",
|
| 387 |
+
"from torch import nn\n",
|
| 388 |
+
"from torchmetrics.classification import Accuracy, F1Score, ConfusionMatrix\n",
|
| 389 |
+
"import seaborn as sns\n",
|
| 390 |
+
"import matplotlib.pyplot as plt\n",
|
| 391 |
+
"import pandas as pd\n",
|
| 392 |
+
"import numpy as np\n",
|
| 393 |
+
"\n",
|
| 394 |
+
"class EffNetV2_S(pl.LightningModule):\n",
|
| 395 |
+
" \"\"\"A PyTorch Lightning Module for fine-tuning EfficientNetV2-S.\n",
|
| 396 |
+
"\n",
|
| 397 |
+
" This module encapsulates the EfficientNetV2-S model and provides a flexible\n",
|
| 398 |
+
" fine-tuning strategy. It can be configured for Stage 1 (training only the\n",
|
| 399 |
+
" classifier and later feature blocks) or Stage 2 (training the entire model).\n",
|
| 400 |
+
"\n",
|
| 401 |
+
" Args:\n",
|
| 402 |
+
" lr (float, optional): The learning rate. Defaults to 1e-3.\n",
|
| 403 |
+
" weight_decay (float, optional): Weight decay for the optimizer. Defaults to 1e-4.\n",
|
| 404 |
+
" num_classes (int, optional): The number of output classes. Defaults to 101.\n",
|
| 405 |
+
" class_names (list, optional): A list of class names for logging. Defaults to None.\n",
|
| 406 |
+
" freeze_features (bool, optional): If True, freezes the backbone and unfreezes\n",
|
| 407 |
+
" only the later blocks (Stage 1). If False, all features are trainable\n",
|
| 408 |
+
" (Stage 2). Defaults to True.\n",
|
| 409 |
+
" unfreeze_from_block (int, optional): Which feature block to start unfreezing\n",
|
| 410 |
+
" from. Used only if freeze_features is True. Defaults to -3 (last 3 blocks).\n",
|
| 411 |
+
" \"\"\"\n",
|
| 412 |
+
" \n",
|
| 413 |
+
" def __init__(\n",
|
| 414 |
+
" self,\n",
|
| 415 |
+
" lr: float = 1e-3,\n",
|
| 416 |
+
" weight_decay: float = 1e-4,\n",
|
| 417 |
+
" num_classes: int = 101,\n",
|
| 418 |
+
" class_names: list = None,\n",
|
| 419 |
+
" freeze_features: bool = True, # True = Stage 1, False = Stage 2\n",
|
| 420 |
+
" unfreeze_from_block: int = -3 # Only used if freeze_features=True\n",
|
| 421 |
+
" ):\n",
|
| 422 |
+
" super().__init__()\n",
|
| 423 |
+
" self.save_hyperparameters()\n",
|
| 424 |
+
" self.class_names = class_names if class_names else [str(i) for i in range(num_classes)]\n",
|
| 425 |
+
"\n",
|
| 426 |
+
" # Load pretrained weights\n",
|
| 427 |
+
" weights = torchvision.models.EfficientNet_V2_S_Weights.DEFAULT\n",
|
| 428 |
+
" self.model = torchvision.models.efficientnet_v2_s(weights=weights)\n",
|
| 429 |
+
"\n",
|
| 430 |
+
" # ---- Freezing strategy ----\n",
|
| 431 |
+
" if freeze_features:\n",
|
| 432 |
+
" # Freeze all first\n",
|
| 433 |
+
" for param in self.model.parameters():\n",
|
| 434 |
+
" param.requires_grad = False\n",
|
| 435 |
+
" # Unfreeze from a specific block (default: last 3 blocks)\n",
|
| 436 |
+
" for param in self.model.features[unfreeze_from_block:].parameters():\n",
|
| 437 |
+
" param.requires_grad = True\n",
|
| 438 |
+
" else:\n",
|
| 439 |
+
" # Stage 2: unfreeze everything\n",
|
| 440 |
+
" for param in self.model.parameters():\n",
|
| 441 |
+
" param.requires_grad = True\n",
|
| 442 |
+
"\n",
|
| 443 |
+
" # Classifier head\n",
|
| 444 |
+
" self.model.classifier = nn.Sequential(\n",
|
| 445 |
+
" nn.Dropout(p=0.2, inplace=True),\n",
|
| 446 |
+
" nn.Linear(in_features=1280, out_features=self.hparams.num_classes, bias=True)\n",
|
| 447 |
+
" )\n",
|
| 448 |
+
"\n",
|
| 449 |
+
" # Loss & metrics\n",
|
| 450 |
+
" self.loss_fn = nn.CrossEntropyLoss(label_smoothing=0.1)\n",
|
| 451 |
+
" self.train_accuracy = Accuracy(task=\"multiclass\", num_classes=self.hparams.num_classes)\n",
|
| 452 |
+
" self.val_accuracy = Accuracy(task=\"multiclass\", num_classes=self.hparams.num_classes)\n",
|
| 453 |
+
" self.train_f1 = F1Score(task=\"multiclass\", num_classes=self.hparams.num_classes, average='macro')\n",
|
| 454 |
+
" self.val_f1 = F1Score(task=\"multiclass\", num_classes=self.hparams.num_classes, average='macro')\n",
|
| 455 |
+
" self.val_conf_matrix = ConfusionMatrix(task=\"multiclass\", num_classes=self.hparams.num_classes)\n",
|
| 456 |
+
" self.test_conf_matrix = ConfusionMatrix(task=\"multiclass\", num_classes=self.hparams.num_classes)\n",
|
| 457 |
+
"\n",
|
| 458 |
+
" def forward(self, x):\n",
|
| 459 |
+
" return self.model(x)\n",
|
| 460 |
+
"\n",
|
| 461 |
+
" def training_step(self, batch, batch_idx):\n",
|
| 462 |
+
" x, y = batch\n",
|
| 463 |
+
" logits = self(x)\n",
|
| 464 |
+
" loss = self.loss_fn(logits, y)\n",
|
| 465 |
+
" self.train_accuracy(logits, y)\n",
|
| 466 |
+
" self.train_f1(logits, y)\n",
|
| 467 |
+
" self.log('train_loss', loss, on_step=False, on_epoch=True, prog_bar=True)\n",
|
| 468 |
+
" self.log('train_acc', self.train_accuracy, on_step=False, on_epoch=True, prog_bar=True)\n",
|
| 469 |
+
" self.log('train_f1', self.train_f1, on_step=False, on_epoch=True, prog_bar=True)\n",
|
| 470 |
+
" return loss\n",
|
| 471 |
+
"\n",
|
| 472 |
+
" def validation_step(self, batch, batch_idx):\n",
|
| 473 |
+
" x, y = batch\n",
|
| 474 |
+
" logits = self(x)\n",
|
| 475 |
+
" loss = self.loss_fn(logits, y)\n",
|
| 476 |
+
" self.val_accuracy(logits, y)\n",
|
| 477 |
+
" self.val_f1(logits, y)\n",
|
| 478 |
+
" self.log('val_loss', loss, prog_bar=True)\n",
|
| 479 |
+
" self.log('val_acc', self.val_accuracy, prog_bar=True)\n",
|
| 480 |
+
" self.log('val_f1', self.val_f1, prog_bar=True)\n",
|
| 481 |
+
" self.val_conf_matrix.update(logits, y)\n",
|
| 482 |
+
"\n",
|
| 483 |
+
" def on_validation_epoch_end(self):\n",
|
| 484 |
+
" cm = self.val_conf_matrix.compute()\n",
|
| 485 |
+
" per_class_acc = cm.diag() / (cm.sum(dim=1) + 1e-6)\n",
|
| 486 |
+
" print(\"\\n--- Per-Class Validation Accuracy ---\")\n",
|
| 487 |
+
" for i, acc in enumerate(per_class_acc):\n",
|
| 488 |
+
" self.log(f'val_acc/{self.class_names[i]}', acc.item(), on_epoch=True)\n",
|
| 489 |
+
" print(f\"{self.class_names[i]:<20}: {acc.item():.4f}\")\n",
|
| 490 |
+
" print(\"------------------------------------\")\n",
|
| 491 |
+
" self.val_conf_matrix.reset()\n",
|
| 492 |
+
"\n",
|
| 493 |
+
" def test_step(self, batch, batch_idx):\n",
|
| 494 |
+
" x, y = batch\n",
|
| 495 |
+
" logits = self(x)\n",
|
| 496 |
+
" self.test_conf_matrix.update(logits, y)\n",
|
| 497 |
+
"\n",
|
| 498 |
+
" def on_test_end(self):\n",
|
| 499 |
+
" cm = self.test_conf_matrix.compute()\n",
|
| 500 |
+
" print(\"\\nGenerating final confusion matrix plot...\")\n",
|
| 501 |
+
" self.test_conf_matrix.reset()\n",
|
| 502 |
+
"\n",
|
| 503 |
+
" def configure_optimizers(self):\n",
|
| 504 |
+
" optimizer = torch.optim.Adam(\n",
|
| 505 |
+
" self.parameters(),\n",
|
| 506 |
+
" lr=self.hparams.lr,\n",
|
| 507 |
+
" weight_decay=self.hparams.weight_decay\n",
|
| 508 |
+
" )\n",
|
| 509 |
+
" scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(\n",
|
| 510 |
+
" optimizer,\n",
|
| 511 |
+
" T_max=self.trainer.max_epochs,\n",
|
| 512 |
+
" eta_min=1e-6\n",
|
| 513 |
+
" )\n",
|
| 514 |
+
" return {\"optimizer\": optimizer, \"lr_scheduler\": {\"scheduler\": scheduler, \"interval\": \"epoch\"}}\n",
|
| 515 |
+
" \n",
|
| 516 |
+
"class EffNetb2(pl.LightningModule):\n",
|
| 517 |
+
" \"\"\"A PyTorch Lightning Module for fine-tuning EfficientNet-B2.\n",
|
| 518 |
+
"\n",
|
| 519 |
+
" This module encapsulates the EfficientNet-B2 model and provides a flexible\n",
|
| 520 |
+
" fine-tuning strategy. It can be configured for Stage 1 (training only the\n",
|
| 521 |
+
" classifier and later feature blocks) or Stage 2 (training the entire model).\n",
|
| 522 |
+
"\n",
|
| 523 |
+
" Args:\n",
|
| 524 |
+
" lr (float, optional): The learning rate. Defaults to 1e-3.\n",
|
| 525 |
+
" weight_decay (float, optional): Weight decay for the optimizer. Defaults to 1e-4.\n",
|
| 526 |
+
" num_classes (int, optional): The number of output classes. Defaults to 101.\n",
|
| 527 |
+
" class_names (list, optional): A list of class names for logging. Defaults to None.\n",
|
| 528 |
+
" freeze_features (bool, optional): If True, freezes the backbone and unfreezes\n",
|
| 529 |
+
" only the later blocks (Stage 1). If False, all features are trainable\n",
|
| 530 |
+
" (Stage 2). Defaults to True.\n",
|
| 531 |
+
" unfreeze_from_block (int, optional): Which feature block to start unfreezing\n",
|
| 532 |
+
" from. Used only if freeze_features is True. Defaults to -3 (last 3 blocks).\n",
|
| 533 |
+
" \"\"\"\n",
|
| 534 |
+
"\n",
|
| 535 |
+
" def __init__(\n",
|
| 536 |
+
" self,\n",
|
| 537 |
+
" lr: float = 1e-3,\n",
|
| 538 |
+
" weight_decay: float = 1e-4,\n",
|
| 539 |
+
" num_classes: int = 101,\n",
|
| 540 |
+
" class_names: list = None,\n",
|
| 541 |
+
" freeze_features: bool = True,\n",
|
| 542 |
+
" unfreeze_from_block: int = -3\n",
|
| 543 |
+
" ):\n",
|
| 544 |
+
" super().__init__()\n",
|
| 545 |
+
" self.save_hyperparameters()\n",
|
| 546 |
+
" self.class_names = class_names if class_names is not None else [str(i) for i in range(num_classes)]\n",
|
| 547 |
+
"\n",
|
| 548 |
+
" # Model setup\n",
|
| 549 |
+
" weights = torchvision.models.EfficientNet_B2_Weights.DEFAULT\n",
|
| 550 |
+
" self.model = torchvision.models.efficientnet_b2(weights=weights)\n",
|
| 551 |
+
" \n",
|
| 552 |
+
" # --- : Flexible Freezing Strategy ---\n",
|
| 553 |
+
" if self.hparams.freeze_features:\n",
|
| 554 |
+
" # Stage 1: Freeze all first\n",
|
| 555 |
+
" for param in self.model.parameters():\n",
|
| 556 |
+
" param.requires_grad = False\n",
|
| 557 |
+
" # Unfreeze from a specific block (default: last 3 blocks)\n",
|
| 558 |
+
" for param in self.model.features[self.hparams.unfreeze_from_block:].parameters():\n",
|
| 559 |
+
" param.requires_grad = True\n",
|
| 560 |
+
" else:\n",
|
| 561 |
+
" # Stage 2: unfreeze everything\n",
|
| 562 |
+
" for param in self.model.parameters():\n",
|
| 563 |
+
" param.requires_grad = True\n",
|
| 564 |
+
"\n",
|
| 565 |
+
" # Classifier head\n",
|
| 566 |
+
" self.model.classifier = nn.Sequential(\n",
|
| 567 |
+
" nn.Dropout(p=0.3, inplace=True),\n",
|
| 568 |
+
" nn.Linear(in_features=1408, out_features=self.hparams.num_classes)\n",
|
| 569 |
+
" )\n",
|
| 570 |
+
"\n",
|
| 571 |
+
" # Metrics\n",
|
| 572 |
+
" self.loss_fn = nn.CrossEntropyLoss(label_smoothing=0.1)\n",
|
| 573 |
+
" self.train_accuracy = Accuracy(task=\"multiclass\", num_classes=self.hparams.num_classes)\n",
|
| 574 |
+
" self.val_accuracy = Accuracy(task=\"multiclass\", num_classes=self.hparams.num_classes)\n",
|
| 575 |
+
" self.train_f1 = F1Score(task=\"multiclass\", num_classes=self.hparams.num_classes, average='macro')\n",
|
| 576 |
+
" self.val_f1 = F1Score(task=\"multiclass\", num_classes=self.hparams.num_classes, average='macro')\n",
|
| 577 |
+
" self.val_conf_matrix = ConfusionMatrix(task=\"multiclass\", num_classes=self.hparams.num_classes)\n",
|
| 578 |
+
" self.test_conf_matrix = ConfusionMatrix(task=\"multiclass\", num_classes=self.hparams.num_classes)\n",
|
| 579 |
+
"\n",
|
| 580 |
+
" def forward(self, x):\n",
|
| 581 |
+
" return self.model(x)\n",
|
| 582 |
+
"\n",
|
| 583 |
+
" def training_step(self, batch, batch_idx):\n",
|
| 584 |
+
" x, y = batch\n",
|
| 585 |
+
" logits = self(x)\n",
|
| 586 |
+
" loss = self.loss_fn(logits, y)\n",
|
| 587 |
+
" self.train_accuracy(logits, y)\n",
|
| 588 |
+
" self.train_f1(logits, y)\n",
|
| 589 |
+
" self.log('train_loss', loss, on_step=False, on_epoch=True, prog_bar=True)\n",
|
| 590 |
+
" self.log('train_acc', self.train_accuracy, on_step=False, on_epoch=True, prog_bar=True)\n",
|
| 591 |
+
" self.log('train_f1', self.train_f1, on_step=False, on_epoch=True, prog_bar=True)\n",
|
| 592 |
+
" return loss\n",
|
| 593 |
+
"\n",
|
| 594 |
+
" def validation_step(self, batch, batch_idx):\n",
|
| 595 |
+
" x, y = batch\n",
|
| 596 |
+
" logits = self(x)\n",
|
| 597 |
+
" loss = self.loss_fn(logits, y)\n",
|
| 598 |
+
" self.val_accuracy(logits, y)\n",
|
| 599 |
+
" self.val_f1(logits, y)\n",
|
| 600 |
+
" self.log('val_loss', loss, prog_bar=True)\n",
|
| 601 |
+
" self.log('val_acc', self.val_accuracy, prog_bar=True)\n",
|
| 602 |
+
" self.log('val_f1', self.val_f1, prog_bar=True)\n",
|
| 603 |
+
" self.val_conf_matrix.update(logits, y)\n",
|
| 604 |
+
"\n",
|
| 605 |
+
" def on_validation_epoch_end(self):\n",
|
| 606 |
+
" cm = self.val_conf_matrix.compute()\n",
|
| 607 |
+
"\n",
|
| 608 |
+
" # Add a small epsilon (1e-6) to the denominator for numerical stability.\n",
|
| 609 |
+
" per_class_acc = cm.diag() / (cm.sum(dim=1) + 1e-6)\n",
|
| 610 |
+
"\n",
|
| 611 |
+
" print(\"\\n--- Per-Class Validation Accuracy ---\")\n",
|
| 612 |
+
" for i, acc in enumerate(per_class_acc):\n",
|
| 613 |
+
" class_name = self.class_names[i]\n",
|
| 614 |
+
" self.log(f'val_acc/{class_name}', acc.item(), on_epoch=True)\n",
|
| 615 |
+
" print(f\"{class_name:<20}: {acc.item():.4f}\")\n",
|
| 616 |
+
" print(\"------------------------------------\")\n",
|
| 617 |
+
"\n",
|
| 618 |
+
" self.val_conf_matrix.reset()\n",
|
| 619 |
+
"\n",
|
| 620 |
+
" def test_step(self, batch, batch_idx):\n",
|
| 621 |
+
" x, y = batch\n",
|
| 622 |
+
" logits = self(x)\n",
|
| 623 |
+
" self.test_conf_matrix.update(logits, y)\n",
|
| 624 |
+
"\n",
|
| 625 |
+
" def on_test_end(self):\n",
|
| 626 |
+
" cm = self.test_conf_matrix.compute()\n",
|
| 627 |
+
" print(\"\\nGenerating final confusion matrix plot...\")\n",
|
| 628 |
+
" # Assuming plot_confusion_matrix is defined elsewhere\n",
|
| 629 |
+
" # plot_confusion_matrix(cm.cpu().numpy(), class_names=self.class_names)\n",
|
| 630 |
+
" self.test_conf_matrix.reset()\n",
|
| 631 |
+
"\n",
|
| 632 |
+
" def configure_optimizers(self):\n",
|
| 633 |
+
" optimizer = torch.optim.Adam(\n",
|
| 634 |
+
" self.parameters(),\n",
|
| 635 |
+
" lr=self.hparams.lr,\n",
|
| 636 |
+
" weight_decay=self.hparams.weight_decay\n",
|
| 637 |
+
" )\n",
|
| 638 |
+
" scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(\n",
|
| 639 |
+
" optimizer,\n",
|
| 640 |
+
" T_max=self.trainer.max_epochs,\n",
|
| 641 |
+
" eta_min=1e-6\n",
|
| 642 |
+
" )\n",
|
| 643 |
+
" return {\n",
|
| 644 |
+
" \"optimizer\": optimizer,\n",
|
| 645 |
+
" \"lr_scheduler\": {\n",
|
| 646 |
+
" \"scheduler\": scheduler,\n",
|
| 647 |
+
" \"interval\": \"epoch\",\n",
|
| 648 |
+
" },\n",
|
| 649 |
+
" }\n"
|
| 650 |
+
]
|
| 651 |
+
},
|
| 652 |
+
{
|
| 653 |
+
"cell_type": "markdown",
|
| 654 |
+
"id": "3f5bf233",
|
| 655 |
+
"metadata": {},
|
| 656 |
+
"source": [
|
| 657 |
+
"## 5. Training and plotting the Confusion Matrix"
|
| 658 |
+
]
|
| 659 |
+
},
|
| 660 |
+
{
|
| 661 |
+
"cell_type": "code",
|
| 662 |
+
"execution_count": null,
|
| 663 |
+
"id": "6b080afd",
|
| 664 |
+
"metadata": {},
|
| 665 |
+
"outputs": [],
|
| 666 |
+
"source": [
|
| 667 |
+
"import pytorch_lightning as pl\n",
|
| 668 |
+
"from pytorch_lightning import Trainer, LightningModule\n",
|
| 669 |
+
"from pytorch_lightning.loggers import CSVLogger\n",
|
| 670 |
+
"from pytorch_lightning.callbacks import EarlyStopping ,ModelCheckpoint\n",
|
| 671 |
+
"from typing import Optional\n",
|
| 672 |
+
"import matplotlib.pyplot as plt\n",
|
| 673 |
+
"import seaborn as sns\n",
|
| 674 |
+
"import numpy as np\n",
|
| 675 |
+
"import pandas as pd\n",
|
| 676 |
+
"from typing import List\n",
|
| 677 |
+
"\n",
|
| 678 |
+
"DATA_DIR = \"data\"\n",
|
| 679 |
+
"MODEL_NAME = \"EfficientNet_V2_S\"\n",
|
| 680 |
+
"BATCH_SIZE = 32\n",
|
| 681 |
+
"SUBSET_FRACTION = 0.2 # Useing a smaller subset for quick testing\n",
|
| 682 |
+
"CHECKPOINT_PATH = \"checkpoints/best-model-epoch=22-val_acc=0.8541.ckpt\" # Path to your trained model checkpoint\n",
|
| 683 |
+
"\n",
|
| 684 |
+
"def plot_confusion_matrix(cm: np.ndarray, class_names: List[str], figsize: tuple = (25, 25)):\n",
|
| 685 |
+
" \"\"\"\n",
|
| 686 |
+
" Creates and saves a multi-class confusion matrix plot.\n",
|
| 687 |
+
"\n",
|
| 688 |
+
" This function normalizes the confusion matrix to show prediction\n",
|
| 689 |
+
" percentages for each class, visualizes it as a heatmap, and saves\n",
|
| 690 |
+
" the resulting figure to a file.\n",
|
| 691 |
+
"\n",
|
| 692 |
+
" Args:\n",
|
| 693 |
+
" cm (np.ndarray): The confusion matrix from torchmetrics or scikit-learn.\n",
|
| 694 |
+
" class_names (List[str]): A list of class names for the labels.\n",
|
| 695 |
+
" figsize (tuple, optional): The size of the figure. Defaults to (25, 25).\n",
|
| 696 |
+
" \"\"\"\n",
|
| 697 |
+
" # 1. Normalize the confusion matrix to show percentages\n",
|
| 698 |
+
" # Add a small epsilon to prevent division by zero\n",
|
| 699 |
+
" cm_normalized = cm.astype('float') / (cm.sum(axis=1)[:, np.newaxis] + 1e-6)\n",
|
| 700 |
+
"\n",
|
| 701 |
+
" # 2. Create a DataFrame for a beautiful plot with labels\n",
|
| 702 |
+
" df_cm = pd.DataFrame(cm_normalized, index=class_names, columns=class_names)\n",
|
| 703 |
+
"\n",
|
| 704 |
+
" # 3. Create the plot\n",
|
| 705 |
+
" plt.figure(figsize=figsize)\n",
|
| 706 |
+
" heatmap = sns.heatmap(df_cm, annot=False, cmap='Blues') # Annotations off for 101 classes\n",
|
| 707 |
+
"\n",
|
| 708 |
+
" # 4. Format the plot\n",
|
| 709 |
+
" heatmap.yaxis.set_ticklabels(heatmap.yaxis.get_ticklabels(), rotation=0, ha='right', fontsize=8)\n",
|
| 710 |
+
" heatmap.xaxis.set_ticklabels(heatmap.xaxis.get_ticklabels(), rotation=45, ha='right', fontsize=8)\n",
|
| 711 |
+
"\n",
|
| 712 |
+
" plt.ylabel('True Label')\n",
|
| 713 |
+
" plt.xlabel('Predicted Label')\n",
|
| 714 |
+
" plt.title('Normalized Confusion Matrix')\n",
|
| 715 |
+
" plt.tight_layout()\n",
|
| 716 |
+
"\n",
|
| 717 |
+
" # 5. Save the figure and show the plot\n",
|
| 718 |
+
" plt.savefig('confusion_matrix.png', dpi=300)\n",
|
| 719 |
+
" print(\"Confusion matrix plot saved to confusion_matrix.png\")\n",
|
| 720 |
+
" plt.show()\n",
|
| 721 |
+
"\n",
|
| 722 |
+
"def run_training_session(\n",
|
| 723 |
+
" model_name: str = \"EfficientNet_V2_S\",\n",
|
| 724 |
+
" batch_size: int = 32,\n",
|
| 725 |
+
" data_dir: str = 'data',\n",
|
| 726 |
+
" subset_fraction: float = 1.0,\n",
|
| 727 |
+
" checkpoint_path: str = \"checkpoints/\",\n",
|
| 728 |
+
" lr: float = 1e-3,\n",
|
| 729 |
+
" weight_decay: float = 1e-4,\n",
|
| 730 |
+
" freeze_features: bool = True,\n",
|
| 731 |
+
" early_stopping_patience: int = 5,\n",
|
| 732 |
+
" max_epochs: int = 100,\n",
|
| 733 |
+
" accelerator: str = 'auto',\n",
|
| 734 |
+
" resume_from_checkpoint: Optional[str] = None\n",
|
| 735 |
+
") -> Trainer:\n",
|
| 736 |
+
" \"\"\"\n",
|
| 737 |
+
" Sets up and runs a complete training session for a specified model.\n",
|
| 738 |
+
"\n",
|
| 739 |
+
" This function handles the entire pipeline: data preparation, model\n",
|
| 740 |
+
" instantiation, logger and callback setup, and trainer execution.\n",
|
| 741 |
+
"\n",
|
| 742 |
+
" Args:\n",
|
| 743 |
+
" model_name (str): The name of the model architecture to train.\n",
|
| 744 |
+
" batch_size (int): The number of samples per batch.\n",
|
| 745 |
+
" data_dir (str): The root directory for the dataset.\n",
|
| 746 |
+
" subset_fraction (float): The fraction of the dataset to use for training.\n",
|
| 747 |
+
" checkpoint_path (str): Directory to save model checkpoints.\n",
|
| 748 |
+
" lr (float): The learning rate for the optimizer.\n",
|
| 749 |
+
" weight_decay (float): The weight decay for the optimizer.\n",
|
| 750 |
+
" freeze_features (bool): Flag to control the fine-tuning strategy\n",
|
| 751 |
+
" (e.g., for two-stage training).\n",
|
| 752 |
+
" early_stopping_patience (int): Number of epochs with no improvement\n",
|
| 753 |
+
" after which training will be stopped.\n",
|
| 754 |
+
" max_epochs (int): The maximum number of epochs to train for.\n",
|
| 755 |
+
" accelerator (str): The hardware accelerator to use ('auto', 'cpu', 'gpu').\n",
|
| 756 |
+
" resume_from_checkpoint (Optional[str]): Path to a checkpoint file to\n",
|
| 757 |
+
" resume training from. Defaults to None.\n",
|
| 758 |
+
"\n",
|
| 759 |
+
" Returns:\n",
|
| 760 |
+
" Trainer: The PyTorch Lightning Trainer object after fitting is complete.\n",
|
| 761 |
+
" \"\"\"\n",
|
| 762 |
+
" # A registry to map model names to their actual classes\n",
|
| 763 |
+
" model_class_registry = {\n",
|
| 764 |
+
" \"EfficientNet_V2_S\": EffNetV2_S,\n",
|
| 765 |
+
" \"EfficientNet_B2\": EffNetb2,\n",
|
| 766 |
+
" }\n",
|
| 767 |
+
" if model_name not in model_class_registry:\n",
|
| 768 |
+
" raise ValueError(f\"Model '{model_name}' is not a recognized class.\")\n",
|
| 769 |
+
"\n",
|
| 770 |
+
" # Get model-specific transforms\n",
|
| 771 |
+
" components = get_model_components(model_name)\n",
|
| 772 |
+
" train_transforms = components[\"train_transforms\"]\n",
|
| 773 |
+
" val_transforms = components[\"val_transforms\"]\n",
|
| 774 |
+
"\n",
|
| 775 |
+
" # Set up the DataModule\n",
|
| 776 |
+
" food_datamodule = Food101DataModule(\n",
|
| 777 |
+
" data_dir=data_dir,\n",
|
| 778 |
+
" batch_size=batch_size,\n",
|
| 779 |
+
" train_transforms=train_transforms,\n",
|
| 780 |
+
" val_transforms=val_transforms,\n",
|
| 781 |
+
" subset_fraction=subset_fraction\n",
|
| 782 |
+
" )\n",
|
| 783 |
+
" food_datamodule.prepare_data()\n",
|
| 784 |
+
" food_datamodule.setup()\n",
|
| 785 |
+
"\n",
|
| 786 |
+
" # Instantiate the model dynamically\n",
|
| 787 |
+
" model_class = model_class_registry[model_name]\n",
|
| 788 |
+
" model = model_class(\n",
|
| 789 |
+
" num_classes=len(food_datamodule.classes),\n",
|
| 790 |
+
" class_names=food_datamodule.classes,\n",
|
| 791 |
+
" lr=lr,\n",
|
| 792 |
+
" weight_decay=weight_decay,\n",
|
| 793 |
+
" freeze_features=freeze_features\n",
|
| 794 |
+
" )\n",
|
| 795 |
+
"\n",
|
| 796 |
+
" # Set up logger and callbacks\n",
|
| 797 |
+
" logger = CSVLogger(save_dir=\"logs/\", name=model_name)\n",
|
| 798 |
+
" \n",
|
| 799 |
+
" early_stop_callback = EarlyStopping(\n",
|
| 800 |
+
" monitor=\"val_loss\",\n",
|
| 801 |
+
" patience=early_stopping_patience,\n",
|
| 802 |
+
" mode=\"min\"\n",
|
| 803 |
+
" )\n",
|
| 804 |
+
" best_model_checkpoint = ModelCheckpoint(\n",
|
| 805 |
+
" dirpath=checkpoint_path,\n",
|
| 806 |
+
" filename=\"best-model-{epoch:02d}-{val_acc:.4f}\",\n",
|
| 807 |
+
" save_top_k=1,\n",
|
| 808 |
+
" monitor=\"val_acc\",\n",
|
| 809 |
+
" mode=\"max\"\n",
|
| 810 |
+
" )\n",
|
| 811 |
+
" \n",
|
| 812 |
+
" callbacks = [early_stop_callback, best_model_checkpoint]\n",
|
| 813 |
+
"\n",
|
| 814 |
+
" # Instantiate the Trainer\n",
|
| 815 |
+
" trainer = Trainer(\n",
|
| 816 |
+
" max_epochs=max_epochs,\n",
|
| 817 |
+
" accelerator=accelerator,\n",
|
| 818 |
+
" callbacks=callbacks,\n",
|
| 819 |
+
" logger=logger,\n",
|
| 820 |
+
" )\n",
|
| 821 |
+
"\n",
|
| 822 |
+
" # Start training\n",
|
| 823 |
+
" trainer.fit(\n",
|
| 824 |
+
" model,\n",
|
| 825 |
+
" datamodule=food_datamodule,\n",
|
| 826 |
+
" ckpt_path=resume_from_checkpoint \n",
|
| 827 |
+
" )\n",
|
| 828 |
+
" \n",
|
| 829 |
+
" return trainer\n"
|
| 830 |
+
]
|
| 831 |
+
},
|
| 832 |
+
{
|
| 833 |
+
"cell_type": "code",
|
| 834 |
+
"execution_count": null,
|
| 835 |
+
"id": "04c534dc",
|
| 836 |
+
"metadata": {},
|
| 837 |
+
"outputs": [],
|
| 838 |
+
"source": [
|
| 839 |
+
"# --- 1. DEFINE YOUR TRAINING CONFIGURATION HERE ---\n",
|
| 840 |
+
"config = {\n",
|
| 841 |
+
" \"model_name\": \"EfficientNet_V2_S\",\n",
|
| 842 |
+
" \"batch_size\": 32,\n",
|
| 843 |
+
" \"lr\": 1e-4,\n",
|
| 844 |
+
" \"epochs\": 50,\n",
|
| 845 |
+
" \"subset_fraction\": 1.0, # Use 1.0 for the full dataset\n",
|
| 846 |
+
" \"freeze_features\": True,\n",
|
| 847 |
+
" \"early_stopping_patience\": 10\n",
|
| 848 |
+
"}\n",
|
| 849 |
+
"\n",
|
| 850 |
+
"# --- 2. PRINT CONFIGURATION AND START TRAINING ---\n",
|
| 851 |
+
"print(\"--- Starting Training Session ---\")\n",
|
| 852 |
+
"for key, value in config.items():\n",
|
| 853 |
+
" print(f\" {key}: {value}\")\n",
|
| 854 |
+
"print(\"---------------------------------\")\n",
|
| 855 |
+
"\n",
|
| 856 |
+
"run_training_session(\n",
|
| 857 |
+
" model_name=config[\"model_name\"],\n",
|
| 858 |
+
" batch_size=config[\"batch_size\"],\n",
|
| 859 |
+
" lr=config[\"lr\"],\n",
|
| 860 |
+
" max_epochs=config[\"epochs\"],\n",
|
| 861 |
+
" subset_fraction=config[\"subset_fraction\"],\n",
|
| 862 |
+
" freeze_features=config[\"freeze_features\"],\n",
|
| 863 |
+
" early_stopping_patience=config[\"early_stopping_patience\"]\n",
|
| 864 |
+
")\n",
|
| 865 |
+
"\n",
|
| 866 |
+
"print(\"\\n--- Training Session Complete ---\")\n",
|
| 867 |
+
"\n",
|
| 868 |
+
"print(\"\\n--- Starting Evaluation on Test Set ---\")\n",
|
| 869 |
+
"\n",
|
| 870 |
+
"print(f\"Loading model from checkpoint: {CHECKPOINT_PATH}\")\n",
|
| 871 |
+
"\n",
|
| 872 |
+
"# Step 1: Set up the DataModule for the test set\n",
|
| 873 |
+
"components = get_model_components(MODEL_NAME)\n",
|
| 874 |
+
"val_transforms = components[\"val_transforms\"]\n",
|
| 875 |
+
"\n",
|
| 876 |
+
"datamodule = Food101DataModule(\n",
|
| 877 |
+
" data_dir=DATA_DIR,\n",
|
| 878 |
+
" batch_size=BATCH_SIZE,\n",
|
| 879 |
+
" val_transforms=val_transforms\n",
|
| 880 |
+
")\n",
|
| 881 |
+
"# This prepares the test dataloader specifically\n",
|
| 882 |
+
"datamodule.setup(stage='test')\n",
|
| 883 |
+
"\n",
|
| 884 |
+
"# Step 2: Load the trained model from the checkpoint file\n",
|
| 885 |
+
"model = EffNetV2_S.load_from_checkpoint(CHECKPOINT_PATH)\n",
|
| 886 |
+
"model.class_names = datamodule.classes\n",
|
| 887 |
+
"model.eval() # Set the model to evaluation mode\n",
|
| 888 |
+
"\n",
|
| 889 |
+
"# Step 3: Create a Trainer and run the test\n",
|
| 890 |
+
"trainer = pl.Trainer(accelerator='auto')\n",
|
| 891 |
+
"\n",
|
| 892 |
+
"# This call will run the test_step and automatically trigger the \n",
|
| 893 |
+
"# on_test_end hook in your model, which generates the plot.\n",
|
| 894 |
+
"trainer.test(model, datamodule=datamodule)\n",
|
| 895 |
+
"\n",
|
| 896 |
+
"print(\"\\nEvaluation complete. The confusion matrix plot has been saved.\")"
|
| 897 |
+
]
|
| 898 |
+
},
|
| 899 |
+
{
|
| 900 |
+
"cell_type": "markdown",
|
| 901 |
+
"id": "2325adef",
|
| 902 |
+
"metadata": {},
|
| 903 |
+
"source": [
|
| 904 |
+
"## 6. Local Gradio Demo"
|
| 905 |
+
]
|
| 906 |
+
},
|
| 907 |
+
{
|
| 908 |
+
"cell_type": "code",
|
| 909 |
+
"execution_count": null,
|
| 910 |
+
"id": "44decdea",
|
| 911 |
+
"metadata": {},
|
| 912 |
+
"outputs": [],
|
| 913 |
+
"source": [
|
| 914 |
+
"FOOD101_CLASSES = [\n",
|
| 915 |
+
" 'apple_pie', 'baby_back_ribs', 'baklava', 'beef_carpaccio', 'beef_tartare', \n",
|
| 916 |
+
" 'beet_salad', 'beignets', 'bibimbap', 'bread_pudding', 'breakfast_burrito', \n",
|
| 917 |
+
" 'bruschetta', 'caesar_salad', 'cannoli', 'caprese_salad', 'carrot_cake', \n",
|
| 918 |
+
" 'ceviche', 'cheesecake', 'cheese_plate', 'chicken_curry', 'chicken_quesadilla', \n",
|
| 919 |
+
" 'chicken_wings', 'chocolate_cake', 'chocolate_mousse', 'churros', \n",
|
| 920 |
+
" 'clam_chowder', 'club_sandwich', 'crab_cakes', 'creme_brulee', 'croque_madame', \n",
|
| 921 |
+
" 'cup_cakes', 'deviled_eggs', 'donuts', 'dumplings', 'edamame', 'eggs_benedict', \n",
|
| 922 |
+
" 'escargots', 'falafel', 'filet_mignon', 'fish_and_chips', 'foie_gras', \n",
|
| 923 |
+
" 'french_fries', 'french_onion_soup', 'french_toast', 'fried_calamari', \n",
|
| 924 |
+
" 'fried_rice', 'frozen_yogurt', 'garlic_bread', 'gnocchi', 'greek_salad', \n",
|
| 925 |
+
" 'grilled_cheese_sandwich', 'grilled_salmon', 'guacamole', 'gyoza', 'hamburger', \n",
|
| 926 |
+
" 'hot_and_sour_soup', 'hot_dog', 'huevos_rancheros', 'hummus', 'ice_cream', \n",
|
| 927 |
+
" 'lasagna', 'lobster_bisque', 'lobster_roll_sandwich', 'macaroni_and_cheese', \n",
|
| 928 |
+
" 'macarons', 'miso_soup', 'mussels', 'nachos', 'omelette', 'onion_rings', \n",
|
| 929 |
+
" 'oysters', 'pad_thai', 'paella', 'pancakes', 'panna_cotta', 'peking_duck', \n",
|
| 930 |
+
" 'pho', 'pizza', 'pork_chop', 'poutine', 'prime_rib', 'pulled_pork_sandwich', \n",
|
| 931 |
+
" 'ramen', 'ravioli', 'red_velvet_cake', 'risotto', 'samosa', 'sashimi', \n",
|
| 932 |
+
" 'scallops', 'seaweed_salad', 'shrimp_and_grits', 'spaghetti_bolognese', \n",
|
| 933 |
+
" 'spaghetti_carbonara', 'spring_rolls', 'steak', 'strawberry_shortcake', \n",
|
| 934 |
+
" 'sushi', 'tacos', 'takoyaki', 'tiramisu', 'tuna_tartare', 'waffles'\n",
|
| 935 |
+
"]"
|
| 936 |
+
]
|
| 937 |
+
},
|
| 938 |
+
{
|
| 939 |
+
"cell_type": "code",
|
| 940 |
+
"execution_count": null,
|
| 941 |
+
"id": "10bdf9fd",
|
| 942 |
+
"metadata": {},
|
| 943 |
+
"outputs": [],
|
| 944 |
+
"source": [
|
| 945 |
+
"import gradio as gr\n",
|
| 946 |
+
"import torch\n",
|
| 947 |
+
"from gradio.themes.base import Base\n",
|
| 948 |
+
"from torchvision.datasets import Food101\n",
|
| 949 |
+
"\n",
|
| 950 |
+
"# --- 1. Configuration ---\n",
|
| 951 |
+
"MODEL_PATH = \"checkpoints/best-model-epoch=22-val_acc=0.8541.ckpt\" \n",
|
| 952 |
+
"MODEL_NAME = \"EfficientNet_V2_S\"\n",
|
| 953 |
+
"\n",
|
| 954 |
+
"theme = gr.themes.Soft(\n",
|
| 955 |
+
" primary_hue=\"orange\",\n",
|
| 956 |
+
" secondary_hue=\"blue\",\n",
|
| 957 |
+
").set(\n",
|
| 958 |
+
"\n",
|
| 959 |
+
" body_background_fill=\"#f2f2f2\"\n",
|
| 960 |
+
")\n",
|
| 961 |
+
"\n",
|
| 962 |
+
"# --- 2. Load Model and Assets ---\n",
|
| 963 |
+
"print(\"Loading model and assets...\")\n",
|
| 964 |
+
"model = EffNetV2_S.load_from_checkpoint(MODEL_PATH)\n",
|
| 965 |
+
"model.eval()\n",
|
| 966 |
+
"\n",
|
| 967 |
+
"components = get_model_components(MODEL_NAME)\n",
|
| 968 |
+
"transforms = components[\"val_transforms\"]\n",
|
| 969 |
+
"class_names = FOOD101_CLASSES \n",
|
| 970 |
+
"\n",
|
| 971 |
+
"print(\"Model and assets loaded successfully.\")\n",
|
| 972 |
+
"\n",
|
| 973 |
+
"# --- 3. Prediction Function ---\n",
|
| 974 |
+
"def predict(image):\n",
|
| 975 |
+
" \"\"\"\n",
|
| 976 |
+
" Takes a PIL image, preprocesses it, and returns the model's top 3 predictions.\n",
|
| 977 |
+
" \"\"\"\n",
|
| 978 |
+
" # 1. Preprocess the image and add a batch dimension\n",
|
| 979 |
+
" input_tensor = transforms(image).unsqueeze(0)\n",
|
| 980 |
+
" \n",
|
| 981 |
+
" # 2. Move the input tensor to the same device as the model\n",
|
| 982 |
+
" # This ensures both the model and the data are on the GPU.\n",
|
| 983 |
+
" device = next(model.parameters()).device\n",
|
| 984 |
+
" input_tensor = input_tensor.to(device)\n",
|
| 985 |
+
" \n",
|
| 986 |
+
" # 3. Make a prediction\n",
|
| 987 |
+
" with torch.no_grad():\n",
|
| 988 |
+
" output = model(input_tensor)\n",
|
| 989 |
+
" \n",
|
| 990 |
+
" # 4. Post-process the output\n",
|
| 991 |
+
" probabilities = torch.nn.functional.softmax(output[0], dim=0)\n",
|
| 992 |
+
" confidences = {class_names[i]: float(probabilities[i]) for i in range(len(class_names))}\n",
|
| 993 |
+
" \n",
|
| 994 |
+
" return confidences\n",
|
| 995 |
+
" \n",
|
| 996 |
+
"\n",
|
| 997 |
+
"demo = gr.Interface(\n",
|
| 998 |
+
" fn=predict,\n",
|
| 999 |
+
" inputs=gr.Image(type=\"pil\", label=\"Upload a Food Image\"),\n",
|
| 1000 |
+
" outputs=gr.Label(num_top_classes=3, label=\"Top Predictions\"),\n",
|
| 1001 |
+
" theme=theme,\n",
|
| 1002 |
+
" \n",
|
| 1003 |
+
" # UI Enhancements\n",
|
| 1004 |
+
" title=\"🍔 Food-101 Image Classifier 🍟\",\n",
|
| 1005 |
+
" description=(\n",
|
| 1006 |
+
" \"What's on your plate? Upload an image or try one of the examples below to classify it. \"\n",
|
| 1007 |
+
" \"This demo uses an EfficientNetV2-S model fine-tuned on the Food-101 dataset.\"\n",
|
| 1008 |
+
" ),\n",
|
| 1009 |
+
" article=(\n",
|
| 1010 |
+
" \"<p style='text-align: center;'>A project by Daniel Kiani. \"\n",
|
| 1011 |
+
" \"<a href='https://github.com/Deathshot78/Food101-Classification' target='_blank'>Check out the code on GitHub!</a></p>\"\n",
|
| 1012 |
+
" ),\n",
|
| 1013 |
+
" examples=[\n",
|
| 1014 |
+
" [\"assets/ramen.jpg\"],\n",
|
| 1015 |
+
" [\"assets/pizza.jpg\"],\n",
|
| 1016 |
+
" [\"assets/oysters.jpg\"],\n",
|
| 1017 |
+
" [\"assets/onion_rings.jpg\"]\n",
|
| 1018 |
+
" ]\n",
|
| 1019 |
+
")"
|
| 1020 |
+
]
|
| 1021 |
+
},
|
| 1022 |
+
{
|
| 1023 |
+
"cell_type": "code",
|
| 1024 |
+
"execution_count": null,
|
| 1025 |
+
"id": "b536610d",
|
| 1026 |
+
"metadata": {},
|
| 1027 |
+
"outputs": [],
|
| 1028 |
+
"source": [
|
| 1029 |
+
"# Launch the Gradio app locally\n",
|
| 1030 |
+
"demo.launch()"
|
| 1031 |
+
]
|
| 1032 |
+
}
|
| 1033 |
+
],
|
| 1034 |
+
"metadata": {
|
| 1035 |
+
"kernelspec": {
|
| 1036 |
+
"display_name": "Python 3",
|
| 1037 |
+
"language": "python",
|
| 1038 |
+
"name": "python3"
|
| 1039 |
+
},
|
| 1040 |
+
"language_info": {
|
| 1041 |
+
"codemirror_mode": {
|
| 1042 |
+
"name": "ipython",
|
| 1043 |
+
"version": 3
|
| 1044 |
+
},
|
| 1045 |
+
"file_extension": ".py",
|
| 1046 |
+
"mimetype": "text/x-python",
|
| 1047 |
+
"name": "python",
|
| 1048 |
+
"nbconvert_exporter": "python",
|
| 1049 |
+
"pygments_lexer": "ipython3",
|
| 1050 |
+
"version": "3.10.6"
|
| 1051 |
+
}
|
| 1052 |
+
},
|
| 1053 |
+
"nbformat": 4,
|
| 1054 |
+
"nbformat_minor": 5
|
| 1055 |
+
}
|
requirements.txt
ADDED
|
Binary file (320 Bytes). View file
|
|
|
scripts/app.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import torch
|
| 3 |
+
from gradio.themes.base import Base
|
| 4 |
+
from torchvision.datasets import Food101
|
| 5 |
+
from models import EffNetV2_S
|
| 6 |
+
from prepare_data import get_model_components
|
| 7 |
+
from class_names import FOOD101_CLASSES
|
| 8 |
+
|
| 9 |
+
# --- 1. Configuration ---
|
| 10 |
+
MODEL_PATH = "checkpoints/best-model-epoch=22-val_acc=0.8541.ckpt"
|
| 11 |
+
MODEL_NAME = "EfficientNet_V2_S"
|
| 12 |
+
|
| 13 |
+
theme = gr.themes.Soft(
|
| 14 |
+
primary_hue="orange",
|
| 15 |
+
secondary_hue="blue",
|
| 16 |
+
).set(
|
| 17 |
+
|
| 18 |
+
body_background_fill="#f2f2f2"
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
# --- 2. Load Model and Assets ---
|
| 22 |
+
print("Loading model and assets...")
|
| 23 |
+
model = EffNetV2_S.load_from_checkpoint(MODEL_PATH)
|
| 24 |
+
model.eval()
|
| 25 |
+
|
| 26 |
+
components = get_model_components(MODEL_NAME)
|
| 27 |
+
transforms = components["val_transforms"]
|
| 28 |
+
class_names = FOOD101_CLASSES
|
| 29 |
+
|
| 30 |
+
print("Model and assets loaded successfully.")
|
| 31 |
+
|
| 32 |
+
# --- 3. Prediction Function ---
|
| 33 |
+
def predict(image):
|
| 34 |
+
"""
|
| 35 |
+
Takes a PIL image, preprocesses it, and returns the model's top 3 predictions.
|
| 36 |
+
"""
|
| 37 |
+
# 1. Preprocess the image and add a batch dimension
|
| 38 |
+
input_tensor = transforms(image).unsqueeze(0)
|
| 39 |
+
|
| 40 |
+
# 2. Move the input tensor to the same device as the model
|
| 41 |
+
# This ensures both the model and the data are on the GPU.
|
| 42 |
+
device = next(model.parameters()).device
|
| 43 |
+
input_tensor = input_tensor.to(device)
|
| 44 |
+
|
| 45 |
+
# 3. Make a prediction
|
| 46 |
+
with torch.no_grad():
|
| 47 |
+
output = model(input_tensor)
|
| 48 |
+
|
| 49 |
+
# 4. Post-process the output
|
| 50 |
+
probabilities = torch.nn.functional.softmax(output[0], dim=0)
|
| 51 |
+
confidences = {class_names[i]: float(probabilities[i]) for i in range(len(class_names))}
|
| 52 |
+
|
| 53 |
+
return confidences
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
demo = gr.Interface(
|
| 57 |
+
fn=predict,
|
| 58 |
+
inputs=gr.Image(type="pil", label="Upload a Food Image"),
|
| 59 |
+
outputs=gr.Label(num_top_classes=3, label="Top Predictions"),
|
| 60 |
+
theme=theme,
|
| 61 |
+
|
| 62 |
+
# UI Enhancements
|
| 63 |
+
title="🍔 Food-101 Image Classifier 🍟",
|
| 64 |
+
description=(
|
| 65 |
+
"What's on your plate? Upload an image or try one of the examples below to classify it. "
|
| 66 |
+
"This demo uses an EfficientNetV2-S model fine-tuned on the Food-101 dataset."
|
| 67 |
+
),
|
| 68 |
+
article=(
|
| 69 |
+
"<p style='text-align: center;'>A project by Daniel Kiani. "
|
| 70 |
+
"<a href='https://github.com/Deathshot78/Food101-Classification' target='_blank'>Check out the code on GitHub!</a></p>"
|
| 71 |
+
),
|
| 72 |
+
examples=[
|
| 73 |
+
["assets/ramen.jpg"],
|
| 74 |
+
["assets/pizza.jpg"],
|
| 75 |
+
["assets/oysters.jpg"],
|
| 76 |
+
["assets/onion_rings.jpg"]
|
| 77 |
+
]
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
# --- 5. Launch the App ---
|
| 81 |
+
if __name__ == "__main__":
|
| 82 |
+
demo.launch(debug=True)
|
scripts/class_names.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FOOD101_CLASSES = [
|
| 2 |
+
'apple_pie', 'baby_back_ribs', 'baklava', 'beef_carpaccio', 'beef_tartare',
|
| 3 |
+
'beet_salad', 'beignets', 'bibimbap', 'bread_pudding', 'breakfast_burrito',
|
| 4 |
+
'bruschetta', 'caesar_salad', 'cannoli', 'caprese_salad', 'carrot_cake',
|
| 5 |
+
'ceviche', 'cheesecake', 'cheese_plate', 'chicken_curry', 'chicken_quesadilla',
|
| 6 |
+
'chicken_wings', 'chocolate_cake', 'chocolate_mousse', 'churros',
|
| 7 |
+
'clam_chowder', 'club_sandwich', 'crab_cakes', 'creme_brulee', 'croque_madame',
|
| 8 |
+
'cup_cakes', 'deviled_eggs', 'donuts', 'dumplings', 'edamame', 'eggs_benedict',
|
| 9 |
+
'escargots', 'falafel', 'filet_mignon', 'fish_and_chips', 'foie_gras',
|
| 10 |
+
'french_fries', 'french_onion_soup', 'french_toast', 'fried_calamari',
|
| 11 |
+
'fried_rice', 'frozen_yogurt', 'garlic_bread', 'gnocchi', 'greek_salad',
|
| 12 |
+
'grilled_cheese_sandwich', 'grilled_salmon', 'guacamole', 'gyoza', 'hamburger',
|
| 13 |
+
'hot_and_sour_soup', 'hot_dog', 'huevos_rancheros', 'hummus', 'ice_cream',
|
| 14 |
+
'lasagna', 'lobster_bisque', 'lobster_roll_sandwich', 'macaroni_and_cheese',
|
| 15 |
+
'macarons', 'miso_soup', 'mussels', 'nachos', 'omelette', 'onion_rings',
|
| 16 |
+
'oysters', 'pad_thai', 'paella', 'pancakes', 'panna_cotta', 'peking_duck',
|
| 17 |
+
'pho', 'pizza', 'pork_chop', 'poutine', 'prime_rib', 'pulled_pork_sandwich',
|
| 18 |
+
'ramen', 'ravioli', 'red_velvet_cake', 'risotto', 'samosa', 'sashimi',
|
| 19 |
+
'scallops', 'seaweed_salad', 'shrimp_and_grits', 'spaghetti_bolognese',
|
| 20 |
+
'spaghetti_carbonara', 'spring_rolls', 'steak', 'strawberry_shortcake',
|
| 21 |
+
'sushi', 'tacos', 'takoyaki', 'tiramisu', 'tuna_tartare', 'waffles'
|
| 22 |
+
]
|
scripts/main.py
ADDED
|
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from prepare_data import Food101DataModule, CustomFood101, get_model_components
|
| 2 |
+
from models import EffNetV2_S , EffNetb2
|
| 3 |
+
import pytorch_lightning as pl
|
| 4 |
+
from pytorch_lightning import Trainer, LightningModule
|
| 5 |
+
from pytorch_lightning.loggers import CSVLogger
|
| 6 |
+
from pytorch_lightning.callbacks import EarlyStopping ,ModelCheckpoint
|
| 7 |
+
from typing import Optional
|
| 8 |
+
import matplotlib.pyplot as plt
|
| 9 |
+
import seaborn as sns
|
| 10 |
+
import numpy as np
|
| 11 |
+
import pandas as pd
|
| 12 |
+
from typing import List
|
| 13 |
+
|
| 14 |
+
DATA_DIR = "data"
|
| 15 |
+
MODEL_NAME = "EfficientNet_V2_S"
|
| 16 |
+
BATCH_SIZE = 32
|
| 17 |
+
SUBSET_FRACTION = 0.2 # Useing a smaller subset for quick testing
|
| 18 |
+
CHECKPOINT_PATH = "checkpoints/best-model-epoch=22-val_acc=0.8541.ckpt" # Path to your trained model checkpoint
|
| 19 |
+
|
| 20 |
+
def plot_confusion_matrix(cm: np.ndarray, class_names: List[str], figsize: tuple = (25, 25)):
|
| 21 |
+
"""
|
| 22 |
+
Creates and saves a multi-class confusion matrix plot.
|
| 23 |
+
|
| 24 |
+
This function normalizes the confusion matrix to show prediction
|
| 25 |
+
percentages for each class, visualizes it as a heatmap, and saves
|
| 26 |
+
the resulting figure to a file.
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
cm (np.ndarray): The confusion matrix from torchmetrics or scikit-learn.
|
| 30 |
+
class_names (List[str]): A list of class names for the labels.
|
| 31 |
+
figsize (tuple, optional): The size of the figure. Defaults to (25, 25).
|
| 32 |
+
"""
|
| 33 |
+
# 1. Normalize the confusion matrix to show percentages
|
| 34 |
+
# Add a small epsilon to prevent division by zero
|
| 35 |
+
cm_normalized = cm.astype('float') / (cm.sum(axis=1)[:, np.newaxis] + 1e-6)
|
| 36 |
+
|
| 37 |
+
# 2. Create a DataFrame for a beautiful plot with labels
|
| 38 |
+
df_cm = pd.DataFrame(cm_normalized, index=class_names, columns=class_names)
|
| 39 |
+
|
| 40 |
+
# 3. Create the plot
|
| 41 |
+
plt.figure(figsize=figsize)
|
| 42 |
+
heatmap = sns.heatmap(df_cm, annot=False, cmap='Blues') # Annotations off for 101 classes
|
| 43 |
+
|
| 44 |
+
# 4. Format the plot
|
| 45 |
+
heatmap.yaxis.set_ticklabels(heatmap.yaxis.get_ticklabels(), rotation=0, ha='right', fontsize=8)
|
| 46 |
+
heatmap.xaxis.set_ticklabels(heatmap.xaxis.get_ticklabels(), rotation=45, ha='right', fontsize=8)
|
| 47 |
+
|
| 48 |
+
plt.ylabel('True Label')
|
| 49 |
+
plt.xlabel('Predicted Label')
|
| 50 |
+
plt.title('Normalized Confusion Matrix')
|
| 51 |
+
plt.tight_layout()
|
| 52 |
+
|
| 53 |
+
# 5. Save the figure and show the plot
|
| 54 |
+
plt.savefig('confusion_matrix.png', dpi=300)
|
| 55 |
+
print("Confusion matrix plot saved to confusion_matrix.png")
|
| 56 |
+
plt.show()
|
| 57 |
+
|
| 58 |
+
def run_training_session(
|
| 59 |
+
model_name: str = "EfficientNet_V2_S",
|
| 60 |
+
batch_size: int = 32,
|
| 61 |
+
data_dir: str = 'data',
|
| 62 |
+
subset_fraction: float = 1.0,
|
| 63 |
+
checkpoint_path: str = "checkpoints/",
|
| 64 |
+
lr: float = 1e-3,
|
| 65 |
+
weight_decay: float = 1e-4,
|
| 66 |
+
freeze_features: bool = True,
|
| 67 |
+
early_stopping_patience: int = 5,
|
| 68 |
+
max_epochs: int = 100,
|
| 69 |
+
accelerator: str = 'auto',
|
| 70 |
+
resume_from_checkpoint: Optional[str] = None
|
| 71 |
+
) -> Trainer:
|
| 72 |
+
"""
|
| 73 |
+
Sets up and runs a complete training session for a specified model.
|
| 74 |
+
|
| 75 |
+
This function handles the entire pipeline: data preparation, model
|
| 76 |
+
instantiation, logger and callback setup, and trainer execution.
|
| 77 |
+
|
| 78 |
+
Args:
|
| 79 |
+
model_name (str): The name of the model architecture to train.
|
| 80 |
+
batch_size (int): The number of samples per batch.
|
| 81 |
+
data_dir (str): The root directory for the dataset.
|
| 82 |
+
subset_fraction (float): The fraction of the dataset to use for training.
|
| 83 |
+
checkpoint_path (str): Directory to save model checkpoints.
|
| 84 |
+
lr (float): The learning rate for the optimizer.
|
| 85 |
+
weight_decay (float): The weight decay for the optimizer.
|
| 86 |
+
freeze_features (bool): Flag to control the fine-tuning strategy
|
| 87 |
+
(e.g., for two-stage training).
|
| 88 |
+
early_stopping_patience (int): Number of epochs with no improvement
|
| 89 |
+
after which training will be stopped.
|
| 90 |
+
max_epochs (int): The maximum number of epochs to train for.
|
| 91 |
+
accelerator (str): The hardware accelerator to use ('auto', 'cpu', 'gpu').
|
| 92 |
+
resume_from_checkpoint (Optional[str]): Path to a checkpoint file to
|
| 93 |
+
resume training from. Defaults to None.
|
| 94 |
+
|
| 95 |
+
Returns:
|
| 96 |
+
Trainer: The PyTorch Lightning Trainer object after fitting is complete.
|
| 97 |
+
"""
|
| 98 |
+
# A registry to map model names to their actual classes
|
| 99 |
+
model_class_registry = {
|
| 100 |
+
"EfficientNet_V2_S": EffNetV2_S,
|
| 101 |
+
"EfficientNet_B2": EffNetb2,
|
| 102 |
+
}
|
| 103 |
+
if model_name not in model_class_registry:
|
| 104 |
+
raise ValueError(f"Model '{model_name}' is not a recognized class.")
|
| 105 |
+
|
| 106 |
+
# Get model-specific transforms
|
| 107 |
+
components = get_model_components(model_name)
|
| 108 |
+
train_transforms = components["train_transforms"]
|
| 109 |
+
val_transforms = components["val_transforms"]
|
| 110 |
+
|
| 111 |
+
# Set up the DataModule
|
| 112 |
+
food_datamodule = Food101DataModule(
|
| 113 |
+
data_dir=data_dir,
|
| 114 |
+
batch_size=batch_size,
|
| 115 |
+
train_transforms=train_transforms,
|
| 116 |
+
val_transforms=val_transforms,
|
| 117 |
+
subset_fraction=subset_fraction
|
| 118 |
+
)
|
| 119 |
+
food_datamodule.prepare_data()
|
| 120 |
+
food_datamodule.setup()
|
| 121 |
+
|
| 122 |
+
# Instantiate the model dynamically
|
| 123 |
+
model_class = model_class_registry[model_name]
|
| 124 |
+
model = model_class(
|
| 125 |
+
num_classes=len(food_datamodule.classes),
|
| 126 |
+
class_names=food_datamodule.classes,
|
| 127 |
+
lr=lr,
|
| 128 |
+
weight_decay=weight_decay,
|
| 129 |
+
freeze_features=freeze_features
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
# Set up logger and callbacks
|
| 133 |
+
logger = CSVLogger(save_dir="logs/", name=model_name)
|
| 134 |
+
|
| 135 |
+
early_stop_callback = EarlyStopping(
|
| 136 |
+
monitor="val_loss",
|
| 137 |
+
patience=early_stopping_patience,
|
| 138 |
+
mode="min"
|
| 139 |
+
)
|
| 140 |
+
best_model_checkpoint = ModelCheckpoint(
|
| 141 |
+
dirpath=checkpoint_path,
|
| 142 |
+
filename="best-model-{epoch:02d}-{val_acc:.4f}",
|
| 143 |
+
save_top_k=1,
|
| 144 |
+
monitor="val_acc",
|
| 145 |
+
mode="max"
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
callbacks = [early_stop_callback, best_model_checkpoint]
|
| 149 |
+
|
| 150 |
+
# Instantiate the Trainer
|
| 151 |
+
trainer = Trainer(
|
| 152 |
+
max_epochs=max_epochs,
|
| 153 |
+
accelerator=accelerator,
|
| 154 |
+
callbacks=callbacks,
|
| 155 |
+
logger=logger,
|
| 156 |
+
)
|
| 157 |
+
|
| 158 |
+
# Start training
|
| 159 |
+
trainer.fit(
|
| 160 |
+
model,
|
| 161 |
+
datamodule=food_datamodule,
|
| 162 |
+
ckpt_path=resume_from_checkpoint
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
return trainer
|
| 166 |
+
|
| 167 |
+
# ===================================================================
|
| 168 |
+
# Main Execution Block
|
| 169 |
+
# ===================================================================
|
| 170 |
+
if __name__ == "__main__":
|
| 171 |
+
|
| 172 |
+
# --- 1. DEFINE YOUR TRAINING CONFIGURATION HERE ---
|
| 173 |
+
config = {
|
| 174 |
+
"model_name": "EfficientNet_V2_S",
|
| 175 |
+
"batch_size": 32,
|
| 176 |
+
"lr": 1e-4,
|
| 177 |
+
"epochs": 50,
|
| 178 |
+
"subset_fraction": 1.0, # Use 1.0 for the full dataset
|
| 179 |
+
"freeze_features": True,
|
| 180 |
+
"early_stopping_patience": 10
|
| 181 |
+
}
|
| 182 |
+
|
| 183 |
+
# --- 2. PRINT CONFIGURATION AND START TRAINING ---
|
| 184 |
+
print("--- Starting Training Session ---")
|
| 185 |
+
for key, value in config.items():
|
| 186 |
+
print(f" {key}: {value}")
|
| 187 |
+
print("---------------------------------")
|
| 188 |
+
|
| 189 |
+
run_training_session(
|
| 190 |
+
model_name=config["model_name"],
|
| 191 |
+
batch_size=config["batch_size"],
|
| 192 |
+
lr=config["lr"],
|
| 193 |
+
max_epochs=config["epochs"],
|
| 194 |
+
subset_fraction=config["subset_fraction"],
|
| 195 |
+
freeze_features=config["freeze_features"],
|
| 196 |
+
early_stopping_patience=config["early_stopping_patience"]
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
print("\n--- Training Session Complete ---")
|
| 200 |
+
|
| 201 |
+
print("\n--- Starting Evaluation on Test Set ---")
|
| 202 |
+
|
| 203 |
+
print(f"Loading model from checkpoint: {CHECKPOINT_PATH}")
|
| 204 |
+
|
| 205 |
+
# Step 1: Set up the DataModule for the test set
|
| 206 |
+
components = get_model_components(MODEL_NAME)
|
| 207 |
+
val_transforms = components["val_transforms"]
|
| 208 |
+
|
| 209 |
+
datamodule = Food101DataModule(
|
| 210 |
+
data_dir=DATA_DIR,
|
| 211 |
+
batch_size=BATCH_SIZE,
|
| 212 |
+
val_transforms=val_transforms
|
| 213 |
+
)
|
| 214 |
+
# This prepares the test dataloader specifically
|
| 215 |
+
datamodule.setup(stage='test')
|
| 216 |
+
|
| 217 |
+
# Step 2: Load the trained model from the checkpoint file
|
| 218 |
+
model = EffNetV2_S.load_from_checkpoint(CHECKPOINT_PATH)
|
| 219 |
+
model.class_names = datamodule.classes
|
| 220 |
+
model.eval() # Set the model to evaluation mode
|
| 221 |
+
|
| 222 |
+
# Step 3: Create a Trainer and run the test
|
| 223 |
+
trainer = pl.Trainer(accelerator='auto')
|
| 224 |
+
|
| 225 |
+
# This call will run the test_step and automatically trigger the
|
| 226 |
+
# on_test_end hook in your model, which generates the plot.
|
| 227 |
+
trainer.test(model, datamodule=datamodule)
|
| 228 |
+
|
| 229 |
+
print("\nEvaluation complete. The confusion matrix plot has been saved.")
|
scripts/models.py
ADDED
|
@@ -0,0 +1,266 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torchvision
|
| 3 |
+
import pytorch_lightning as pl
|
| 4 |
+
from torch import nn
|
| 5 |
+
from torchmetrics.classification import Accuracy, F1Score, ConfusionMatrix
|
| 6 |
+
import seaborn as sns
|
| 7 |
+
import matplotlib.pyplot as plt
|
| 8 |
+
import pandas as pd
|
| 9 |
+
import numpy as np
|
| 10 |
+
|
| 11 |
+
class EffNetV2_S(pl.LightningModule):
|
| 12 |
+
"""A PyTorch Lightning Module for fine-tuning EfficientNetV2-S.
|
| 13 |
+
|
| 14 |
+
This module encapsulates the EfficientNetV2-S model and provides a flexible
|
| 15 |
+
fine-tuning strategy. It can be configured for Stage 1 (training only the
|
| 16 |
+
classifier and later feature blocks) or Stage 2 (training the entire model).
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
lr (float, optional): The learning rate. Defaults to 1e-3.
|
| 20 |
+
weight_decay (float, optional): Weight decay for the optimizer. Defaults to 1e-4.
|
| 21 |
+
num_classes (int, optional): The number of output classes. Defaults to 101.
|
| 22 |
+
class_names (list, optional): A list of class names for logging. Defaults to None.
|
| 23 |
+
freeze_features (bool, optional): If True, freezes the backbone and unfreezes
|
| 24 |
+
only the later blocks (Stage 1). If False, all features are trainable
|
| 25 |
+
(Stage 2). Defaults to True.
|
| 26 |
+
unfreeze_from_block (int, optional): Which feature block to start unfreezing
|
| 27 |
+
from. Used only if freeze_features is True. Defaults to -3 (last 3 blocks).
|
| 28 |
+
"""
|
| 29 |
+
|
| 30 |
+
def __init__(
|
| 31 |
+
self,
|
| 32 |
+
lr: float = 1e-3,
|
| 33 |
+
weight_decay: float = 1e-4,
|
| 34 |
+
num_classes: int = 101,
|
| 35 |
+
class_names: list = None,
|
| 36 |
+
freeze_features: bool = True, # True = Stage 1, False = Stage 2
|
| 37 |
+
unfreeze_from_block: int = -3 # Only used if freeze_features=True
|
| 38 |
+
):
|
| 39 |
+
super().__init__()
|
| 40 |
+
self.save_hyperparameters()
|
| 41 |
+
self.class_names = class_names if class_names else [str(i) for i in range(num_classes)]
|
| 42 |
+
|
| 43 |
+
# Load pretrained weights
|
| 44 |
+
weights = torchvision.models.EfficientNet_V2_S_Weights.DEFAULT
|
| 45 |
+
self.model = torchvision.models.efficientnet_v2_s(weights=weights)
|
| 46 |
+
|
| 47 |
+
# ---- Freezing strategy ----
|
| 48 |
+
if freeze_features:
|
| 49 |
+
# Freeze all first
|
| 50 |
+
for param in self.model.parameters():
|
| 51 |
+
param.requires_grad = False
|
| 52 |
+
# Unfreeze from a specific block (default: last 3 blocks)
|
| 53 |
+
for param in self.model.features[unfreeze_from_block:].parameters():
|
| 54 |
+
param.requires_grad = True
|
| 55 |
+
else:
|
| 56 |
+
# Stage 2: unfreeze everything
|
| 57 |
+
for param in self.model.parameters():
|
| 58 |
+
param.requires_grad = True
|
| 59 |
+
|
| 60 |
+
# Classifier head
|
| 61 |
+
self.model.classifier = nn.Sequential(
|
| 62 |
+
nn.Dropout(p=0.2, inplace=True),
|
| 63 |
+
nn.Linear(in_features=1280, out_features=self.hparams.num_classes, bias=True)
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
# Loss & metrics
|
| 67 |
+
self.loss_fn = nn.CrossEntropyLoss(label_smoothing=0.1)
|
| 68 |
+
self.train_accuracy = Accuracy(task="multiclass", num_classes=self.hparams.num_classes)
|
| 69 |
+
self.val_accuracy = Accuracy(task="multiclass", num_classes=self.hparams.num_classes)
|
| 70 |
+
self.train_f1 = F1Score(task="multiclass", num_classes=self.hparams.num_classes, average='macro')
|
| 71 |
+
self.val_f1 = F1Score(task="multiclass", num_classes=self.hparams.num_classes, average='macro')
|
| 72 |
+
self.val_conf_matrix = ConfusionMatrix(task="multiclass", num_classes=self.hparams.num_classes)
|
| 73 |
+
self.test_conf_matrix = ConfusionMatrix(task="multiclass", num_classes=self.hparams.num_classes)
|
| 74 |
+
|
| 75 |
+
def forward(self, x):
|
| 76 |
+
return self.model(x)
|
| 77 |
+
|
| 78 |
+
def training_step(self, batch, batch_idx):
|
| 79 |
+
x, y = batch
|
| 80 |
+
logits = self(x)
|
| 81 |
+
loss = self.loss_fn(logits, y)
|
| 82 |
+
self.train_accuracy(logits, y)
|
| 83 |
+
self.train_f1(logits, y)
|
| 84 |
+
self.log('train_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
|
| 85 |
+
self.log('train_acc', self.train_accuracy, on_step=False, on_epoch=True, prog_bar=True)
|
| 86 |
+
self.log('train_f1', self.train_f1, on_step=False, on_epoch=True, prog_bar=True)
|
| 87 |
+
return loss
|
| 88 |
+
|
| 89 |
+
def validation_step(self, batch, batch_idx):
|
| 90 |
+
x, y = batch
|
| 91 |
+
logits = self(x)
|
| 92 |
+
loss = self.loss_fn(logits, y)
|
| 93 |
+
self.val_accuracy(logits, y)
|
| 94 |
+
self.val_f1(logits, y)
|
| 95 |
+
self.log('val_loss', loss, prog_bar=True)
|
| 96 |
+
self.log('val_acc', self.val_accuracy, prog_bar=True)
|
| 97 |
+
self.log('val_f1', self.val_f1, prog_bar=True)
|
| 98 |
+
self.val_conf_matrix.update(logits, y)
|
| 99 |
+
|
| 100 |
+
def on_validation_epoch_end(self):
|
| 101 |
+
cm = self.val_conf_matrix.compute()
|
| 102 |
+
per_class_acc = cm.diag() / (cm.sum(dim=1) + 1e-6)
|
| 103 |
+
print("\n--- Per-Class Validation Accuracy ---")
|
| 104 |
+
for i, acc in enumerate(per_class_acc):
|
| 105 |
+
self.log(f'val_acc/{self.class_names[i]}', acc.item(), on_epoch=True)
|
| 106 |
+
print(f"{self.class_names[i]:<20}: {acc.item():.4f}")
|
| 107 |
+
print("------------------------------------")
|
| 108 |
+
self.val_conf_matrix.reset()
|
| 109 |
+
|
| 110 |
+
def test_step(self, batch, batch_idx):
|
| 111 |
+
x, y = batch
|
| 112 |
+
logits = self(x)
|
| 113 |
+
self.test_conf_matrix.update(logits, y)
|
| 114 |
+
|
| 115 |
+
def on_test_end(self):
|
| 116 |
+
cm = self.test_conf_matrix.compute()
|
| 117 |
+
print("\nGenerating final confusion matrix plot...")
|
| 118 |
+
self.test_conf_matrix.reset()
|
| 119 |
+
|
| 120 |
+
def configure_optimizers(self):
|
| 121 |
+
optimizer = torch.optim.Adam(
|
| 122 |
+
self.parameters(),
|
| 123 |
+
lr=self.hparams.lr,
|
| 124 |
+
weight_decay=self.hparams.weight_decay
|
| 125 |
+
)
|
| 126 |
+
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
| 127 |
+
optimizer,
|
| 128 |
+
T_max=self.trainer.max_epochs,
|
| 129 |
+
eta_min=1e-6
|
| 130 |
+
)
|
| 131 |
+
return {"optimizer": optimizer, "lr_scheduler": {"scheduler": scheduler, "interval": "epoch"}}
|
| 132 |
+
|
| 133 |
+
class EffNetb2(pl.LightningModule):
|
| 134 |
+
"""A PyTorch Lightning Module for fine-tuning EfficientNet-B2.
|
| 135 |
+
|
| 136 |
+
This module encapsulates the EfficientNet-B2 model and provides a flexible
|
| 137 |
+
fine-tuning strategy. It can be configured for Stage 1 (training only the
|
| 138 |
+
classifier and later feature blocks) or Stage 2 (training the entire model).
|
| 139 |
+
|
| 140 |
+
Args:
|
| 141 |
+
lr (float, optional): The learning rate. Defaults to 1e-3.
|
| 142 |
+
weight_decay (float, optional): Weight decay for the optimizer. Defaults to 1e-4.
|
| 143 |
+
num_classes (int, optional): The number of output classes. Defaults to 101.
|
| 144 |
+
class_names (list, optional): A list of class names for logging. Defaults to None.
|
| 145 |
+
freeze_features (bool, optional): If True, freezes the backbone and unfreezes
|
| 146 |
+
only the later blocks (Stage 1). If False, all features are trainable
|
| 147 |
+
(Stage 2). Defaults to True.
|
| 148 |
+
unfreeze_from_block (int, optional): Which feature block to start unfreezing
|
| 149 |
+
from. Used only if freeze_features is True. Defaults to -3 (last 3 blocks).
|
| 150 |
+
"""
|
| 151 |
+
|
| 152 |
+
def __init__(
|
| 153 |
+
self,
|
| 154 |
+
lr: float = 1e-3,
|
| 155 |
+
weight_decay: float = 1e-4,
|
| 156 |
+
num_classes: int = 101,
|
| 157 |
+
class_names: list = None,
|
| 158 |
+
freeze_features: bool = True,
|
| 159 |
+
unfreeze_from_block: int = -3
|
| 160 |
+
):
|
| 161 |
+
super().__init__()
|
| 162 |
+
self.save_hyperparameters()
|
| 163 |
+
self.class_names = class_names if class_names is not None else [str(i) for i in range(num_classes)]
|
| 164 |
+
|
| 165 |
+
# Model setup
|
| 166 |
+
weights = torchvision.models.EfficientNet_B2_Weights.DEFAULT
|
| 167 |
+
self.model = torchvision.models.efficientnet_b2(weights=weights)
|
| 168 |
+
|
| 169 |
+
# --- : Flexible Freezing Strategy ---
|
| 170 |
+
if self.hparams.freeze_features:
|
| 171 |
+
# Stage 1: Freeze all first
|
| 172 |
+
for param in self.model.parameters():
|
| 173 |
+
param.requires_grad = False
|
| 174 |
+
# Unfreeze from a specific block (default: last 3 blocks)
|
| 175 |
+
for param in self.model.features[self.hparams.unfreeze_from_block:].parameters():
|
| 176 |
+
param.requires_grad = True
|
| 177 |
+
else:
|
| 178 |
+
# Stage 2: unfreeze everything
|
| 179 |
+
for param in self.model.parameters():
|
| 180 |
+
param.requires_grad = True
|
| 181 |
+
|
| 182 |
+
# Classifier head
|
| 183 |
+
self.model.classifier = nn.Sequential(
|
| 184 |
+
nn.Dropout(p=0.3, inplace=True),
|
| 185 |
+
nn.Linear(in_features=1408, out_features=self.hparams.num_classes)
|
| 186 |
+
)
|
| 187 |
+
|
| 188 |
+
# Metrics
|
| 189 |
+
self.loss_fn = nn.CrossEntropyLoss(label_smoothing=0.1)
|
| 190 |
+
self.train_accuracy = Accuracy(task="multiclass", num_classes=self.hparams.num_classes)
|
| 191 |
+
self.val_accuracy = Accuracy(task="multiclass", num_classes=self.hparams.num_classes)
|
| 192 |
+
self.train_f1 = F1Score(task="multiclass", num_classes=self.hparams.num_classes, average='macro')
|
| 193 |
+
self.val_f1 = F1Score(task="multiclass", num_classes=self.hparams.num_classes, average='macro')
|
| 194 |
+
self.val_conf_matrix = ConfusionMatrix(task="multiclass", num_classes=self.hparams.num_classes)
|
| 195 |
+
self.test_conf_matrix = ConfusionMatrix(task="multiclass", num_classes=self.hparams.num_classes)
|
| 196 |
+
|
| 197 |
+
def forward(self, x):
|
| 198 |
+
return self.model(x)
|
| 199 |
+
|
| 200 |
+
def training_step(self, batch, batch_idx):
|
| 201 |
+
x, y = batch
|
| 202 |
+
logits = self(x)
|
| 203 |
+
loss = self.loss_fn(logits, y)
|
| 204 |
+
self.train_accuracy(logits, y)
|
| 205 |
+
self.train_f1(logits, y)
|
| 206 |
+
self.log('train_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
|
| 207 |
+
self.log('train_acc', self.train_accuracy, on_step=False, on_epoch=True, prog_bar=True)
|
| 208 |
+
self.log('train_f1', self.train_f1, on_step=False, on_epoch=True, prog_bar=True)
|
| 209 |
+
return loss
|
| 210 |
+
|
| 211 |
+
def validation_step(self, batch, batch_idx):
|
| 212 |
+
x, y = batch
|
| 213 |
+
logits = self(x)
|
| 214 |
+
loss = self.loss_fn(logits, y)
|
| 215 |
+
self.val_accuracy(logits, y)
|
| 216 |
+
self.val_f1(logits, y)
|
| 217 |
+
self.log('val_loss', loss, prog_bar=True)
|
| 218 |
+
self.log('val_acc', self.val_accuracy, prog_bar=True)
|
| 219 |
+
self.log('val_f1', self.val_f1, prog_bar=True)
|
| 220 |
+
self.val_conf_matrix.update(logits, y)
|
| 221 |
+
|
| 222 |
+
def on_validation_epoch_end(self):
|
| 223 |
+
cm = self.val_conf_matrix.compute()
|
| 224 |
+
|
| 225 |
+
# Add a small epsilon (1e-6) to the denominator for numerical stability.
|
| 226 |
+
per_class_acc = cm.diag() / (cm.sum(dim=1) + 1e-6)
|
| 227 |
+
|
| 228 |
+
print("\n--- Per-Class Validation Accuracy ---")
|
| 229 |
+
for i, acc in enumerate(per_class_acc):
|
| 230 |
+
class_name = self.class_names[i]
|
| 231 |
+
self.log(f'val_acc/{class_name}', acc.item(), on_epoch=True)
|
| 232 |
+
print(f"{class_name:<20}: {acc.item():.4f}")
|
| 233 |
+
print("------------------------------------")
|
| 234 |
+
|
| 235 |
+
self.val_conf_matrix.reset()
|
| 236 |
+
|
| 237 |
+
def test_step(self, batch, batch_idx):
|
| 238 |
+
x, y = batch
|
| 239 |
+
logits = self(x)
|
| 240 |
+
self.test_conf_matrix.update(logits, y)
|
| 241 |
+
|
| 242 |
+
def on_test_end(self):
|
| 243 |
+
cm = self.test_conf_matrix.compute()
|
| 244 |
+
print("\nGenerating final confusion matrix plot...")
|
| 245 |
+
# Assuming plot_confusion_matrix is defined elsewhere
|
| 246 |
+
# plot_confusion_matrix(cm.cpu().numpy(), class_names=self.class_names)
|
| 247 |
+
self.test_conf_matrix.reset()
|
| 248 |
+
|
| 249 |
+
def configure_optimizers(self):
|
| 250 |
+
optimizer = torch.optim.Adam(
|
| 251 |
+
self.parameters(),
|
| 252 |
+
lr=self.hparams.lr,
|
| 253 |
+
weight_decay=self.hparams.weight_decay
|
| 254 |
+
)
|
| 255 |
+
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
|
| 256 |
+
optimizer,
|
| 257 |
+
T_max=self.trainer.max_epochs,
|
| 258 |
+
eta_min=1e-6
|
| 259 |
+
)
|
| 260 |
+
return {
|
| 261 |
+
"optimizer": optimizer,
|
| 262 |
+
"lr_scheduler": {
|
| 263 |
+
"scheduler": scheduler,
|
| 264 |
+
"interval": "epoch",
|
| 265 |
+
},
|
| 266 |
+
}
|
scripts/prepare_data.py
ADDED
|
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torchvision import datasets
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
import os
|
| 4 |
+
import pytorch_lightning as pl
|
| 5 |
+
from torch.utils.data import DataLoader, Subset
|
| 6 |
+
from torchvision import datasets
|
| 7 |
+
from torchvision import transforms as T
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torchvision
|
| 10 |
+
from torchvision.datasets import Food101
|
| 11 |
+
from torch.utils.data import DataLoader, Dataset
|
| 12 |
+
from typing import Dict, Tuple, Any
|
| 13 |
+
import random
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def get_model_components(
|
| 17 |
+
model_name: str,
|
| 18 |
+
return_classifier: bool = False,
|
| 19 |
+
augmentation_level: str = "default"
|
| 20 |
+
) -> Dict[str, Any]:
|
| 21 |
+
"""
|
| 22 |
+
Retrieves pre-trained model components from torchvision.
|
| 23 |
+
|
| 24 |
+
This function fetches the appropriate weights and transforms for a given
|
| 25 |
+
model. It supports different levels of training data augmentation.
|
| 26 |
+
|
| 27 |
+
Args:
|
| 28 |
+
model_name (str): The name of the model to get components for.
|
| 29 |
+
Supported models include "EfficientNet_V2_S" and "EfficientNet_B2".
|
| 30 |
+
return_classifier (bool, optional): If True, the model's classifier
|
| 31 |
+
head is also returned. Defaults to False.
|
| 32 |
+
augmentation_level (str, optional): The level of data augmentation to use
|
| 33 |
+
for the training set. Can be "default" or "strong".
|
| 34 |
+
Defaults to "default".
|
| 35 |
+
|
| 36 |
+
Returns:
|
| 37 |
+
Dict[str, Any]: A dictionary containing the requested components.
|
| 38 |
+
Always includes 'train_transforms' and 'val_transforms'.
|
| 39 |
+
Includes 'classifier' if return_classifier is True.
|
| 40 |
+
|
| 41 |
+
Raises:
|
| 42 |
+
ValueError: If model_name or augmentation_level is not supported.
|
| 43 |
+
"""
|
| 44 |
+
model_registry = {
|
| 45 |
+
"EfficientNet_V2_S": (
|
| 46 |
+
torchvision.models.efficientnet_v2_s,
|
| 47 |
+
torchvision.models.EfficientNet_V2_S_Weights.DEFAULT
|
| 48 |
+
),
|
| 49 |
+
"EfficientNet_B2": (
|
| 50 |
+
torchvision.models.efficientnet_b2,
|
| 51 |
+
torchvision.models.EfficientNet_B2_Weights.DEFAULT
|
| 52 |
+
)
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
if model_name not in model_registry:
|
| 56 |
+
raise ValueError(f"Model '{model_name}' is not supported. "
|
| 57 |
+
f"Supported models are: {list(model_registry.keys())}")
|
| 58 |
+
|
| 59 |
+
# 1. Look up the model and weights classes
|
| 60 |
+
model_class, weights_class = model_registry[model_name]
|
| 61 |
+
weights = weights_class
|
| 62 |
+
val_transforms = weights.transforms()
|
| 63 |
+
|
| 64 |
+
# 2. Create the training transforms based on the desired level
|
| 65 |
+
if augmentation_level == "default":
|
| 66 |
+
train_transforms = T.Compose([
|
| 67 |
+
T.TrivialAugmentWide(),
|
| 68 |
+
val_transforms # val_transforms includes ToTensor and Normalize
|
| 69 |
+
])
|
| 70 |
+
elif augmentation_level == "strong":
|
| 71 |
+
# Note: We don't need to add ToTensor() or Normalize() here because
|
| 72 |
+
# they are already included inside the 'val_transforms' pipeline.
|
| 73 |
+
train_transforms = T.Compose([
|
| 74 |
+
T.RandomResizedCrop(size=val_transforms.crop_size, scale=(0.7, 1.0)),
|
| 75 |
+
T.RandomHorizontalFlip(p=0.5),
|
| 76 |
+
T.RandAugment(num_ops=2, magnitude=9),
|
| 77 |
+
# RandomErasing should be applied to a tensor, so we apply it after
|
| 78 |
+
# val_transforms, which handles the PIL -> Tensor conversion.
|
| 79 |
+
val_transforms,
|
| 80 |
+
T.RandomErasing(p=0.25, scale=(0.02, 0.33), ratio=(0.3, 3.3), value='random')
|
| 81 |
+
])
|
| 82 |
+
else:
|
| 83 |
+
raise ValueError(f"Augmentation level '{augmentation_level}' is not supported. "
|
| 84 |
+
f"Choose from 'default' or 'strong'.")
|
| 85 |
+
|
| 86 |
+
# 3. Prepare the dictionary to be returned
|
| 87 |
+
components = {
|
| 88 |
+
"train_transforms": train_transforms,
|
| 89 |
+
"val_transforms": val_transforms
|
| 90 |
+
}
|
| 91 |
+
|
| 92 |
+
# 4. Optionally, instantiate the model to get the classifier
|
| 93 |
+
if return_classifier:
|
| 94 |
+
model = model_class(weights=weights)
|
| 95 |
+
components["classifier"] = model.classifier
|
| 96 |
+
|
| 97 |
+
return components
|
| 98 |
+
|
| 99 |
+
class CustomFood101(Dataset):
|
| 100 |
+
"""A PyTorch Dataset for Food101 with conditional downloading and subset support.
|
| 101 |
+
|
| 102 |
+
This class wraps the torchvision Food101 dataset. It only downloads the data
|
| 103 |
+
if the specified directory doesn't already exist. It can also create a
|
| 104 |
+
reproducible, shuffled subset of the data for faster experimentation.
|
| 105 |
+
|
| 106 |
+
Args:
|
| 107 |
+
split (str): The dataset split, either "train" or "test".
|
| 108 |
+
transform (callable, optional): A function/transform to apply to the images.
|
| 109 |
+
data_dir (str, optional): The directory to store the data. Defaults to "data".
|
| 110 |
+
subset_fraction (float, optional): The fraction of the dataset to use.
|
| 111 |
+
Defaults to 1.0 (using the full dataset).
|
| 112 |
+
"""
|
| 113 |
+
|
| 114 |
+
def __init__(self, split, transform=None, data_dir="data", subset_fraction: float = 0.1):
|
| 115 |
+
# Check if the dataset already exists before setting the download flag.
|
| 116 |
+
dataset_path = os.path.join(data_dir, "food-101")
|
| 117 |
+
should_download = not os.path.isdir(dataset_path)
|
| 118 |
+
|
| 119 |
+
# 1. Load the full dataset metadata with the conditional flag
|
| 120 |
+
self.full_dataset = Food101(root=data_dir, split=split, transform=transform, download=should_download)
|
| 121 |
+
self.classes = self.full_dataset.classes
|
| 122 |
+
|
| 123 |
+
# 2. Create a reproducible subset of indices
|
| 124 |
+
if subset_fraction < 1.0:
|
| 125 |
+
num_samples = int(len(self.full_dataset) * subset_fraction)
|
| 126 |
+
all_indices = list(range(len(self.full_dataset)))
|
| 127 |
+
# Shuffle with a fixed seed for reproducibility
|
| 128 |
+
random.Random(42).shuffle(all_indices)
|
| 129 |
+
self.indices = all_indices[:num_samples]
|
| 130 |
+
else:
|
| 131 |
+
self.indices = list(range(len(self.full_dataset)))
|
| 132 |
+
|
| 133 |
+
def __len__(self):
|
| 134 |
+
"""Returns the total number of samples in the subset."""
|
| 135 |
+
return len(self.indices)
|
| 136 |
+
|
| 137 |
+
def __getitem__(self, idx):
|
| 138 |
+
"""
|
| 139 |
+
Fetches the sample for the given subset index and applies the transform.
|
| 140 |
+
"""
|
| 141 |
+
# Map the subset index to the actual index in the full dataset
|
| 142 |
+
original_idx = self.indices[idx]
|
| 143 |
+
image, label = self.full_dataset[original_idx]
|
| 144 |
+
return image, label
|
| 145 |
+
|
| 146 |
+
class Food101DataModule(pl.LightningDataModule):
|
| 147 |
+
"""A PyTorch Lightning DataModule for the Food101 dataset.
|
| 148 |
+
|
| 149 |
+
This module encapsulates all data-related logic, including downloading,
|
| 150 |
+
processing, and creating DataLoaders for the training, validation, and
|
| 151 |
+
test sets. It uses the CustomFood101 dataset internally and allows for
|
| 152 |
+
controlling the fraction of data used in the training and validation splits.
|
| 153 |
+
|
| 154 |
+
Args:
|
| 155 |
+
data_dir (str, optional): Root directory for the data. Defaults to "data".
|
| 156 |
+
batch_size (int, optional): The batch size for DataLoaders. Defaults to 32.
|
| 157 |
+
num_workers (int, optional): Number of workers for data loading. Defaults to 2.
|
| 158 |
+
train_transforms (callable, optional): Transformations for the training set.
|
| 159 |
+
val_transforms (callable, optional): Transformations for the validation/test set.
|
| 160 |
+
subset_fraction (float, optional): The fraction of data to use for training
|
| 161 |
+
and validation. Defaults to 1.0.
|
| 162 |
+
"""
|
| 163 |
+
def __init__(self, data_dir="data", batch_size=32, num_workers=2,
|
| 164 |
+
train_transforms=None, val_transforms=None, subset_fraction: float = 0.5):
|
| 165 |
+
super().__init__()
|
| 166 |
+
self.data_dir = data_dir
|
| 167 |
+
self.batch_size = batch_size
|
| 168 |
+
self.num_workers = num_workers
|
| 169 |
+
self.train_transforms = train_transforms
|
| 170 |
+
self.val_transforms = val_transforms
|
| 171 |
+
self.subset_fraction = subset_fraction
|
| 172 |
+
|
| 173 |
+
self.classes = []
|
| 174 |
+
|
| 175 |
+
def prepare_data(self):
|
| 176 |
+
"""Downloads data if needed."""
|
| 177 |
+
CustomFood101(split='train', data_dir=self.data_dir)
|
| 178 |
+
CustomFood101(split='test', data_dir=self.data_dir)
|
| 179 |
+
|
| 180 |
+
def setup(self, stage=None):
|
| 181 |
+
"""Assigns datasets, passing the subset_fraction."""
|
| 182 |
+
if stage == 'fit' or stage is None:
|
| 183 |
+
self.train_dataset = CustomFood101(split='train', transform=self.train_transforms,
|
| 184 |
+
data_dir=self.data_dir, subset_fraction=self.subset_fraction)
|
| 185 |
+
self.val_dataset = CustomFood101(split='test', transform=self.val_transforms,
|
| 186 |
+
data_dir=self.data_dir, subset_fraction=self.subset_fraction)
|
| 187 |
+
self.classes = self.train_dataset.classes
|
| 188 |
+
|
| 189 |
+
if stage == 'test' or stage is None:
|
| 190 |
+
self.test_dataset = CustomFood101(split='test', transform=self.val_transforms,
|
| 191 |
+
data_dir=self.data_dir, subset_fraction=1.0) # Use full test set
|
| 192 |
+
if not self.classes:
|
| 193 |
+
self.classes = self.test_dataset.classes
|
| 194 |
+
|
| 195 |
+
def train_dataloader(self):
|
| 196 |
+
return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers)
|
| 197 |
+
|
| 198 |
+
def val_dataloader(self):
|
| 199 |
+
return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers)
|
| 200 |
+
|
| 201 |
+
def test_dataloader(self):
|
| 202 |
+
return DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers)
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
if __name__ == "__main__":
|
| 206 |
+
# Define configuration for the script
|
| 207 |
+
DATA_DIR = "data"
|
| 208 |
+
MODEL_NAME = "EfficientNet_V2_S"
|
| 209 |
+
BATCH_SIZE = 32
|
| 210 |
+
|
| 211 |
+
print(f"Running data preparation script for model: {MODEL_NAME}")
|
| 212 |
+
|
| 213 |
+
# 1. Get model-specific transforms
|
| 214 |
+
components = get_model_components(MODEL_NAME)
|
| 215 |
+
train_transforms = components["train_transforms"]
|
| 216 |
+
val_transforms = components["val_transforms"]
|
| 217 |
+
|
| 218 |
+
# 2. Instantiate the DataModule
|
| 219 |
+
datamodule = Food101DataModule(
|
| 220 |
+
data_dir=DATA_DIR,
|
| 221 |
+
batch_size=BATCH_SIZE,
|
| 222 |
+
train_transforms=train_transforms,
|
| 223 |
+
val_transforms=val_transforms,
|
| 224 |
+
subset_fraction=0.1 # Use a small subset for quick verification
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
# 3. Trigger download and setup
|
| 228 |
+
datamodule.prepare_data()
|
| 229 |
+
datamodule.setup(stage='fit')
|
| 230 |
+
|
| 231 |
+
# 4. (Optional) Verification Step
|
| 232 |
+
print("\n--- Verifying Dataloader ---")
|
| 233 |
+
# Get one batch from the training dataloader
|
| 234 |
+
train_dl = datamodule.train_dataloader()
|
| 235 |
+
images, labels = next(iter(train_dl))
|
| 236 |
+
|
| 237 |
+
print(f"Number of classes: {len(datamodule.classes)}")
|
| 238 |
+
print(f"Image batch shape: {images.shape}")
|
| 239 |
+
print(f"Label batch shape: {labels.shape}")
|
| 240 |
+
print("--- Verification Complete ---")
|