| | |
| | """Last_model.ipynb |
| | |
| | Automatically generated by Colab. |
| | |
| | Original file is located at |
| | https://colab.research.google.com/drive/1AdRILP1oqdiVuRSQr2dZZy0QgU8insn_ |
| | |
| | 🚗 TwinCar Project: SOTA Training, Full Visuals, and Advanced Reporting |
| | |
| | |
| | --- |
| | |
| | |
| | --- |
| | |
| | 1. Environment Setup and Imports |
| | Explanation: |
| | We start by importing all necessary libraries and prepping our working environment for advanced data handling and visualization. |
| | |
| | --- |
| | """ |
| |
|
| | |
| | import os |
| | import zipfile |
| | import numpy as np |
| | import pandas as pd |
| | import matplotlib.pyplot as plt |
| | import seaborn as sns |
| | from PIL import Image |
| | from tqdm import tqdm |
| |
|
| | import torch |
| | import torch.nn as nn |
| | import torch.optim as optim |
| | from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler |
| | from torchvision import transforms |
| |
|
| | from sklearn.model_selection import train_test_split |
| | from sklearn.utils.class_weight import compute_class_weight |
| | from sklearn.metrics import ( |
| | accuracy_score, precision_score, recall_score, f1_score, hamming_loss, |
| | cohen_kappa_score, matthews_corrcoef, jaccard_score, |
| | confusion_matrix, classification_report |
| | ) |
| |
|
| | import timm |
| | import scipy.io |
| |
|
| | """2. Data Extraction and Preparation |
| | Explanation: |
| | We extract and organize the Stanford Cars dataset, parse .mat files to CSV for class and label mapping, and prepare all paths. |
| | |
| | --- |
| | |
| | |
| | """ |
| |
|
| | |
| | from google.colab import drive |
| | drive.mount('/content/drive') |
| |
|
| | zip_path = '/content/drive/MyDrive/stanford_cars.zip' |
| | extract_dir = '/content/stanford_cars' |
| | if not os.path.exists(extract_dir): |
| | with zipfile.ZipFile(zip_path, 'r') as zip_ref: |
| | zip_ref.extractall(extract_dir) |
| | print("✅ Dataset extracted at", extract_dir) |
| |
|
| | meta = scipy.io.loadmat(f"{extract_dir}/car_devkit/devkit/cars_meta.mat") |
| | class_names = [x[0] for x in meta['class_names'][0]] |
| | NUM_CLASSES = len(class_names) |
| |
|
| | train_annos = scipy.io.loadmat(f"{extract_dir}/car_devkit/devkit/cars_train_annos.mat")['annotations'][0] |
| | train_rows = [[x[5][0], int(x[4][0]) - 1] for x in train_annos] |
| | df_train = pd.DataFrame(train_rows, columns=["filename", "label"]) |
| | df_train.to_csv('/content/train_labels.csv', index=False) |
| |
|
| | test_annos = scipy.io.loadmat(f"{extract_dir}/car_devkit/devkit/cars_test_annos.mat")['annotations'][0] |
| | test_rows = [[x[4][0]] for x in test_annos] |
| | df_test = pd.DataFrame(test_rows, columns=["filename"]) |
| | df_test.to_csv('/content/test_labels.csv', index=False) |
| |
|
| | train_root = f"{extract_dir}/cars_train/cars_train" |
| | test_root = f"{extract_dir}/cars_test/cars_test" |
| |
|
| | """3. Advanced Dataset and Augmentations |
| | Explanation: |
| | We build a flexible dataset class, apply advanced augmentations, and lay the foundation for Mixup/CutMix later. |
| | |
| | --- |
| | |
| | |
| | """ |
| |
|
| | |
| |
|
| | class StanfordCarsFromCSV(Dataset): |
| | def __init__(self, root_dir, csv_file, transform=None, has_labels=True): |
| | self.root_dir = root_dir |
| | self.data = pd.read_csv(csv_file) |
| | self.transform = transform |
| | self.has_labels = has_labels |
| | def __len__(self): |
| | return len(self.data) |
| | def __getitem__(self, idx): |
| | row = self.data.iloc[idx] |
| | img_path = os.path.join(self.root_dir, row['filename']) |
| | image = Image.open(img_path).convert('RGB') |
| | if self.transform: |
| | image = self.transform(image) |
| | if self.has_labels: |
| | return image, int(row['label']) |
| | return image, row['filename'] |
| |
|
| | imagenet_mean = [0.485, 0.456, 0.406] |
| | imagenet_std = [0.229, 0.224, 0.225] |
| | train_transform = transforms.Compose([ |
| | transforms.RandomResizedCrop(224, scale=(0.7, 1.0)), |
| | transforms.RandomHorizontalFlip(), |
| | transforms.RandomRotation(15), |
| | transforms.ColorJitter(0.4, 0.4, 0.4, 0.2), |
| | transforms.RandomApply([transforms.GaussianBlur(3)], p=0.15), |
| | transforms.ToTensor(), |
| | transforms.Normalize(mean=imagenet_mean, std=imagenet_std) |
| | ]) |
| | val_transform = transforms.Compose([ |
| | transforms.Resize(256), |
| | transforms.CenterCrop(224), |
| | transforms.ToTensor(), |
| | transforms.Normalize(mean=imagenet_mean, std=imagenet_std) |
| | ]) |
| |
|
| | """4. Data Splitting, Weighted Sampling, and DataLoader |
| | Explanation: |
| | We split the data into train and validation sets with stratification for balanced classes, |
| | use class weighting to counter imbalance, and create PyTorch DataLoaders for efficient training and evaluation. |
| | |
| | --- |
| | |
| | |
| | """ |
| |
|
| | |
| |
|
| | from torch.utils.data import DataLoader, WeightedRandomSampler |
| | from sklearn.model_selection import train_test_split |
| | from sklearn.utils.class_weight import compute_class_weight |
| |
|
| | |
| | BATCH_SIZE = 32 |
| | VAL_RATIO = 0.1 |
| | RANDOM_SEED = 42 |
| |
|
| | |
| | df_all = pd.read_csv('/content/train_labels.csv') |
| | df_train, df_val = train_test_split( |
| | df_all, |
| | test_size=VAL_RATIO, |
| | stratify=df_all['label'], |
| | random_state=RANDOM_SEED |
| | ) |
| | df_train.to_csv('/content/train_split.csv', index=False) |
| | df_val.to_csv('/content/val_split.csv', index=False) |
| |
|
| | |
| | train_dataset = StanfordCarsFromCSV(train_root, '/content/train_split.csv', train_transform) |
| | val_dataset = StanfordCarsFromCSV(train_root, '/content/val_split.csv', val_transform) |
| | test_dataset = StanfordCarsFromCSV(test_root, '/content/test_labels.csv', val_transform, has_labels=False) |
| |
|
| | |
| | labels = [label for _, label in train_dataset] |
| | class_weights = compute_class_weight(class_weight='balanced', classes=np.unique(labels), y=labels) |
| | sample_weights = [class_weights[label] for label in labels] |
| | sampler = WeightedRandomSampler(sample_weights, len(sample_weights), replacement=True) |
| |
|
| | |
| | train_loader = DataLoader( |
| | train_dataset, |
| | batch_size=BATCH_SIZE, |
| | sampler=sampler, |
| | num_workers=2, |
| | pin_memory=True, |
| | drop_last=True |
| | ) |
| | val_loader = DataLoader( |
| | val_dataset, |
| | batch_size=BATCH_SIZE, |
| | shuffle=False, |
| | num_workers=2, |
| | pin_memory=True, |
| | drop_last=False |
| | ) |
| | test_loader = DataLoader( |
| | test_dataset, |
| | batch_size=BATCH_SIZE, |
| | shuffle=False, |
| | num_workers=2, |
| | pin_memory=True, |
| | drop_last=False |
| | ) |
| |
|
| | print(f"Train samples: {len(train_dataset)} | Val samples: {len(val_dataset)} | Test samples: {len(test_dataset)}") |
| | print(f"Train loader batches (per epoch): {len(train_loader)} (should be integer and even-sized)") |
| |
|
| | """5. Model Initialization: EfficientNetV2 + Mixup/CutMix Ready |
| | Explanation: |
| | We load EfficientNetV2 with ImageNet weights for best transfer learning, |
| | set up optimizer, scheduler, and prepare for Mixup/CutMix advanced augmentation. |
| | |
| | --- |
| | |
| | |
| | """ |
| |
|
| | |
| |
|
| | from timm.data import Mixup |
| |
|
| | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
| | model = timm.create_model('efficientnetv2_rw_s', pretrained=True, num_classes=NUM_CLASSES, drop_rate=0.3) |
| | model = model.to(device) |
| |
|
| | optimizer = optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-5) |
| | scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=25) |
| | criterion = nn.CrossEntropyLoss(label_smoothing=0.0) |
| |
|
| | mixup_fn = Mixup( |
| | mixup_alpha=0.4, cutmix_alpha=1.0, cutmix_minmax=None, |
| | prob=1.0, switch_prob=0.5, mode='batch', |
| | label_smoothing=0.1, num_classes=NUM_CLASSES |
| | ) |
| |
|
| | """6. Advanced Training Loop: Full Metrics, Early Stopping, and Mixup |
| | Explanation: |
| | This loop supports Mixup/CutMix, logs all advanced metrics, and uses early stopping with automatic best model saving. |
| | Ready for real production—and all your plots and reporting. |
| | |
| | --- |
| | |
| | |
| | """ |
| |
|
| | |
| |
|
| | EPOCHS = 25 |
| | patience, counter = 7, 0 |
| | best_val_f1 = 0 |
| |
|
| | metrics_dict = { |
| | 'train_loss': [], 'train_acc': [], |
| | 'val_loss': [], 'val_acc': [], |
| | 'val_precision_macro': [], 'val_precision_weighted': [], |
| | 'val_recall_macro': [], 'val_recall_weighted': [], |
| | 'val_f1_macro': [], 'val_f1_weighted': [], |
| | 'val_hamming': [], 'val_cohen_kappa': [], |
| | 'val_mcc': [], 'val_jaccard_macro': [], |
| | 'val_top3': [], 'val_top5': [], |
| | } |
| |
|
| | for epoch in range(EPOCHS): |
| | |
| | model.train() |
| | total_loss, correct, total = 0, 0, 0 |
| | for imgs, labels in tqdm(train_loader, desc=f"Train Epoch {epoch+1}"): |
| | imgs, labels = imgs.to(device), labels.to(device) |
| | optimizer.zero_grad() |
| | imgs, labels = mixup_fn(imgs, labels) |
| | outputs = model(imgs) |
| | loss = criterion(outputs, labels) |
| | loss.backward() |
| | optimizer.step() |
| | total_loss += loss.item() * imgs.size(0) |
| | correct += (outputs.argmax(1) == labels.argmax(1)).sum().item() |
| | total += labels.size(0) |
| | train_loss = total_loss / total |
| | train_acc = correct / total |
| | metrics_dict['train_loss'].append(train_loss) |
| | metrics_dict['train_acc'].append(train_acc) |
| |
|
| | |
| | model.eval() |
| | val_loss, val_correct, val_total = 0, 0, 0 |
| | val_probs, val_preds, val_targets = [], [], [] |
| | with torch.no_grad(): |
| | for imgs, labels in tqdm(val_loader, desc=f"Val Epoch {epoch+1}"): |
| | imgs, labels = imgs.to(device), labels.to(device) |
| | outputs = model(imgs) |
| | v_loss = criterion(outputs, labels) |
| | val_loss += v_loss.item() * imgs.size(0) |
| | probs = torch.softmax(outputs, dim=1) |
| | preds = outputs.argmax(1) |
| | val_correct += (preds == labels).sum().item() |
| | val_total += labels.size(0) |
| | val_probs.extend(probs.cpu().numpy()) |
| | val_preds.extend(preds.cpu().numpy()) |
| | val_targets.extend(labels.cpu().numpy()) |
| | val_loss /= val_total |
| | val_acc = val_correct / val_total |
| | val_preds_np = np.array(val_preds) |
| | val_targets_np = np.array(val_targets) |
| | val_probs_np = np.array(val_probs) |
| |
|
| | |
| | val_precision_macro = precision_score(val_targets_np, val_preds_np, average='macro', zero_division=0) |
| | val_precision_weighted = precision_score(val_targets_np, val_preds_np, average='weighted', zero_division=0) |
| | val_recall_macro = recall_score(val_targets_np, val_preds_np, average='macro', zero_division=0) |
| | val_recall_weighted = recall_score(val_targets_np, val_preds_np, average='weighted', zero_division=0) |
| | val_f1_macro = f1_score(val_targets_np, val_preds_np, average='macro', zero_division=0) |
| | val_f1_weighted = f1_score(val_targets_np, val_preds_np, average='weighted', zero_division=0) |
| | top3_acc = np.mean([ |
| | label in np.argsort(prob)[-3:] for prob, label in zip(val_probs_np, val_targets_np) |
| | ]) |
| | top5_acc = np.mean([ |
| | label in np.argsort(prob)[-5:] for prob, label in zip(val_probs_np, val_targets_np) |
| | ]) |
| | val_hamming = hamming_loss(val_targets_np, val_preds_np) |
| | val_cohen_kappa = cohen_kappa_score(val_targets_np, val_preds_np) |
| | val_mcc = matthews_corrcoef(val_targets_np, val_preds_np) |
| | val_jaccard_macro = jaccard_score(val_targets_np, val_preds_np, average='macro', zero_division=0) |
| |
|
| | |
| | metrics_dict['val_loss'].append(val_loss) |
| | metrics_dict['val_acc'].append(val_acc) |
| | metrics_dict['val_precision_macro'].append(val_precision_macro) |
| | metrics_dict['val_precision_weighted'].append(val_precision_weighted) |
| | metrics_dict['val_recall_macro'].append(val_recall_macro) |
| | metrics_dict['val_recall_weighted'].append(val_recall_weighted) |
| | metrics_dict['val_f1_macro'].append(val_f1_macro) |
| | metrics_dict['val_f1_weighted'].append(val_f1_weighted) |
| | metrics_dict['val_hamming'].append(val_hamming) |
| | metrics_dict['val_cohen_kappa'].append(val_cohen_kappa) |
| | metrics_dict['val_mcc'].append(val_mcc) |
| | metrics_dict['val_jaccard_macro'].append(val_jaccard_macro) |
| | metrics_dict['val_top3'].append(top3_acc) |
| | metrics_dict['val_top5'].append(top5_acc) |
| |
|
| | scheduler.step() |
| | print(f"Epoch {epoch+1:2d} | Train Acc: {train_acc:.4f} | Val Acc: {val_acc:.4f} | F1(macro): {val_f1_macro:.4f} | Top3: {top3_acc:.3f} | Top5: {top5_acc:.3f}") |
| |
|
| | |
| | if val_f1_macro > best_val_f1: |
| | best_val_f1 = val_f1_macro |
| | torch.save(model.state_dict(), '/content/drive/MyDrive/efficientnetv2_best_model.pth') |
| | counter = 0 |
| | else: |
| | counter += 1 |
| | if counter >= patience: |
| | print("⏹️ Early stopping triggered.") |
| | break |
| |
|
| | print("✅ Training complete. Best model saved.") |
| |
|
| | """7.Explanation |
| | After training, all metrics (accuracy, loss, precision, recall, F1, top-k, etc.) are saved as a CSV for analysis and reporting. |
| | |
| | We plot core metrics (accuracy, F1, loss, precision/recall, top-3/top-5 accuracy) with: |
| | |
| | Large, clear fonts |
| | |
| | Annotations for best epoch |
| | |
| | Colorful, pro-style Seaborn plots |
| | |
| | Publication-ready grid and tight layouts |
| | |
| | --- |
| | |
| | |
| | """ |
| |
|
| | |
| |
|
| | import seaborn as sns |
| |
|
| | |
| | metrics_df = pd.DataFrame(metrics_dict) |
| | metrics_df.to_csv('/content/drive/MyDrive/metrics_log.csv', index_label='epoch') |
| | print("✅ metrics_log.csv saved.") |
| |
|
| | sns.set(style='whitegrid', font_scale=1.3) |
| |
|
| | |
| | plt.figure(figsize=(12,7)) |
| | plt.plot(metrics_df['train_acc'], label='Train Acc', lw=2) |
| | plt.plot(metrics_df['val_acc'], label='Val Acc', lw=2) |
| | plt.plot(metrics_df['val_f1_macro'], label='Val F1 (macro)', lw=2) |
| | plt.xlabel('Epoch', fontsize=16) |
| | plt.ylabel('Score', fontsize=16) |
| | plt.title('Accuracy and Macro F1 per Epoch', fontsize=18) |
| | plt.legend(loc='lower right') |
| | plt.grid(True, alpha=0.3) |
| | best_epoch = metrics_df['val_f1_macro'].idxmax() |
| | plt.scatter(best_epoch, metrics_df['val_f1_macro'][best_epoch], c='red', s=90, label='Best Epoch') |
| | plt.annotate(f'Best\n{metrics_df["val_f1_macro"][best_epoch]:.2f}', |
| | (best_epoch, metrics_df["val_f1_macro"][best_epoch]), |
| | textcoords="offset points", xytext=(-5,10), ha='right', fontsize=14, color='red') |
| | plt.tight_layout() |
| | plt.savefig('/content/drive/MyDrive/metrics_acc_f1_beautiful.png') |
| | plt.show() |
| |
|
| | |
| | plt.figure(figsize=(12,7)) |
| | plt.plot(metrics_df['train_loss'], label='Train Loss', lw=2) |
| | plt.plot(metrics_df['val_loss'], label='Val Loss', lw=2) |
| | plt.xlabel('Epoch', fontsize=16) |
| | plt.ylabel('Loss', fontsize=16) |
| | plt.title('Train & Validation Loss per Epoch', fontsize=18) |
| | plt.legend(loc='upper right') |
| | plt.grid(True, alpha=0.3) |
| | plt.tight_layout() |
| | plt.savefig('/content/drive/MyDrive/metrics_loss_beautiful.png') |
| | plt.show() |
| |
|
| | |
| | plt.figure(figsize=(12,7)) |
| | plt.plot(metrics_df['val_precision_macro'], label='Val Precision (macro)', lw=2) |
| | plt.plot(metrics_df['val_recall_macro'], label='Val Recall (macro)', lw=2) |
| | plt.plot(metrics_df['val_precision_weighted'], label='Val Precision (weighted)', lw=2) |
| | plt.plot(metrics_df['val_recall_weighted'], label='Val Recall (weighted)', lw=2) |
| | plt.xlabel('Epoch', fontsize=16) |
| | plt.ylabel('Score', fontsize=16) |
| | plt.title('Validation Precision & Recall per Epoch', fontsize=18) |
| | plt.legend(loc='lower right') |
| | plt.grid(True, alpha=0.3) |
| | plt.tight_layout() |
| | plt.savefig('/content/drive/MyDrive/metrics_precision_recall_beautiful.png') |
| | plt.show() |
| |
|
| | |
| | plt.figure(figsize=(12,7)) |
| | plt.fill_between(metrics_df.index, metrics_df['val_top3'], alpha=0.3, label='Val Top-3 Acc') |
| | plt.fill_between(metrics_df.index, metrics_df['val_top5'], alpha=0.2, label='Val Top-5 Acc', color='orange') |
| | plt.plot(metrics_df['val_top3'], lw=2, color='blue') |
| | plt.plot(metrics_df['val_top5'], lw=2, color='orange') |
| | plt.xlabel('Epoch', fontsize=16) |
| | plt.ylabel('Accuracy', fontsize=16) |
| | plt.title('Top-3 and Top-5 Validation Accuracy per Epoch', fontsize=18) |
| | plt.legend(loc='lower right') |
| | plt.grid(True, alpha=0.3) |
| | plt.tight_layout() |
| | plt.savefig('/content/drive/MyDrive/metrics_topk_beautiful.png') |
| | plt.show() |
| |
|
| | """8.Confusion Matrix & Per-Class Analysis with Advanced Visuals |
| | Explanation |
| | After training, it's crucial to understand not just overall metrics, but where your model succeeds and fails. |
| | We: |
| | |
| | Save a detailed classification report (per-class precision/recall/F1). |
| | |
| | Draw a high-contrast confusion matrix with large ticks, tight color scaling, and readable value overlays. |
| | |
| | Plot Top 20 Most Confused Classes for targeted debugging. |
| | |
| | Show Top 20 Most Accurate Classes with horizontal barplots (values on bars, sorted). |
| | |
| | |
| | |
| | --- |
| | |
| | |
| | """ |
| |
|
| | |
| |
|
| | from sklearn.metrics import classification_report, confusion_matrix |
| | import seaborn as sns |
| |
|
| | |
| | model.load_state_dict(torch.load('/content/drive/MyDrive/efficientnetv2_best_model.pth', map_location=device)) |
| | model.eval() |
| |
|
| | |
| | all_preds, all_labels = [], [] |
| | with torch.no_grad(): |
| | for imgs, labels in val_loader: |
| | imgs, labels = imgs.to(device), labels.to(device) |
| | outputs = model(imgs) |
| | preds = outputs.argmax(1) |
| | all_preds.extend(preds.cpu().numpy()) |
| | all_labels.extend(labels.cpu().numpy()) |
| | all_preds = np.array(all_preds) |
| | all_labels = np.array(all_labels) |
| |
|
| | |
| | report = classification_report( |
| | all_labels, all_preds, target_names=class_names, output_dict=True |
| | ) |
| | pd.DataFrame(report).transpose().to_csv('/content/drive/MyDrive/classification_report.csv') |
| | print("✅ classification_report.csv saved.") |
| |
|
| | |
| | cm = confusion_matrix(all_labels, all_preds) |
| | plt.figure(figsize=(18,18)) |
| | sns.heatmap( |
| | cm, |
| | cmap="Blues", |
| | xticklabels=class_names, |
| | yticklabels=class_names, |
| | square=True, |
| | cbar_kws={"shrink": 0.5, "label": "Count"}, |
| | linewidths=.2 |
| | ) |
| | plt.title('Confusion Matrix', fontsize=20) |
| | plt.xlabel('Predicted label', fontsize=16) |
| | plt.ylabel('True label', fontsize=16) |
| | plt.xticks(fontsize=8, rotation=90) |
| | plt.yticks(fontsize=8) |
| | plt.tight_layout() |
| | plt.savefig('/content/drive/MyDrive/confusion_matrix_beautiful.png', dpi=300) |
| | plt.show() |
| |
|
| | |
| | off_diag = cm.copy() |
| | np.fill_diagonal(off_diag, 0) |
| | most_confused = np.argsort(off_diag.sum(axis=1))[::-1][:20] |
| | cm_top = cm[np.ix_(most_confused, most_confused)] |
| | labels_top = [class_names[i] for i in most_confused] |
| |
|
| | plt.figure(figsize=(12,10)) |
| | sns.heatmap( |
| | cm_top, |
| | annot=True, fmt='d', cmap="Blues", |
| | xticklabels=labels_top, yticklabels=labels_top, |
| | linewidths=.2, cbar=False, annot_kws={"size":14} |
| | ) |
| | plt.title('Most Confused Classes (Top 20)', fontsize=18) |
| | plt.xlabel('Predicted label', fontsize=15) |
| | plt.ylabel('True label', fontsize=15) |
| | plt.xticks(fontsize=11, rotation=90) |
| | plt.yticks(fontsize=11) |
| | plt.tight_layout() |
| | plt.savefig('/content/drive/MyDrive/confused_top20_beautiful.png', dpi=300) |
| | plt.show() |
| |
|
| | |
| | acc_per_class = cm.diagonal() / (cm.sum(axis=1) + 1e-8) |
| | df_acc = pd.DataFrame({'class': class_names, 'accuracy': acc_per_class}) |
| | top_acc = df_acc.sort_values('accuracy', ascending=False).head(20) |
| | plt.figure(figsize=(10,8)) |
| | sns.barplot( |
| | data=top_acc, y='class', x='accuracy', palette='Blues_d', orient='h' |
| | ) |
| | plt.title('Top 20 Classes by Accuracy', fontsize=18) |
| | plt.xlabel('Accuracy', fontsize=15) |
| | plt.ylabel('Class', fontsize=15) |
| | for i, v in enumerate(top_acc['accuracy']): |
| | plt.text(v + 0.01, i, f"{v:.2f}", color='blue', va='center', fontsize=13) |
| | plt.tight_layout() |
| | plt.savefig('/content/drive/MyDrive/top20_accuracy_beautiful.png', dpi=300) |
| | plt.show() |
| |
|
| | """9. Test-Time Augmentation (TTA) & Batch Prediction |
| | Explanation |
| | Test-Time Augmentation boosts prediction robustness by averaging predictions over multiple random transformations of each test image. |
| | Batch Prediction allows you to efficiently label a folder of test images with class names—production style. |
| | """ |
| |
|
| | |
| |
|
| | tta_transforms = [ |
| | val_transform, |
| | transforms.Compose([ |
| | transforms.Resize(256), |
| | transforms.RandomHorizontalFlip(p=1.0), |
| | transforms.CenterCrop(224), |
| | transforms.ToTensor(), |
| | transforms.Normalize(mean=imagenet_mean, std=imagenet_std) |
| | ]), |
| | transforms.Compose([ |
| | transforms.Resize(256), |
| | transforms.RandomRotation(10), |
| | transforms.CenterCrop(224), |
| | transforms.ToTensor(), |
| | transforms.Normalize(mean=imagenet_mean, std=imagenet_std) |
| | ]), |
| | transforms.Compose([ |
| | transforms.Resize(256), |
| | transforms.ColorJitter(0.2, 0.2, 0.2, 0.1), |
| | transforms.CenterCrop(224), |
| | transforms.ToTensor(), |
| | transforms.Normalize(mean=imagenet_mean, std=imagenet_std) |
| | ]) |
| | ] |
| |
|
| | def tta_predict(model, img_pil, tta_transforms, device='cuda'): |
| | model.eval() |
| | logits = [] |
| | for tform in tta_transforms: |
| | img = tform(img_pil).unsqueeze(0).to(device) |
| | with torch.no_grad(): |
| | logit = model(img) |
| | logits.append(logit) |
| | avg_logits = torch.stack(logits).mean(0) |
| | return avg_logits |
| |
|
| | |
| | tta_val_preds, tta_val_labels = [], [] |
| | for imgs, labels in tqdm(val_loader, desc="TTA Validation"): |
| | batch_preds = [] |
| | for i in range(imgs.size(0)): |
| | img_pil = transforms.ToPILImage()(imgs[i].cpu()) |
| | avg_logits = tta_predict(model, img_pil, tta_transforms, device) |
| | pred = avg_logits.argmax(dim=1).cpu().item() |
| | batch_preds.append(pred) |
| | tta_val_preds.extend(batch_preds) |
| | tta_val_labels.extend(labels.cpu().numpy()) |
| |
|
| | tta_val_preds = np.array(tta_val_preds) |
| | tta_val_labels = np.array(tta_val_labels) |
| |
|
| | |
| | tta_f1_macro = f1_score(tta_val_labels, tta_val_preds, average='macro', zero_division=0) |
| | tta_acc = accuracy_score(tta_val_labels, tta_val_preds) |
| | tta_precision = precision_score(tta_val_labels, tta_val_preds, average='macro', zero_division=0) |
| | tta_recall = recall_score(tta_val_labels, tta_val_preds, average='macro', zero_division=0) |
| | print(f"TTA Validation Accuracy: {tta_acc:.4f}") |
| | print(f"TTA Validation F1 (macro): {tta_f1_macro:.4f}") |
| | print(f"TTA Validation Precision (macro): {tta_precision:.4f}") |
| | print(f"TTA Validation Recall (macro): {tta_recall:.4f}") |
| |
|
| | |
| | cm_tta = confusion_matrix(tta_val_labels, tta_val_preds) |
| | plt.figure(figsize=(18,18)) |
| | sns.heatmap( |
| | cm_tta, |
| | cmap="Purples", |
| | xticklabels=class_names, |
| | yticklabels=class_names, |
| | square=True, |
| | cbar_kws={"shrink": 0.5, "label": "Count"}, |
| | linewidths=.2 |
| | ) |
| | plt.title('TTA Confusion Matrix (Validation)', fontsize=20) |
| | plt.xlabel('Predicted label', fontsize=16) |
| | plt.ylabel('True label', fontsize=16) |
| | plt.xticks(fontsize=8, rotation=90) |
| | plt.yticks(fontsize=8) |
| | plt.tight_layout() |
| | plt.savefig('/content/drive/MyDrive/tta_confusion_matrix_beautiful.png', dpi=300) |
| | plt.show() |
| |
|
| | """10. Extraordinary Grad-CAM++ Overlays (Grid) |
| | Explanation |
| | We generate Grad-CAM++ visualizations for a set of sample images. |
| | Each visualization shows:The input image,The Grad-CAM++ heatmap overlay,The true and predicted class for easy comparison. |
| | All visualizations are saved both individually and as a large, labeled grid. |
| | |
| | |
| | |
| | --- |
| | """ |
| |
|
| | |
| |
|
| | from pytorch_grad_cam import GradCAMPlusPlus |
| | from pytorch_grad_cam.utils.image import show_cam_on_image |
| | from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget |
| |
|
| | import os |
| |
|
| | os.makedirs('/content/drive/MyDrive/gradcam_outputs', exist_ok=True) |
| |
|
| | |
| | model.eval() |
| | model.to(device) |
| |
|
| | |
| | target_layer = model.blocks[-1] if hasattr(model, "blocks") else model.layer4[-1] |
| |
|
| | |
| | cam = GradCAMPlusPlus(model=model, target_layers=[target_layer]) |
| |
|
| | num_images = 12 |
| | fig, axes = plt.subplots(3, 4, figsize=(18, 14)) |
| | fig.suptitle('Grad-CAM++ Explanations: True vs. Predicted', fontsize=22, weight='bold') |
| |
|
| | for idx in range(num_images): |
| | img_tensor, label = val_dataset[idx] |
| | img_pil = transforms.ToPILImage()(img_tensor.cpu()) |
| | input_tensor = img_tensor.unsqueeze(0).to(device) |
| | with torch.no_grad(): |
| | output = model(input_tensor) |
| | pred = output.argmax(1).item() |
| | targets = [ClassifierOutputTarget(pred)] |
| | grayscale_cam = cam(input_tensor=input_tensor, targets=targets)[0] |
| | image_np = img_tensor.permute(1, 2, 0).cpu().numpy() |
| | image_np = (image_np * np.array(imagenet_std)) + np.array(imagenet_mean) |
| | image_np = np.clip(image_np, 0, 1) |
| | cam_image = show_cam_on_image(image_np, grayscale_cam, use_rgb=True) |
| |
|
| | |
| | overlay_path = f"/content/drive/MyDrive/gradcam_outputs/val_{idx}_true_{class_names[label]}_pred_{class_names[pred]}.png" |
| | plt.imsave(overlay_path, cam_image) |
| |
|
| | |
| | ax = axes[idx // 4, idx % 4] |
| | ax.imshow(cam_image) |
| | ax.set_title( |
| | f"True: {class_names[label][:18]}\nPred: {class_names[pred][:18]}", |
| | fontsize=12, |
| | color="green" if pred == label else "red", |
| | weight="bold" |
| | ) |
| | ax.axis('off') |
| |
|
| | plt.tight_layout(rect=[0, 0.03, 1, 0.95]) |
| | plt.savefig('/content/drive/MyDrive/gradcam_outputs/gradcam_grid.png', dpi=250) |
| | plt.show() |
| |
|
| | """11. Gradio Interactive Demo: Model + Grad-CAM++""" |
| |
|
| | |
| |
|
| | import gradio as gr |
| | from PIL import Image as PILImage |
| |
|
| | def predict_and_explain(img): |
| | image_pil = img.convert("RGB").resize((224, 224)) |
| | input_tensor = val_transform(image_pil).unsqueeze(0).to(device) |
| | with torch.no_grad(): |
| | output = model(input_tensor) |
| | pred_idx = output.argmax().item() |
| | targets = [ClassifierOutputTarget(pred_idx)] |
| | grayscale_cam = cam(input_tensor=input_tensor, targets=targets)[0] |
| | image_np = np.array(image_pil).astype(np.float32) / 255.0 |
| | cam_image = show_cam_on_image(image_np, grayscale_cam, use_rgb=True) |
| | pred_name = class_names[pred_idx] |
| | return PILImage.fromarray(cam_image), f"Prediction: {pred_name} (class index {pred_idx})" |
| |
|
| | demo = gr.Interface( |
| | fn=predict_and_explain, |
| | inputs=gr.Image(type="pil", label="Upload Car Image"), |
| | outputs=[gr.Image(label="Grad-CAM++ Output"), gr.Text(label="Prediction")], |
| | title="🚗 TwinCar: Car Make/Model Classifier + Explainability Demo", |
| | description="Upload a car photo. See the prediction (make/model/year) and a Grad-CAM++ heatmap showing what influenced the model.", |
| | allow_flagging='never' |
| | ) |
| | demo.launch(share=True) |