DFU / model.py
EngReem85's picture
Create model.py
f077fac verified
# Import system libraries
import os
import cv2
import numpy as np
import matplotlib.pyplot as plt
from glob import glob
from PIL import Image
from sklearn.model_selection import train_test_split
# Import data handling tools
import pandas as pd
import seaborn as sns
sns.set_style('darkgrid')
from sklearn.metrics import confusion_matrix, classification_report, roc_curve, auc
from sklearn.utils import shuffle
from torch.utils.data import WeightedRandomSampler
from skimage.feature import local_binary_pattern
# Import deep learning libraries
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision.transforms as transforms
import torchvision.models as models
from torch.utils.data import Dataset, DataLoader
# Define dataset path and classes
DATASET_PATH = "/kaggle/input/ms-dfu/DFU_CLASSES(4)"
CLASSES = ["NONE", "INFECTION", "ISCHAEMIA", "BOTH"]
# Ensure output directories exist
os.makedirs("/kaggle/working/logs", exist_ok=True)
os.makedirs("/kaggle/working/predictions", exist_ok=True)
os.makedirs("/kaggle/working/visualizations", exist_ok=True)
# Squeeze-and-Excitation Block
class SEBlock(nn.Module):
def __init__(self, in_channels, reduction=16):
super(SEBlock, self).__init__()
self.fc1 = nn.Linear(in_channels, in_channels // reduction, bias=False)
self.fc2 = nn.Linear(in_channels // reduction, in_channels, bias=False)
self.global_pool = nn.AdaptiveAvgPool2d(1)
def forward(self, x):
batch, channels, _, _ = x.size()
y = self.global_pool(x).view(batch, channels)
y = F.relu(self.fc1(y))
y = torch.sigmoid(self.fc2(y)).view(batch, channels, 1, 1)
return x * y
# Focal Loss Implementation
class FocalLoss(nn.Module):
def __init__(self, gamma=3.0, alpha=0.5):
super(FocalLoss, self).__init__()
self.gamma = gamma
self.alpha = alpha
def forward(self, inputs, targets):
ce_loss = F.cross_entropy(inputs, targets, reduction='none')
pt = torch.exp(-ce_loss)
focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss
return focal_loss.mean()
# Channel-Centric Depth-wise Group Shuffle (CCDGS) Block
class CCDGSBlock(nn.Module):
def __init__(self, in_channels, group_size=4):
super(CCDGSBlock, self).__init__()
self.group_size = group_size
self.group_conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, groups=group_size, bias=False)
self.bn1 = nn.BatchNorm2d(in_channels)
self.depth_conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, groups=in_channels, bias=False)
self.bn2 = nn.BatchNorm2d(in_channels)
self.point_conv = nn.Conv2d(in_channels, in_channels, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(in_channels)
self.global_pool = nn.AdaptiveAvgPool2d(1)
def channel_shuffle(self, x, groups):
batchsize, num_channels, height, width = x.size()
channels_per_group = num_channels // groups
x = x.view(batchsize, groups, channels_per_group, height, width)
x = torch.transpose(x, 1, 2).contiguous()
x = x.view(batchsize, -1, height, width)
return x
def forward(self, x):
out = self.group_conv(x)
out = self.bn1(out)
out = F.relu(out)
out = self.channel_shuffle(out, self.group_size)
out = self.depth_conv(out)
out = self.bn2(out)
out = F.relu(out)
out = self.point_conv(out)
out = self.bn3(out)
out = F.relu(out)
out = self.global_pool(out)
return out
# Triplet Attention Module
class TripletAttention(nn.Module):
def __init__(self, in_channels, kernel_size=7):
super(TripletAttention, self).__init__()
self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)
self.conv2 = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)
self.conv3 = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False)
def z_pool(self, x):
max_pool = torch.max(x, dim=1, keepdim=True)[0]
avg_pool = torch.mean(x, dim=1, keepdim=True)
return torch.cat([max_pool, avg_pool], dim=1)
def forward(self, x):
x1 = torch.rot90(x, 1, [2, 3])
x1 = self.z_pool(x1)
x1 = self.conv1(x1)
x1 = torch.sigmoid(x1)
x1 = torch.rot90(x1, -1, [2, 3])
y1 = x * x1
x2 = torch.rot90(x, 1, [1, 3])
x2 = self.z_pool(x2)
x2 = self.conv2(x2)
x2 = torch.sigmoid(x2)
x2 = torch.rot90(x2, -1, [1, 3])
y2 = x * x2
x3 = self.z_pool(x)
x3 = self.conv3(x3)
x3 = torch.sigmoid(x3)
y3 = x * x3
out = (y1 + y2 + y3) / 3.0
return out
# Dense-ShuffleGCANet Model
class DenseShuffleGCANet(nn.Module):
def __init__(self, num_classes=4, handcrafted_feature_dim=41):
super(DenseShuffleGCANet, self).__init__()
densenet = models.densenet169(weights='IMAGENET1K_V1')
self.features = densenet.features
self.ccdgs = CCDGSBlock(in_channels=1664, group_size=4)
self.triplet_attention = TripletAttention(in_channels=1664)
self.se_block = SEBlock(in_channels=1664)
self.global_pool = nn.AdaptiveAvgPool2d(1)
self.flatten = nn.Flatten()
self.dropout = nn.Dropout(0.6)
self.fc1 = nn.Linear(1664 + handcrafted_feature_dim, 512)
self.fc2 = nn.Linear(512, num_classes)
def forward(self, x, handcrafted_features=None):
x = self.features(x)
x = self.ccdgs(x)
x = self.triplet_attention(x)
x = self.se_block(x)
x = self.global_pool(x)
x = self.flatten(x)
if handcrafted_features is not None:
x = torch.cat([x, handcrafted_features], dim=1)
x = self.dropout(x)
x = F.relu(self.fc1(x))
x = self.dropout(x)
x = self.fc2(x)
return x
# Function to display sample images
def display_sample_images(images, labels, split_name, classes, num_samples=4):
plt.figure(figsize=(15, 10))
for class_idx, class_name in enumerate(classes):
class_indices = [i for i, label in enumerate(labels) if label == class_idx]
if not class_indices:
continue
selected_indices = class_indices[:num_samples]
for i, idx in enumerate(selected_indices):
img = cv2.cvtColor(images[idx], cv2.COLOR_BGR2RGB)
plt.subplot(len(classes), num_samples, class_idx * num_samples + i + 1)
plt.imshow(img)
plt.title(f'{class_name}')
plt.axis('off')
plt.suptitle(f'{split_name} Sample Images')
plt.tight_layout(rect=[0, 0, 1, 0.95])
plt.savefig(f'/kaggle/working/visualizations/{split_name.lower()}_samples.png')
plt.close()
# Function to visualize handcrafted features
# def visualize_handcrafted_features(images, labels, classes, num_samples=2):
# for class_idx, class_name in enumerate(classes):
# class_indices = [i for i, label in enumerate(labels) if label == class_idx]
# if not class_indices:
# continue
# selected_indices = class_indices[:num_samples]
# for idx in selected_indices:
# img = cv2.cvtColor(images[idx], cv2.COLOR_BGR2RGB)
# gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
# gray_blur = cv2.GaussianBlur(gray, (5, 5), 0)
# gray_eq = cv2.equalizeHist(gray_blur)
# median = np.median(gray_eq)
# lower_threshold = int(max(0, 0.66 * median))
# upper_threshold = int(min(255, 1.33 * median))
# edges = cv2.Canny(gray_eq, lower_threshold, upper_threshold)
# print(f"Visualization - Class: {class_name}, Image {idx}, Edge pixels: {np.sum(edges > 0)}")
# edge_hist, _ = np.histogram(edges.ravel(), bins=8, range=(0, 256), density=True)
# lbp = local_binary_pattern(gray, P=8, R=1, method='uniform')
# lbp_hist, _ = np.histogram(lbp.ravel(), bins=np.arange(0, 10), density=True)
# color_hist = []
# for channel in range(img.shape[2]):
# hist, _ = np.histogram(img[:, :, channel], bins=8, range=(0, 256), density=True)
# color_hist.extend(hist)
# plt.figure(figsize=(18, 4))
# plt.subplot(1, 5, 1)
# plt.imshow(img)
# plt.title(f'Original ({class_name})')
# plt.axis('off')
# plt.subplot(1, 5, 2)
# plt.imshow(edges, cmap='gray')
# plt.title('Canny Edge Map')
# plt.axis('off')
# plt.subplot(1, 5, 3)
# plt.bar(range(len(lbp_hist)), lbp_hist)
# plt.title('LBP Histogram')
# plt.subplot(1, 5, 4)
# plt.bar(range(len(color_hist)), color_hist)
# plt.title('Color Histogram')
# plt.subplot(1, 5, 5)
# plt.bar(range(len(edge_hist)), edge_hist)
# plt.title('Edge Histogram')
# plt.tight_layout()
# plt.savefig(f'/kaggle/working/visualizations/handcrafted_features_{class_name}_{idx}.png')
# plt.close()
# Function to visualize handcrafted features
# def visualize_handcrafted_features(images, labels, classes, num_samples=1):
# for class_idx, class_name in enumerate(classes):
# class_indices = [i for i, label in enumerate(labels) if label == class_idx]
# if not class_indices:
# continue
# selected_indices = class_indices[:num_samples]
# for idx in selected_indices:
# img = cv2.cvtColor(images[idx], cv2.COLOR_BGR2RGB)
# gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
# gray_blur = cv2.GaussianBlur(gray, (5, 5), 0)
# gray_eq = cv2.equalizeHist(gray_blur)
# median = np.median(gray_eq)
# lower_threshold = int(max(0, 0.66 * median))
# upper_threshold = int(min(255, 1.33 * median))
# edges = cv2.Canny(gray_eq, lower_threshold, upper_threshold)
# print(f"Visualization - Class: {class_name}, Image {idx}, Edge pixels: {np.sum(edges > 0)}")
# edge_hist, _ = np.histogram(edges.ravel(), bins=8, range=(0, 256), density=True)
# lbp = local_binary_pattern(gray, P=8, R=1, method='uniform')
# lbp_hist, _ = np.histogram(lbp.ravel(), bins=np.arange(0, 10), density=True)
# color_hist = []
# for channel in range(img.shape[2]):
# hist, _ = np.histogram(img[:, :, channel], bins=8, range=(0, 256), density=True)
# color_hist.extend(hist)
# plt.figure(figsize=(18, 4))
# plt.subplot(1, 5, 1)
# plt.imshow(img)
# plt.title(f'Original ({class_name})')
# plt.axis('off')
# plt.subplot(1, 5, 2)
# plt.imshow(edges, cmap='gray')
# plt.title('Canny Edge Map')
# plt.axis('off')
# # LBP Histogram with bold values
# plt.subplot(1, 5, 3)
# bars = plt.bar(range(len(lbp_hist)), lbp_hist)
# plt.title('LBP Histogram')
# for bar in bars:
# height = bar.get_height()
# plt.text(bar.get_x() + bar.get_width()/2., height,
# f'{height:.2f}',
# ha='center', va='bottom',
# fontsize=4, fontweight='bold') # Bold and slightly larger
# # Color Histogram with bold values
# plt.subplot(1, 5, 4)
# bars = plt.bar(range(len(color_hist)), color_hist)
# plt.title('Color Histogram')
# for bar in bars:
# height = bar.get_height()
# plt.text(bar.get_x() + bar.get_width()/2., height,
# f'{height:.2f}',
# ha='center', va='bottom',
# fontsize=4, fontweight='bold') # Bold and slightly larger
# # Edge Histogram with bold values
# plt.subplot(1, 5, 5)
# bars = plt.bar(range(len(edge_hist)), edge_hist)
# plt.title('Edge Histogram')
# for bar in bars:
# height = bar.get_height()
# plt.text(bar.get_x() + bar.get_width()/2., height,
# f'{height:.2f}',
# ha='center', va='bottom',
# fontsize=4, fontweight='bold') # Bold and slightly larger
# plt.tight_layout()
# plt.savefig(f'/kaggle/working/visualizations/handcrafted_features_{class_name}_{idx}.png')
# plt.close()
def visualize_handcrafted_features(images, labels, classes, num_samples=1):
# Create main visualization directory
main_dir = '/kaggle/working/visualizations/handcrafted_features'
os.makedirs(main_dir, exist_ok=True)
for class_idx, class_name in enumerate(classes):
# Create class-specific subdirectory
class_dir = os.path.join(main_dir, f"class_{class_idx}_{class_name}")
os.makedirs(class_dir, exist_ok=True)
class_indices = [i for i, label in enumerate(labels) if label == class_idx]
if not class_indices:
continue
selected_indices = class_indices[:num_samples]
for idx in selected_indices:
img = cv2.cvtColor(images[idx], cv2.COLOR_BGR2RGB)
gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
# Preprocessing
gray_blur = cv2.GaussianBlur(gray, (5, 5), 0)
gray_eq = cv2.equalizeHist(gray_blur)
# Edge detection
median = np.median(gray_eq)
lower_threshold = int(max(0, 0.66 * median))
upper_threshold = int(min(255, 1.33 * median))
edges = cv2.Canny(gray_eq, lower_threshold, upper_threshold)
print(f"Visualization - Class: {class_name}, Image {idx}, Edge pixels: {np.sum(edges > 0)}")
# Feature extraction
edge_hist, _ = np.histogram(edges.ravel(), bins=8, range=(0, 256), density=True)
lbp = local_binary_pattern(gray, P=8, R=1, method='uniform')
lbp_hist, _ = np.histogram(lbp.ravel(), bins=np.arange(0, 10), density=True)
color_hist = []
for channel in range(img.shape[2]):
hist, _ = np.histogram(img[:, :, channel], bins=8, range=(0, 256), density=True)
color_hist.extend(hist)
# 1. Original Image
plt.figure(figsize=(5, 5))
plt.imshow(img)
plt.title(f'Original ({class_name})')
plt.axis('off')
plt.savefig(os.path.join(class_dir, f'sample_{idx}_original.png'),
dpi=120, bbox_inches='tight')
plt.close()
# 2. Edge Map
plt.figure(figsize=(5, 5))
plt.imshow(edges, cmap='gray')
plt.title('Canny Edge Map')
plt.axis('off')
plt.savefig(os.path.join(class_dir, f'sample_{idx}_edges.png'),
dpi=120, bbox_inches='tight')
plt.close()
# 3. LBP Histogram (with your exact text styling)
plt.figure(figsize=(8, 4))
bars = plt.bar(range(len(lbp_hist)), lbp_hist)
plt.title('LBP Histogram')
for bar in bars:
height = bar.get_height()
plt.text(bar.get_x() + bar.get_width()/2., height,
f'{height:.2f}',
ha='center', va='bottom',
fontsize=8, fontweight='bold') # Slightly larger font for separate image
plt.savefig(os.path.join(class_dir, f'sample_{idx}_lbp_hist.png'),
dpi=120, bbox_inches='tight')
plt.close()
# 4. Color Histogram
plt.figure(figsize=(10, 4))
bars = plt.bar(range(len(color_hist)), color_hist)
plt.title('Color Histogram')
for bar in bars:
height = bar.get_height()
plt.text(bar.get_x() + bar.get_width()/2., height,
f'{height:.2f}',
ha='center', va='bottom',
fontsize=8, fontweight='bold') # Slightly larger font
plt.savefig(os.path.join(class_dir, f'sample_{idx}_color_hist.png'),
dpi=120, bbox_inches='tight')
plt.close()
# 5. Edge Histogram
plt.figure(figsize=(8, 4))
bars = plt.bar(range(len(edge_hist)), edge_hist)
plt.title('Edge Histogram')
for bar in bars:
height = bar.get_height()
plt.text(bar.get_x() + bar.get_width()/2., height,
f'{height:.2f}',
ha='center', va='bottom',
fontsize=8, fontweight='bold') # Slightly larger font
plt.savefig(os.path.join(class_dir, f'sample_{idx}_edge_hist.png'),
dpi=120, bbox_inches='tight')
plt.close()
print(f"Saved separate visualizations for class {class_name} sample {idx} in: {class_dir}")
# Function to extract handcrafted features
def extract_handcrafted_features(image):
if isinstance(image, torch.Tensor):
image = image.cpu().numpy()
image = np.transpose(image, (1, 2, 0))
image = (image * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])) * 255
image = image.astype(np.uint8)
gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY)
gray_blur = cv2.GaussianBlur(gray, (5, 5), 0)
gray_eq = cv2.equalizeHist(gray_blur)
median = np.median(gray_eq)
lower_threshold = int(max(0, 0.66 * median))
upper_threshold = int(min(255, 1.33 * median))
edges = cv2.Canny(gray_eq, lower_threshold, upper_threshold)
print(f"Edge detection stats - Lower threshold: {lower_threshold}, Upper threshold: {upper_threshold}, Edge pixels: {np.sum(edges > 0)}")
edge_hist, _ = np.histogram(edges.ravel(), bins=8, range=(0, 256), density=True)
lbp = local_binary_pattern(gray, P=8, R=1, method='uniform')
lbp_hist, _ = np.histogram(lbp.ravel(), bins=np.arange(0, 10), density=True)
color_hist = []
for channel in range(image.shape[2]):
hist, _ = np.histogram(image[:, :, channel], bins=8, range=(0, 256), density=True)
color_hist.extend(hist)
features = np.concatenate([lbp_hist, color_hist, edge_hist])
return torch.tensor(features, dtype=torch.float32)
# Function to load images with handcrafted features
def load_images(split_path, classes, use_csv=False):
image_data = []
labels = []
handcrafted_features = []
print(f"Loading images from: {split_path}")
csv_path = os.path.join(DATASET_PATH, "labels.csv")
if use_csv and os.path.exists(csv_path):
print("Found labels.csv, loading dataset from CSV")
df = pd.read_csv(csv_path)
print("CSV columns:", df.columns)
for idx, row in df.iterrows():
img_path = os.path.join(DATASET_PATH, row['image_path'])
label_name = row['label']
if label_name not in CLASSES:
print(f"Warning: Label {label_name} not in CLASSES, skipping")
continue
label = CLASSES.index(label_name)
try:
img = Image.open(img_path).convert('RGB')
img_array = np.array(img)
img_array = cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR)
features = extract_handcrafted_features(img_array)
except Exception as e:
print(f"Warning: Failed to load image {img_path}: {e}")
continue
image_data.append(img_array)
labels.append(label)
handcrafted_features.append(features)
else:
if not os.path.exists(split_path):
print(f"Error: Directory {split_path} does not exist")
return image_data, labels, handcrafted_features
for class_idx, class_name in enumerate(classes):
class_path = os.path.join(split_path, class_name)
print(f"Checking class: {class_name} at {class_path}")
if not os.path.exists(class_path):
print(f"Warning: Class directory {class_path} does not exist")
continue
all_files = glob(os.path.join(class_path, '*'))
print(f"All files in {class_path}: {all_files[:5]}")
image_paths = glob(os.path.join(class_path, '*.[jJ][pP][gG]')) + \
glob(os.path.join(class_path, '*.[jJ][pP][eE][gG]')) + \
glob(os.path.join(class_path, '*.png'))
print(f"Found {len(image_paths)} images for class {class_name}")
for img_path in image_paths:
try:
img = Image.open(img_path).convert('RGB')
img_array = np.array(img)
img_array = cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR)
features = extract_handcrafted_features(img_array)
except Exception as e:
print(f"Warning: Failed to load image {img_path}: {e}")
continue
image_data.append(img_array)
labels.append(class_idx)
handcrafted_features.append(features)
print(f"Total images loaded: {len(image_data)}")
return image_data, labels, handcrafted_features
# Function to visualize dataset distribution
# def visualize_data_distribution(split_path, split_name, classes):
# class_counts = []
# for class_name in classes:
# class_path = os.path.join(split_path, class_name)
# image_paths = glob(os.path.join(class_path, '*.[jJ][pP][gG]')) + \
# glob(os.path.join(class_path, '*.[jJ][pP][eE][gG]')) + \
# glob(os.path.join(class_path, '*.png'))
# class_counts.append(len(image_paths))
# print(f"Split: {split_name}, Class: {class_name}, Number of images: {len(image_paths)}")
# plt.figure(figsize=(10, 6))
# plt.bar(classes, class_counts)
# plt.title(f'{split_name} Dataset Distribution')
# plt.xlabel('Classes')
# plt.ylabel('Number of Images')
# plt.xticks(rotation=45, ha='right')
# plt.tight_layout()
# plt.savefig(f"/kaggle/working/visualizations/{split_name.lower()}_distribution.png")
# plt.close()
# Function to visualize dataset distribution
def visualize_data_distribution(split_path, split_name, classes):
class_counts = []
for class_name in classes:
class_path = os.path.join(split_path, class_name)
image_paths = glob(os.path.join(class_path, '*.[jJ][pP][gG]')) + \
glob(os.path.join(class_path, '*.[jJ][pP][eE][gG]')) + \
glob(os.path.join(class_path, '*.png'))
class_counts.append(len(image_paths))
print(f"Split: {split_name}, Class: {class_name}, Number of images: {len(image_paths)}")
plt.figure(figsize=(10, 6))
bars = plt.bar(classes, class_counts) # Store the bar objects
plt.title(f'{split_name} Dataset Distribution')
plt.xlabel('Classes')
plt.ylabel('Number of Images')
plt.xticks(rotation=45, ha='right')
# Add value labels on top of each bar
for bar in bars:
height = bar.get_height()
plt.text(bar.get_x() + bar.get_width()/2., height,
f'{height}',
ha='center', va='bottom',
fontsize=10, fontweight='bold')
plt.tight_layout()
plt.savefig(f"/kaggle/working/visualizations/{split_name.lower()}_distribution.png")
plt.close()
# Custom Dataset Class
class FootUlcerDataset(Dataset):
def __init__(self, images, labels, handcrafted_features):
self.images = images
self.labels = labels
self.handcrafted_features = handcrafted_features
def __len__(self):
return len(self.images)
def __getitem__(self, idx):
image = self.images[idx]
label = self.labels[idx]
features = self.handcrafted_features[idx]
image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
if label in [2, 3]:
transform = transforms.Compose([
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomRotation(30),
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
transforms.RandomResizedCrop(224, scale=(0.8, 1.0)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
else:
transform = transforms.Compose([
transforms.RandomHorizontalFlip(p=0.5),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
image = transform(image)
return image, label, features
# Function to print model summary
def print_model_summary(model, input_size=(3, 224, 224), handcrafted_feature_dim=41):
device = next(model.parameters()).device
model.eval()
print("\nModel Summary:")
print("=" * 80)
print(f"{'Layer':<30} {'Output Shape':<25} {'Param #':<15}")
print("-" * 80)
total_params = 0
x = torch.randn(1, *input_size).to(device)
handcrafted_features = torch.randn(1, handcrafted_feature_dim).to(device)
def register_hook(module, input, output):
nonlocal total_params
class_name = str(module.__class__.__name__)
param_count = sum(p.numel() for p in module.parameters())
total_params += param_count
output_shape = list(output.shape) if isinstance(output, torch.Tensor) else "N/A"
print(f"{class_name:<30} {str(output_shape):<25} {param_count:<15}")
hooks = []
for name, module in model.named_modules():
if module != model:
hooks.append(module.register_forward_hook(register_hook))
with torch.no_grad():
model(x, handcrafted_features)
for hook in hooks:
hook.remove()
print("-" * 80)
print(f"Total Parameters: {total_params:,}")
print("=" * 80)
# Function to plot ROC curves
def plot_roc_curves(labels, probabilities, split_name, classes, model_idx=None):
plt.figure(figsize=(10, 8))
for i, class_name in enumerate(classes):
fpr, tpr, _ = roc_curve(np.array(labels) == i, probabilities[:, i])
roc_auc = auc(fpr, tpr)
plt.plot(fpr, tpr, label=f'{class_name} (AUC = {roc_auc:.2f})')
plt.plot([0, 1], [0, 1], 'k--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title(f'ROC Curves - {split_name}' + (f' (Model {model_idx})' if model_idx is not None else ''))
plt.legend(loc='lower right')
plt.grid(True)
filename = f'/kaggle/working/visualizations/roc_{split_name.lower()}' + (f'_model_{model_idx}.png' if model_idx is not None else '_ensemble.png')
plt.savefig(filename)
plt.close()
# Function to visualize feature extraction layer by layer
# def visualize_feature_extraction(model, dataloader, device, classes, num_samples=1):
# model.eval()
# feature_maps = {}
# layer_names = ['features', 'ccdgs', 'triplet_attention', 'se_block', 'fc1', 'fc2']
# print("\nModel structure (named modules):")
# for name, module in model.named_modules():
# print(f"Layer: {name}, Module: {type(module).__name__}")
# print("\nRegistering forward hooks for layers:", layer_names)
# def get_hook(name):
# def hook(module, input, output):
# feature_maps[name] = output.detach()
# print(f"Captured output for {name}, shape: {output.shape}")
# return hook
# hooks = []
# for name in layer_names:
# module = getattr(model, name, None)
# if module:
# hooks.append(module.register_forward_hook(get_hook(name)))
# print(f"Hook registered for {name}")
# else:
# print(f"Warning: Layer {name} not found in model")
# images_list = []
# labels_list = []
# probs_list = []
# features_list = []
# with torch.no_grad():
# for images, labels, features in dataloader:
# images, labels, features = images.to(device), labels.to(device), features.to(device)
# print(f"Processing batch with {images.shape[0]} images, features shape: {features.shape}")
# outputs = model(images, features)
# probs = F.softmax(outputs, dim=1)
# images_list.extend(images.cpu().numpy())
# labels_list.extend(labels.cpu().numpy())
# probs_list.extend(probs.cpu().numpy())
# features_list.extend(features.cpu().numpy())
# break
# print(f"Removing {len(hooks)} hooks")
# for hook in hooks:
# hook.remove()
# print(f"Feature maps captured: {list(feature_maps.keys())}")
# for idx in range(min(num_samples, len(images_list))):
# img = images_list[idx].transpose(1, 2, 0)
# img = img * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
# img = np.clip(img, 0, 1)
# true_label = CLASSES[labels_list[idx]]
# plt.figure(figsize=(5, 5))
# plt.imshow(img)
# plt.title(f'Input Image (Class: {true_label})')
# plt.axis('off')
# input_img_path = f'/kaggle/working/visualizations/input_image_sample_{idx}.png'
# plt.savefig(input_img_path)
# plt.close()
# print(f"Saved input image to: {input_img_path}")
# for layer_name in layer_names:
# if layer_name not in feature_maps:
# print(f"No feature map for {layer_name}, skipping visualization")
# continue
# features = feature_maps[layer_name][idx]
# print(f"Visualizing {layer_name}, feature shape: {features.shape}, feature dim:{features.dim()}")
# if features.dim() == 3:
# num_channels = min(features.shape[0], 16)
# plt.figure(figsize=(15, 10))
# for i in range(num_channels):
# plt.subplot(4, 4, i + 1)
# feature_map = features[i].cpu().numpy()
# feature_map = (feature_map - feature_map.min()) / (feature_map.max() - feature_map.min() + 1e-8)
# plt.imshow(feature_map, cmap='viridis')
# plt.title(f'Channel {i+1}')
# plt.axis('off')
# plt.suptitle(f'Feature Maps - {layer_name} (Class: {true_label})')
# plt.tight_layout(rect=[0, 0, 1, 0.95])
# feature_map_path = f'/kaggle/working/visualizations/feature_maps_{layer_name}_{idx}.png'
# plt.savefig(feature_map_path)
# plt.close()
# print(f"Saved {num_channels} feature maps for {layer_name} to: {feature_map_path}")
# else:
# values = features.flatten().cpu().numpy()
# plt.figure(figsize=(10, 5))
# plt.bar(range(len(values)), values, color='blue') # Changed bar color to blue
# plt.title(f'Feature Vector - {layer_name} (Class: {true_label})')
# plt.xlabel('Index')
# plt.ylabel('Value')
# ## Use only if it,s looking good, the grid part
# plt.grid(True, linestyle='--', alpha=0.6) # Added light grid for better readability
# plt.tight_layout()
# vector_path = f'/kaggle/working/visualizations/feature_vector_{layer_name}_{idx}.png'
# plt.savefig(vector_path)
# plt.close()
# print(f"Saved feature vector with {len(values)} elements for {layer_name} to: {vector_path}")
# plt.figure(figsize=(8, 6))
# bars = plt.bar(classes, probs_list[idx]) # Store the bar objects
# plt.title(f'Classification Probabilities (True: {true_label})')
# plt.xlabel('Classes')
# plt.ylabel('Probability')
# plt.xticks(rotation=45)
# # Add value labels on top of each bar
# for bar in bars:
# height = bar.get_height()
# plt.text(bar.get_x() + bar.get_width()/2., height,
# f'{height:.3f}',
# ha='center', va='bottom',
# fontsize=10, fontweight='bold')
# plt.tight_layout()
# probs_path = f'/kaggle/working/visualizations/classification_probs_{idx}.png'
# plt.savefig(probs_path)
# plt.close()
# print(f"Saved classification probabilities to: {probs_path}")
# print("\nListing saved visualization files:")
# os.system('ls /kaggle/working/visualizations/')
def visualize_feature_extraction(model, dataloader, device, classes, num_samples_per_class=1):
model.eval()
feature_maps = {}
layer_names = ['features', 'ccdgs', 'triplet_attention', 'se_block', 'fc1', 'fc2']
print("\nModel structure (named modules):")
for name, module in model.named_modules():
print(f"Layer: {name}, Module: {type(module).__name__}")
print("\nRegistering forward hooks for layers:", layer_names)
def get_hook(name):
def hook(module, input, output):
feature_maps[name] = output.detach()
print(f"Captured output for {name}, shape: {output.shape}")
return hook
hooks = []
for name in layer_names:
module = getattr(model, name, None)
if module:
hooks.append(module.register_forward_hook(get_hook(name)))
print(f"Hook registered for {name}")
else:
print(f"Warning: Layer {name} not found in model")
# Collect samples from each class
class_samples = {class_idx: [] for class_idx in range(len(classes))}
with torch.no_grad():
for images, labels, features in dataloader:
images, labels, features = images.to(device), labels.to(device), features.to(device)
outputs = model(images, features)
probs = F.softmax(outputs, dim=1)
for i in range(len(images)):
class_idx = labels[i].item()
if len(class_samples[class_idx]) < num_samples_per_class:
class_samples[class_idx].append((
images[i].cpu().numpy(),
labels[i].cpu().numpy(),
probs[i].cpu().numpy(),
features[i].cpu().numpy()
))
# Check if we have enough samples from each class
if all(len(samples) >= num_samples_per_class for samples in class_samples.values()):
break
print(f"Removing {len(hooks)} hooks")
for hook in hooks:
hook.remove()
print(f"Feature maps captured: {list(feature_maps.keys())}")
# Process one sample from each class
for class_idx in range(len(classes)):
if not class_samples[class_idx]:
print(f"No samples found for class {class_idx} ({classes[class_idx]})")
continue
# Take the first sample for this class
img, label, prob, features = class_samples[class_idx][0]
true_label = classes[label]
# Process input image
img = img.transpose(1, 2, 0)
img = img * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
img = np.clip(img, 0, 1)
plt.figure(figsize=(5, 5))
plt.imshow(img)
plt.title(f'Input Image (Class: {true_label})')
plt.axis('off')
input_img_path = f'/kaggle/working/visualizations/class_{class_idx}_input_image.png'
plt.savefig(input_img_path)
plt.close()
print(f"Saved input image for class {true_label} to: {input_img_path}")
# Process feature maps for each layer
for layer_name in layer_names:
if layer_name not in feature_maps:
print(f"No feature map for {layer_name}, skipping visualization")
continue
features = feature_maps[layer_name][class_idx] # Assuming feature maps are in order
print(f"Visualizing {layer_name} for class {true_label}, feature shape: {features.shape}")
if features.dim() == 3:
num_channels = min(features.shape[0], 16)
plt.figure(figsize=(15, 10))
for i in range(num_channels):
plt.subplot(4, 4, i + 1)
feature_map = features[i].cpu().numpy()
feature_map = (feature_map - feature_map.min()) / (feature_map.max() - feature_map.min() + 1e-8)
plt.imshow(feature_map, cmap='viridis')
plt.title(f'Channel {i+1}')
plt.axis('off')
plt.suptitle(f'Feature Maps - {layer_name} (Class: {true_label})')
plt.tight_layout(rect=[0, 0, 1, 0.95])
feature_map_path = f'/kaggle/working/visualizations/class_{class_idx}_feature_maps_{layer_name}.png'
plt.savefig(feature_map_path)
plt.close()
print(f"Saved {num_channels} feature maps for {layer_name} to: {feature_map_path}")
# else:
# values = features.flatten().cpu().numpy()
# plt.figure(figsize=(10, 5))
# plt.bar(range(len(values)), values, color='blue')
# plt.title(f'Feature Vector - {layer_name} (Class: {true_label})')
# plt.xlabel('Index')
# plt.ylabel('Value')
# plt.grid(True, linestyle='--', alpha=0.6)
# plt.tight_layout()
# vector_path = f'/kaggle/working/visualizations/class_{class_idx}_feature_vector_{layer_name}.png'
# plt.savefig(vector_path)
# plt.close()
# print(f"Saved feature vector with {len(values)} elements for {layer_name} to: {vector_path}")
else:
values = features.flatten().cpu().numpy()
num_features = len(values)
# Adjust figure width based on number of features
fig_width = max(20, num_features * 0.025) # 0.025 inches per bar (adjustable)
plt.figure(figsize=(fig_width, 5)) # Wider for more bars
# Plot bars with optimized width & spacing
bars = plt.bar(
range(num_features),
values,
color='#1f77b4', # Matplotlib default blue (better than 'blue')
edgecolor='#1f77b4',
linewidth=0.05, # Thinner border for dense plots
width=0.9, # Slightly narrower to guarantee gaps
align='center'
)
# Hide x-axis labels if too many features
if num_features > 100:
ticks = list(range(0, num_features, 50)) + [num_features-1] # Add last feature
plt.xticks(ticks) # Diagonal labels
plt.title(f'Feature Vector - {layer_name} (Class: {true_label})')
plt.xlabel('Feature')
plt.ylabel('Activation Value')
plt.grid(True, linestyle=':', alpha=0.5)
plt.tight_layout()
vector_path = f'/kaggle/working/visualizations/class_{class_idx}_feature_vector_{layer_name}.png'
plt.savefig(vector_path, dpi=120, bbox_inches='tight', facecolor='white')
plt.close()
print(f"Saved feature vector with {len(values)} elements for {layer_name} to: {vector_path}")
# Process classification probabilities
plt.figure(figsize=(8, 6))
bars = plt.bar(classes, prob)
plt.title(f'Classification Probabilities (True: {true_label})')
plt.xlabel('Classes')
plt.ylabel('Probability')
plt.xticks(rotation=45)
for bar in bars:
height = bar.get_height()
plt.text(bar.get_x() + bar.get_width()/2., height,
f'{height:.3f}',
ha='center', va='bottom',
fontsize=10, fontweight='bold')
plt.tight_layout()
probs_path = f'/kaggle/working/visualizations/class_{class_idx}_classification_probs.png'
plt.savefig(probs_path)
plt.close()
print(f"Saved classification probabilities to: {probs_path}")
print("\nListing saved visualization files:")
os.system('ls /kaggle/working/visualizations/')
# Training Function
def train_model(model, dataloader, criterion, optimizer, device, epochs=100, model_idx=0):
history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}
best_val_loss = float('inf')
best_val_acc = 0.0
best_train_acc = 0.0
patience = 10
counter = 0
scaler = torch.cuda.amp.GradScaler()
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, verbose=True)
train_batches = len(dataloader['train'])
val_batches = len(dataloader['val'])
print(f"Training dataset: {train_batches} batches")
print(f"Validation dataset: {val_batches} batches")
for epoch in range(epochs):
print(f"\n--- Epoch {epoch+1}/{epochs} ---")
model.train()
running_loss = 0.0
correct, total = 0, 0
for batch_idx, (images, labels, features) in enumerate(dataloader['train']):
print(f"Training epoch {epoch+1}, batch {batch_idx+1}/{train_batches}")
images, labels, features = images.to(device), labels.to(device), features.to(device)
optimizer.zero_grad()
with torch.cuda.amp.autocast():
outputs = model(images, features)
loss = criterion(outputs, labels)
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
running_loss += loss.item()
_, predicted = outputs.max(1)
total += labels.size(0)
correct += predicted.eq(labels).sum().item()
epoch_loss = running_loss / train_batches if train_batches > 0 else 0.0
epoch_acc = 100. * correct / total if total > 0 else 0.0
history['train_loss'].append(epoch_loss)
history['train_acc'].append(epoch_acc)
best_train_acc = max(best_train_acc, epoch_acc)
model.eval()
val_loss = 0.0
val_correct, val_total = 0, 0
with torch.no_grad():
for batch_idx, (val_images, val_labels, val_features) in enumerate(dataloader['val']):
print(f"Validation epoch {epoch+1}, batch {batch_idx+1}/{val_batches}")
val_images, val_labels, val_features = val_images.to(device), val_labels.to(device), val_features.to(device)
with torch.cuda.amp.autocast():
val_outputs = model(val_images, val_features)
loss = criterion(val_outputs, val_labels)
val_loss += loss.item()
_, predicted = val_outputs.max(1)
val_total += val_labels.size(0)
val_correct += predicted.eq(val_labels).sum().item()
val_epoch_loss = val_loss / val_batches if val_batches > 0 else 0.0
val_epoch_acc = 100. * val_correct / val_total if val_total > 0 else 0.0
history['val_loss'].append(val_epoch_loss)
history['val_acc'].append(val_epoch_acc)
best_val_acc = max(best_val_acc, val_epoch_acc)
scheduler.step(val_epoch_loss)
if val_epoch_loss < best_val_loss:
best_val_loss = val_epoch_loss
counter = 0
torch.save(model.state_dict(), f'/kaggle/working/best_model_{model_idx}.pth')
else:
counter += 1
if counter >= patience:
print("Early stopping triggered")
break
print(f"Training Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.2f}%")
print(f"Validation Loss: {val_epoch_loss:.4f}, Accuracy: {val_epoch_acc:.2f}%")
print(f"Best Training Accuracy (Model {model_idx}): {best_train_acc:.2f}%")
print(f"Best Validation Accuracy (Model {model_idx}): {best_val_acc:.2f}%")
return history, best_train_acc, best_val_acc
# Function to Plot Training History
def plot_training_history(history, epochs, model_idx=0):
epochs_range = range(1, len(history['train_loss']) + 1)
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, history['train_loss'], label='Training Loss')
plt.plot(epochs_range, history['val_loss'], label='Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title(f'Training and Validation Loss (Model {model_idx})')
plt.legend()
plt.grid(True)
plt.subplot(1, 2, 2)
plt.plot(epochs_range, history['train_acc'], label='Training Accuracy')
plt.plot(epochs_range, history['val_acc'], label='Validation Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy (%)')
plt.title(f'Training and Validation Accuracy (Model {model_idx})')
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.savefig(f'/kaggle/working/visualizations/training_history_model_{model_idx}.png')
plt.close()
# Function to Evaluate Model
def evaluate_model(model, dataloader, device, split_name, classes, model_idx=None, use_tta=False):
model.eval()
correct = 0
total = 0
all_predictions = []
all_labels = []
all_probs = []
mean = torch.tensor([0.485, 0.456, 0.406], device=device).view(3, 1, 1)
std = torch.tensor([0.229, 0.224, 0.225], device=device).view(3, 1, 1)
tta_transforms = [
transforms.Compose([
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
]),
transforms.Compose([
transforms.RandomHorizontalFlip(p=1.0),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
]),
transforms.Compose([
transforms.RandomRotation(10),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
]
with torch.no_grad():
for images, labels, features in dataloader:
images, labels, features = images.to(device), labels.to(device), features.to(device)
if use_tta:
batch_probs = []
for transform in tta_transforms:
denorm_images = images * std + mean
denorm_images = denorm_images.clamp(0, 1) * 255
denorm_images = denorm_images.to(torch.uint8)
tta_images = torch.stack([
transform(Image.fromarray(img.cpu().numpy().transpose(1, 2, 0)))
for img in denorm_images
]).to(device)
outputs = model(tta_images, features)
batch_probs.append(F.softmax(outputs, dim=1))
avg_probs = torch.stack(batch_probs).mean(dim=0)
_, predicted = torch.max(avg_probs, 1)
all_probs.extend(avg_probs.cpu().numpy())
else:
outputs = model(images, features)
_, predicted = torch.max(outputs.data, 1)
all_probs.extend(F.softmax(outputs, dim=1).cpu().numpy())
all_predictions.extend(predicted.cpu().numpy())
all_labels.extend(labels.cpu().numpy())
total += labels.size(0)
correct += (predicted == labels).sum().item()
accuracy = 100 * correct / total if total > 0 else 0.0
# cm = confusion_matrix(all_labels, all_predictions)
# plt.figure(figsize=(10, 8))
# sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=classes, yticklabels=classes)
# plt.title(f'Confusion Matrix - {split_name}' + (f' (Model {model_idx})' if model_idx is not None else ''))
# plt.xlabel('Predicted')
# plt.ylabel('True')
cm = confusion_matrix(all_labels, all_predictions)
plt.figure(figsize=(10, 8))
# Create heatmap with custom annotation formatting
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
xticklabels=classes, yticklabels=classes,
annot_kws={'size': 12, 'weight': 'bold'}) # Larger bold annotations
# Make title and axis labels bold
plt.title(f'Confusion Matrix - {split_name}' +
(f' (Model {model_idx})' if model_idx is not None else ''),
fontsize=14, fontweight='bold') # Bold title with larger font
plt.xlabel('Predicted', fontsize=12, fontweight='bold') # Bold x-label
plt.ylabel('True', fontsize=12, fontweight='bold') # Bold y-label
plt.tight_layout()
filename = f'/kaggle/working/visualizations/cm_{split_name.lower()}' + (f'_model_{model_idx}.png' if model_idx is not None else '_ensemble.png')
plt.savefig(filename)
plt.close()
report = classification_report(all_labels, all_predictions, target_names=classes, output_dict=True)
report_df = pd.DataFrame(report).transpose()
report_filename = f'/kaggle/working/classification_report_{split_name.lower()}' + (f'_model_{model_idx}.csv' if model_idx is not None else '_ensemble.csv')
report_df.to_csv(report_filename)
all_probs = np.array(all_probs)
plot_roc_curves(all_labels, all_probs, split_name, classes, model_idx)
return accuracy, all_predictions, all_labels, all_probs, report_df
# Ensemble Voting Function
def ensemble_voting(models, dataloader, device, split_name, classes):
all_predictions = []
all_labels = []
all_probs = []
for model in models:
model.eval()
with torch.no_grad():
for images, labels, features in dataloader:
images, labels, features = images.to(device), labels.to(device), features.to(device)
votes = []
probs = []
for model in models:
outputs = model(images, features)
_, predicted = torch.max(outputs.data, 1)
votes.append(predicted.cpu().numpy())
probs.append(F.softmax(outputs, dim=1).cpu().numpy())
votes = np.array(votes)
final_predictions = np.apply_along_axis(lambda x: np.bincount(x).argmax(), axis=0, arr=votes)
avg_probs = np.mean(probs, axis=0)
all_predictions.extend(final_predictions)
all_labels.extend(labels.cpu().numpy())
all_probs.extend(avg_probs)
accuracy = 100 * sum(np.array(all_predictions) == np.array(all_labels)) / len(all_labels)
# cm = confusion_matrix(all_labels, all_predictions)
# plt.figure(figsize=(10, 8))
# sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=classes, yticklabels=classes)
# plt.title(f'Confusion Matrix - {split_name} (Ensemble)')
# plt.xlabel('Predicted')
# plt.ylabel('True')
cm = confusion_matrix(all_labels, all_predictions)
plt.figure(figsize=(10, 8))
# Create heatmap with custom annotation formatting
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
xticklabels=classes, yticklabels=classes,
annot_kws={'size': 12, 'weight': 'bold'}) # Larger bold annotations
plt.title(f'Confusion Matrix - {split_name} (Ensemble)', fontsize=14, fontweight='bold')
plt.xlabel('Predicted', fontsize=12, fontweight='bold') # Bold x-label
plt.ylabel('True', fontsize=12, fontweight='bold') # Bold y-label
plt.tight_layout()
plt.savefig(f'/kaggle/working/visualizations/cm_{split_name.lower()}_ensemble.png')
plt.close()
report = classification_report(all_labels, all_predictions, target_names=classes, output_dict=True)
report_df = pd.DataFrame(report).transpose()
report_df.to_csv(f'/kaggle/working/classification_report_{split_name.lower()}_ensemble.csv')
all_probs = np.array(all_probs)
plot_roc_curves(all_labels, all_probs, f'{split_name} (Ensemble)', classes)
return accuracy, all_predictions, all_labels, all_probs, report_df
# Function to visualize voting process
def visualize_voting_process(models, dataloader, device, classes, num_samples=5):
model_predictions = []
true_labels = []
images_list = []
with torch.no_grad():
for images, labels, features in dataloader:
images, labels, features = images.to(device), labels.to(device), features.to(device)
batch_preds = []
for model in models:
model.eval()
outputs = model(images, features)
_, predicted = torch.max(outputs.data, 1)
batch_preds.append(predicted.cpu().numpy())
model_predictions.extend(np.array(batch_preds).T)
true_labels.extend(labels.cpu().numpy())
images_list.extend(images.cpu().numpy())
if len(true_labels) >= num_samples:
break
model_predictions = model_predictions[:num_samples]
true_labels = true_labels[:num_samples]
images_list = images_list[:num_samples]
plt.figure(figsize=(15, num_samples * 3))
for i in range(num_samples):
img = images_list[i].transpose(1, 2, 0)
img = img * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
img = np.clip(img, 0, 1)
plt.subplot(num_samples, 1, i + 1)
plt.imshow(img)
preds = [CLASSES[p] for p in model_predictions[i]]
ensemble_pred = CLASSES[np.bincount(model_predictions[i]).argmax()]
title = f'True: {CLASSES[true_labels[i]]}\n' + \
f'Model 1: {preds[0]}, Model 2: {preds[1]}, Model 3: {preds[2]}\n' + \
f'Ensemble: {ensemble_pred}'
plt.title(title)
plt.axis('off')
plt.tight_layout()
plt.savefig('/kaggle/working/visualizations/voting_process.png')
plt.close()
# Function to visualize predictions per class
# def visualize_predictions_per_class(model, dataloader, device, classes, split_name, model_idx=None, num_samples=4):
# model.eval()
# class_images = {i: [] for i in range(len(classes))}
# class_preds = {i: [] for i in range(len(classes))}
# class_labels = {i: [] for i in range(len(classes))}
# with torch.no_grad():
# for images, labels, features in dataloader:
# images, labels, features = images.to(device), labels.to(device), features.to(device)
# outputs = model(images, features)
# _, predicted = torch.max(outputs.data, 1)
# for img, pred, label in zip(images.cpu().numpy(), predicted.cpu().numpy(), labels.cpu().numpy()):
# if len(class_images[label]) < num_samples:
# class_images[label].append(img)
# class_preds[label].append(pred)
# class_labels[label].append(label)
# if all(len(class_images[i]) >= num_samples for i in range(len(classes))):
# break
# for class_idx, class_name in enumerate(classes):
# plt.figure(figsize=(15, 5))
# for i in range(min(num_samples, len(class_images[class_idx]))):
# img = class_images[class_idx][i].transpose(1, 2, 0)
# img = img * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
# img = np.clip(img, 0, 1)
# plt.subplot(1, num_samples, i + 1)
# plt.imshow(img)
# plt.title(f'True: {class_name}\nPred: {CLASSES[class_preds[class_idx][i]]}')
# plt.axis('off')
# plt.suptitle(f'Predictions for {class_name} ({split_name})')
# plt.tight_layout(rect=[0, 0, 1, 0.95])
# filename = f'/kaggle/working/visualizations/predictions_{class_name}_{split_name.lower()}' + (f'_model_{model_idx}.png' if model_idx is not None else '_ensemble.png')
# plt.savefig(filename)
# plt.close()
def visualize_predictions_per_class(model, dataloader, device, classes, split_name, model_idx=None, num_samples=4):
model.eval()
class_images = {i: [] for i in range(len(classes))}
class_preds = {i: [] for i in range(len(classes))}
class_labels = {i: [] for i in range(len(classes))}
with torch.no_grad():
for images, labels, features in dataloader:
images, labels, features = images.to(device), labels.to(device), features.to(device)
outputs = model(images, features)
_, predicted = torch.max(outputs.data, 1)
for img, pred, label in zip(images.cpu().numpy(), predicted.cpu().numpy(), labels.cpu().numpy()):
if len(class_images[label]) < num_samples:
class_images[label].append(img)
class_preds[label].append(pred)
class_labels[label].append(label)
if all(len(class_images[i]) >= num_samples for i in range(len(classes))):
break
for class_idx, class_name in enumerate(classes):
plt.figure(figsize=(15, 5))
for i in range(min(num_samples, len(class_images[class_idx]))):
img = class_images[class_idx][i].transpose(1, 2, 0)
img = img * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])
img = np.clip(img, 0, 1)
plt.subplot(1, num_samples, i + 1)
plt.imshow(img)
plt.title(f'True: {class_name}\nPred: {CLASSES[class_preds[class_idx][i]]}',
fontweight='bold') # Bold title
plt.axis('off')
# Make suptitle bold and adjust font properties
plt.suptitle(f'Predictions for {class_name} ({split_name})',
fontweight='bold',
fontsize=12) # Optional: slightly larger font
plt.tight_layout(rect=[0, 0, 1, 0.95])
filename = f'/kaggle/working/visualizations/predictions_{class_name}_{split_name.lower()}' + (f'_model_{model_idx}.png' if model_idx is not None else '_ensemble.png')
plt.savefig(filename)
plt.close()
# # Function to visualize predictions combined per class
# def visualize_predictions_grid_per_class(model, dataloader, device, classes, split_name, model_idx=None, num_samples=2):
# import os
# os.makedirs("/kaggle/working/visualizations", exist_ok=True)
# model.eval()
# num_classes = len(classes)
# # Collect samples
# class_images = {i: [] for i in range(num_classes)}
# class_preds = {i: [] for i in range(num_classes)}
# class_labels = {i: [] for i in range(num_classes)}
# with torch.no_grad():
# for images, labels, features in dataloader:
# images, labels, features = images.to(device), labels.to(device), features.to(device)
# outputs = model(images, features)
# _, predicted = torch.max(outputs.data, 1)
# for img, pred, label in zip(images.cpu().numpy(), predicted.cpu().numpy(), labels.cpu().numpy()):
# if len(class_images[label]) < num_samples:
# class_images[label].append(img)
# class_preds[label].append(pred)
# class_labels[label].append(label)
# if all(len(class_images[i]) >= num_samples for i in range(num_classes)):
# break
# # Plot: Grid of num_samples rows × num_classes columns
# plt.figure(figsize=(4 * num_classes, 4 * num_samples))
# for row in range(num_samples):
# for class_idx in range(num_classes):
# if row >= len(class_images[class_idx]):
# continue
# img = class_images[class_idx][row].transpose(1, 2, 0)
# img = img * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406]) # unnormalize
# img = np.clip(img, 0, 1)
# ax_idx = row * num_classes + class_idx + 1
# plt.subplot(num_samples, num_classes, ax_idx)
# true_label = classes[class_labels[class_idx][row]]
# pred_label = classes[class_preds[class_idx][row]]
# plt.imshow(img)
# plt.title(f'True: {true_label}\nPred: {pred_label}')
# plt.axis('off')
# plt.suptitle(f'{split_name} Predictions Grid ({num_samples}×{num_classes})')
# filename = f'/kaggle/working/visualizations/predictions_grid_{split_name.lower()}' + (f'_model_{model_idx}.png' if model_idx is not None else '_ensemble.png')
# plt.tight_layout(rect=[0, 0, 1, 0.95])
# plt.savefig(filename)
# plt.close()
# Main Execution
if __name__ == "__main__":
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Debug dataset directory
print("Checking dataset directory structure:")
for dirname, _, filenames in os.walk(DATASET_PATH):
print(f"Directory: {dirname}, Files: {len(filenames)}")
for filename in filenames[:5]:
print(f" - {os.path.join(dirname, filename)}")
# Check for CSV file
csv_path = os.path.join(DATASET_PATH, "labels.csv")
use_csv = os.path.exists(csv_path)
if use_csv:
print("Detected labels.csv, will load dataset from CSV")
else:
print("No labels.csv found, assuming directory-based structure")
# Load dataset
train_path = os.path.join(DATASET_PATH, "TRAIN")
val_path = os.path.join(DATASET_PATH, "VALIDATION")
test_path = os.path.join(DATASET_PATH, "TEST")
train_images, train_labels, train_features = load_images(train_path, CLASSES, use_csv)
val_images, val_labels, val_features = load_images(val_path, CLASSES, use_csv)
test_images, test_labels, test_features = load_images(test_path, CLASSES, use_csv)
# Check if datasets are empty
if not train_images:
raise ValueError("Training dataset is empty. Please check the dataset path, class names, image files, or CSV structure.")
if not val_images:
print("Warning: Validation dataset is empty. Creating validation split from training data.")
train_images, val_images, train_labels, val_labels, train_features, val_features = train_test_split(
train_images, train_labels, train_features, test_size=0.2, stratify=train_labels, random_state=42
)
if not test_images:
print("Warning: Test dataset is empty.")
test_images, test_labels, test_features = [], [], []
# Visualize dataset
visualize_data_distribution(train_path, "Train", CLASSES)
visualize_data_distribution(val_path, "Validation", CLASSES)
visualize_data_distribution(test_path, "Test", CLASSES)
display_sample_images(train_images, train_labels, "Train", CLASSES)
display_sample_images(val_images, val_labels, "Validation", CLASSES)
display_sample_images(test_images, test_labels, "Test", CLASSES)
visualize_handcrafted_features(train_images, train_labels, CLASSES)
# Create datasets
train_dataset = FootUlcerDataset(train_images, train_labels, train_features)
val_dataset = FootUlcerDataset(val_images, val_labels, val_features)
test_dataset = FootUlcerDataset(test_images, test_labels, test_features)
# Create WeightedRandomSampler
train_labels_np = np.array(train_labels)
class_counts = np.array([sum(train_labels_np == i) for i in range(len(CLASSES))])
print(f"Class counts: {dict(zip(CLASSES, class_counts))}")
if np.any(class_counts == 0):
print("Warning: Some classes have zero samples in the training set.")
class_weights = 1.0 / (class_counts + 1e-6)
sample_weights = class_weights[train_labels_np]
sampler = WeightedRandomSampler(sample_weights, len(train_labels), replacement=True)
# Create DataLoaders
batch_size = 32
dataloader = {
'train': DataLoader(train_dataset, batch_size=batch_size, sampler=sampler),
'val': DataLoader(val_dataset, batch_size=batch_size, shuffle=False),
'test': DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
}
# Train and evaluate models
num_models = 3
models_list = []
test_accuracies = []
best_train_accuracies = []
best_val_accuracies = []
for i in range(num_models):
print(f"\nTraining Model {i+1}/{num_models}")
model = DenseShuffleGCANet(num_classes=len(CLASSES), handcrafted_feature_dim=41).to(device)
print(f"\nModel {i+1} Summary:")
print_model_summary(model, input_size=(3, 224, 224), handcrafted_feature_dim=41)
criterion = FocalLoss(gamma=3.0, alpha=0.5)
optimizer = optim.Adam(model.parameters(), lr=0.00005, weight_decay=0.001)
history, best_train_acc, best_val_acc = train_model(model, dataloader, criterion, optimizer, device, epochs=100, model_idx=i)
plot_training_history(history, len(history['train_loss']), model_idx=i)
best_train_accuracies.append(best_train_acc)
best_val_accuracies.append(best_val_acc)
print(f"\nEvaluating Model {i+1} on Training Set")
train_acc, train_preds, train_labels, _, _ = evaluate_model(model, dataloader['train'], device, 'Train', CLASSES, i)
print(f"Model {i+1} Train Accuracy: {train_acc:.2f}%")
print(f"\nEvaluating Model {i+1} on Validation Set")
val_acc, val_preds, val_labels, _, _ = evaluate_model(model, dataloader['val'], device, 'Validation', CLASSES, i)
print(f"Model {i+1} Validation Accuracy: {val_acc:.2f}%")
print(f"\nEvaluating Model {i+1} on Test Set")
test_acc, test_preds, test_labels, _, _ = evaluate_model(model, dataloader['test'], device, 'Test', CLASSES, i)
print(f"Model {i+1} Test Accuracy: {test_acc:.2f}%")
test_accuracies.append(test_acc)
visualize_predictions_per_class(model, dataloader['test'], device, CLASSES, 'Test', model_idx=i)
if i == 0:
print(f"\nVisualizing Feature Extraction for Model {i+1}")
visualize_feature_extraction(model, dataloader['test'], device, CLASSES, num_samples_per_class=1)
models_list.append(model)
# Evaluate ensemble
print("\nEvaluating Ensemble on Training Set")
ensemble_train_acc, _, _, _, _ = ensemble_voting(models_list, dataloader['train'], device, 'Train', CLASSES)
print(f"Ensemble Train Accuracy: {ensemble_train_acc:.2f}%")
print("\nEvaluating Ensemble on Validation Set")
ensemble_val_acc, _, _, _, _ = ensemble_voting(models_list, dataloader['val'], device, 'Validation', CLASSES)
print(f"Ensemble Validation Accuracy: {ensemble_val_acc:.2f}%")
print("\nEvaluating Ensemble on Test Set")
ensemble_test_acc, ensemble_test_preds, ensemble_test_labels, _, _ = ensemble_voting(models_list, dataloader['test'], device, 'Test', CLASSES)
print(f"Ensemble Test Accuracy: {ensemble_test_acc:.2f}%")
visualize_predictions_per_class(models_list[0], dataloader['test'], device, CLASSES, 'Test', model_idx=None)
visualize_voting_process(models_list, dataloader['test'], device, CLASSES)
# Evaluate TTA
print("\nEvaluating Best Model with TTA on Test Set")
tta_acc, _, _, _, _ = evaluate_model(models_list[0], dataloader['test'], device, 'Test_TTA', CLASSES, model_idx=0, use_tta=True)
print(f"TTA Test Accuracy: {tta_acc:.2f}%")
# Statistical Analysis
print("\nStatistical Analysis of Model Performance:")
print(f"Mean Test Accuracy: {np.mean(test_accuracies):.2f}% ± {np.std(test_accuracies):.2f}%")
print(f"Best Training Accuracies: {[f'{acc:.2f}%' for acc in best_train_accuracies]}")
print(f"Best Validation Accuracies: {[f'{acc:.2f}%' for acc in best_val_accuracies]}")
# Save predictions
predictions_df = pd.DataFrame({
'True_Label': [CLASSES[label] for label in ensemble_test_labels],
'Predicted_Label': [CLASSES[pred] for pred in ensemble_test_preds]
})
predictions_df.to_csv('/kaggle/working/predictions/test_predictions_ensemble.csv', index=False)
print("Predictions saved to /kaggle/working/predictions/test_predictions_ensemble.csv")
# Save summary report
summary_report = {
'Model': [f'Model {i+1}' for i in range(num_models)] + ['Ensemble', 'TTA'],
'Test_Accuracy': test_accuracies + [ensemble_test_acc, tta_acc],
'Best_Train_Accuracy': best_train_accuracies + [None, None],
'Best_Val_Accuracy': best_val_accuracies + [None, None]
}
summary_df = pd.DataFrame(summary_report)
summary_df.to_csv('/kaggle/working/summary_report.csv', index=False)
print("\nSummary Report:")
print(summary_df)