DanielKiani commited on
Commit
43124a6
·
0 Parent(s):

Initial commit of Food101 Classification

Browse files
.gitattributes ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
2
+ *.png filter=lfs diff=lfs merge=lfs -text
3
+ *.jpg filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.pyc
4
+
5
+ # Virtual Environment
6
+ venv/
7
+ .venv/
8
+
9
+ # Data and Logs
10
+ data/
11
+ logs/
12
+ notebooks/data/
13
+ notebooks/logs/
14
+
15
+ # IDE files
16
+ .vscode/
17
+ .idea/
18
+
README.md ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ![Food101 Classification Banner](assets/banner.png)
2
+
3
+ [![Python](https://img.shields.io/badge/Python-3.10-blue?logo=python)](https://www.python.org/)[![PyTorch](https://img.shields.io/badge/PyTorch-2.7.1-EE4C2C?logo=pytorch)](https://pytorch.org/)![Made with ML](https://img.shields.io/badge/Made%20with-ML-blueviolet?logo=openai)[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](LICENSE)
4
+
5
+ # 🍽️ Food-101 Image Classification with EfficientNetV2-S and PyTorch Lightning
6
+
7
+ This repository contains the code for an end-to-end deep learning project to classify 101 food categories from the challenging Food-101 dataset. The project demonstrates a systematic approach to model selection, fine-tuning, and hyperparameter optimization, achieving a final validation accuracy of **85.4%** on the full dataset.
8
+
9
+ The entire training and evaluation pipeline is built using modern, reproducible practices with PyTorch Lightning.
10
+
11
+ ---
12
+
13
+ ## 📑 Table of Contents
14
+
15
+ - [�️ Food-101 Image Classification with EfficientNetV2-S and PyTorch Lightning](#️-food-101-image-classification-with-efficientnetv2-s-and-pytorch-lightning)
16
+ - [📑 Table of Contents](#-table-of-contents)
17
+ - [🎯 Project Highlights](#-project-highlights)
18
+ - [💡 Real-World Applications](#-real-world-applications)
19
+ - [🧫 Experimental Results](#-experimental-results)
20
+ - [📊 Final Results](#-final-results)
21
+ - [🔬 Performance Analysis and Error Diagnosis](#-performance-analysis-and-error-diagnosis)
22
+ - [🍤 Lowest-Performing Classes](#-lowest-performing-classes)
23
+ - [Root Cause Analysis of Misclassifications](#root-cause-analysis-of-misclassifications)
24
+ - [Future Work](#future-work)
25
+ - [🧪 Methodology and Experimental Process](#-methodology-and-experimental-process)
26
+ - [📁 Repository Structure](#-repository-structure)
27
+ - [🚀 Getting Started](#-getting-started)
28
+ - [Prerequisites](#prerequisites)
29
+ - [Installation](#installation)
30
+ - [Usage](#usage)
31
+ - [💻 Technologies Used](#-technologies-used)
32
+
33
+ ---
34
+
35
+ ## 🎯 Project Highlights
36
+
37
+ - **High-Performance Model** ⚡: Utilizes a pre-trained `EfficientNetV2-S`, selected for its excellent balance of accuracy and computational efficiency suitable for potential edge deployment.
38
+ - **Reproducible Pipeline** 🔄: Encapsulates the entire workflow—from data loading to training and evaluation—in a clean and organized `LightningModule` and `DataModule`.
39
+ - **Efficient Experimentation** ⏱️: Overcame hardware limitations by implementing dataset subsetting for rapid prototyping.
40
+ - **Advanced Fine-Tuning** 🛠️: Implemented a robust fine-tuning strategy, unfreezing the final three blocks of the feature extractor and using the `Adam` optimizer with a `CosineAnnealingLR` scheduler for stable convergence.
41
+ - **In-Depth Analysis** 🔎: Went beyond simple accuracy by calculating and logging per-class F1-scores and accuracies, enabling a deep dive into the model's strengths and weaknesses.
42
+ - **Live Deployment** 📺: The final model is deployed and accessible as an interactive Gradio web application on Hugging Face Spaces.
43
+
44
+ ---
45
+
46
+ ## 💡 Real-World Applications
47
+
48
+ Beyond being a technical challenge, this food classification model serves as a foundation for numerous real-world applications in health, hospitality, and smart home technology.
49
+
50
+ - **Health and Nutrition Tracking**
51
+ - **Automated Calorie Counting:** Users can snap a photo of their meal, and an app can automatically identify each food item to provide an instant estimate of calories, macros, and other nutritional information.
52
+ - **Dietary Management:** Assists individuals with allergies or specific dietary needs (e.g., diabetes, gluten-free) by helping them identify and log their food intake accurately.
53
+ - **Restaurant and Hospitality Tech**
54
+ - **Self-Checkout Systems:** In cafeterias or quick-service restaurants, a camera-based system could identify all items on a tray to automate the billing process, reducing queues and improving efficiency.
55
+ - **Interactive Menus:** Allow diners to point their phone at a dish to get more information, such as ingredients, allergen warnings, or customer reviews.
56
+
57
+ - **Smart Home and Appliances**
58
+ - **Smart Refrigerators:** A fridge equipped with a camera could identify leftover dishes, suggest recipes based on available food, and help track food spoilage to reduce waste.
59
+
60
+ ---
61
+
62
+ ## 🧫 Experimental Results
63
+
64
+ This project followed an iterative approach. The table below summarizes the key experiments and their outcomes, showing the progression from the initial baseline to the final model.
65
+
66
+ | Model | Training Strategy | Data % | Key Hyperparameters | Final Val Accuracy |
67
+ | :--- | :--- | :--- | :--- | :--- |
68
+ | `EfficientNet-B2` | Simple fine-tune (last block) | 50% | `lr=1e-4` | ~64% |
69
+ | `EfficientNet-B2` | Unfreeze last 3 blocks | 50% | `lr=1e-3` | 82.0% |
70
+ | `EfficientNet-B2` | Two-Stage Fine-Tuning | 50% | `lr1=1e-3`, `lr2=1e-5` | Performance Degraded |
71
+ | **`EfficientNetV2-S`** | Unfreeze last 3 blocks | 50% | `lr=1e-4` (Tuned) | 82.4% |
72
+ | **`EfficientNetV2-S`** | Unfreeze last 3 blocks and more advanced transforms | 50% | `lr=1e-4` (Tuned) | ~82.4% Pretty much the same Performance|
73
+ | **`EfficientNetV2-S`** | **Unfreeze last 3 blocks** | **100%** | **`lr=1e-4` (Tuned)** | **85.4%** |
74
+
75
+ ---
76
+
77
+ ## 📊 Final Results
78
+
79
+ After systematically iterating on model architecture and hyperparameters, the final model achieved the following performance on the full Food-101 validation set:
80
+
81
+ | Metric | Score |
82
+ | :------------------ | :------ |
83
+ | Validation Accuracy | **85.4%** |
84
+
85
+ ![Confusion Matrix Plot](assets/confusion_matrix.png)
86
+ *A confusion matrix visualization helps diagnose the model's performance on a per-class basis. (Replace with your own plot)*
87
+
88
+ This model is deployed and accessible as an interactive Gradio web application on Hugging Face Spaces.
89
+
90
+ ![Gradio](assets/gradio.png)
91
+
92
+ Check out my [Food101 Gradio Demo](https://huggingface.co/spaces/your-username/food101-demo).
93
+
94
+ ---
95
+
96
+ ## 🔬 Performance Analysis and Error Diagnosis
97
+
98
+ Beyond the aggregate accuracy, a per-class analysis was conducted to identify the model's specific limitations and diagnose the root causes of misclassifications.
99
+
100
+ The model performed exceptionally well on many classes but struggled with a distinct set of categories, primarily due to visual ambiguity and high variability in appearance.
101
+
102
+ #### 🍤 Lowest-Performing Classes
103
+
104
+ The following five classes had the lowest validation accuracy:
105
+
106
+ | Class Name | Index | Validation Accuracy |
107
+ | :------------------ | :---- | :------------------ |
108
+ | `shrimp_and_grits` | 93 | 44.0% |
109
+ | `ravioli` | 77 | 59.2% |
110
+ | `apple_pie` | 0 | 61.6% |
111
+ | `huevos_rancheros` | 56 | 63.2% |
112
+ | `falafel` | 36 | 63.6% |
113
+
114
+ #### Root Cause Analysis of Misclassifications
115
+
116
+ * **High Intra-Class Variation**: The model struggled with dishes that have no single, consistent appearance.
117
+ * **Fine-Grained Confusion**: Errors occurred between visually similar classes like `ravioli` vs. `dumplings`.
118
+ * **Ambiguous Features**: Foods like `falafel` resemble many small fried dishes, making classification tricky.
119
+
120
+ #### Future Work
121
+
122
+ Improvements could include:
123
+
124
+ - Detailed confusion matrix analysis 🔍
125
+ - More aggressive data augmentation 📈
126
+ - Larger architectures for fine-grained recognition 🏋️
127
+ - Training for longer 🏋️
128
+
129
+ ---
130
+
131
+ ## 🧪 Methodology and Experimental Process
132
+
133
+ Steps taken in the project:
134
+
135
+ 1. **Baseline Establishment** 🏁 – EfficientNet-B2 achieved ~64%.
136
+ 2. **Architecture Selection** 🏗️ – EfficientNetV2-S chosen for balance of accuracy and size.
137
+ 3. **Transforms Selection** 🎨 – TrivialAugmentWide + RandomResizedCrop, RandAugment, etc.
138
+ 4. **Fine-Tuning Strategy** 🔧 – Final 3 blocks unfrozen for training.
139
+ 5. **Final Model Training** 🏆 – Full dataset, Adam, CosineAnnealingLR, EarlyStopping → 85.4%.
140
+
141
+ ---
142
+
143
+ ## 📁 Repository Structure
144
+
145
+ ```bash
146
+ food-101-classification/
147
+ ├── data/
148
+ ├── logs/
149
+ ├── scripts/
150
+ │ ├── main.py
151
+ │ ├── models.py
152
+ │ ├── class_names.py
153
+ │ ├── app.py
154
+ │ └── prepare_data.py
155
+ ├── .gitignore
156
+ ├── requirements.txt
157
+ └── README.md
158
+ ```
159
+
160
+ ---
161
+
162
+ ## 🚀 Getting Started
163
+
164
+ ### Prerequisites
165
+
166
+ - Python 3.10+ 🐍
167
+ - PyTorch 🔥
168
+ - CUDA-enabled GPU (recommended) 🎮
169
+
170
+ ### Installation
171
+
172
+ 1. **Clone the repository:**
173
+
174
+ ```bash
175
+ git clone https://github.com/Deathshot78/Food101-Classification
176
+ cd Food101-Classification
177
+ ```
178
+
179
+ 2. **Install the dependencies:**
180
+
181
+ ```bash
182
+ pip install -r requirements.txt
183
+ ```
184
+
185
+ ### Usage
186
+
187
+ Run training with a subset for quick testing:
188
+
189
+ ```bash
190
+ python main.py
191
+ ```
192
+
193
+ ### 💻 Technologies Used
194
+
195
+ - Python
196
+
197
+ - PyTorch
198
+
199
+ - PyTorch Lightning
200
+
201
+ - TorchMetrics
202
+
203
+ - Gradio
204
+
205
+ - Matplotlib & Seaborn
206
+
assets/banner.png ADDED

Git LFS Details

  • SHA256: 9093c3aa4394c9f611a6983c479b21e2e26f8f68cd35d9d5a7f0e589cfd8202f
  • Pointer size: 132 Bytes
  • Size of remote file: 1.67 MB
assets/confusion_matrix.png ADDED

Git LFS Details

  • SHA256: 5b76963efa5e9a0ab0ca5ac3dc2d1ba63f96e77ec45ecc6386cfdaab0051bf94
  • Pointer size: 132 Bytes
  • Size of remote file: 1.38 MB
assets/gradio.png ADDED

Git LFS Details

  • SHA256: 293eeb6050e65e7ad27b9a6317090f7d721a58413d297c6e7ead2f3657b5a0f7
  • Pointer size: 131 Bytes
  • Size of remote file: 615 kB
assets/onion_rings.jpg ADDED

Git LFS Details

  • SHA256: 0b259ff866df09cab2c71c479c9d5e4b3273baa585b836349d89beaceb05149e
  • Pointer size: 130 Bytes
  • Size of remote file: 12.2 kB
assets/oysters.jpg ADDED

Git LFS Details

  • SHA256: aed7e730dddb604351a5b451b2503e350f7cb2cd78cad33076867736effa9bad
  • Pointer size: 131 Bytes
  • Size of remote file: 130 kB
assets/pizza.jpg ADDED

Git LFS Details

  • SHA256: 00896d5bdd63e1f074009ff7c607ce89cd3e37c8e62c7f190d57e2059f560b79
  • Pointer size: 130 Bytes
  • Size of remote file: 16 kB
assets/ramen.jpg ADDED

Git LFS Details

  • SHA256: 8a7186765988f80cbae7438b5aec7238d043b182676527ebc089d84114c1cc67
  • Pointer size: 131 Bytes
  • Size of remote file: 239 kB
checkpoints/best-model-epoch=22-val_acc=0.8541.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6663c5df734fa5e9a50fab801a746ebd095732faacb1938368444518c3265615
3
+ size 230292623
notebooks/food101_classification.ipynb ADDED
@@ -0,0 +1,1055 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "430db510",
6
+ "metadata": {},
7
+ "source": [
8
+ "# Food-101 Image Classification with EfficientNetV2-S and PyTorch Lightning\n",
9
+ "\n",
10
+ "This repository contains the code for an end-to-end deep learning project to classify 101 food categories from the challenging Food-101 dataset. The project demonstrates a systematic approach to model selection, fine-tuning, and hyperparameter optimization, achieving a final validation accuracy of **85.4%** on the full dataset.\n",
11
+ "\n",
12
+ "The entire training and evaluation pipeline is built using modern, reproducible practices with PyTorch Lightning."
13
+ ]
14
+ },
15
+ {
16
+ "cell_type": "markdown",
17
+ "id": "1116006e",
18
+ "metadata": {},
19
+ "source": [
20
+ "## 1. Imports"
21
+ ]
22
+ },
23
+ {
24
+ "cell_type": "code",
25
+ "execution_count": null,
26
+ "id": "531943f8",
27
+ "metadata": {},
28
+ "outputs": [],
29
+ "source": [
30
+ "import torch\n",
31
+ "import matplotlib.pyplot as plt\n",
32
+ "import pandas as pd\n",
33
+ "import numpy as np\n",
34
+ "import os\n",
35
+ "import pytorch_lightning as pl\n",
36
+ "import torch.optim.lr_scheduler as lr_scheduler\n",
37
+ "import torchvision\n",
38
+ "\n",
39
+ "from torchmetrics.functional import accuracy\n",
40
+ "from torchvision import transforms, datasets\n",
41
+ "from torchinfo import summary\n",
42
+ "from pytorch_lightning.callbacks import EarlyStopping\n",
43
+ "from pytorch_lightning.loggers import TensorBoardLogger\n",
44
+ "from torch import nn\n",
45
+ "from pathlib import Path\n",
46
+ "from torch.utils.data import DataLoader, TensorDataset, Dataset, random_split"
47
+ ]
48
+ },
49
+ {
50
+ "cell_type": "markdown",
51
+ "id": "08a5c10b",
52
+ "metadata": {},
53
+ "source": [
54
+ "## 2. Quick inspection of the top model"
55
+ ]
56
+ },
57
+ {
58
+ "cell_type": "code",
59
+ "execution_count": null,
60
+ "id": "e20ad559",
61
+ "metadata": {},
62
+ "outputs": [],
63
+ "source": [
64
+ "# Here we inspect the models classifier layer to match the number of classes in Food101\n",
65
+ "\n",
66
+ "weights = torchvision.models.EfficientNet_V2_S_Weights.DEFAULT\n",
67
+ "model = torchvision.models.efficientnet_v2_s(weights=weights)\n",
68
+ "effnet_v2_s_transforms = weights.transforms()\n",
69
+ "\n",
70
+ "print(model.classifier)"
71
+ ]
72
+ },
73
+ {
74
+ "cell_type": "code",
75
+ "execution_count": null,
76
+ "id": "3442b5a9",
77
+ "metadata": {},
78
+ "outputs": [],
79
+ "source": [
80
+ "# Inspect the model\n",
81
+ "\n",
82
+ "summary(model=model,\n",
83
+ " input_size=(1, 3, 224, 224),\n",
84
+ " col_names=['input_size', 'output_size', 'num_params', 'trainable'],\n",
85
+ " col_width=20,\n",
86
+ " row_settings=['var_names'])"
87
+ ]
88
+ },
89
+ {
90
+ "cell_type": "code",
91
+ "execution_count": null,
92
+ "id": "bb353575",
93
+ "metadata": {},
94
+ "outputs": [],
95
+ "source": [
96
+ "# This will be the base transforms for training \n",
97
+ "\n",
98
+ "effnet_v2_s_transforms = weights.transforms()\n",
99
+ "train_transforms = torchvision.transforms.Compose([\n",
100
+ " torchvision.transforms.TrivialAugmentWide(),\n",
101
+ " effnet_v2_s_transforms])\n",
102
+ "\n",
103
+ "train_transforms"
104
+ ]
105
+ },
106
+ {
107
+ "cell_type": "markdown",
108
+ "id": "363ce600",
109
+ "metadata": {},
110
+ "source": [
111
+ "## 3. Dataset and Torch lightning Datamodule Classes"
112
+ ]
113
+ },
114
+ {
115
+ "cell_type": "code",
116
+ "execution_count": null,
117
+ "id": "2a04cc09",
118
+ "metadata": {},
119
+ "outputs": [],
120
+ "source": [
121
+ "from torchvision import datasets\n",
122
+ "from pathlib import Path\n",
123
+ "import os\n",
124
+ "import pytorch_lightning as pl\n",
125
+ "from torch.utils.data import DataLoader, Subset\n",
126
+ "from torchvision import datasets\n",
127
+ "from torchvision import transforms as T\n",
128
+ "import numpy as np\n",
129
+ "import torchvision\n",
130
+ "from torchvision.datasets import Food101\n",
131
+ "from torch.utils.data import DataLoader, Dataset\n",
132
+ "from typing import Dict, Tuple, Any\n",
133
+ "import random\n",
134
+ "\n",
135
+ "\n",
136
+ "def get_model_components(\n",
137
+ " model_name: str, \n",
138
+ " return_classifier: bool = False, \n",
139
+ " augmentation_level: str = \"default\"\n",
140
+ ") -> Dict[str, Any]:\n",
141
+ " \"\"\"\n",
142
+ " Retrieves pre-trained model components from torchvision.\n",
143
+ "\n",
144
+ " This function fetches the appropriate weights and transforms for a given\n",
145
+ " model. It supports different levels of training data augmentation.\n",
146
+ "\n",
147
+ " Args:\n",
148
+ " model_name (str): The name of the model to get components for.\n",
149
+ " Supported models include \"EfficientNet_V2_S\" and \"EfficientNet_B2\".\n",
150
+ " return_classifier (bool, optional): If True, the model's classifier\n",
151
+ " head is also returned. Defaults to False.\n",
152
+ " augmentation_level (str, optional): The level of data augmentation to use\n",
153
+ " for the training set. Can be \"default\" or \"strong\". \n",
154
+ " Defaults to \"default\".\n",
155
+ "\n",
156
+ " Returns:\n",
157
+ " Dict[str, Any]: A dictionary containing the requested components.\n",
158
+ " Always includes 'train_transforms' and 'val_transforms'.\n",
159
+ " Includes 'classifier' if return_classifier is True.\n",
160
+ " \n",
161
+ " Raises:\n",
162
+ " ValueError: If model_name or augmentation_level is not supported.\n",
163
+ " \"\"\"\n",
164
+ " model_registry = {\n",
165
+ " \"EfficientNet_V2_S\": (\n",
166
+ " torchvision.models.efficientnet_v2_s,\n",
167
+ " torchvision.models.EfficientNet_V2_S_Weights.DEFAULT\n",
168
+ " ),\n",
169
+ " \"EfficientNet_B2\": (\n",
170
+ " torchvision.models.efficientnet_b2,\n",
171
+ " torchvision.models.EfficientNet_B2_Weights.DEFAULT\n",
172
+ " )\n",
173
+ " }\n",
174
+ "\n",
175
+ " if model_name not in model_registry:\n",
176
+ " raise ValueError(f\"Model '{model_name}' is not supported. \"\n",
177
+ " f\"Supported models are: {list(model_registry.keys())}\")\n",
178
+ "\n",
179
+ " # 1. Look up the model and weights classes\n",
180
+ " model_class, weights_class = model_registry[model_name]\n",
181
+ " weights = weights_class\n",
182
+ " val_transforms = weights.transforms()\n",
183
+ "\n",
184
+ " # 2. Create the training transforms based on the desired level\n",
185
+ " if augmentation_level == \"default\":\n",
186
+ " train_transforms = T.Compose([\n",
187
+ " T.TrivialAugmentWide(),\n",
188
+ " val_transforms # val_transforms includes ToTensor and Normalize\n",
189
+ " ])\n",
190
+ " elif augmentation_level == \"strong\":\n",
191
+ " # Note: We don't need to add ToTensor() or Normalize() here because\n",
192
+ " # they are already included inside the 'val_transforms' pipeline.\n",
193
+ " train_transforms = T.Compose([\n",
194
+ " T.RandomResizedCrop(size=val_transforms.crop_size, scale=(0.7, 1.0)),\n",
195
+ " T.RandomHorizontalFlip(p=0.5),\n",
196
+ " T.RandAugment(num_ops=2, magnitude=9),\n",
197
+ " # RandomErasing should be applied to a tensor, so we apply it after\n",
198
+ " # val_transforms, which handles the PIL -> Tensor conversion.\n",
199
+ " val_transforms, \n",
200
+ " T.RandomErasing(p=0.25, scale=(0.02, 0.33), ratio=(0.3, 3.3), value='random')\n",
201
+ " ])\n",
202
+ " else:\n",
203
+ " raise ValueError(f\"Augmentation level '{augmentation_level}' is not supported. \"\n",
204
+ " f\"Choose from 'default' or 'strong'.\")\n",
205
+ " \n",
206
+ " # 3. Prepare the dictionary to be returned\n",
207
+ " components = {\n",
208
+ " \"train_transforms\": train_transforms,\n",
209
+ " \"val_transforms\": val_transforms\n",
210
+ " }\n",
211
+ "\n",
212
+ " # 4. Optionally, instantiate the model to get the classifier\n",
213
+ " if return_classifier:\n",
214
+ " model = model_class(weights=weights)\n",
215
+ " components[\"classifier\"] = model.classifier\n",
216
+ "\n",
217
+ " return components\n",
218
+ " \n",
219
+ "class CustomFood101(Dataset):\n",
220
+ " \"\"\"A PyTorch Dataset for Food101 with conditional downloading and subset support.\n",
221
+ "\n",
222
+ " This class wraps the torchvision Food101 dataset. It only downloads the data\n",
223
+ " if the specified directory doesn't already exist. It can also create a\n",
224
+ " reproducible, shuffled subset of the data for faster experimentation.\n",
225
+ "\n",
226
+ " Args:\n",
227
+ " split (str): The dataset split, either \"train\" or \"test\".\n",
228
+ " transform (callable, optional): A function/transform to apply to the images.\n",
229
+ " data_dir (str, optional): The directory to store the data. Defaults to \"data\".\n",
230
+ " subset_fraction (float, optional): The fraction of the dataset to use.\n",
231
+ " Defaults to 1.0 (using the full dataset).\n",
232
+ " \"\"\"\n",
233
+ "\n",
234
+ " def __init__(self, split, transform=None, data_dir=\"data\", subset_fraction: float = 0.1):\n",
235
+ " # Check if the dataset already exists before setting the download flag.\n",
236
+ " dataset_path = os.path.join(data_dir, \"food-101\")\n",
237
+ " should_download = not os.path.isdir(dataset_path)\n",
238
+ "\n",
239
+ " # 1. Load the full dataset metadata with the conditional flag\n",
240
+ " self.full_dataset = Food101(root=data_dir, split=split, transform=transform, download=should_download)\n",
241
+ " self.classes = self.full_dataset.classes\n",
242
+ "\n",
243
+ " # 2. Create a reproducible subset of indices\n",
244
+ " if subset_fraction < 1.0:\n",
245
+ " num_samples = int(len(self.full_dataset) * subset_fraction)\n",
246
+ " all_indices = list(range(len(self.full_dataset)))\n",
247
+ " # Shuffle with a fixed seed for reproducibility\n",
248
+ " random.Random(42).shuffle(all_indices)\n",
249
+ " self.indices = all_indices[:num_samples]\n",
250
+ " else:\n",
251
+ " self.indices = list(range(len(self.full_dataset)))\n",
252
+ "\n",
253
+ " def __len__(self):\n",
254
+ " \"\"\"Returns the total number of samples in the subset.\"\"\"\n",
255
+ " return len(self.indices)\n",
256
+ "\n",
257
+ " def __getitem__(self, idx):\n",
258
+ " \"\"\"\n",
259
+ " Fetches the sample for the given subset index and applies the transform.\n",
260
+ " \"\"\"\n",
261
+ " # Map the subset index to the actual index in the full dataset\n",
262
+ " original_idx = self.indices[idx]\n",
263
+ " image, label = self.full_dataset[original_idx]\n",
264
+ " return image, label\n",
265
+ "\n",
266
+ "class Food101DataModule(pl.LightningDataModule):\n",
267
+ " \"\"\"A PyTorch Lightning DataModule for the Food101 dataset.\n",
268
+ "\n",
269
+ " This module encapsulates all data-related logic, including downloading,\n",
270
+ " processing, and creating DataLoaders for the training, validation, and\n",
271
+ " test sets. It uses the CustomFood101 dataset internally and allows for\n",
272
+ " controlling the fraction of data used in the training and validation splits.\n",
273
+ "\n",
274
+ " Args:\n",
275
+ " data_dir (str, optional): Root directory for the data. Defaults to \"data\".\n",
276
+ " batch_size (int, optional): The batch size for DataLoaders. Defaults to 32.\n",
277
+ " num_workers (int, optional): Number of workers for data loading. Defaults to 2.\n",
278
+ " train_transforms (callable, optional): Transformations for the training set.\n",
279
+ " val_transforms (callable, optional): Transformations for the validation/test set.\n",
280
+ " subset_fraction (float, optional): The fraction of data to use for training\n",
281
+ " and validation. Defaults to 1.0.\n",
282
+ " \"\"\"\n",
283
+ " def __init__(self, data_dir=\"data\", batch_size=32, num_workers=2,\n",
284
+ " train_transforms=None, val_transforms=None, subset_fraction: float = 0.5):\n",
285
+ " super().__init__()\n",
286
+ " self.data_dir = data_dir\n",
287
+ " self.batch_size = batch_size\n",
288
+ " self.num_workers = num_workers\n",
289
+ " self.train_transforms = train_transforms\n",
290
+ " self.val_transforms = val_transforms\n",
291
+ " self.subset_fraction = subset_fraction\n",
292
+ "\n",
293
+ " self.classes = []\n",
294
+ "\n",
295
+ " def prepare_data(self):\n",
296
+ " \"\"\"Downloads data if needed.\"\"\"\n",
297
+ " CustomFood101(split='train', data_dir=self.data_dir)\n",
298
+ " CustomFood101(split='test', data_dir=self.data_dir)\n",
299
+ "\n",
300
+ " def setup(self, stage=None):\n",
301
+ " \"\"\"Assigns datasets, passing the subset_fraction.\"\"\"\n",
302
+ " if stage == 'fit' or stage is None:\n",
303
+ " self.train_dataset = CustomFood101(split='train', transform=self.train_transforms,\n",
304
+ " data_dir=self.data_dir, subset_fraction=self.subset_fraction)\n",
305
+ " self.val_dataset = CustomFood101(split='test', transform=self.val_transforms,\n",
306
+ " data_dir=self.data_dir, subset_fraction=self.subset_fraction)\n",
307
+ " self.classes = self.train_dataset.classes\n",
308
+ "\n",
309
+ " if stage == 'test' or stage is None:\n",
310
+ " self.test_dataset = CustomFood101(split='test', transform=self.val_transforms,\n",
311
+ " data_dir=self.data_dir, subset_fraction=1.0) # Use full test set\n",
312
+ " if not self.classes:\n",
313
+ " self.classes = self.test_dataset.classes\n",
314
+ "\n",
315
+ " def train_dataloader(self):\n",
316
+ " return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers)\n",
317
+ "\n",
318
+ " def val_dataloader(self):\n",
319
+ " return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers)\n",
320
+ "\n",
321
+ " def test_dataloader(self):\n",
322
+ " return DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers)\n"
323
+ ]
324
+ },
325
+ {
326
+ "cell_type": "code",
327
+ "execution_count": null,
328
+ "id": "93ffd2e2",
329
+ "metadata": {},
330
+ "outputs": [],
331
+ "source": [
332
+ "# Define configuration for the script\n",
333
+ "DATA_DIR = \"data\"\n",
334
+ "MODEL_NAME = \"EfficientNet_V2_S\"\n",
335
+ "BATCH_SIZE = 32\n",
336
+ "\n",
337
+ "print(f\"Running data preparation script for model: {MODEL_NAME}\")\n",
338
+ "\n",
339
+ "# 1. Get model-specific transforms\n",
340
+ "components = get_model_components(MODEL_NAME)\n",
341
+ "train_transforms = components[\"train_transforms\"]\n",
342
+ "val_transforms = components[\"val_transforms\"]\n",
343
+ "\n",
344
+ "# 2. Instantiate the DataModule\n",
345
+ "datamodule = Food101DataModule(\n",
346
+ " data_dir=DATA_DIR,\n",
347
+ " batch_size=BATCH_SIZE,\n",
348
+ " train_transforms=train_transforms,\n",
349
+ " val_transforms=val_transforms,\n",
350
+ " subset_fraction=0.1 # Use a small subset for quick verification\n",
351
+ ")\n",
352
+ "\n",
353
+ "# 3. Trigger download and setup\n",
354
+ "datamodule.prepare_data()\n",
355
+ "datamodule.setup(stage='fit')\n",
356
+ "\n",
357
+ "# 4. (Optional) Verification Step\n",
358
+ "print(\"\\n--- Verifying Dataloader ---\")\n",
359
+ "# Get one batch from the training dataloader\n",
360
+ "train_dl = datamodule.train_dataloader()\n",
361
+ "images, labels = next(iter(train_dl))\n",
362
+ "\n",
363
+ "print(f\"Number of classes: {len(datamodule.classes)}\")\n",
364
+ "print(f\"Image batch shape: {images.shape}\")\n",
365
+ "print(f\"Label batch shape: {labels.shape}\")\n",
366
+ "print(\"--- Verification Complete ---\") "
367
+ ]
368
+ },
369
+ {
370
+ "cell_type": "markdown",
371
+ "id": "3edf64c2",
372
+ "metadata": {},
373
+ "source": [
374
+ "## 4. Model Classes"
375
+ ]
376
+ },
377
+ {
378
+ "cell_type": "code",
379
+ "execution_count": null,
380
+ "id": "eb264fe4",
381
+ "metadata": {},
382
+ "outputs": [],
383
+ "source": [
384
+ "import torch\n",
385
+ "import torchvision\n",
386
+ "import pytorch_lightning as pl\n",
387
+ "from torch import nn\n",
388
+ "from torchmetrics.classification import Accuracy, F1Score, ConfusionMatrix\n",
389
+ "import seaborn as sns\n",
390
+ "import matplotlib.pyplot as plt\n",
391
+ "import pandas as pd\n",
392
+ "import numpy as np\n",
393
+ "\n",
394
+ "class EffNetV2_S(pl.LightningModule):\n",
395
+ " \"\"\"A PyTorch Lightning Module for fine-tuning EfficientNetV2-S.\n",
396
+ "\n",
397
+ " This module encapsulates the EfficientNetV2-S model and provides a flexible\n",
398
+ " fine-tuning strategy. It can be configured for Stage 1 (training only the\n",
399
+ " classifier and later feature blocks) or Stage 2 (training the entire model).\n",
400
+ "\n",
401
+ " Args:\n",
402
+ " lr (float, optional): The learning rate. Defaults to 1e-3.\n",
403
+ " weight_decay (float, optional): Weight decay for the optimizer. Defaults to 1e-4.\n",
404
+ " num_classes (int, optional): The number of output classes. Defaults to 101.\n",
405
+ " class_names (list, optional): A list of class names for logging. Defaults to None.\n",
406
+ " freeze_features (bool, optional): If True, freezes the backbone and unfreezes\n",
407
+ " only the later blocks (Stage 1). If False, all features are trainable\n",
408
+ " (Stage 2). Defaults to True.\n",
409
+ " unfreeze_from_block (int, optional): Which feature block to start unfreezing\n",
410
+ " from. Used only if freeze_features is True. Defaults to -3 (last 3 blocks).\n",
411
+ " \"\"\"\n",
412
+ " \n",
413
+ " def __init__(\n",
414
+ " self,\n",
415
+ " lr: float = 1e-3,\n",
416
+ " weight_decay: float = 1e-4,\n",
417
+ " num_classes: int = 101,\n",
418
+ " class_names: list = None,\n",
419
+ " freeze_features: bool = True, # True = Stage 1, False = Stage 2\n",
420
+ " unfreeze_from_block: int = -3 # Only used if freeze_features=True\n",
421
+ " ):\n",
422
+ " super().__init__()\n",
423
+ " self.save_hyperparameters()\n",
424
+ " self.class_names = class_names if class_names else [str(i) for i in range(num_classes)]\n",
425
+ "\n",
426
+ " # Load pretrained weights\n",
427
+ " weights = torchvision.models.EfficientNet_V2_S_Weights.DEFAULT\n",
428
+ " self.model = torchvision.models.efficientnet_v2_s(weights=weights)\n",
429
+ "\n",
430
+ " # ---- Freezing strategy ----\n",
431
+ " if freeze_features:\n",
432
+ " # Freeze all first\n",
433
+ " for param in self.model.parameters():\n",
434
+ " param.requires_grad = False\n",
435
+ " # Unfreeze from a specific block (default: last 3 blocks)\n",
436
+ " for param in self.model.features[unfreeze_from_block:].parameters():\n",
437
+ " param.requires_grad = True\n",
438
+ " else:\n",
439
+ " # Stage 2: unfreeze everything\n",
440
+ " for param in self.model.parameters():\n",
441
+ " param.requires_grad = True\n",
442
+ "\n",
443
+ " # Classifier head\n",
444
+ " self.model.classifier = nn.Sequential(\n",
445
+ " nn.Dropout(p=0.2, inplace=True),\n",
446
+ " nn.Linear(in_features=1280, out_features=self.hparams.num_classes, bias=True)\n",
447
+ " )\n",
448
+ "\n",
449
+ " # Loss & metrics\n",
450
+ " self.loss_fn = nn.CrossEntropyLoss(label_smoothing=0.1)\n",
451
+ " self.train_accuracy = Accuracy(task=\"multiclass\", num_classes=self.hparams.num_classes)\n",
452
+ " self.val_accuracy = Accuracy(task=\"multiclass\", num_classes=self.hparams.num_classes)\n",
453
+ " self.train_f1 = F1Score(task=\"multiclass\", num_classes=self.hparams.num_classes, average='macro')\n",
454
+ " self.val_f1 = F1Score(task=\"multiclass\", num_classes=self.hparams.num_classes, average='macro')\n",
455
+ " self.val_conf_matrix = ConfusionMatrix(task=\"multiclass\", num_classes=self.hparams.num_classes)\n",
456
+ " self.test_conf_matrix = ConfusionMatrix(task=\"multiclass\", num_classes=self.hparams.num_classes)\n",
457
+ "\n",
458
+ " def forward(self, x):\n",
459
+ " return self.model(x)\n",
460
+ "\n",
461
+ " def training_step(self, batch, batch_idx):\n",
462
+ " x, y = batch\n",
463
+ " logits = self(x)\n",
464
+ " loss = self.loss_fn(logits, y)\n",
465
+ " self.train_accuracy(logits, y)\n",
466
+ " self.train_f1(logits, y)\n",
467
+ " self.log('train_loss', loss, on_step=False, on_epoch=True, prog_bar=True)\n",
468
+ " self.log('train_acc', self.train_accuracy, on_step=False, on_epoch=True, prog_bar=True)\n",
469
+ " self.log('train_f1', self.train_f1, on_step=False, on_epoch=True, prog_bar=True)\n",
470
+ " return loss\n",
471
+ "\n",
472
+ " def validation_step(self, batch, batch_idx):\n",
473
+ " x, y = batch\n",
474
+ " logits = self(x)\n",
475
+ " loss = self.loss_fn(logits, y)\n",
476
+ " self.val_accuracy(logits, y)\n",
477
+ " self.val_f1(logits, y)\n",
478
+ " self.log('val_loss', loss, prog_bar=True)\n",
479
+ " self.log('val_acc', self.val_accuracy, prog_bar=True)\n",
480
+ " self.log('val_f1', self.val_f1, prog_bar=True)\n",
481
+ " self.val_conf_matrix.update(logits, y)\n",
482
+ "\n",
483
+ " def on_validation_epoch_end(self):\n",
484
+ " cm = self.val_conf_matrix.compute()\n",
485
+ " per_class_acc = cm.diag() / (cm.sum(dim=1) + 1e-6)\n",
486
+ " print(\"\\n--- Per-Class Validation Accuracy ---\")\n",
487
+ " for i, acc in enumerate(per_class_acc):\n",
488
+ " self.log(f'val_acc/{self.class_names[i]}', acc.item(), on_epoch=True)\n",
489
+ " print(f\"{self.class_names[i]:<20}: {acc.item():.4f}\")\n",
490
+ " print(\"------------------------------------\")\n",
491
+ " self.val_conf_matrix.reset()\n",
492
+ "\n",
493
+ " def test_step(self, batch, batch_idx):\n",
494
+ " x, y = batch\n",
495
+ " logits = self(x)\n",
496
+ " self.test_conf_matrix.update(logits, y)\n",
497
+ "\n",
498
+ " def on_test_end(self):\n",
499
+ " cm = self.test_conf_matrix.compute()\n",
500
+ " print(\"\\nGenerating final confusion matrix plot...\")\n",
501
+ " self.test_conf_matrix.reset()\n",
502
+ "\n",
503
+ " def configure_optimizers(self):\n",
504
+ " optimizer = torch.optim.Adam(\n",
505
+ " self.parameters(),\n",
506
+ " lr=self.hparams.lr,\n",
507
+ " weight_decay=self.hparams.weight_decay\n",
508
+ " )\n",
509
+ " scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(\n",
510
+ " optimizer,\n",
511
+ " T_max=self.trainer.max_epochs,\n",
512
+ " eta_min=1e-6\n",
513
+ " )\n",
514
+ " return {\"optimizer\": optimizer, \"lr_scheduler\": {\"scheduler\": scheduler, \"interval\": \"epoch\"}}\n",
515
+ " \n",
516
+ "class EffNetb2(pl.LightningModule):\n",
517
+ " \"\"\"A PyTorch Lightning Module for fine-tuning EfficientNet-B2.\n",
518
+ "\n",
519
+ " This module encapsulates the EfficientNet-B2 model and provides a flexible\n",
520
+ " fine-tuning strategy. It can be configured for Stage 1 (training only the\n",
521
+ " classifier and later feature blocks) or Stage 2 (training the entire model).\n",
522
+ "\n",
523
+ " Args:\n",
524
+ " lr (float, optional): The learning rate. Defaults to 1e-3.\n",
525
+ " weight_decay (float, optional): Weight decay for the optimizer. Defaults to 1e-4.\n",
526
+ " num_classes (int, optional): The number of output classes. Defaults to 101.\n",
527
+ " class_names (list, optional): A list of class names for logging. Defaults to None.\n",
528
+ " freeze_features (bool, optional): If True, freezes the backbone and unfreezes\n",
529
+ " only the later blocks (Stage 1). If False, all features are trainable\n",
530
+ " (Stage 2). Defaults to True.\n",
531
+ " unfreeze_from_block (int, optional): Which feature block to start unfreezing\n",
532
+ " from. Used only if freeze_features is True. Defaults to -3 (last 3 blocks).\n",
533
+ " \"\"\"\n",
534
+ "\n",
535
+ " def __init__(\n",
536
+ " self,\n",
537
+ " lr: float = 1e-3,\n",
538
+ " weight_decay: float = 1e-4,\n",
539
+ " num_classes: int = 101,\n",
540
+ " class_names: list = None,\n",
541
+ " freeze_features: bool = True,\n",
542
+ " unfreeze_from_block: int = -3\n",
543
+ " ):\n",
544
+ " super().__init__()\n",
545
+ " self.save_hyperparameters()\n",
546
+ " self.class_names = class_names if class_names is not None else [str(i) for i in range(num_classes)]\n",
547
+ "\n",
548
+ " # Model setup\n",
549
+ " weights = torchvision.models.EfficientNet_B2_Weights.DEFAULT\n",
550
+ " self.model = torchvision.models.efficientnet_b2(weights=weights)\n",
551
+ " \n",
552
+ " # --- : Flexible Freezing Strategy ---\n",
553
+ " if self.hparams.freeze_features:\n",
554
+ " # Stage 1: Freeze all first\n",
555
+ " for param in self.model.parameters():\n",
556
+ " param.requires_grad = False\n",
557
+ " # Unfreeze from a specific block (default: last 3 blocks)\n",
558
+ " for param in self.model.features[self.hparams.unfreeze_from_block:].parameters():\n",
559
+ " param.requires_grad = True\n",
560
+ " else:\n",
561
+ " # Stage 2: unfreeze everything\n",
562
+ " for param in self.model.parameters():\n",
563
+ " param.requires_grad = True\n",
564
+ "\n",
565
+ " # Classifier head\n",
566
+ " self.model.classifier = nn.Sequential(\n",
567
+ " nn.Dropout(p=0.3, inplace=True),\n",
568
+ " nn.Linear(in_features=1408, out_features=self.hparams.num_classes)\n",
569
+ " )\n",
570
+ "\n",
571
+ " # Metrics\n",
572
+ " self.loss_fn = nn.CrossEntropyLoss(label_smoothing=0.1)\n",
573
+ " self.train_accuracy = Accuracy(task=\"multiclass\", num_classes=self.hparams.num_classes)\n",
574
+ " self.val_accuracy = Accuracy(task=\"multiclass\", num_classes=self.hparams.num_classes)\n",
575
+ " self.train_f1 = F1Score(task=\"multiclass\", num_classes=self.hparams.num_classes, average='macro')\n",
576
+ " self.val_f1 = F1Score(task=\"multiclass\", num_classes=self.hparams.num_classes, average='macro')\n",
577
+ " self.val_conf_matrix = ConfusionMatrix(task=\"multiclass\", num_classes=self.hparams.num_classes)\n",
578
+ " self.test_conf_matrix = ConfusionMatrix(task=\"multiclass\", num_classes=self.hparams.num_classes)\n",
579
+ "\n",
580
+ " def forward(self, x):\n",
581
+ " return self.model(x)\n",
582
+ "\n",
583
+ " def training_step(self, batch, batch_idx):\n",
584
+ " x, y = batch\n",
585
+ " logits = self(x)\n",
586
+ " loss = self.loss_fn(logits, y)\n",
587
+ " self.train_accuracy(logits, y)\n",
588
+ " self.train_f1(logits, y)\n",
589
+ " self.log('train_loss', loss, on_step=False, on_epoch=True, prog_bar=True)\n",
590
+ " self.log('train_acc', self.train_accuracy, on_step=False, on_epoch=True, prog_bar=True)\n",
591
+ " self.log('train_f1', self.train_f1, on_step=False, on_epoch=True, prog_bar=True)\n",
592
+ " return loss\n",
593
+ "\n",
594
+ " def validation_step(self, batch, batch_idx):\n",
595
+ " x, y = batch\n",
596
+ " logits = self(x)\n",
597
+ " loss = self.loss_fn(logits, y)\n",
598
+ " self.val_accuracy(logits, y)\n",
599
+ " self.val_f1(logits, y)\n",
600
+ " self.log('val_loss', loss, prog_bar=True)\n",
601
+ " self.log('val_acc', self.val_accuracy, prog_bar=True)\n",
602
+ " self.log('val_f1', self.val_f1, prog_bar=True)\n",
603
+ " self.val_conf_matrix.update(logits, y)\n",
604
+ "\n",
605
+ " def on_validation_epoch_end(self):\n",
606
+ " cm = self.val_conf_matrix.compute()\n",
607
+ "\n",
608
+ " # Add a small epsilon (1e-6) to the denominator for numerical stability.\n",
609
+ " per_class_acc = cm.diag() / (cm.sum(dim=1) + 1e-6)\n",
610
+ "\n",
611
+ " print(\"\\n--- Per-Class Validation Accuracy ---\")\n",
612
+ " for i, acc in enumerate(per_class_acc):\n",
613
+ " class_name = self.class_names[i]\n",
614
+ " self.log(f'val_acc/{class_name}', acc.item(), on_epoch=True)\n",
615
+ " print(f\"{class_name:<20}: {acc.item():.4f}\")\n",
616
+ " print(\"------------------------------------\")\n",
617
+ "\n",
618
+ " self.val_conf_matrix.reset()\n",
619
+ "\n",
620
+ " def test_step(self, batch, batch_idx):\n",
621
+ " x, y = batch\n",
622
+ " logits = self(x)\n",
623
+ " self.test_conf_matrix.update(logits, y)\n",
624
+ "\n",
625
+ " def on_test_end(self):\n",
626
+ " cm = self.test_conf_matrix.compute()\n",
627
+ " print(\"\\nGenerating final confusion matrix plot...\")\n",
628
+ " # Assuming plot_confusion_matrix is defined elsewhere\n",
629
+ " # plot_confusion_matrix(cm.cpu().numpy(), class_names=self.class_names)\n",
630
+ " self.test_conf_matrix.reset()\n",
631
+ "\n",
632
+ " def configure_optimizers(self):\n",
633
+ " optimizer = torch.optim.Adam(\n",
634
+ " self.parameters(),\n",
635
+ " lr=self.hparams.lr,\n",
636
+ " weight_decay=self.hparams.weight_decay\n",
637
+ " )\n",
638
+ " scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(\n",
639
+ " optimizer,\n",
640
+ " T_max=self.trainer.max_epochs,\n",
641
+ " eta_min=1e-6\n",
642
+ " )\n",
643
+ " return {\n",
644
+ " \"optimizer\": optimizer,\n",
645
+ " \"lr_scheduler\": {\n",
646
+ " \"scheduler\": scheduler,\n",
647
+ " \"interval\": \"epoch\",\n",
648
+ " },\n",
649
+ " }\n"
650
+ ]
651
+ },
652
+ {
653
+ "cell_type": "markdown",
654
+ "id": "3f5bf233",
655
+ "metadata": {},
656
+ "source": [
657
+ "## 5. Training and plotting the Confusion Matrix"
658
+ ]
659
+ },
660
+ {
661
+ "cell_type": "code",
662
+ "execution_count": null,
663
+ "id": "6b080afd",
664
+ "metadata": {},
665
+ "outputs": [],
666
+ "source": [
667
+ "import pytorch_lightning as pl\n",
668
+ "from pytorch_lightning import Trainer, LightningModule\n",
669
+ "from pytorch_lightning.loggers import CSVLogger\n",
670
+ "from pytorch_lightning.callbacks import EarlyStopping ,ModelCheckpoint\n",
671
+ "from typing import Optional\n",
672
+ "import matplotlib.pyplot as plt\n",
673
+ "import seaborn as sns\n",
674
+ "import numpy as np\n",
675
+ "import pandas as pd\n",
676
+ "from typing import List\n",
677
+ "\n",
678
+ "DATA_DIR = \"data\"\n",
679
+ "MODEL_NAME = \"EfficientNet_V2_S\"\n",
680
+ "BATCH_SIZE = 32\n",
681
+ "SUBSET_FRACTION = 0.2 # Useing a smaller subset for quick testing\n",
682
+ "CHECKPOINT_PATH = \"checkpoints/best-model-epoch=22-val_acc=0.8541.ckpt\" # Path to your trained model checkpoint\n",
683
+ "\n",
684
+ "def plot_confusion_matrix(cm: np.ndarray, class_names: List[str], figsize: tuple = (25, 25)):\n",
685
+ " \"\"\"\n",
686
+ " Creates and saves a multi-class confusion matrix plot.\n",
687
+ "\n",
688
+ " This function normalizes the confusion matrix to show prediction\n",
689
+ " percentages for each class, visualizes it as a heatmap, and saves\n",
690
+ " the resulting figure to a file.\n",
691
+ "\n",
692
+ " Args:\n",
693
+ " cm (np.ndarray): The confusion matrix from torchmetrics or scikit-learn.\n",
694
+ " class_names (List[str]): A list of class names for the labels.\n",
695
+ " figsize (tuple, optional): The size of the figure. Defaults to (25, 25).\n",
696
+ " \"\"\"\n",
697
+ " # 1. Normalize the confusion matrix to show percentages\n",
698
+ " # Add a small epsilon to prevent division by zero\n",
699
+ " cm_normalized = cm.astype('float') / (cm.sum(axis=1)[:, np.newaxis] + 1e-6)\n",
700
+ "\n",
701
+ " # 2. Create a DataFrame for a beautiful plot with labels\n",
702
+ " df_cm = pd.DataFrame(cm_normalized, index=class_names, columns=class_names)\n",
703
+ "\n",
704
+ " # 3. Create the plot\n",
705
+ " plt.figure(figsize=figsize)\n",
706
+ " heatmap = sns.heatmap(df_cm, annot=False, cmap='Blues') # Annotations off for 101 classes\n",
707
+ "\n",
708
+ " # 4. Format the plot\n",
709
+ " heatmap.yaxis.set_ticklabels(heatmap.yaxis.get_ticklabels(), rotation=0, ha='right', fontsize=8)\n",
710
+ " heatmap.xaxis.set_ticklabels(heatmap.xaxis.get_ticklabels(), rotation=45, ha='right', fontsize=8)\n",
711
+ "\n",
712
+ " plt.ylabel('True Label')\n",
713
+ " plt.xlabel('Predicted Label')\n",
714
+ " plt.title('Normalized Confusion Matrix')\n",
715
+ " plt.tight_layout()\n",
716
+ "\n",
717
+ " # 5. Save the figure and show the plot\n",
718
+ " plt.savefig('confusion_matrix.png', dpi=300)\n",
719
+ " print(\"Confusion matrix plot saved to confusion_matrix.png\")\n",
720
+ " plt.show()\n",
721
+ "\n",
722
+ "def run_training_session(\n",
723
+ " model_name: str = \"EfficientNet_V2_S\",\n",
724
+ " batch_size: int = 32,\n",
725
+ " data_dir: str = 'data',\n",
726
+ " subset_fraction: float = 1.0,\n",
727
+ " checkpoint_path: str = \"checkpoints/\",\n",
728
+ " lr: float = 1e-3,\n",
729
+ " weight_decay: float = 1e-4,\n",
730
+ " freeze_features: bool = True,\n",
731
+ " early_stopping_patience: int = 5,\n",
732
+ " max_epochs: int = 100,\n",
733
+ " accelerator: str = 'auto',\n",
734
+ " resume_from_checkpoint: Optional[str] = None\n",
735
+ ") -> Trainer:\n",
736
+ " \"\"\"\n",
737
+ " Sets up and runs a complete training session for a specified model.\n",
738
+ "\n",
739
+ " This function handles the entire pipeline: data preparation, model\n",
740
+ " instantiation, logger and callback setup, and trainer execution.\n",
741
+ "\n",
742
+ " Args:\n",
743
+ " model_name (str): The name of the model architecture to train.\n",
744
+ " batch_size (int): The number of samples per batch.\n",
745
+ " data_dir (str): The root directory for the dataset.\n",
746
+ " subset_fraction (float): The fraction of the dataset to use for training.\n",
747
+ " checkpoint_path (str): Directory to save model checkpoints.\n",
748
+ " lr (float): The learning rate for the optimizer.\n",
749
+ " weight_decay (float): The weight decay for the optimizer.\n",
750
+ " freeze_features (bool): Flag to control the fine-tuning strategy\n",
751
+ " (e.g., for two-stage training).\n",
752
+ " early_stopping_patience (int): Number of epochs with no improvement\n",
753
+ " after which training will be stopped.\n",
754
+ " max_epochs (int): The maximum number of epochs to train for.\n",
755
+ " accelerator (str): The hardware accelerator to use ('auto', 'cpu', 'gpu').\n",
756
+ " resume_from_checkpoint (Optional[str]): Path to a checkpoint file to\n",
757
+ " resume training from. Defaults to None.\n",
758
+ "\n",
759
+ " Returns:\n",
760
+ " Trainer: The PyTorch Lightning Trainer object after fitting is complete.\n",
761
+ " \"\"\"\n",
762
+ " # A registry to map model names to their actual classes\n",
763
+ " model_class_registry = {\n",
764
+ " \"EfficientNet_V2_S\": EffNetV2_S,\n",
765
+ " \"EfficientNet_B2\": EffNetb2,\n",
766
+ " }\n",
767
+ " if model_name not in model_class_registry:\n",
768
+ " raise ValueError(f\"Model '{model_name}' is not a recognized class.\")\n",
769
+ "\n",
770
+ " # Get model-specific transforms\n",
771
+ " components = get_model_components(model_name)\n",
772
+ " train_transforms = components[\"train_transforms\"]\n",
773
+ " val_transforms = components[\"val_transforms\"]\n",
774
+ "\n",
775
+ " # Set up the DataModule\n",
776
+ " food_datamodule = Food101DataModule(\n",
777
+ " data_dir=data_dir,\n",
778
+ " batch_size=batch_size,\n",
779
+ " train_transforms=train_transforms,\n",
780
+ " val_transforms=val_transforms,\n",
781
+ " subset_fraction=subset_fraction\n",
782
+ " )\n",
783
+ " food_datamodule.prepare_data()\n",
784
+ " food_datamodule.setup()\n",
785
+ "\n",
786
+ " # Instantiate the model dynamically\n",
787
+ " model_class = model_class_registry[model_name]\n",
788
+ " model = model_class(\n",
789
+ " num_classes=len(food_datamodule.classes),\n",
790
+ " class_names=food_datamodule.classes,\n",
791
+ " lr=lr,\n",
792
+ " weight_decay=weight_decay,\n",
793
+ " freeze_features=freeze_features\n",
794
+ " )\n",
795
+ "\n",
796
+ " # Set up logger and callbacks\n",
797
+ " logger = CSVLogger(save_dir=\"logs/\", name=model_name)\n",
798
+ " \n",
799
+ " early_stop_callback = EarlyStopping(\n",
800
+ " monitor=\"val_loss\",\n",
801
+ " patience=early_stopping_patience,\n",
802
+ " mode=\"min\"\n",
803
+ " )\n",
804
+ " best_model_checkpoint = ModelCheckpoint(\n",
805
+ " dirpath=checkpoint_path,\n",
806
+ " filename=\"best-model-{epoch:02d}-{val_acc:.4f}\",\n",
807
+ " save_top_k=1,\n",
808
+ " monitor=\"val_acc\",\n",
809
+ " mode=\"max\"\n",
810
+ " )\n",
811
+ " \n",
812
+ " callbacks = [early_stop_callback, best_model_checkpoint]\n",
813
+ "\n",
814
+ " # Instantiate the Trainer\n",
815
+ " trainer = Trainer(\n",
816
+ " max_epochs=max_epochs,\n",
817
+ " accelerator=accelerator,\n",
818
+ " callbacks=callbacks,\n",
819
+ " logger=logger,\n",
820
+ " )\n",
821
+ "\n",
822
+ " # Start training\n",
823
+ " trainer.fit(\n",
824
+ " model,\n",
825
+ " datamodule=food_datamodule,\n",
826
+ " ckpt_path=resume_from_checkpoint \n",
827
+ " )\n",
828
+ " \n",
829
+ " return trainer\n"
830
+ ]
831
+ },
832
+ {
833
+ "cell_type": "code",
834
+ "execution_count": null,
835
+ "id": "04c534dc",
836
+ "metadata": {},
837
+ "outputs": [],
838
+ "source": [
839
+ "# --- 1. DEFINE YOUR TRAINING CONFIGURATION HERE ---\n",
840
+ "config = {\n",
841
+ " \"model_name\": \"EfficientNet_V2_S\",\n",
842
+ " \"batch_size\": 32,\n",
843
+ " \"lr\": 1e-4,\n",
844
+ " \"epochs\": 50,\n",
845
+ " \"subset_fraction\": 1.0, # Use 1.0 for the full dataset\n",
846
+ " \"freeze_features\": True,\n",
847
+ " \"early_stopping_patience\": 10\n",
848
+ "}\n",
849
+ "\n",
850
+ "# --- 2. PRINT CONFIGURATION AND START TRAINING ---\n",
851
+ "print(\"--- Starting Training Session ---\")\n",
852
+ "for key, value in config.items():\n",
853
+ " print(f\" {key}: {value}\")\n",
854
+ "print(\"---------------------------------\")\n",
855
+ "\n",
856
+ "run_training_session(\n",
857
+ " model_name=config[\"model_name\"],\n",
858
+ " batch_size=config[\"batch_size\"],\n",
859
+ " lr=config[\"lr\"],\n",
860
+ " max_epochs=config[\"epochs\"],\n",
861
+ " subset_fraction=config[\"subset_fraction\"],\n",
862
+ " freeze_features=config[\"freeze_features\"],\n",
863
+ " early_stopping_patience=config[\"early_stopping_patience\"]\n",
864
+ ")\n",
865
+ "\n",
866
+ "print(\"\\n--- Training Session Complete ---\")\n",
867
+ "\n",
868
+ "print(\"\\n--- Starting Evaluation on Test Set ---\")\n",
869
+ "\n",
870
+ "print(f\"Loading model from checkpoint: {CHECKPOINT_PATH}\")\n",
871
+ "\n",
872
+ "# Step 1: Set up the DataModule for the test set\n",
873
+ "components = get_model_components(MODEL_NAME)\n",
874
+ "val_transforms = components[\"val_transforms\"]\n",
875
+ "\n",
876
+ "datamodule = Food101DataModule(\n",
877
+ " data_dir=DATA_DIR,\n",
878
+ " batch_size=BATCH_SIZE,\n",
879
+ " val_transforms=val_transforms\n",
880
+ ")\n",
881
+ "# This prepares the test dataloader specifically\n",
882
+ "datamodule.setup(stage='test')\n",
883
+ "\n",
884
+ "# Step 2: Load the trained model from the checkpoint file\n",
885
+ "model = EffNetV2_S.load_from_checkpoint(CHECKPOINT_PATH)\n",
886
+ "model.class_names = datamodule.classes\n",
887
+ "model.eval() # Set the model to evaluation mode\n",
888
+ "\n",
889
+ "# Step 3: Create a Trainer and run the test\n",
890
+ "trainer = pl.Trainer(accelerator='auto')\n",
891
+ "\n",
892
+ "# This call will run the test_step and automatically trigger the \n",
893
+ "# on_test_end hook in your model, which generates the plot.\n",
894
+ "trainer.test(model, datamodule=datamodule)\n",
895
+ "\n",
896
+ "print(\"\\nEvaluation complete. The confusion matrix plot has been saved.\")"
897
+ ]
898
+ },
899
+ {
900
+ "cell_type": "markdown",
901
+ "id": "2325adef",
902
+ "metadata": {},
903
+ "source": [
904
+ "## 6. Local Gradio Demo"
905
+ ]
906
+ },
907
+ {
908
+ "cell_type": "code",
909
+ "execution_count": null,
910
+ "id": "44decdea",
911
+ "metadata": {},
912
+ "outputs": [],
913
+ "source": [
914
+ "FOOD101_CLASSES = [\n",
915
+ " 'apple_pie', 'baby_back_ribs', 'baklava', 'beef_carpaccio', 'beef_tartare', \n",
916
+ " 'beet_salad', 'beignets', 'bibimbap', 'bread_pudding', 'breakfast_burrito', \n",
917
+ " 'bruschetta', 'caesar_salad', 'cannoli', 'caprese_salad', 'carrot_cake', \n",
918
+ " 'ceviche', 'cheesecake', 'cheese_plate', 'chicken_curry', 'chicken_quesadilla', \n",
919
+ " 'chicken_wings', 'chocolate_cake', 'chocolate_mousse', 'churros', \n",
920
+ " 'clam_chowder', 'club_sandwich', 'crab_cakes', 'creme_brulee', 'croque_madame', \n",
921
+ " 'cup_cakes', 'deviled_eggs', 'donuts', 'dumplings', 'edamame', 'eggs_benedict', \n",
922
+ " 'escargots', 'falafel', 'filet_mignon', 'fish_and_chips', 'foie_gras', \n",
923
+ " 'french_fries', 'french_onion_soup', 'french_toast', 'fried_calamari', \n",
924
+ " 'fried_rice', 'frozen_yogurt', 'garlic_bread', 'gnocchi', 'greek_salad', \n",
925
+ " 'grilled_cheese_sandwich', 'grilled_salmon', 'guacamole', 'gyoza', 'hamburger', \n",
926
+ " 'hot_and_sour_soup', 'hot_dog', 'huevos_rancheros', 'hummus', 'ice_cream', \n",
927
+ " 'lasagna', 'lobster_bisque', 'lobster_roll_sandwich', 'macaroni_and_cheese', \n",
928
+ " 'macarons', 'miso_soup', 'mussels', 'nachos', 'omelette', 'onion_rings', \n",
929
+ " 'oysters', 'pad_thai', 'paella', 'pancakes', 'panna_cotta', 'peking_duck', \n",
930
+ " 'pho', 'pizza', 'pork_chop', 'poutine', 'prime_rib', 'pulled_pork_sandwich', \n",
931
+ " 'ramen', 'ravioli', 'red_velvet_cake', 'risotto', 'samosa', 'sashimi', \n",
932
+ " 'scallops', 'seaweed_salad', 'shrimp_and_grits', 'spaghetti_bolognese', \n",
933
+ " 'spaghetti_carbonara', 'spring_rolls', 'steak', 'strawberry_shortcake', \n",
934
+ " 'sushi', 'tacos', 'takoyaki', 'tiramisu', 'tuna_tartare', 'waffles'\n",
935
+ "]"
936
+ ]
937
+ },
938
+ {
939
+ "cell_type": "code",
940
+ "execution_count": null,
941
+ "id": "10bdf9fd",
942
+ "metadata": {},
943
+ "outputs": [],
944
+ "source": [
945
+ "import gradio as gr\n",
946
+ "import torch\n",
947
+ "from gradio.themes.base import Base\n",
948
+ "from torchvision.datasets import Food101\n",
949
+ "\n",
950
+ "# --- 1. Configuration ---\n",
951
+ "MODEL_PATH = \"checkpoints/best-model-epoch=22-val_acc=0.8541.ckpt\" \n",
952
+ "MODEL_NAME = \"EfficientNet_V2_S\"\n",
953
+ "\n",
954
+ "theme = gr.themes.Soft(\n",
955
+ " primary_hue=\"orange\",\n",
956
+ " secondary_hue=\"blue\",\n",
957
+ ").set(\n",
958
+ "\n",
959
+ " body_background_fill=\"#f2f2f2\"\n",
960
+ ")\n",
961
+ "\n",
962
+ "# --- 2. Load Model and Assets ---\n",
963
+ "print(\"Loading model and assets...\")\n",
964
+ "model = EffNetV2_S.load_from_checkpoint(MODEL_PATH)\n",
965
+ "model.eval()\n",
966
+ "\n",
967
+ "components = get_model_components(MODEL_NAME)\n",
968
+ "transforms = components[\"val_transforms\"]\n",
969
+ "class_names = FOOD101_CLASSES \n",
970
+ "\n",
971
+ "print(\"Model and assets loaded successfully.\")\n",
972
+ "\n",
973
+ "# --- 3. Prediction Function ---\n",
974
+ "def predict(image):\n",
975
+ " \"\"\"\n",
976
+ " Takes a PIL image, preprocesses it, and returns the model's top 3 predictions.\n",
977
+ " \"\"\"\n",
978
+ " # 1. Preprocess the image and add a batch dimension\n",
979
+ " input_tensor = transforms(image).unsqueeze(0)\n",
980
+ " \n",
981
+ " # 2. Move the input tensor to the same device as the model\n",
982
+ " # This ensures both the model and the data are on the GPU.\n",
983
+ " device = next(model.parameters()).device\n",
984
+ " input_tensor = input_tensor.to(device)\n",
985
+ " \n",
986
+ " # 3. Make a prediction\n",
987
+ " with torch.no_grad():\n",
988
+ " output = model(input_tensor)\n",
989
+ " \n",
990
+ " # 4. Post-process the output\n",
991
+ " probabilities = torch.nn.functional.softmax(output[0], dim=0)\n",
992
+ " confidences = {class_names[i]: float(probabilities[i]) for i in range(len(class_names))}\n",
993
+ " \n",
994
+ " return confidences\n",
995
+ " \n",
996
+ "\n",
997
+ "demo = gr.Interface(\n",
998
+ " fn=predict,\n",
999
+ " inputs=gr.Image(type=\"pil\", label=\"Upload a Food Image\"),\n",
1000
+ " outputs=gr.Label(num_top_classes=3, label=\"Top Predictions\"),\n",
1001
+ " theme=theme,\n",
1002
+ " \n",
1003
+ " # UI Enhancements\n",
1004
+ " title=\"🍔 Food-101 Image Classifier 🍟\",\n",
1005
+ " description=(\n",
1006
+ " \"What's on your plate? Upload an image or try one of the examples below to classify it. \"\n",
1007
+ " \"This demo uses an EfficientNetV2-S model fine-tuned on the Food-101 dataset.\"\n",
1008
+ " ),\n",
1009
+ " article=(\n",
1010
+ " \"<p style='text-align: center;'>A project by Daniel Kiani. \"\n",
1011
+ " \"<a href='https://github.com/Deathshot78/Food101-Classification' target='_blank'>Check out the code on GitHub!</a></p>\"\n",
1012
+ " ),\n",
1013
+ " examples=[\n",
1014
+ " [\"assets/ramen.jpg\"],\n",
1015
+ " [\"assets/pizza.jpg\"],\n",
1016
+ " [\"assets/oysters.jpg\"],\n",
1017
+ " [\"assets/onion_rings.jpg\"]\n",
1018
+ " ]\n",
1019
+ ")"
1020
+ ]
1021
+ },
1022
+ {
1023
+ "cell_type": "code",
1024
+ "execution_count": null,
1025
+ "id": "b536610d",
1026
+ "metadata": {},
1027
+ "outputs": [],
1028
+ "source": [
1029
+ "# Launch the Gradio app locally\n",
1030
+ "demo.launch()"
1031
+ ]
1032
+ }
1033
+ ],
1034
+ "metadata": {
1035
+ "kernelspec": {
1036
+ "display_name": "Python 3",
1037
+ "language": "python",
1038
+ "name": "python3"
1039
+ },
1040
+ "language_info": {
1041
+ "codemirror_mode": {
1042
+ "name": "ipython",
1043
+ "version": 3
1044
+ },
1045
+ "file_extension": ".py",
1046
+ "mimetype": "text/x-python",
1047
+ "name": "python",
1048
+ "nbconvert_exporter": "python",
1049
+ "pygments_lexer": "ipython3",
1050
+ "version": "3.10.6"
1051
+ }
1052
+ },
1053
+ "nbformat": 4,
1054
+ "nbformat_minor": 5
1055
+ }
requirements.txt ADDED
Binary file (320 Bytes). View file
 
scripts/app.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from gradio.themes.base import Base
4
+ from torchvision.datasets import Food101
5
+ from models import EffNetV2_S
6
+ from prepare_data import get_model_components
7
+ from class_names import FOOD101_CLASSES
8
+
9
+ # --- 1. Configuration ---
10
+ MODEL_PATH = "checkpoints/best-model-epoch=22-val_acc=0.8541.ckpt"
11
+ MODEL_NAME = "EfficientNet_V2_S"
12
+
13
+ theme = gr.themes.Soft(
14
+ primary_hue="orange",
15
+ secondary_hue="blue",
16
+ ).set(
17
+
18
+ body_background_fill="#f2f2f2"
19
+ )
20
+
21
+ # --- 2. Load Model and Assets ---
22
+ print("Loading model and assets...")
23
+ model = EffNetV2_S.load_from_checkpoint(MODEL_PATH)
24
+ model.eval()
25
+
26
+ components = get_model_components(MODEL_NAME)
27
+ transforms = components["val_transforms"]
28
+ class_names = FOOD101_CLASSES
29
+
30
+ print("Model and assets loaded successfully.")
31
+
32
+ # --- 3. Prediction Function ---
33
+ def predict(image):
34
+ """
35
+ Takes a PIL image, preprocesses it, and returns the model's top 3 predictions.
36
+ """
37
+ # 1. Preprocess the image and add a batch dimension
38
+ input_tensor = transforms(image).unsqueeze(0)
39
+
40
+ # 2. Move the input tensor to the same device as the model
41
+ # This ensures both the model and the data are on the GPU.
42
+ device = next(model.parameters()).device
43
+ input_tensor = input_tensor.to(device)
44
+
45
+ # 3. Make a prediction
46
+ with torch.no_grad():
47
+ output = model(input_tensor)
48
+
49
+ # 4. Post-process the output
50
+ probabilities = torch.nn.functional.softmax(output[0], dim=0)
51
+ confidences = {class_names[i]: float(probabilities[i]) for i in range(len(class_names))}
52
+
53
+ return confidences
54
+
55
+
56
+ demo = gr.Interface(
57
+ fn=predict,
58
+ inputs=gr.Image(type="pil", label="Upload a Food Image"),
59
+ outputs=gr.Label(num_top_classes=3, label="Top Predictions"),
60
+ theme=theme,
61
+
62
+ # UI Enhancements
63
+ title="🍔 Food-101 Image Classifier 🍟",
64
+ description=(
65
+ "What's on your plate? Upload an image or try one of the examples below to classify it. "
66
+ "This demo uses an EfficientNetV2-S model fine-tuned on the Food-101 dataset."
67
+ ),
68
+ article=(
69
+ "<p style='text-align: center;'>A project by Daniel Kiani. "
70
+ "<a href='https://github.com/Deathshot78/Food101-Classification' target='_blank'>Check out the code on GitHub!</a></p>"
71
+ ),
72
+ examples=[
73
+ ["assets/ramen.jpg"],
74
+ ["assets/pizza.jpg"],
75
+ ["assets/oysters.jpg"],
76
+ ["assets/onion_rings.jpg"]
77
+ ]
78
+ )
79
+
80
+ # --- 5. Launch the App ---
81
+ if __name__ == "__main__":
82
+ demo.launch(debug=True)
scripts/class_names.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FOOD101_CLASSES = [
2
+ 'apple_pie', 'baby_back_ribs', 'baklava', 'beef_carpaccio', 'beef_tartare',
3
+ 'beet_salad', 'beignets', 'bibimbap', 'bread_pudding', 'breakfast_burrito',
4
+ 'bruschetta', 'caesar_salad', 'cannoli', 'caprese_salad', 'carrot_cake',
5
+ 'ceviche', 'cheesecake', 'cheese_plate', 'chicken_curry', 'chicken_quesadilla',
6
+ 'chicken_wings', 'chocolate_cake', 'chocolate_mousse', 'churros',
7
+ 'clam_chowder', 'club_sandwich', 'crab_cakes', 'creme_brulee', 'croque_madame',
8
+ 'cup_cakes', 'deviled_eggs', 'donuts', 'dumplings', 'edamame', 'eggs_benedict',
9
+ 'escargots', 'falafel', 'filet_mignon', 'fish_and_chips', 'foie_gras',
10
+ 'french_fries', 'french_onion_soup', 'french_toast', 'fried_calamari',
11
+ 'fried_rice', 'frozen_yogurt', 'garlic_bread', 'gnocchi', 'greek_salad',
12
+ 'grilled_cheese_sandwich', 'grilled_salmon', 'guacamole', 'gyoza', 'hamburger',
13
+ 'hot_and_sour_soup', 'hot_dog', 'huevos_rancheros', 'hummus', 'ice_cream',
14
+ 'lasagna', 'lobster_bisque', 'lobster_roll_sandwich', 'macaroni_and_cheese',
15
+ 'macarons', 'miso_soup', 'mussels', 'nachos', 'omelette', 'onion_rings',
16
+ 'oysters', 'pad_thai', 'paella', 'pancakes', 'panna_cotta', 'peking_duck',
17
+ 'pho', 'pizza', 'pork_chop', 'poutine', 'prime_rib', 'pulled_pork_sandwich',
18
+ 'ramen', 'ravioli', 'red_velvet_cake', 'risotto', 'samosa', 'sashimi',
19
+ 'scallops', 'seaweed_salad', 'shrimp_and_grits', 'spaghetti_bolognese',
20
+ 'spaghetti_carbonara', 'spring_rolls', 'steak', 'strawberry_shortcake',
21
+ 'sushi', 'tacos', 'takoyaki', 'tiramisu', 'tuna_tartare', 'waffles'
22
+ ]
scripts/main.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from prepare_data import Food101DataModule, CustomFood101, get_model_components
2
+ from models import EffNetV2_S , EffNetb2
3
+ import pytorch_lightning as pl
4
+ from pytorch_lightning import Trainer, LightningModule
5
+ from pytorch_lightning.loggers import CSVLogger
6
+ from pytorch_lightning.callbacks import EarlyStopping ,ModelCheckpoint
7
+ from typing import Optional
8
+ import matplotlib.pyplot as plt
9
+ import seaborn as sns
10
+ import numpy as np
11
+ import pandas as pd
12
+ from typing import List
13
+
14
+ DATA_DIR = "data"
15
+ MODEL_NAME = "EfficientNet_V2_S"
16
+ BATCH_SIZE = 32
17
+ SUBSET_FRACTION = 0.2 # Useing a smaller subset for quick testing
18
+ CHECKPOINT_PATH = "checkpoints/best-model-epoch=22-val_acc=0.8541.ckpt" # Path to your trained model checkpoint
19
+
20
+ def plot_confusion_matrix(cm: np.ndarray, class_names: List[str], figsize: tuple = (25, 25)):
21
+ """
22
+ Creates and saves a multi-class confusion matrix plot.
23
+
24
+ This function normalizes the confusion matrix to show prediction
25
+ percentages for each class, visualizes it as a heatmap, and saves
26
+ the resulting figure to a file.
27
+
28
+ Args:
29
+ cm (np.ndarray): The confusion matrix from torchmetrics or scikit-learn.
30
+ class_names (List[str]): A list of class names for the labels.
31
+ figsize (tuple, optional): The size of the figure. Defaults to (25, 25).
32
+ """
33
+ # 1. Normalize the confusion matrix to show percentages
34
+ # Add a small epsilon to prevent division by zero
35
+ cm_normalized = cm.astype('float') / (cm.sum(axis=1)[:, np.newaxis] + 1e-6)
36
+
37
+ # 2. Create a DataFrame for a beautiful plot with labels
38
+ df_cm = pd.DataFrame(cm_normalized, index=class_names, columns=class_names)
39
+
40
+ # 3. Create the plot
41
+ plt.figure(figsize=figsize)
42
+ heatmap = sns.heatmap(df_cm, annot=False, cmap='Blues') # Annotations off for 101 classes
43
+
44
+ # 4. Format the plot
45
+ heatmap.yaxis.set_ticklabels(heatmap.yaxis.get_ticklabels(), rotation=0, ha='right', fontsize=8)
46
+ heatmap.xaxis.set_ticklabels(heatmap.xaxis.get_ticklabels(), rotation=45, ha='right', fontsize=8)
47
+
48
+ plt.ylabel('True Label')
49
+ plt.xlabel('Predicted Label')
50
+ plt.title('Normalized Confusion Matrix')
51
+ plt.tight_layout()
52
+
53
+ # 5. Save the figure and show the plot
54
+ plt.savefig('confusion_matrix.png', dpi=300)
55
+ print("Confusion matrix plot saved to confusion_matrix.png")
56
+ plt.show()
57
+
58
+ def run_training_session(
59
+ model_name: str = "EfficientNet_V2_S",
60
+ batch_size: int = 32,
61
+ data_dir: str = 'data',
62
+ subset_fraction: float = 1.0,
63
+ checkpoint_path: str = "checkpoints/",
64
+ lr: float = 1e-3,
65
+ weight_decay: float = 1e-4,
66
+ freeze_features: bool = True,
67
+ early_stopping_patience: int = 5,
68
+ max_epochs: int = 100,
69
+ accelerator: str = 'auto',
70
+ resume_from_checkpoint: Optional[str] = None
71
+ ) -> Trainer:
72
+ """
73
+ Sets up and runs a complete training session for a specified model.
74
+
75
+ This function handles the entire pipeline: data preparation, model
76
+ instantiation, logger and callback setup, and trainer execution.
77
+
78
+ Args:
79
+ model_name (str): The name of the model architecture to train.
80
+ batch_size (int): The number of samples per batch.
81
+ data_dir (str): The root directory for the dataset.
82
+ subset_fraction (float): The fraction of the dataset to use for training.
83
+ checkpoint_path (str): Directory to save model checkpoints.
84
+ lr (float): The learning rate for the optimizer.
85
+ weight_decay (float): The weight decay for the optimizer.
86
+ freeze_features (bool): Flag to control the fine-tuning strategy
87
+ (e.g., for two-stage training).
88
+ early_stopping_patience (int): Number of epochs with no improvement
89
+ after which training will be stopped.
90
+ max_epochs (int): The maximum number of epochs to train for.
91
+ accelerator (str): The hardware accelerator to use ('auto', 'cpu', 'gpu').
92
+ resume_from_checkpoint (Optional[str]): Path to a checkpoint file to
93
+ resume training from. Defaults to None.
94
+
95
+ Returns:
96
+ Trainer: The PyTorch Lightning Trainer object after fitting is complete.
97
+ """
98
+ # A registry to map model names to their actual classes
99
+ model_class_registry = {
100
+ "EfficientNet_V2_S": EffNetV2_S,
101
+ "EfficientNet_B2": EffNetb2,
102
+ }
103
+ if model_name not in model_class_registry:
104
+ raise ValueError(f"Model '{model_name}' is not a recognized class.")
105
+
106
+ # Get model-specific transforms
107
+ components = get_model_components(model_name)
108
+ train_transforms = components["train_transforms"]
109
+ val_transforms = components["val_transforms"]
110
+
111
+ # Set up the DataModule
112
+ food_datamodule = Food101DataModule(
113
+ data_dir=data_dir,
114
+ batch_size=batch_size,
115
+ train_transforms=train_transforms,
116
+ val_transforms=val_transforms,
117
+ subset_fraction=subset_fraction
118
+ )
119
+ food_datamodule.prepare_data()
120
+ food_datamodule.setup()
121
+
122
+ # Instantiate the model dynamically
123
+ model_class = model_class_registry[model_name]
124
+ model = model_class(
125
+ num_classes=len(food_datamodule.classes),
126
+ class_names=food_datamodule.classes,
127
+ lr=lr,
128
+ weight_decay=weight_decay,
129
+ freeze_features=freeze_features
130
+ )
131
+
132
+ # Set up logger and callbacks
133
+ logger = CSVLogger(save_dir="logs/", name=model_name)
134
+
135
+ early_stop_callback = EarlyStopping(
136
+ monitor="val_loss",
137
+ patience=early_stopping_patience,
138
+ mode="min"
139
+ )
140
+ best_model_checkpoint = ModelCheckpoint(
141
+ dirpath=checkpoint_path,
142
+ filename="best-model-{epoch:02d}-{val_acc:.4f}",
143
+ save_top_k=1,
144
+ monitor="val_acc",
145
+ mode="max"
146
+ )
147
+
148
+ callbacks = [early_stop_callback, best_model_checkpoint]
149
+
150
+ # Instantiate the Trainer
151
+ trainer = Trainer(
152
+ max_epochs=max_epochs,
153
+ accelerator=accelerator,
154
+ callbacks=callbacks,
155
+ logger=logger,
156
+ )
157
+
158
+ # Start training
159
+ trainer.fit(
160
+ model,
161
+ datamodule=food_datamodule,
162
+ ckpt_path=resume_from_checkpoint
163
+ )
164
+
165
+ return trainer
166
+
167
+ # ===================================================================
168
+ # Main Execution Block
169
+ # ===================================================================
170
+ if __name__ == "__main__":
171
+
172
+ # --- 1. DEFINE YOUR TRAINING CONFIGURATION HERE ---
173
+ config = {
174
+ "model_name": "EfficientNet_V2_S",
175
+ "batch_size": 32,
176
+ "lr": 1e-4,
177
+ "epochs": 50,
178
+ "subset_fraction": 1.0, # Use 1.0 for the full dataset
179
+ "freeze_features": True,
180
+ "early_stopping_patience": 10
181
+ }
182
+
183
+ # --- 2. PRINT CONFIGURATION AND START TRAINING ---
184
+ print("--- Starting Training Session ---")
185
+ for key, value in config.items():
186
+ print(f" {key}: {value}")
187
+ print("---------------------------------")
188
+
189
+ run_training_session(
190
+ model_name=config["model_name"],
191
+ batch_size=config["batch_size"],
192
+ lr=config["lr"],
193
+ max_epochs=config["epochs"],
194
+ subset_fraction=config["subset_fraction"],
195
+ freeze_features=config["freeze_features"],
196
+ early_stopping_patience=config["early_stopping_patience"]
197
+ )
198
+
199
+ print("\n--- Training Session Complete ---")
200
+
201
+ print("\n--- Starting Evaluation on Test Set ---")
202
+
203
+ print(f"Loading model from checkpoint: {CHECKPOINT_PATH}")
204
+
205
+ # Step 1: Set up the DataModule for the test set
206
+ components = get_model_components(MODEL_NAME)
207
+ val_transforms = components["val_transforms"]
208
+
209
+ datamodule = Food101DataModule(
210
+ data_dir=DATA_DIR,
211
+ batch_size=BATCH_SIZE,
212
+ val_transforms=val_transforms
213
+ )
214
+ # This prepares the test dataloader specifically
215
+ datamodule.setup(stage='test')
216
+
217
+ # Step 2: Load the trained model from the checkpoint file
218
+ model = EffNetV2_S.load_from_checkpoint(CHECKPOINT_PATH)
219
+ model.class_names = datamodule.classes
220
+ model.eval() # Set the model to evaluation mode
221
+
222
+ # Step 3: Create a Trainer and run the test
223
+ trainer = pl.Trainer(accelerator='auto')
224
+
225
+ # This call will run the test_step and automatically trigger the
226
+ # on_test_end hook in your model, which generates the plot.
227
+ trainer.test(model, datamodule=datamodule)
228
+
229
+ print("\nEvaluation complete. The confusion matrix plot has been saved.")
scripts/models.py ADDED
@@ -0,0 +1,266 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision
3
+ import pytorch_lightning as pl
4
+ from torch import nn
5
+ from torchmetrics.classification import Accuracy, F1Score, ConfusionMatrix
6
+ import seaborn as sns
7
+ import matplotlib.pyplot as plt
8
+ import pandas as pd
9
+ import numpy as np
10
+
11
+ class EffNetV2_S(pl.LightningModule):
12
+ """A PyTorch Lightning Module for fine-tuning EfficientNetV2-S.
13
+
14
+ This module encapsulates the EfficientNetV2-S model and provides a flexible
15
+ fine-tuning strategy. It can be configured for Stage 1 (training only the
16
+ classifier and later feature blocks) or Stage 2 (training the entire model).
17
+
18
+ Args:
19
+ lr (float, optional): The learning rate. Defaults to 1e-3.
20
+ weight_decay (float, optional): Weight decay for the optimizer. Defaults to 1e-4.
21
+ num_classes (int, optional): The number of output classes. Defaults to 101.
22
+ class_names (list, optional): A list of class names for logging. Defaults to None.
23
+ freeze_features (bool, optional): If True, freezes the backbone and unfreezes
24
+ only the later blocks (Stage 1). If False, all features are trainable
25
+ (Stage 2). Defaults to True.
26
+ unfreeze_from_block (int, optional): Which feature block to start unfreezing
27
+ from. Used only if freeze_features is True. Defaults to -3 (last 3 blocks).
28
+ """
29
+
30
+ def __init__(
31
+ self,
32
+ lr: float = 1e-3,
33
+ weight_decay: float = 1e-4,
34
+ num_classes: int = 101,
35
+ class_names: list = None,
36
+ freeze_features: bool = True, # True = Stage 1, False = Stage 2
37
+ unfreeze_from_block: int = -3 # Only used if freeze_features=True
38
+ ):
39
+ super().__init__()
40
+ self.save_hyperparameters()
41
+ self.class_names = class_names if class_names else [str(i) for i in range(num_classes)]
42
+
43
+ # Load pretrained weights
44
+ weights = torchvision.models.EfficientNet_V2_S_Weights.DEFAULT
45
+ self.model = torchvision.models.efficientnet_v2_s(weights=weights)
46
+
47
+ # ---- Freezing strategy ----
48
+ if freeze_features:
49
+ # Freeze all first
50
+ for param in self.model.parameters():
51
+ param.requires_grad = False
52
+ # Unfreeze from a specific block (default: last 3 blocks)
53
+ for param in self.model.features[unfreeze_from_block:].parameters():
54
+ param.requires_grad = True
55
+ else:
56
+ # Stage 2: unfreeze everything
57
+ for param in self.model.parameters():
58
+ param.requires_grad = True
59
+
60
+ # Classifier head
61
+ self.model.classifier = nn.Sequential(
62
+ nn.Dropout(p=0.2, inplace=True),
63
+ nn.Linear(in_features=1280, out_features=self.hparams.num_classes, bias=True)
64
+ )
65
+
66
+ # Loss & metrics
67
+ self.loss_fn = nn.CrossEntropyLoss(label_smoothing=0.1)
68
+ self.train_accuracy = Accuracy(task="multiclass", num_classes=self.hparams.num_classes)
69
+ self.val_accuracy = Accuracy(task="multiclass", num_classes=self.hparams.num_classes)
70
+ self.train_f1 = F1Score(task="multiclass", num_classes=self.hparams.num_classes, average='macro')
71
+ self.val_f1 = F1Score(task="multiclass", num_classes=self.hparams.num_classes, average='macro')
72
+ self.val_conf_matrix = ConfusionMatrix(task="multiclass", num_classes=self.hparams.num_classes)
73
+ self.test_conf_matrix = ConfusionMatrix(task="multiclass", num_classes=self.hparams.num_classes)
74
+
75
+ def forward(self, x):
76
+ return self.model(x)
77
+
78
+ def training_step(self, batch, batch_idx):
79
+ x, y = batch
80
+ logits = self(x)
81
+ loss = self.loss_fn(logits, y)
82
+ self.train_accuracy(logits, y)
83
+ self.train_f1(logits, y)
84
+ self.log('train_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
85
+ self.log('train_acc', self.train_accuracy, on_step=False, on_epoch=True, prog_bar=True)
86
+ self.log('train_f1', self.train_f1, on_step=False, on_epoch=True, prog_bar=True)
87
+ return loss
88
+
89
+ def validation_step(self, batch, batch_idx):
90
+ x, y = batch
91
+ logits = self(x)
92
+ loss = self.loss_fn(logits, y)
93
+ self.val_accuracy(logits, y)
94
+ self.val_f1(logits, y)
95
+ self.log('val_loss', loss, prog_bar=True)
96
+ self.log('val_acc', self.val_accuracy, prog_bar=True)
97
+ self.log('val_f1', self.val_f1, prog_bar=True)
98
+ self.val_conf_matrix.update(logits, y)
99
+
100
+ def on_validation_epoch_end(self):
101
+ cm = self.val_conf_matrix.compute()
102
+ per_class_acc = cm.diag() / (cm.sum(dim=1) + 1e-6)
103
+ print("\n--- Per-Class Validation Accuracy ---")
104
+ for i, acc in enumerate(per_class_acc):
105
+ self.log(f'val_acc/{self.class_names[i]}', acc.item(), on_epoch=True)
106
+ print(f"{self.class_names[i]:<20}: {acc.item():.4f}")
107
+ print("------------------------------------")
108
+ self.val_conf_matrix.reset()
109
+
110
+ def test_step(self, batch, batch_idx):
111
+ x, y = batch
112
+ logits = self(x)
113
+ self.test_conf_matrix.update(logits, y)
114
+
115
+ def on_test_end(self):
116
+ cm = self.test_conf_matrix.compute()
117
+ print("\nGenerating final confusion matrix plot...")
118
+ self.test_conf_matrix.reset()
119
+
120
+ def configure_optimizers(self):
121
+ optimizer = torch.optim.Adam(
122
+ self.parameters(),
123
+ lr=self.hparams.lr,
124
+ weight_decay=self.hparams.weight_decay
125
+ )
126
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
127
+ optimizer,
128
+ T_max=self.trainer.max_epochs,
129
+ eta_min=1e-6
130
+ )
131
+ return {"optimizer": optimizer, "lr_scheduler": {"scheduler": scheduler, "interval": "epoch"}}
132
+
133
+ class EffNetb2(pl.LightningModule):
134
+ """A PyTorch Lightning Module for fine-tuning EfficientNet-B2.
135
+
136
+ This module encapsulates the EfficientNet-B2 model and provides a flexible
137
+ fine-tuning strategy. It can be configured for Stage 1 (training only the
138
+ classifier and later feature blocks) or Stage 2 (training the entire model).
139
+
140
+ Args:
141
+ lr (float, optional): The learning rate. Defaults to 1e-3.
142
+ weight_decay (float, optional): Weight decay for the optimizer. Defaults to 1e-4.
143
+ num_classes (int, optional): The number of output classes. Defaults to 101.
144
+ class_names (list, optional): A list of class names for logging. Defaults to None.
145
+ freeze_features (bool, optional): If True, freezes the backbone and unfreezes
146
+ only the later blocks (Stage 1). If False, all features are trainable
147
+ (Stage 2). Defaults to True.
148
+ unfreeze_from_block (int, optional): Which feature block to start unfreezing
149
+ from. Used only if freeze_features is True. Defaults to -3 (last 3 blocks).
150
+ """
151
+
152
+ def __init__(
153
+ self,
154
+ lr: float = 1e-3,
155
+ weight_decay: float = 1e-4,
156
+ num_classes: int = 101,
157
+ class_names: list = None,
158
+ freeze_features: bool = True,
159
+ unfreeze_from_block: int = -3
160
+ ):
161
+ super().__init__()
162
+ self.save_hyperparameters()
163
+ self.class_names = class_names if class_names is not None else [str(i) for i in range(num_classes)]
164
+
165
+ # Model setup
166
+ weights = torchvision.models.EfficientNet_B2_Weights.DEFAULT
167
+ self.model = torchvision.models.efficientnet_b2(weights=weights)
168
+
169
+ # --- : Flexible Freezing Strategy ---
170
+ if self.hparams.freeze_features:
171
+ # Stage 1: Freeze all first
172
+ for param in self.model.parameters():
173
+ param.requires_grad = False
174
+ # Unfreeze from a specific block (default: last 3 blocks)
175
+ for param in self.model.features[self.hparams.unfreeze_from_block:].parameters():
176
+ param.requires_grad = True
177
+ else:
178
+ # Stage 2: unfreeze everything
179
+ for param in self.model.parameters():
180
+ param.requires_grad = True
181
+
182
+ # Classifier head
183
+ self.model.classifier = nn.Sequential(
184
+ nn.Dropout(p=0.3, inplace=True),
185
+ nn.Linear(in_features=1408, out_features=self.hparams.num_classes)
186
+ )
187
+
188
+ # Metrics
189
+ self.loss_fn = nn.CrossEntropyLoss(label_smoothing=0.1)
190
+ self.train_accuracy = Accuracy(task="multiclass", num_classes=self.hparams.num_classes)
191
+ self.val_accuracy = Accuracy(task="multiclass", num_classes=self.hparams.num_classes)
192
+ self.train_f1 = F1Score(task="multiclass", num_classes=self.hparams.num_classes, average='macro')
193
+ self.val_f1 = F1Score(task="multiclass", num_classes=self.hparams.num_classes, average='macro')
194
+ self.val_conf_matrix = ConfusionMatrix(task="multiclass", num_classes=self.hparams.num_classes)
195
+ self.test_conf_matrix = ConfusionMatrix(task="multiclass", num_classes=self.hparams.num_classes)
196
+
197
+ def forward(self, x):
198
+ return self.model(x)
199
+
200
+ def training_step(self, batch, batch_idx):
201
+ x, y = batch
202
+ logits = self(x)
203
+ loss = self.loss_fn(logits, y)
204
+ self.train_accuracy(logits, y)
205
+ self.train_f1(logits, y)
206
+ self.log('train_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
207
+ self.log('train_acc', self.train_accuracy, on_step=False, on_epoch=True, prog_bar=True)
208
+ self.log('train_f1', self.train_f1, on_step=False, on_epoch=True, prog_bar=True)
209
+ return loss
210
+
211
+ def validation_step(self, batch, batch_idx):
212
+ x, y = batch
213
+ logits = self(x)
214
+ loss = self.loss_fn(logits, y)
215
+ self.val_accuracy(logits, y)
216
+ self.val_f1(logits, y)
217
+ self.log('val_loss', loss, prog_bar=True)
218
+ self.log('val_acc', self.val_accuracy, prog_bar=True)
219
+ self.log('val_f1', self.val_f1, prog_bar=True)
220
+ self.val_conf_matrix.update(logits, y)
221
+
222
+ def on_validation_epoch_end(self):
223
+ cm = self.val_conf_matrix.compute()
224
+
225
+ # Add a small epsilon (1e-6) to the denominator for numerical stability.
226
+ per_class_acc = cm.diag() / (cm.sum(dim=1) + 1e-6)
227
+
228
+ print("\n--- Per-Class Validation Accuracy ---")
229
+ for i, acc in enumerate(per_class_acc):
230
+ class_name = self.class_names[i]
231
+ self.log(f'val_acc/{class_name}', acc.item(), on_epoch=True)
232
+ print(f"{class_name:<20}: {acc.item():.4f}")
233
+ print("------------------------------------")
234
+
235
+ self.val_conf_matrix.reset()
236
+
237
+ def test_step(self, batch, batch_idx):
238
+ x, y = batch
239
+ logits = self(x)
240
+ self.test_conf_matrix.update(logits, y)
241
+
242
+ def on_test_end(self):
243
+ cm = self.test_conf_matrix.compute()
244
+ print("\nGenerating final confusion matrix plot...")
245
+ # Assuming plot_confusion_matrix is defined elsewhere
246
+ # plot_confusion_matrix(cm.cpu().numpy(), class_names=self.class_names)
247
+ self.test_conf_matrix.reset()
248
+
249
+ def configure_optimizers(self):
250
+ optimizer = torch.optim.Adam(
251
+ self.parameters(),
252
+ lr=self.hparams.lr,
253
+ weight_decay=self.hparams.weight_decay
254
+ )
255
+ scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
256
+ optimizer,
257
+ T_max=self.trainer.max_epochs,
258
+ eta_min=1e-6
259
+ )
260
+ return {
261
+ "optimizer": optimizer,
262
+ "lr_scheduler": {
263
+ "scheduler": scheduler,
264
+ "interval": "epoch",
265
+ },
266
+ }
scripts/prepare_data.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torchvision import datasets
2
+ from pathlib import Path
3
+ import os
4
+ import pytorch_lightning as pl
5
+ from torch.utils.data import DataLoader, Subset
6
+ from torchvision import datasets
7
+ from torchvision import transforms as T
8
+ import numpy as np
9
+ import torchvision
10
+ from torchvision.datasets import Food101
11
+ from torch.utils.data import DataLoader, Dataset
12
+ from typing import Dict, Tuple, Any
13
+ import random
14
+
15
+
16
+ def get_model_components(
17
+ model_name: str,
18
+ return_classifier: bool = False,
19
+ augmentation_level: str = "default"
20
+ ) -> Dict[str, Any]:
21
+ """
22
+ Retrieves pre-trained model components from torchvision.
23
+
24
+ This function fetches the appropriate weights and transforms for a given
25
+ model. It supports different levels of training data augmentation.
26
+
27
+ Args:
28
+ model_name (str): The name of the model to get components for.
29
+ Supported models include "EfficientNet_V2_S" and "EfficientNet_B2".
30
+ return_classifier (bool, optional): If True, the model's classifier
31
+ head is also returned. Defaults to False.
32
+ augmentation_level (str, optional): The level of data augmentation to use
33
+ for the training set. Can be "default" or "strong".
34
+ Defaults to "default".
35
+
36
+ Returns:
37
+ Dict[str, Any]: A dictionary containing the requested components.
38
+ Always includes 'train_transforms' and 'val_transforms'.
39
+ Includes 'classifier' if return_classifier is True.
40
+
41
+ Raises:
42
+ ValueError: If model_name or augmentation_level is not supported.
43
+ """
44
+ model_registry = {
45
+ "EfficientNet_V2_S": (
46
+ torchvision.models.efficientnet_v2_s,
47
+ torchvision.models.EfficientNet_V2_S_Weights.DEFAULT
48
+ ),
49
+ "EfficientNet_B2": (
50
+ torchvision.models.efficientnet_b2,
51
+ torchvision.models.EfficientNet_B2_Weights.DEFAULT
52
+ )
53
+ }
54
+
55
+ if model_name not in model_registry:
56
+ raise ValueError(f"Model '{model_name}' is not supported. "
57
+ f"Supported models are: {list(model_registry.keys())}")
58
+
59
+ # 1. Look up the model and weights classes
60
+ model_class, weights_class = model_registry[model_name]
61
+ weights = weights_class
62
+ val_transforms = weights.transforms()
63
+
64
+ # 2. Create the training transforms based on the desired level
65
+ if augmentation_level == "default":
66
+ train_transforms = T.Compose([
67
+ T.TrivialAugmentWide(),
68
+ val_transforms # val_transforms includes ToTensor and Normalize
69
+ ])
70
+ elif augmentation_level == "strong":
71
+ # Note: We don't need to add ToTensor() or Normalize() here because
72
+ # they are already included inside the 'val_transforms' pipeline.
73
+ train_transforms = T.Compose([
74
+ T.RandomResizedCrop(size=val_transforms.crop_size, scale=(0.7, 1.0)),
75
+ T.RandomHorizontalFlip(p=0.5),
76
+ T.RandAugment(num_ops=2, magnitude=9),
77
+ # RandomErasing should be applied to a tensor, so we apply it after
78
+ # val_transforms, which handles the PIL -> Tensor conversion.
79
+ val_transforms,
80
+ T.RandomErasing(p=0.25, scale=(0.02, 0.33), ratio=(0.3, 3.3), value='random')
81
+ ])
82
+ else:
83
+ raise ValueError(f"Augmentation level '{augmentation_level}' is not supported. "
84
+ f"Choose from 'default' or 'strong'.")
85
+
86
+ # 3. Prepare the dictionary to be returned
87
+ components = {
88
+ "train_transforms": train_transforms,
89
+ "val_transforms": val_transforms
90
+ }
91
+
92
+ # 4. Optionally, instantiate the model to get the classifier
93
+ if return_classifier:
94
+ model = model_class(weights=weights)
95
+ components["classifier"] = model.classifier
96
+
97
+ return components
98
+
99
+ class CustomFood101(Dataset):
100
+ """A PyTorch Dataset for Food101 with conditional downloading and subset support.
101
+
102
+ This class wraps the torchvision Food101 dataset. It only downloads the data
103
+ if the specified directory doesn't already exist. It can also create a
104
+ reproducible, shuffled subset of the data for faster experimentation.
105
+
106
+ Args:
107
+ split (str): The dataset split, either "train" or "test".
108
+ transform (callable, optional): A function/transform to apply to the images.
109
+ data_dir (str, optional): The directory to store the data. Defaults to "data".
110
+ subset_fraction (float, optional): The fraction of the dataset to use.
111
+ Defaults to 1.0 (using the full dataset).
112
+ """
113
+
114
+ def __init__(self, split, transform=None, data_dir="data", subset_fraction: float = 0.1):
115
+ # Check if the dataset already exists before setting the download flag.
116
+ dataset_path = os.path.join(data_dir, "food-101")
117
+ should_download = not os.path.isdir(dataset_path)
118
+
119
+ # 1. Load the full dataset metadata with the conditional flag
120
+ self.full_dataset = Food101(root=data_dir, split=split, transform=transform, download=should_download)
121
+ self.classes = self.full_dataset.classes
122
+
123
+ # 2. Create a reproducible subset of indices
124
+ if subset_fraction < 1.0:
125
+ num_samples = int(len(self.full_dataset) * subset_fraction)
126
+ all_indices = list(range(len(self.full_dataset)))
127
+ # Shuffle with a fixed seed for reproducibility
128
+ random.Random(42).shuffle(all_indices)
129
+ self.indices = all_indices[:num_samples]
130
+ else:
131
+ self.indices = list(range(len(self.full_dataset)))
132
+
133
+ def __len__(self):
134
+ """Returns the total number of samples in the subset."""
135
+ return len(self.indices)
136
+
137
+ def __getitem__(self, idx):
138
+ """
139
+ Fetches the sample for the given subset index and applies the transform.
140
+ """
141
+ # Map the subset index to the actual index in the full dataset
142
+ original_idx = self.indices[idx]
143
+ image, label = self.full_dataset[original_idx]
144
+ return image, label
145
+
146
+ class Food101DataModule(pl.LightningDataModule):
147
+ """A PyTorch Lightning DataModule for the Food101 dataset.
148
+
149
+ This module encapsulates all data-related logic, including downloading,
150
+ processing, and creating DataLoaders for the training, validation, and
151
+ test sets. It uses the CustomFood101 dataset internally and allows for
152
+ controlling the fraction of data used in the training and validation splits.
153
+
154
+ Args:
155
+ data_dir (str, optional): Root directory for the data. Defaults to "data".
156
+ batch_size (int, optional): The batch size for DataLoaders. Defaults to 32.
157
+ num_workers (int, optional): Number of workers for data loading. Defaults to 2.
158
+ train_transforms (callable, optional): Transformations for the training set.
159
+ val_transforms (callable, optional): Transformations for the validation/test set.
160
+ subset_fraction (float, optional): The fraction of data to use for training
161
+ and validation. Defaults to 1.0.
162
+ """
163
+ def __init__(self, data_dir="data", batch_size=32, num_workers=2,
164
+ train_transforms=None, val_transforms=None, subset_fraction: float = 0.5):
165
+ super().__init__()
166
+ self.data_dir = data_dir
167
+ self.batch_size = batch_size
168
+ self.num_workers = num_workers
169
+ self.train_transforms = train_transforms
170
+ self.val_transforms = val_transforms
171
+ self.subset_fraction = subset_fraction
172
+
173
+ self.classes = []
174
+
175
+ def prepare_data(self):
176
+ """Downloads data if needed."""
177
+ CustomFood101(split='train', data_dir=self.data_dir)
178
+ CustomFood101(split='test', data_dir=self.data_dir)
179
+
180
+ def setup(self, stage=None):
181
+ """Assigns datasets, passing the subset_fraction."""
182
+ if stage == 'fit' or stage is None:
183
+ self.train_dataset = CustomFood101(split='train', transform=self.train_transforms,
184
+ data_dir=self.data_dir, subset_fraction=self.subset_fraction)
185
+ self.val_dataset = CustomFood101(split='test', transform=self.val_transforms,
186
+ data_dir=self.data_dir, subset_fraction=self.subset_fraction)
187
+ self.classes = self.train_dataset.classes
188
+
189
+ if stage == 'test' or stage is None:
190
+ self.test_dataset = CustomFood101(split='test', transform=self.val_transforms,
191
+ data_dir=self.data_dir, subset_fraction=1.0) # Use full test set
192
+ if not self.classes:
193
+ self.classes = self.test_dataset.classes
194
+
195
+ def train_dataloader(self):
196
+ return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=self.num_workers)
197
+
198
+ def val_dataloader(self):
199
+ return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers)
200
+
201
+ def test_dataloader(self):
202
+ return DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=self.num_workers)
203
+
204
+
205
+ if __name__ == "__main__":
206
+ # Define configuration for the script
207
+ DATA_DIR = "data"
208
+ MODEL_NAME = "EfficientNet_V2_S"
209
+ BATCH_SIZE = 32
210
+
211
+ print(f"Running data preparation script for model: {MODEL_NAME}")
212
+
213
+ # 1. Get model-specific transforms
214
+ components = get_model_components(MODEL_NAME)
215
+ train_transforms = components["train_transforms"]
216
+ val_transforms = components["val_transforms"]
217
+
218
+ # 2. Instantiate the DataModule
219
+ datamodule = Food101DataModule(
220
+ data_dir=DATA_DIR,
221
+ batch_size=BATCH_SIZE,
222
+ train_transforms=train_transforms,
223
+ val_transforms=val_transforms,
224
+ subset_fraction=0.1 # Use a small subset for quick verification
225
+ )
226
+
227
+ # 3. Trigger download and setup
228
+ datamodule.prepare_data()
229
+ datamodule.setup(stage='fit')
230
+
231
+ # 4. (Optional) Verification Step
232
+ print("\n--- Verifying Dataloader ---")
233
+ # Get one batch from the training dataloader
234
+ train_dl = datamodule.train_dataloader()
235
+ images, labels = next(iter(train_dl))
236
+
237
+ print(f"Number of classes: {len(datamodule.classes)}")
238
+ print(f"Image batch shape: {images.shape}")
239
+ print(f"Label batch shape: {labels.shape}")
240
+ print("--- Verification Complete ---")