|
|
--- |
|
|
license: mit |
|
|
tags: |
|
|
- image-classification |
|
|
- tensorflow |
|
|
- Grad-CAM |
|
|
- BrainTumor |
|
|
datasets: |
|
|
- AIOmarRehan/Brain_Tumor_MRI_Dataset |
|
|
--- |
|
|
# My Model |
|
|
This model classifies images from the Brain Tumor dataset with Grad-CAM, you can try out the model on my profile. |
|
|
|
|
|
# **Brain Tumor Classification Using InceptionV3 and Grad-CAM** |
|
|
|
|
|
A complete deep learning pipeline for **brain tumor classification** using MRI scans. |
|
|
This project demonstrates: |
|
|
|
|
|
* **End-to-end data preprocessing** |
|
|
* **Augmentation & dataset balancing** |
|
|
* **Efficient tf.data pipelines** |
|
|
* **Transfer learning with InceptionV3** |
|
|
* **Deep model evaluation** |
|
|
* **Grad-CAM interpretability** |
|
|
* **LaTeX mathematical explanations** |
|
|
|
|
|
--- |
|
|
|
|
|
## **1. Dataset Exploration & Inspection** |
|
|
|
|
|
We begin by recursively scanning all MRI images and creating a structured DataFrame: |
|
|
|
|
|
```python |
|
|
from pathlib import Path |
|
|
import pandas as pd |
|
|
|
|
|
image_extensions = {'.jpg', '.jpeg', '.png'} |
|
|
paths = [ |
|
|
(path.parts[-2], path.name, str(path)) |
|
|
for path in Path("/content/my_data").rglob('*.*') |
|
|
if path.suffix.lower() in image_extensions |
|
|
] |
|
|
|
|
|
df = pd.DataFrame(paths, columns=['class', 'image', 'full_path']) |
|
|
df = df.sort_values('class').reset_index(drop=True) |
|
|
df.head() |
|
|
``` |
|
|
|
|
|
Count images per class: |
|
|
|
|
|
```python |
|
|
class_count = df['class'].value_counts() |
|
|
print(class_count) |
|
|
``` |
|
|
|
|
|
### **Visualizations** |
|
|
|
|
|
```python |
|
|
import matplotlib.pyplot as plt |
|
|
|
|
|
plt.figure(figsize=(32,16)) |
|
|
class_count.plot(kind='bar', edgecolor='black') |
|
|
plt.title('Number of Images per Class') |
|
|
plt.show() |
|
|
``` |
|
|
|
|
|
### **Insights** |
|
|
|
|
|
* Classes are **imbalanced** |
|
|
* Images have **variable resolution** |
|
|
* Some outliers require **cleaning** |
|
|
|
|
|
--- |
|
|
|
|
|
## **2. Data Cleaning & Quality Checks** |
|
|
|
|
|
### **Duplicate removal using MD5 hashes** |
|
|
|
|
|
```python |
|
|
import hashlib |
|
|
|
|
|
def get_hash(file_path): |
|
|
with open(file_path, 'rb') as f: |
|
|
return hashlib.md5(f.read()).hexdigest() |
|
|
|
|
|
df['file_hash'] = df['full_path'].apply(get_hash) |
|
|
df_unique = df.drop_duplicates(subset='file_hash', keep='first') |
|
|
``` |
|
|
|
|
|
### **Additional checks** |
|
|
|
|
|
* Corrupted image detection |
|
|
* Resolution anomalies |
|
|
* Brightness/contrast outliers |
|
|
|
|
|
Cleaning ensures a **robust dataset** with minimal noise. |
|
|
|
|
|
--- |
|
|
|
|
|
## **3. Data Augmentation & Class Balancing** |
|
|
|
|
|
Target ~2,000 images per class using heavy augmentation: |
|
|
|
|
|
```python |
|
|
from tensorflow.keras.preprocessing.image import ImageDataGenerator |
|
|
|
|
|
datagen = ImageDataGenerator( |
|
|
rotation_range=20, |
|
|
width_shift_range=0.1, |
|
|
height_shift_range=0.1, |
|
|
shear_range=0.1, |
|
|
zoom_range=0.1, |
|
|
horizontal_flip=True, |
|
|
fill_mode='nearest' |
|
|
) |
|
|
``` |
|
|
|
|
|
Used for minority class upsampling and preventing overfitting. |
|
|
|
|
|
--- |
|
|
|
|
|
## **4. Image Preprocessing Pipeline** |
|
|
|
|
|
```python |
|
|
import tensorflow as tf |
|
|
|
|
|
def preprocess_image(path, target_size=(512, 512), augment=True): |
|
|
img = tf.io.read_file(path) |
|
|
img = tf.image.decode_image(img, channels=3) |
|
|
img = tf.image.resize(img, target_size) |
|
|
img = tf.cast(img, tf.float32) / 255.0 |
|
|
|
|
|
if augment: |
|
|
img = tf.image.random_flip_left_right(img) |
|
|
img = tf.image.random_flip_up_down(img) |
|
|
img = tf.image.random_brightness(img, max_delta=0.1) |
|
|
img = tf.image.random_contrast(img, 0.9, 1.1) |
|
|
|
|
|
return tf.clip_by_value(img, 0.0, 1.0) |
|
|
``` |
|
|
|
|
|
* **Train set:** augmentation enabled |
|
|
* **Validation/Test sets:** kept clean |
|
|
|
|
|
--- |
|
|
|
|
|
## **5. Dataset Preparation with `tf.data`** |
|
|
|
|
|
```python |
|
|
AUTOTUNE = tf.data.AUTOTUNE |
|
|
batch_size = 32 |
|
|
|
|
|
train_ds = tf.data.Dataset.from_tensor_slices((train_paths, train_labels)) |
|
|
train_ds = train_ds.shuffle(len(train_paths)) |
|
|
train_ds = train_ds.map( |
|
|
lambda x, y: (preprocess_image(x, augment=True), y), |
|
|
num_parallel_calls=AUTOTUNE |
|
|
) |
|
|
train_ds = train_ds.batch(batch_size).prefetch(AUTOTUNE) |
|
|
``` |
|
|
|
|
|
Benefits: |
|
|
|
|
|
* Parallel loading |
|
|
* Smart prefetching |
|
|
* GPU utilization maximized |
|
|
|
|
|
--- |
|
|
|
|
|
## **6. Model Architecture: InceptionV3** |
|
|
|
|
|
Transfer learning from ImageNet: |
|
|
|
|
|
```python |
|
|
from tensorflow.keras.applications.inception_v3 import InceptionV3 |
|
|
from tensorflow.keras.layers import GlobalAveragePooling2D, Dense, Dropout |
|
|
from tensorflow.keras.models import Model |
|
|
|
|
|
inception = InceptionV3(input_shape=input_shape, weights='imagenet', include_top=False) |
|
|
|
|
|
for layer in inception.layers: |
|
|
layer.trainable = False |
|
|
|
|
|
x = GlobalAveragePooling2D()(inception.output) |
|
|
x = Dense(512, activation='relu')(x) |
|
|
x = Dropout(0.5)(x) |
|
|
prediction = Dense(len(le.classes_), activation='softmax')(x) |
|
|
|
|
|
model = Model(inputs=inception.input, outputs=prediction) |
|
|
``` |
|
|
|
|
|
### Why InceptionV3? |
|
|
|
|
|
* Factorized convolutions |
|
|
* Multi-scale feature extraction |
|
|
* Lightweight and fast |
|
|
* Strong performance in medical imaging |
|
|
|
|
|
--- |
|
|
|
|
|
## **7. Training & Callbacks** |
|
|
|
|
|
```python |
|
|
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint, ReduceLROnPlateau |
|
|
|
|
|
model.compile( |
|
|
loss='sparse_categorical_crossentropy', |
|
|
optimizer='adam', |
|
|
metrics=['accuracy'] |
|
|
) |
|
|
|
|
|
callbacks = [ |
|
|
EarlyStopping(monitor='val_loss', patience=40, restore_best_weights=True), |
|
|
ModelCheckpoint("best_model.h5", save_best_only=True, monitor='val_loss'), |
|
|
ReduceLROnPlateau(monitor='val_loss', factor=0.5, patience=10, min_lr=1e-5) |
|
|
] |
|
|
``` |
|
|
|
|
|
Training: |
|
|
|
|
|
```python |
|
|
history = model.fit(train_ds, validation_data=val_ds, epochs=50, callbacks=callbacks) |
|
|
``` |
|
|
|
|
|
--- |
|
|
|
|
|
## **8. Training Curves** |
|
|
|
|
|
```python |
|
|
import matplotlib.pyplot as plt |
|
|
|
|
|
plt.plot(history.history['accuracy'], label='Train Accuracy') |
|
|
plt.plot(history.history['val_accuracy'], label='Val Accuracy') |
|
|
plt.title('Training vs Validation Accuracy') |
|
|
plt.legend() |
|
|
plt.show() |
|
|
``` |
|
|
|
|
|
* Curves indicate **smooth convergence** |
|
|
* Small train/val gap → **limited overfitting** |
|
|
|
|
|
--- |
|
|
|
|
|
## **9. Performance Metrics** |
|
|
|
|
|
### Confusion Matrix |
|
|
|
|
|
```python |
|
|
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay |
|
|
|
|
|
cm = confusion_matrix(y_true, y_pred) |
|
|
ConfusionMatrixDisplay(cm, display_labels=le.classes_).plot(cmap='Blues') |
|
|
``` |
|
|
<p align="center"> |
|
|
<img src="https://files.catbox.moe/wuynop.png" width="100%"> |
|
|
</p> |
|
|
|
|
|
### Multi-class AUC (One-vs-Rest) |
|
|
|
|
|
**Macro AUC formula:** |
|
|
|
|
|
<img src="https://latex.codecogs.com/svg.image?\color{white}\text{AUC}_{macro}=\frac{1}{K}\sum_{i=1}^{K}\text{AUC}_i"/> |
|
|
|
|
|
```python |
|
|
from sklearn.preprocessing import label_binarize |
|
|
from sklearn.metrics import roc_curve, auc |
|
|
|
|
|
y_true_bin = label_binarize(y_true, classes=np.arange(len(le.classes_))) |
|
|
``` |
|
|
<p align="center"> |
|
|
<img src="https://files.catbox.moe/w3fazk.png" width="100%"> |
|
|
</p> |
|
|
|
|
|
--- |
|
|
|
|
|
## **10. Grad-CAM: Interpretability** |
|
|
|
|
|
Grad-CAM highlights regions the model uses for classification. |
|
|
|
|
|
### Grad-CAM heatmap: |
|
|
|
|
|
<img src="https://latex.codecogs.com/svg.image?\color{white}L^c_{\text{Grad-CAM}}=\text{ReLU}\left(\sum_k\alpha_k^cA^k\right)" /> |
|
|
|
|
|
Where: |
|
|
|
|
|
<img src="https://latex.codecogs.com/svg.image?\color{white}\alpha_k^c=\frac{1}{Z}\sum_{i}\sum_{j}\frac{\partial y^c}{\partial A_{ij}^k}" /> |
|
|
|
|
|
|
|
|
Python implementation: |
|
|
|
|
|
```python |
|
|
def gradcam(model, img, cls=None): |
|
|
# last conv |
|
|
lc = next(l for l in reversed(model.layers) if "conv" in l.name.lower()) |
|
|
gm = tf.keras.Model(model.input, [lc.output, model.output]) |
|
|
|
|
|
with tf.GradientTape() as t: |
|
|
conv, pred = gm(img[None]) |
|
|
cls = tf.argmax(pred[0]) if cls is None else cls |
|
|
loss = pred[:, cls] |
|
|
|
|
|
g = t.gradient(loss, conv) |
|
|
w = tf.reduce_mean(g, axis=(0,1,2)) |
|
|
cam = tf.reduce_sum(w * conv[0], -1) |
|
|
|
|
|
cam = tf.nn.relu(cam) |
|
|
cam /= tf.reduce_max(cam) + 1e-8 |
|
|
return cam.numpy() |
|
|
``` |
|
|
|
|
|
Visualization example: |
|
|
|
|
|
```python |
|
|
plt.figure(figsize=(20,10)) |
|
|
for i, img in enumerate(sample_images): |
|
|
overlay, info = VizGradCAM(model, img) |
|
|
plt.subplot(2, 5, i+1) |
|
|
plt.imshow(overlay) |
|
|
plt.axis("off") |
|
|
plt.title(f"True Label: {le.classes_[sample_labels[i]]}") |
|
|
plt.show() |
|
|
``` |
|
|
|
|
|
<p align="center"> |
|
|
<img src="https://files.catbox.moe/ysg2yc.png" width="100%"> |
|
|
</p> |
|
|
|
|
|
> **Note:** When the model is highly confident in a prediction, the Grad-CAM gradients become near-zero, producing little to no heatmap activation. |
|
|
|
|
|
--- |
|
|
|
|
|
## **11. Technical LaTeX Notes** |
|
|
|
|
|
### Sparse Categorical Crossentropy |
|
|
|
|
|
<img src="https://latex.codecogs.com/svg.image?\color{white}L=-\frac{1}{N}\sum_{i=1}^{N}\log(p_{i,y_i})" /> |
|
|
|
|
|
|
|
|
### Global Average Pooling |
|
|
|
|
|
<img src="https://latex.codecogs.com/svg.image?\color{white}f_c=\frac{1}{h \cdot \omega}\sum_{i=1}^{h}\sum_{j=1}^{\omega}F_{i,j,c}" /> |
|
|
|
|
|
|
|
|
--- |
|
|
|
|
|
## **12. Model Saving** |
|
|
|
|
|
```python |
|
|
model.save("InceptionV3_Brain_Tumor_MRI.h5") |
|
|
``` |
|
|
|
|
|
--- |
|
|
|
|
|
## **13. Results** |
|
|
> **Note:** Click the image below to view the video showcasing the project’s results. |
|
|
<a href="https://files.catbox.moe/27ct3j.mp4"> |
|
|
<img src="https://images.unsplash.com/photo-1611162616475-46b635cb6868?q=80&w=1974&auto=format&fit=crop&ixlib=rb-4.1.0&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D" width="400"> |
|
|
</a> |
|
|
|
|
|
<hr style="border-bottom: 5px solid gray; margin-top: 10px;"> |
|
|
|
|
|
--- |
|
|
|
|
|
## **Key Takeaways** |
|
|
|
|
|
* Strong data cleaning = reliable model |
|
|
* Heavy augmentation reduces bias |
|
|
* InceptionV3 provides excellent feature extraction |
|
|
* Evaluation metrics reveal clinical reliability |
|
|
* Grad-CAM adds essential interpretability |