Efficient_NetV2_Edition / last_model.py
kikogazda's picture
Upload 32 files
2ab7451 verified
# -*- coding: utf-8 -*-
"""Last_model.ipynb
Automatically generated by Colab.
Original file is located at
https://colab.research.google.com/drive/1AdRILP1oqdiVuRSQr2dZZy0QgU8insn_
🚗 TwinCar Project: SOTA Training, Full Visuals, and Advanced Reporting
---
---
1. Environment Setup and Imports
Explanation:
We start by importing all necessary libraries and prepping our working environment for advanced data handling and visualization.
---
"""
# Block 1: Environment Setup and Imports
import os
import zipfile
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from torchvision import transforms
from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import (
accuracy_score, precision_score, recall_score, f1_score, hamming_loss,
cohen_kappa_score, matthews_corrcoef, jaccard_score,
confusion_matrix, classification_report
)
import timm
import scipy.io
"""2. Data Extraction and Preparation
Explanation:
We extract and organize the Stanford Cars dataset, parse .mat files to CSV for class and label mapping, and prepare all paths.
---
"""
# Block 2: Data Extraction and Preparation
from google.colab import drive
drive.mount('/content/drive')
zip_path = '/content/drive/MyDrive/stanford_cars.zip'
extract_dir = '/content/stanford_cars'
if not os.path.exists(extract_dir):
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extractall(extract_dir)
print("✅ Dataset extracted at", extract_dir)
meta = scipy.io.loadmat(f"{extract_dir}/car_devkit/devkit/cars_meta.mat")
class_names = [x[0] for x in meta['class_names'][0]]
NUM_CLASSES = len(class_names)
train_annos = scipy.io.loadmat(f"{extract_dir}/car_devkit/devkit/cars_train_annos.mat")['annotations'][0]
train_rows = [[x[5][0], int(x[4][0]) - 1] for x in train_annos]
df_train = pd.DataFrame(train_rows, columns=["filename", "label"])
df_train.to_csv('/content/train_labels.csv', index=False)
test_annos = scipy.io.loadmat(f"{extract_dir}/car_devkit/devkit/cars_test_annos.mat")['annotations'][0]
test_rows = [[x[4][0]] for x in test_annos]
df_test = pd.DataFrame(test_rows, columns=["filename"])
df_test.to_csv('/content/test_labels.csv', index=False)
train_root = f"{extract_dir}/cars_train/cars_train"
test_root = f"{extract_dir}/cars_test/cars_test"
"""3. Advanced Dataset and Augmentations
Explanation:
We build a flexible dataset class, apply advanced augmentations, and lay the foundation for Mixup/CutMix later.
---
"""
# Block 3: Dataset and Advanced Augmentations
class StanfordCarsFromCSV(Dataset):
def __init__(self, root_dir, csv_file, transform=None, has_labels=True):
self.root_dir = root_dir
self.data = pd.read_csv(csv_file)
self.transform = transform
self.has_labels = has_labels
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
row = self.data.iloc[idx]
img_path = os.path.join(self.root_dir, row['filename'])
image = Image.open(img_path).convert('RGB')
if self.transform:
image = self.transform(image)
if self.has_labels:
return image, int(row['label'])
return image, row['filename']
imagenet_mean = [0.485, 0.456, 0.406]
imagenet_std = [0.229, 0.224, 0.225]
train_transform = transforms.Compose([
transforms.RandomResizedCrop(224, scale=(0.7, 1.0)),
transforms.RandomHorizontalFlip(),
transforms.RandomRotation(15),
transforms.ColorJitter(0.4, 0.4, 0.4, 0.2),
transforms.RandomApply([transforms.GaussianBlur(3)], p=0.15),
transforms.ToTensor(),
transforms.Normalize(mean=imagenet_mean, std=imagenet_std)
])
val_transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=imagenet_mean, std=imagenet_std)
])
"""4. Data Splitting, Weighted Sampling, and DataLoader
Explanation:
We split the data into train and validation sets with stratification for balanced classes,
use class weighting to counter imbalance, and create PyTorch DataLoaders for efficient training and evaluation.
---
"""
# 4. Data Splitting, Loader Setup, and Weighted Sampling
from torch.utils.data import DataLoader, WeightedRandomSampler
from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight
# --- Settings ---
BATCH_SIZE = 32
VAL_RATIO = 0.1
RANDOM_SEED = 42
# --- Stratified Split for Balanced Classes ---
df_all = pd.read_csv('/content/train_labels.csv')
df_train, df_val = train_test_split(
df_all,
test_size=VAL_RATIO,
stratify=df_all['label'],
random_state=RANDOM_SEED
)
df_train.to_csv('/content/train_split.csv', index=False)
df_val.to_csv('/content/val_split.csv', index=False)
# --- Datasets ---
train_dataset = StanfordCarsFromCSV(train_root, '/content/train_split.csv', train_transform)
val_dataset = StanfordCarsFromCSV(train_root, '/content/val_split.csv', val_transform)
test_dataset = StanfordCarsFromCSV(test_root, '/content/test_labels.csv', val_transform, has_labels=False)
# --- Weighted Sampler for Balanced Training ---
labels = [label for _, label in train_dataset]
class_weights = compute_class_weight(class_weight='balanced', classes=np.unique(labels), y=labels)
sample_weights = [class_weights[label] for label in labels]
sampler = WeightedRandomSampler(sample_weights, len(sample_weights), replacement=True)
# --- DataLoaders (drop_last=True for Mixup/CutMix compatibility) ---
train_loader = DataLoader(
train_dataset,
batch_size=BATCH_SIZE,
sampler=sampler,
num_workers=2,
pin_memory=True,
drop_last=True
)
val_loader = DataLoader(
val_dataset,
batch_size=BATCH_SIZE,
shuffle=False,
num_workers=2,
pin_memory=True,
drop_last=False
)
test_loader = DataLoader(
test_dataset,
batch_size=BATCH_SIZE,
shuffle=False,
num_workers=2,
pin_memory=True,
drop_last=False
)
print(f"Train samples: {len(train_dataset)} | Val samples: {len(val_dataset)} | Test samples: {len(test_dataset)}")
print(f"Train loader batches (per epoch): {len(train_loader)} (should be integer and even-sized)")
"""5. Model Initialization: EfficientNetV2 + Mixup/CutMix Ready
Explanation:
We load EfficientNetV2 with ImageNet weights for best transfer learning,
set up optimizer, scheduler, and prepare for Mixup/CutMix advanced augmentation.
---
"""
# Block 5: Model Initialization (EfficientNetV2 + Mixup/CutMix)
from timm.data import Mixup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = timm.create_model('efficientnetv2_rw_s', pretrained=True, num_classes=NUM_CLASSES, drop_rate=0.3)
model = model.to(device)
optimizer = optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-5)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=25)
criterion = nn.CrossEntropyLoss(label_smoothing=0.0)
mixup_fn = Mixup(
mixup_alpha=0.4, cutmix_alpha=1.0, cutmix_minmax=None,
prob=1.0, switch_prob=0.5, mode='batch',
label_smoothing=0.1, num_classes=NUM_CLASSES
)
"""6. Advanced Training Loop: Full Metrics, Early Stopping, and Mixup
Explanation:
This loop supports Mixup/CutMix, logs all advanced metrics, and uses early stopping with automatic best model saving.
Ready for real production—and all your plots and reporting.
---
"""
# Block 6: Advanced Training Loop
EPOCHS = 25
patience, counter = 7, 0
best_val_f1 = 0
metrics_dict = {
'train_loss': [], 'train_acc': [],
'val_loss': [], 'val_acc': [],
'val_precision_macro': [], 'val_precision_weighted': [],
'val_recall_macro': [], 'val_recall_weighted': [],
'val_f1_macro': [], 'val_f1_weighted': [],
'val_hamming': [], 'val_cohen_kappa': [],
'val_mcc': [], 'val_jaccard_macro': [],
'val_top3': [], 'val_top5': [],
}
for epoch in range(EPOCHS):
# TRAIN
model.train()
total_loss, correct, total = 0, 0, 0
for imgs, labels in tqdm(train_loader, desc=f"Train Epoch {epoch+1}"):
imgs, labels = imgs.to(device), labels.to(device)
optimizer.zero_grad()
imgs, labels = mixup_fn(imgs, labels)
outputs = model(imgs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
total_loss += loss.item() * imgs.size(0)
correct += (outputs.argmax(1) == labels.argmax(1)).sum().item()
total += labels.size(0)
train_loss = total_loss / total
train_acc = correct / total
metrics_dict['train_loss'].append(train_loss)
metrics_dict['train_acc'].append(train_acc)
# VALIDATION
model.eval()
val_loss, val_correct, val_total = 0, 0, 0
val_probs, val_preds, val_targets = [], [], []
with torch.no_grad():
for imgs, labels in tqdm(val_loader, desc=f"Val Epoch {epoch+1}"):
imgs, labels = imgs.to(device), labels.to(device)
outputs = model(imgs)
v_loss = criterion(outputs, labels)
val_loss += v_loss.item() * imgs.size(0)
probs = torch.softmax(outputs, dim=1)
preds = outputs.argmax(1)
val_correct += (preds == labels).sum().item()
val_total += labels.size(0)
val_probs.extend(probs.cpu().numpy())
val_preds.extend(preds.cpu().numpy())
val_targets.extend(labels.cpu().numpy())
val_loss /= val_total
val_acc = val_correct / val_total
val_preds_np = np.array(val_preds)
val_targets_np = np.array(val_targets)
val_probs_np = np.array(val_probs)
# Metrics
val_precision_macro = precision_score(val_targets_np, val_preds_np, average='macro', zero_division=0)
val_precision_weighted = precision_score(val_targets_np, val_preds_np, average='weighted', zero_division=0)
val_recall_macro = recall_score(val_targets_np, val_preds_np, average='macro', zero_division=0)
val_recall_weighted = recall_score(val_targets_np, val_preds_np, average='weighted', zero_division=0)
val_f1_macro = f1_score(val_targets_np, val_preds_np, average='macro', zero_division=0)
val_f1_weighted = f1_score(val_targets_np, val_preds_np, average='weighted', zero_division=0)
top3_acc = np.mean([
label in np.argsort(prob)[-3:] for prob, label in zip(val_probs_np, val_targets_np)
])
top5_acc = np.mean([
label in np.argsort(prob)[-5:] for prob, label in zip(val_probs_np, val_targets_np)
])
val_hamming = hamming_loss(val_targets_np, val_preds_np)
val_cohen_kappa = cohen_kappa_score(val_targets_np, val_preds_np)
val_mcc = matthews_corrcoef(val_targets_np, val_preds_np)
val_jaccard_macro = jaccard_score(val_targets_np, val_preds_np, average='macro', zero_division=0)
# Log metrics
metrics_dict['val_loss'].append(val_loss)
metrics_dict['val_acc'].append(val_acc)
metrics_dict['val_precision_macro'].append(val_precision_macro)
metrics_dict['val_precision_weighted'].append(val_precision_weighted)
metrics_dict['val_recall_macro'].append(val_recall_macro)
metrics_dict['val_recall_weighted'].append(val_recall_weighted)
metrics_dict['val_f1_macro'].append(val_f1_macro)
metrics_dict['val_f1_weighted'].append(val_f1_weighted)
metrics_dict['val_hamming'].append(val_hamming)
metrics_dict['val_cohen_kappa'].append(val_cohen_kappa)
metrics_dict['val_mcc'].append(val_mcc)
metrics_dict['val_jaccard_macro'].append(val_jaccard_macro)
metrics_dict['val_top3'].append(top3_acc)
metrics_dict['val_top5'].append(top5_acc)
scheduler.step()
print(f"Epoch {epoch+1:2d} | Train Acc: {train_acc:.4f} | Val Acc: {val_acc:.4f} | F1(macro): {val_f1_macro:.4f} | Top3: {top3_acc:.3f} | Top5: {top5_acc:.3f}")
# Early Stopping
if val_f1_macro > best_val_f1:
best_val_f1 = val_f1_macro
torch.save(model.state_dict(), '/content/drive/MyDrive/efficientnetv2_best_model.pth')
counter = 0
else:
counter += 1
if counter >= patience:
print("⏹️ Early stopping triggered.")
break
print("✅ Training complete. Best model saved.")
"""7.Explanation
After training, all metrics (accuracy, loss, precision, recall, F1, top-k, etc.) are saved as a CSV for analysis and reporting.
We plot core metrics (accuracy, F1, loss, precision/recall, top-3/top-5 accuracy) with:
Large, clear fonts
Annotations for best epoch
Colorful, pro-style Seaborn plots
Publication-ready grid and tight layouts
---
"""
# 7. Metrics Export & Advanced Visualizations
import seaborn as sns
# --- Save all metrics for reproducibility and later analysis
metrics_df = pd.DataFrame(metrics_dict)
metrics_df.to_csv('/content/drive/MyDrive/metrics_log.csv', index_label='epoch')
print("✅ metrics_log.csv saved.")
sns.set(style='whitegrid', font_scale=1.3)
# 1. Accuracy & Macro F1
plt.figure(figsize=(12,7))
plt.plot(metrics_df['train_acc'], label='Train Acc', lw=2)
plt.plot(metrics_df['val_acc'], label='Val Acc', lw=2)
plt.plot(metrics_df['val_f1_macro'], label='Val F1 (macro)', lw=2)
plt.xlabel('Epoch', fontsize=16)
plt.ylabel('Score', fontsize=16)
plt.title('Accuracy and Macro F1 per Epoch', fontsize=18)
plt.legend(loc='lower right')
plt.grid(True, alpha=0.3)
best_epoch = metrics_df['val_f1_macro'].idxmax()
plt.scatter(best_epoch, metrics_df['val_f1_macro'][best_epoch], c='red', s=90, label='Best Epoch')
plt.annotate(f'Best\n{metrics_df["val_f1_macro"][best_epoch]:.2f}',
(best_epoch, metrics_df["val_f1_macro"][best_epoch]),
textcoords="offset points", xytext=(-5,10), ha='right', fontsize=14, color='red')
plt.tight_layout()
plt.savefig('/content/drive/MyDrive/metrics_acc_f1_beautiful.png')
plt.show()
# 2. Loss Curves
plt.figure(figsize=(12,7))
plt.plot(metrics_df['train_loss'], label='Train Loss', lw=2)
plt.plot(metrics_df['val_loss'], label='Val Loss', lw=2)
plt.xlabel('Epoch', fontsize=16)
plt.ylabel('Loss', fontsize=16)
plt.title('Train & Validation Loss per Epoch', fontsize=18)
plt.legend(loc='upper right')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('/content/drive/MyDrive/metrics_loss_beautiful.png')
plt.show()
# 3. Precision & Recall (Macro & Weighted)
plt.figure(figsize=(12,7))
plt.plot(metrics_df['val_precision_macro'], label='Val Precision (macro)', lw=2)
plt.plot(metrics_df['val_recall_macro'], label='Val Recall (macro)', lw=2)
plt.plot(metrics_df['val_precision_weighted'], label='Val Precision (weighted)', lw=2)
plt.plot(metrics_df['val_recall_weighted'], label='Val Recall (weighted)', lw=2)
plt.xlabel('Epoch', fontsize=16)
plt.ylabel('Score', fontsize=16)
plt.title('Validation Precision & Recall per Epoch', fontsize=18)
plt.legend(loc='lower right')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('/content/drive/MyDrive/metrics_precision_recall_beautiful.png')
plt.show()
# 4. Top-3 and Top-5 Validation Accuracy as Area Plot
plt.figure(figsize=(12,7))
plt.fill_between(metrics_df.index, metrics_df['val_top3'], alpha=0.3, label='Val Top-3 Acc')
plt.fill_between(metrics_df.index, metrics_df['val_top5'], alpha=0.2, label='Val Top-5 Acc', color='orange')
plt.plot(metrics_df['val_top3'], lw=2, color='blue')
plt.plot(metrics_df['val_top5'], lw=2, color='orange')
plt.xlabel('Epoch', fontsize=16)
plt.ylabel('Accuracy', fontsize=16)
plt.title('Top-3 and Top-5 Validation Accuracy per Epoch', fontsize=18)
plt.legend(loc='lower right')
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('/content/drive/MyDrive/metrics_topk_beautiful.png')
plt.show()
"""8.Confusion Matrix & Per-Class Analysis with Advanced Visuals
Explanation
After training, it's crucial to understand not just overall metrics, but where your model succeeds and fails.
We:
Save a detailed classification report (per-class precision/recall/F1).
Draw a high-contrast confusion matrix with large ticks, tight color scaling, and readable value overlays.
Plot Top 20 Most Confused Classes for targeted debugging.
Show Top 20 Most Accurate Classes with horizontal barplots (values on bars, sorted).
---
"""
# 8. Confusion Matrix & Per-Class Analysis (Advanced Visuals)
from sklearn.metrics import classification_report, confusion_matrix
import seaborn as sns
# Reload best model for evaluation
model.load_state_dict(torch.load('/content/drive/MyDrive/efficientnetv2_best_model.pth', map_location=device))
model.eval()
# Collect all validation predictions and true labels
all_preds, all_labels = [], []
with torch.no_grad():
for imgs, labels in val_loader:
imgs, labels = imgs.to(device), labels.to(device)
outputs = model(imgs)
preds = outputs.argmax(1)
all_preds.extend(preds.cpu().numpy())
all_labels.extend(labels.cpu().numpy())
all_preds = np.array(all_preds)
all_labels = np.array(all_labels)
# Save detailed classification report (per-class)
report = classification_report(
all_labels, all_preds, target_names=class_names, output_dict=True
)
pd.DataFrame(report).transpose().to_csv('/content/drive/MyDrive/classification_report.csv')
print("✅ classification_report.csv saved.")
# Confusion Matrix (full, high-res)
cm = confusion_matrix(all_labels, all_preds)
plt.figure(figsize=(18,18))
sns.heatmap(
cm,
cmap="Blues",
xticklabels=class_names,
yticklabels=class_names,
square=True,
cbar_kws={"shrink": 0.5, "label": "Count"},
linewidths=.2
)
plt.title('Confusion Matrix', fontsize=20)
plt.xlabel('Predicted label', fontsize=16)
plt.ylabel('True label', fontsize=16)
plt.xticks(fontsize=8, rotation=90)
plt.yticks(fontsize=8)
plt.tight_layout()
plt.savefig('/content/drive/MyDrive/confusion_matrix_beautiful.png', dpi=300)
plt.show()
# Most Confused Classes (Top 20, value overlays)
off_diag = cm.copy()
np.fill_diagonal(off_diag, 0)
most_confused = np.argsort(off_diag.sum(axis=1))[::-1][:20]
cm_top = cm[np.ix_(most_confused, most_confused)]
labels_top = [class_names[i] for i in most_confused]
plt.figure(figsize=(12,10))
sns.heatmap(
cm_top,
annot=True, fmt='d', cmap="Blues",
xticklabels=labels_top, yticklabels=labels_top,
linewidths=.2, cbar=False, annot_kws={"size":14}
)
plt.title('Most Confused Classes (Top 20)', fontsize=18)
plt.xlabel('Predicted label', fontsize=15)
plt.ylabel('True label', fontsize=15)
plt.xticks(fontsize=11, rotation=90)
plt.yticks(fontsize=11)
plt.tight_layout()
plt.savefig('/content/drive/MyDrive/confused_top20_beautiful.png', dpi=300)
plt.show()
# Top-20 Most Accurate Classes (barplot, values on bars)
acc_per_class = cm.diagonal() / (cm.sum(axis=1) + 1e-8)
df_acc = pd.DataFrame({'class': class_names, 'accuracy': acc_per_class})
top_acc = df_acc.sort_values('accuracy', ascending=False).head(20)
plt.figure(figsize=(10,8))
sns.barplot(
data=top_acc, y='class', x='accuracy', palette='Blues_d', orient='h'
)
plt.title('Top 20 Classes by Accuracy', fontsize=18)
plt.xlabel('Accuracy', fontsize=15)
plt.ylabel('Class', fontsize=15)
for i, v in enumerate(top_acc['accuracy']):
plt.text(v + 0.01, i, f"{v:.2f}", color='blue', va='center', fontsize=13)
plt.tight_layout()
plt.savefig('/content/drive/MyDrive/top20_accuracy_beautiful.png', dpi=300)
plt.show()
"""9. Test-Time Augmentation (TTA) & Batch Prediction
Explanation
Test-Time Augmentation boosts prediction robustness by averaging predictions over multiple random transformations of each test image.
Batch Prediction allows you to efficiently label a folder of test images with class names—production style.
"""
# 9. Test-Time Augmentation (TTA) for Validation
tta_transforms = [
val_transform,
transforms.Compose([
transforms.Resize(256),
transforms.RandomHorizontalFlip(p=1.0),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=imagenet_mean, std=imagenet_std)
]),
transforms.Compose([
transforms.Resize(256),
transforms.RandomRotation(10),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=imagenet_mean, std=imagenet_std)
]),
transforms.Compose([
transforms.Resize(256),
transforms.ColorJitter(0.2, 0.2, 0.2, 0.1),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=imagenet_mean, std=imagenet_std)
])
]
def tta_predict(model, img_pil, tta_transforms, device='cuda'):
model.eval()
logits = []
for tform in tta_transforms:
img = tform(img_pil).unsqueeze(0).to(device)
with torch.no_grad():
logit = model(img)
logits.append(logit)
avg_logits = torch.stack(logits).mean(0)
return avg_logits
# Apply TTA to validation set
tta_val_preds, tta_val_labels = [], []
for imgs, labels in tqdm(val_loader, desc="TTA Validation"):
batch_preds = []
for i in range(imgs.size(0)):
img_pil = transforms.ToPILImage()(imgs[i].cpu())
avg_logits = tta_predict(model, img_pil, tta_transforms, device)
pred = avg_logits.argmax(dim=1).cpu().item()
batch_preds.append(pred)
tta_val_preds.extend(batch_preds)
tta_val_labels.extend(labels.cpu().numpy())
tta_val_preds = np.array(tta_val_preds)
tta_val_labels = np.array(tta_val_labels)
# Metrics for TTA
tta_f1_macro = f1_score(tta_val_labels, tta_val_preds, average='macro', zero_division=0)
tta_acc = accuracy_score(tta_val_labels, tta_val_preds)
tta_precision = precision_score(tta_val_labels, tta_val_preds, average='macro', zero_division=0)
tta_recall = recall_score(tta_val_labels, tta_val_preds, average='macro', zero_division=0)
print(f"TTA Validation Accuracy: {tta_acc:.4f}")
print(f"TTA Validation F1 (macro): {tta_f1_macro:.4f}")
print(f"TTA Validation Precision (macro): {tta_precision:.4f}")
print(f"TTA Validation Recall (macro): {tta_recall:.4f}")
# TTA Confusion matrix (optional)
cm_tta = confusion_matrix(tta_val_labels, tta_val_preds)
plt.figure(figsize=(18,18))
sns.heatmap(
cm_tta,
cmap="Purples",
xticklabels=class_names,
yticklabels=class_names,
square=True,
cbar_kws={"shrink": 0.5, "label": "Count"},
linewidths=.2
)
plt.title('TTA Confusion Matrix (Validation)', fontsize=20)
plt.xlabel('Predicted label', fontsize=16)
plt.ylabel('True label', fontsize=16)
plt.xticks(fontsize=8, rotation=90)
plt.yticks(fontsize=8)
plt.tight_layout()
plt.savefig('/content/drive/MyDrive/tta_confusion_matrix_beautiful.png', dpi=300)
plt.show()
"""10. Extraordinary Grad-CAM++ Overlays (Grid)
Explanation
We generate Grad-CAM++ visualizations for a set of sample images.
Each visualization shows:The input image,The Grad-CAM++ heatmap overlay,The true and predicted class for easy comparison.
All visualizations are saved both individually and as a large, labeled grid.
---
"""
# Grad-CAM++ Explanations: Multi-Image Grid (Fixed for latest grad-cam)
from pytorch_grad_cam import GradCAMPlusPlus
from pytorch_grad_cam.utils.image import show_cam_on_image
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
import os
os.makedirs('/content/drive/MyDrive/gradcam_outputs', exist_ok=True)
# Make sure model is on the right device
model.eval()
model.to(device)
# Pick the right target layer for EfficientNetV2 (last block)
target_layer = model.blocks[-1] if hasattr(model, "blocks") else model.layer4[-1]
# No more use_cuda argument—just instantiate
cam = GradCAMPlusPlus(model=model, target_layers=[target_layer])
num_images = 12
fig, axes = plt.subplots(3, 4, figsize=(18, 14))
fig.suptitle('Grad-CAM++ Explanations: True vs. Predicted', fontsize=22, weight='bold')
for idx in range(num_images):
img_tensor, label = val_dataset[idx]
img_pil = transforms.ToPILImage()(img_tensor.cpu())
input_tensor = img_tensor.unsqueeze(0).to(device)
with torch.no_grad():
output = model(input_tensor)
pred = output.argmax(1).item()
targets = [ClassifierOutputTarget(pred)]
grayscale_cam = cam(input_tensor=input_tensor, targets=targets)[0]
image_np = img_tensor.permute(1, 2, 0).cpu().numpy()
image_np = (image_np * np.array(imagenet_std)) + np.array(imagenet_mean)
image_np = np.clip(image_np, 0, 1)
cam_image = show_cam_on_image(image_np, grayscale_cam, use_rgb=True)
# Save each Grad-CAM overlay individually
overlay_path = f"/content/drive/MyDrive/gradcam_outputs/val_{idx}_true_{class_names[label]}_pred_{class_names[pred]}.png"
plt.imsave(overlay_path, cam_image)
# Add to grid
ax = axes[idx // 4, idx % 4]
ax.imshow(cam_image)
ax.set_title(
f"True: {class_names[label][:18]}\nPred: {class_names[pred][:18]}",
fontsize=12,
color="green" if pred == label else "red",
weight="bold"
)
ax.axis('off')
plt.tight_layout(rect=[0, 0.03, 1, 0.95])
plt.savefig('/content/drive/MyDrive/gradcam_outputs/gradcam_grid.png', dpi=250)
plt.show()
"""11. Gradio Interactive Demo: Model + Grad-CAM++"""
# 11. Gradio Interactive Demo: EfficientNetV2 + Grad-CAM++
import gradio as gr
from PIL import Image as PILImage
def predict_and_explain(img):
image_pil = img.convert("RGB").resize((224, 224))
input_tensor = val_transform(image_pil).unsqueeze(0).to(device)
with torch.no_grad():
output = model(input_tensor)
pred_idx = output.argmax().item()
targets = [ClassifierOutputTarget(pred_idx)]
grayscale_cam = cam(input_tensor=input_tensor, targets=targets)[0]
image_np = np.array(image_pil).astype(np.float32) / 255.0
cam_image = show_cam_on_image(image_np, grayscale_cam, use_rgb=True)
pred_name = class_names[pred_idx]
return PILImage.fromarray(cam_image), f"Prediction: {pred_name} (class index {pred_idx})"
demo = gr.Interface(
fn=predict_and_explain,
inputs=gr.Image(type="pil", label="Upload Car Image"),
outputs=[gr.Image(label="Grad-CAM++ Output"), gr.Text(label="Prediction")],
title="🚗 TwinCar: Car Make/Model Classifier + Explainability Demo",
description="Upload a car photo. See the prediction (make/model/year) and a Grad-CAM++ heatmap showing what influenced the model.",
allow_flagging='never'
)
demo.launch(share=True)