# Import system libraries import os import cv2 import numpy as np import matplotlib.pyplot as plt from glob import glob from PIL import Image from sklearn.model_selection import train_test_split # Import data handling tools import pandas as pd import seaborn as sns sns.set_style('darkgrid') from sklearn.metrics import confusion_matrix, classification_report, roc_curve, auc from sklearn.utils import shuffle from torch.utils.data import WeightedRandomSampler from skimage.feature import local_binary_pattern # Import deep learning libraries import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim import torchvision.transforms as transforms import torchvision.models as models from torch.utils.data import Dataset, DataLoader # Define dataset path and classes DATASET_PATH = "/kaggle/input/ms-dfu/DFU_CLASSES(4)" CLASSES = ["NONE", "INFECTION", "ISCHAEMIA", "BOTH"] # Ensure output directories exist os.makedirs("/kaggle/working/logs", exist_ok=True) os.makedirs("/kaggle/working/predictions", exist_ok=True) os.makedirs("/kaggle/working/visualizations", exist_ok=True) # Squeeze-and-Excitation Block class SEBlock(nn.Module): def __init__(self, in_channels, reduction=16): super(SEBlock, self).__init__() self.fc1 = nn.Linear(in_channels, in_channels // reduction, bias=False) self.fc2 = nn.Linear(in_channels // reduction, in_channels, bias=False) self.global_pool = nn.AdaptiveAvgPool2d(1) def forward(self, x): batch, channels, _, _ = x.size() y = self.global_pool(x).view(batch, channels) y = F.relu(self.fc1(y)) y = torch.sigmoid(self.fc2(y)).view(batch, channels, 1, 1) return x * y # Focal Loss Implementation class FocalLoss(nn.Module): def __init__(self, gamma=3.0, alpha=0.5): super(FocalLoss, self).__init__() self.gamma = gamma self.alpha = alpha def forward(self, inputs, targets): ce_loss = F.cross_entropy(inputs, targets, reduction='none') pt = torch.exp(-ce_loss) focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss return focal_loss.mean() # Channel-Centric Depth-wise Group Shuffle (CCDGS) Block class CCDGSBlock(nn.Module): def __init__(self, in_channels, group_size=4): super(CCDGSBlock, self).__init__() self.group_size = group_size self.group_conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, groups=group_size, bias=False) self.bn1 = nn.BatchNorm2d(in_channels) self.depth_conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, groups=in_channels, bias=False) self.bn2 = nn.BatchNorm2d(in_channels) self.point_conv = nn.Conv2d(in_channels, in_channels, kernel_size=1, bias=False) self.bn3 = nn.BatchNorm2d(in_channels) self.global_pool = nn.AdaptiveAvgPool2d(1) def channel_shuffle(self, x, groups): batchsize, num_channels, height, width = x.size() channels_per_group = num_channels // groups x = x.view(batchsize, groups, channels_per_group, height, width) x = torch.transpose(x, 1, 2).contiguous() x = x.view(batchsize, -1, height, width) return x def forward(self, x): out = self.group_conv(x) out = self.bn1(out) out = F.relu(out) out = self.channel_shuffle(out, self.group_size) out = self.depth_conv(out) out = self.bn2(out) out = F.relu(out) out = self.point_conv(out) out = self.bn3(out) out = F.relu(out) out = self.global_pool(out) return out # Triplet Attention Module class TripletAttention(nn.Module): def __init__(self, in_channels, kernel_size=7): super(TripletAttention, self).__init__() self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False) self.conv2 = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False) self.conv3 = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False) def z_pool(self, x): max_pool = torch.max(x, dim=1, keepdim=True)[0] avg_pool = torch.mean(x, dim=1, keepdim=True) return torch.cat([max_pool, avg_pool], dim=1) def forward(self, x): x1 = torch.rot90(x, 1, [2, 3]) x1 = self.z_pool(x1) x1 = self.conv1(x1) x1 = torch.sigmoid(x1) x1 = torch.rot90(x1, -1, [2, 3]) y1 = x * x1 x2 = torch.rot90(x, 1, [1, 3]) x2 = self.z_pool(x2) x2 = self.conv2(x2) x2 = torch.sigmoid(x2) x2 = torch.rot90(x2, -1, [1, 3]) y2 = x * x2 x3 = self.z_pool(x) x3 = self.conv3(x3) x3 = torch.sigmoid(x3) y3 = x * x3 out = (y1 + y2 + y3) / 3.0 return out # Dense-ShuffleGCANet Model class DenseShuffleGCANet(nn.Module): def __init__(self, num_classes=4, handcrafted_feature_dim=41): super(DenseShuffleGCANet, self).__init__() densenet = models.densenet169(weights='IMAGENET1K_V1') self.features = densenet.features self.ccdgs = CCDGSBlock(in_channels=1664, group_size=4) self.triplet_attention = TripletAttention(in_channels=1664) self.se_block = SEBlock(in_channels=1664) self.global_pool = nn.AdaptiveAvgPool2d(1) self.flatten = nn.Flatten() self.dropout = nn.Dropout(0.6) self.fc1 = nn.Linear(1664 + handcrafted_feature_dim, 512) self.fc2 = nn.Linear(512, num_classes) def forward(self, x, handcrafted_features=None): x = self.features(x) x = self.ccdgs(x) x = self.triplet_attention(x) x = self.se_block(x) x = self.global_pool(x) x = self.flatten(x) if handcrafted_features is not None: x = torch.cat([x, handcrafted_features], dim=1) x = self.dropout(x) x = F.relu(self.fc1(x)) x = self.dropout(x) x = self.fc2(x) return x # Function to display sample images def display_sample_images(images, labels, split_name, classes, num_samples=4): plt.figure(figsize=(15, 10)) for class_idx, class_name in enumerate(classes): class_indices = [i for i, label in enumerate(labels) if label == class_idx] if not class_indices: continue selected_indices = class_indices[:num_samples] for i, idx in enumerate(selected_indices): img = cv2.cvtColor(images[idx], cv2.COLOR_BGR2RGB) plt.subplot(len(classes), num_samples, class_idx * num_samples + i + 1) plt.imshow(img) plt.title(f'{class_name}') plt.axis('off') plt.suptitle(f'{split_name} Sample Images') plt.tight_layout(rect=[0, 0, 1, 0.95]) plt.savefig(f'/kaggle/working/visualizations/{split_name.lower()}_samples.png') plt.close() # Function to visualize handcrafted features # def visualize_handcrafted_features(images, labels, classes, num_samples=2): # for class_idx, class_name in enumerate(classes): # class_indices = [i for i, label in enumerate(labels) if label == class_idx] # if not class_indices: # continue # selected_indices = class_indices[:num_samples] # for idx in selected_indices: # img = cv2.cvtColor(images[idx], cv2.COLOR_BGR2RGB) # gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) # gray_blur = cv2.GaussianBlur(gray, (5, 5), 0) # gray_eq = cv2.equalizeHist(gray_blur) # median = np.median(gray_eq) # lower_threshold = int(max(0, 0.66 * median)) # upper_threshold = int(min(255, 1.33 * median)) # edges = cv2.Canny(gray_eq, lower_threshold, upper_threshold) # print(f"Visualization - Class: {class_name}, Image {idx}, Edge pixels: {np.sum(edges > 0)}") # edge_hist, _ = np.histogram(edges.ravel(), bins=8, range=(0, 256), density=True) # lbp = local_binary_pattern(gray, P=8, R=1, method='uniform') # lbp_hist, _ = np.histogram(lbp.ravel(), bins=np.arange(0, 10), density=True) # color_hist = [] # for channel in range(img.shape[2]): # hist, _ = np.histogram(img[:, :, channel], bins=8, range=(0, 256), density=True) # color_hist.extend(hist) # plt.figure(figsize=(18, 4)) # plt.subplot(1, 5, 1) # plt.imshow(img) # plt.title(f'Original ({class_name})') # plt.axis('off') # plt.subplot(1, 5, 2) # plt.imshow(edges, cmap='gray') # plt.title('Canny Edge Map') # plt.axis('off') # plt.subplot(1, 5, 3) # plt.bar(range(len(lbp_hist)), lbp_hist) # plt.title('LBP Histogram') # plt.subplot(1, 5, 4) # plt.bar(range(len(color_hist)), color_hist) # plt.title('Color Histogram') # plt.subplot(1, 5, 5) # plt.bar(range(len(edge_hist)), edge_hist) # plt.title('Edge Histogram') # plt.tight_layout() # plt.savefig(f'/kaggle/working/visualizations/handcrafted_features_{class_name}_{idx}.png') # plt.close() # Function to visualize handcrafted features # def visualize_handcrafted_features(images, labels, classes, num_samples=1): # for class_idx, class_name in enumerate(classes): # class_indices = [i for i, label in enumerate(labels) if label == class_idx] # if not class_indices: # continue # selected_indices = class_indices[:num_samples] # for idx in selected_indices: # img = cv2.cvtColor(images[idx], cv2.COLOR_BGR2RGB) # gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) # gray_blur = cv2.GaussianBlur(gray, (5, 5), 0) # gray_eq = cv2.equalizeHist(gray_blur) # median = np.median(gray_eq) # lower_threshold = int(max(0, 0.66 * median)) # upper_threshold = int(min(255, 1.33 * median)) # edges = cv2.Canny(gray_eq, lower_threshold, upper_threshold) # print(f"Visualization - Class: {class_name}, Image {idx}, Edge pixels: {np.sum(edges > 0)}") # edge_hist, _ = np.histogram(edges.ravel(), bins=8, range=(0, 256), density=True) # lbp = local_binary_pattern(gray, P=8, R=1, method='uniform') # lbp_hist, _ = np.histogram(lbp.ravel(), bins=np.arange(0, 10), density=True) # color_hist = [] # for channel in range(img.shape[2]): # hist, _ = np.histogram(img[:, :, channel], bins=8, range=(0, 256), density=True) # color_hist.extend(hist) # plt.figure(figsize=(18, 4)) # plt.subplot(1, 5, 1) # plt.imshow(img) # plt.title(f'Original ({class_name})') # plt.axis('off') # plt.subplot(1, 5, 2) # plt.imshow(edges, cmap='gray') # plt.title('Canny Edge Map') # plt.axis('off') # # LBP Histogram with bold values # plt.subplot(1, 5, 3) # bars = plt.bar(range(len(lbp_hist)), lbp_hist) # plt.title('LBP Histogram') # for bar in bars: # height = bar.get_height() # plt.text(bar.get_x() + bar.get_width()/2., height, # f'{height:.2f}', # ha='center', va='bottom', # fontsize=4, fontweight='bold') # Bold and slightly larger # # Color Histogram with bold values # plt.subplot(1, 5, 4) # bars = plt.bar(range(len(color_hist)), color_hist) # plt.title('Color Histogram') # for bar in bars: # height = bar.get_height() # plt.text(bar.get_x() + bar.get_width()/2., height, # f'{height:.2f}', # ha='center', va='bottom', # fontsize=4, fontweight='bold') # Bold and slightly larger # # Edge Histogram with bold values # plt.subplot(1, 5, 5) # bars = plt.bar(range(len(edge_hist)), edge_hist) # plt.title('Edge Histogram') # for bar in bars: # height = bar.get_height() # plt.text(bar.get_x() + bar.get_width()/2., height, # f'{height:.2f}', # ha='center', va='bottom', # fontsize=4, fontweight='bold') # Bold and slightly larger # plt.tight_layout() # plt.savefig(f'/kaggle/working/visualizations/handcrafted_features_{class_name}_{idx}.png') # plt.close() def visualize_handcrafted_features(images, labels, classes, num_samples=1): # Create main visualization directory main_dir = '/kaggle/working/visualizations/handcrafted_features' os.makedirs(main_dir, exist_ok=True) for class_idx, class_name in enumerate(classes): # Create class-specific subdirectory class_dir = os.path.join(main_dir, f"class_{class_idx}_{class_name}") os.makedirs(class_dir, exist_ok=True) class_indices = [i for i, label in enumerate(labels) if label == class_idx] if not class_indices: continue selected_indices = class_indices[:num_samples] for idx in selected_indices: img = cv2.cvtColor(images[idx], cv2.COLOR_BGR2RGB) gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) # Preprocessing gray_blur = cv2.GaussianBlur(gray, (5, 5), 0) gray_eq = cv2.equalizeHist(gray_blur) # Edge detection median = np.median(gray_eq) lower_threshold = int(max(0, 0.66 * median)) upper_threshold = int(min(255, 1.33 * median)) edges = cv2.Canny(gray_eq, lower_threshold, upper_threshold) print(f"Visualization - Class: {class_name}, Image {idx}, Edge pixels: {np.sum(edges > 0)}") # Feature extraction edge_hist, _ = np.histogram(edges.ravel(), bins=8, range=(0, 256), density=True) lbp = local_binary_pattern(gray, P=8, R=1, method='uniform') lbp_hist, _ = np.histogram(lbp.ravel(), bins=np.arange(0, 10), density=True) color_hist = [] for channel in range(img.shape[2]): hist, _ = np.histogram(img[:, :, channel], bins=8, range=(0, 256), density=True) color_hist.extend(hist) # 1. Original Image plt.figure(figsize=(5, 5)) plt.imshow(img) plt.title(f'Original ({class_name})') plt.axis('off') plt.savefig(os.path.join(class_dir, f'sample_{idx}_original.png'), dpi=120, bbox_inches='tight') plt.close() # 2. Edge Map plt.figure(figsize=(5, 5)) plt.imshow(edges, cmap='gray') plt.title('Canny Edge Map') plt.axis('off') plt.savefig(os.path.join(class_dir, f'sample_{idx}_edges.png'), dpi=120, bbox_inches='tight') plt.close() # 3. LBP Histogram (with your exact text styling) plt.figure(figsize=(8, 4)) bars = plt.bar(range(len(lbp_hist)), lbp_hist) plt.title('LBP Histogram') for bar in bars: height = bar.get_height() plt.text(bar.get_x() + bar.get_width()/2., height, f'{height:.2f}', ha='center', va='bottom', fontsize=8, fontweight='bold') # Slightly larger font for separate image plt.savefig(os.path.join(class_dir, f'sample_{idx}_lbp_hist.png'), dpi=120, bbox_inches='tight') plt.close() # 4. Color Histogram plt.figure(figsize=(10, 4)) bars = plt.bar(range(len(color_hist)), color_hist) plt.title('Color Histogram') for bar in bars: height = bar.get_height() plt.text(bar.get_x() + bar.get_width()/2., height, f'{height:.2f}', ha='center', va='bottom', fontsize=8, fontweight='bold') # Slightly larger font plt.savefig(os.path.join(class_dir, f'sample_{idx}_color_hist.png'), dpi=120, bbox_inches='tight') plt.close() # 5. Edge Histogram plt.figure(figsize=(8, 4)) bars = plt.bar(range(len(edge_hist)), edge_hist) plt.title('Edge Histogram') for bar in bars: height = bar.get_height() plt.text(bar.get_x() + bar.get_width()/2., height, f'{height:.2f}', ha='center', va='bottom', fontsize=8, fontweight='bold') # Slightly larger font plt.savefig(os.path.join(class_dir, f'sample_{idx}_edge_hist.png'), dpi=120, bbox_inches='tight') plt.close() print(f"Saved separate visualizations for class {class_name} sample {idx} in: {class_dir}") # Function to extract handcrafted features def extract_handcrafted_features(image): if isinstance(image, torch.Tensor): image = image.cpu().numpy() image = np.transpose(image, (1, 2, 0)) image = (image * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])) * 255 image = image.astype(np.uint8) gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) gray_blur = cv2.GaussianBlur(gray, (5, 5), 0) gray_eq = cv2.equalizeHist(gray_blur) median = np.median(gray_eq) lower_threshold = int(max(0, 0.66 * median)) upper_threshold = int(min(255, 1.33 * median)) edges = cv2.Canny(gray_eq, lower_threshold, upper_threshold) print(f"Edge detection stats - Lower threshold: {lower_threshold}, Upper threshold: {upper_threshold}, Edge pixels: {np.sum(edges > 0)}") edge_hist, _ = np.histogram(edges.ravel(), bins=8, range=(0, 256), density=True) lbp = local_binary_pattern(gray, P=8, R=1, method='uniform') lbp_hist, _ = np.histogram(lbp.ravel(), bins=np.arange(0, 10), density=True) color_hist = [] for channel in range(image.shape[2]): hist, _ = np.histogram(image[:, :, channel], bins=8, range=(0, 256), density=True) color_hist.extend(hist) features = np.concatenate([lbp_hist, color_hist, edge_hist]) return torch.tensor(features, dtype=torch.float32) # Function to load images with handcrafted features def load_images(split_path, classes, use_csv=False): image_data = [] labels = [] handcrafted_features = [] print(f"Loading images from: {split_path}") csv_path = os.path.join(DATASET_PATH, "labels.csv") if use_csv and os.path.exists(csv_path): print("Found labels.csv, loading dataset from CSV") df = pd.read_csv(csv_path) print("CSV columns:", df.columns) for idx, row in df.iterrows(): img_path = os.path.join(DATASET_PATH, row['image_path']) label_name = row['label'] if label_name not in CLASSES: print(f"Warning: Label {label_name} not in CLASSES, skipping") continue label = CLASSES.index(label_name) try: img = Image.open(img_path).convert('RGB') img_array = np.array(img) img_array = cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR) features = extract_handcrafted_features(img_array) except Exception as e: print(f"Warning: Failed to load image {img_path}: {e}") continue image_data.append(img_array) labels.append(label) handcrafted_features.append(features) else: if not os.path.exists(split_path): print(f"Error: Directory {split_path} does not exist") return image_data, labels, handcrafted_features for class_idx, class_name in enumerate(classes): class_path = os.path.join(split_path, class_name) print(f"Checking class: {class_name} at {class_path}") if not os.path.exists(class_path): print(f"Warning: Class directory {class_path} does not exist") continue all_files = glob(os.path.join(class_path, '*')) print(f"All files in {class_path}: {all_files[:5]}") image_paths = glob(os.path.join(class_path, '*.[jJ][pP][gG]')) + \ glob(os.path.join(class_path, '*.[jJ][pP][eE][gG]')) + \ glob(os.path.join(class_path, '*.png')) print(f"Found {len(image_paths)} images for class {class_name}") for img_path in image_paths: try: img = Image.open(img_path).convert('RGB') img_array = np.array(img) img_array = cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR) features = extract_handcrafted_features(img_array) except Exception as e: print(f"Warning: Failed to load image {img_path}: {e}") continue image_data.append(img_array) labels.append(class_idx) handcrafted_features.append(features) print(f"Total images loaded: {len(image_data)}") return image_data, labels, handcrafted_features # Function to visualize dataset distribution # def visualize_data_distribution(split_path, split_name, classes): # class_counts = [] # for class_name in classes: # class_path = os.path.join(split_path, class_name) # image_paths = glob(os.path.join(class_path, '*.[jJ][pP][gG]')) + \ # glob(os.path.join(class_path, '*.[jJ][pP][eE][gG]')) + \ # glob(os.path.join(class_path, '*.png')) # class_counts.append(len(image_paths)) # print(f"Split: {split_name}, Class: {class_name}, Number of images: {len(image_paths)}") # plt.figure(figsize=(10, 6)) # plt.bar(classes, class_counts) # plt.title(f'{split_name} Dataset Distribution') # plt.xlabel('Classes') # plt.ylabel('Number of Images') # plt.xticks(rotation=45, ha='right') # plt.tight_layout() # plt.savefig(f"/kaggle/working/visualizations/{split_name.lower()}_distribution.png") # plt.close() # Function to visualize dataset distribution def visualize_data_distribution(split_path, split_name, classes): class_counts = [] for class_name in classes: class_path = os.path.join(split_path, class_name) image_paths = glob(os.path.join(class_path, '*.[jJ][pP][gG]')) + \ glob(os.path.join(class_path, '*.[jJ][pP][eE][gG]')) + \ glob(os.path.join(class_path, '*.png')) class_counts.append(len(image_paths)) print(f"Split: {split_name}, Class: {class_name}, Number of images: {len(image_paths)}") plt.figure(figsize=(10, 6)) bars = plt.bar(classes, class_counts) # Store the bar objects plt.title(f'{split_name} Dataset Distribution') plt.xlabel('Classes') plt.ylabel('Number of Images') plt.xticks(rotation=45, ha='right') # Add value labels on top of each bar for bar in bars: height = bar.get_height() plt.text(bar.get_x() + bar.get_width()/2., height, f'{height}', ha='center', va='bottom', fontsize=10, fontweight='bold') plt.tight_layout() plt.savefig(f"/kaggle/working/visualizations/{split_name.lower()}_distribution.png") plt.close() # Custom Dataset Class class FootUlcerDataset(Dataset): def __init__(self, images, labels, handcrafted_features): self.images = images self.labels = labels self.handcrafted_features = handcrafted_features def __len__(self): return len(self.images) def __getitem__(self, idx): image = self.images[idx] label = self.labels[idx] features = self.handcrafted_features[idx] image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) if label in [2, 3]: transform = transforms.Compose([ transforms.RandomHorizontalFlip(p=0.5), transforms.RandomRotation(30), transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2), transforms.RandomResizedCrop(224, scale=(0.8, 1.0)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) else: transform = transforms.Compose([ transforms.RandomHorizontalFlip(p=0.5), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) image = transform(image) return image, label, features # Function to print model summary def print_model_summary(model, input_size=(3, 224, 224), handcrafted_feature_dim=41): device = next(model.parameters()).device model.eval() print("\nModel Summary:") print("=" * 80) print(f"{'Layer':<30} {'Output Shape':<25} {'Param #':<15}") print("-" * 80) total_params = 0 x = torch.randn(1, *input_size).to(device) handcrafted_features = torch.randn(1, handcrafted_feature_dim).to(device) def register_hook(module, input, output): nonlocal total_params class_name = str(module.__class__.__name__) param_count = sum(p.numel() for p in module.parameters()) total_params += param_count output_shape = list(output.shape) if isinstance(output, torch.Tensor) else "N/A" print(f"{class_name:<30} {str(output_shape):<25} {param_count:<15}") hooks = [] for name, module in model.named_modules(): if module != model: hooks.append(module.register_forward_hook(register_hook)) with torch.no_grad(): model(x, handcrafted_features) for hook in hooks: hook.remove() print("-" * 80) print(f"Total Parameters: {total_params:,}") print("=" * 80) # Function to plot ROC curves def plot_roc_curves(labels, probabilities, split_name, classes, model_idx=None): plt.figure(figsize=(10, 8)) for i, class_name in enumerate(classes): fpr, tpr, _ = roc_curve(np.array(labels) == i, probabilities[:, i]) roc_auc = auc(fpr, tpr) plt.plot(fpr, tpr, label=f'{class_name} (AUC = {roc_auc:.2f})') plt.plot([0, 1], [0, 1], 'k--') plt.xlim([0.0, 1.0]) plt.ylim([0.0, 1.05]) plt.xlabel('False Positive Rate') plt.ylabel('True Positive Rate') plt.title(f'ROC Curves - {split_name}' + (f' (Model {model_idx})' if model_idx is not None else '')) plt.legend(loc='lower right') plt.grid(True) filename = f'/kaggle/working/visualizations/roc_{split_name.lower()}' + (f'_model_{model_idx}.png' if model_idx is not None else '_ensemble.png') plt.savefig(filename) plt.close() # Function to visualize feature extraction layer by layer # def visualize_feature_extraction(model, dataloader, device, classes, num_samples=1): # model.eval() # feature_maps = {} # layer_names = ['features', 'ccdgs', 'triplet_attention', 'se_block', 'fc1', 'fc2'] # print("\nModel structure (named modules):") # for name, module in model.named_modules(): # print(f"Layer: {name}, Module: {type(module).__name__}") # print("\nRegistering forward hooks for layers:", layer_names) # def get_hook(name): # def hook(module, input, output): # feature_maps[name] = output.detach() # print(f"Captured output for {name}, shape: {output.shape}") # return hook # hooks = [] # for name in layer_names: # module = getattr(model, name, None) # if module: # hooks.append(module.register_forward_hook(get_hook(name))) # print(f"Hook registered for {name}") # else: # print(f"Warning: Layer {name} not found in model") # images_list = [] # labels_list = [] # probs_list = [] # features_list = [] # with torch.no_grad(): # for images, labels, features in dataloader: # images, labels, features = images.to(device), labels.to(device), features.to(device) # print(f"Processing batch with {images.shape[0]} images, features shape: {features.shape}") # outputs = model(images, features) # probs = F.softmax(outputs, dim=1) # images_list.extend(images.cpu().numpy()) # labels_list.extend(labels.cpu().numpy()) # probs_list.extend(probs.cpu().numpy()) # features_list.extend(features.cpu().numpy()) # break # print(f"Removing {len(hooks)} hooks") # for hook in hooks: # hook.remove() # print(f"Feature maps captured: {list(feature_maps.keys())}") # for idx in range(min(num_samples, len(images_list))): # img = images_list[idx].transpose(1, 2, 0) # img = img * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406]) # img = np.clip(img, 0, 1) # true_label = CLASSES[labels_list[idx]] # plt.figure(figsize=(5, 5)) # plt.imshow(img) # plt.title(f'Input Image (Class: {true_label})') # plt.axis('off') # input_img_path = f'/kaggle/working/visualizations/input_image_sample_{idx}.png' # plt.savefig(input_img_path) # plt.close() # print(f"Saved input image to: {input_img_path}") # for layer_name in layer_names: # if layer_name not in feature_maps: # print(f"No feature map for {layer_name}, skipping visualization") # continue # features = feature_maps[layer_name][idx] # print(f"Visualizing {layer_name}, feature shape: {features.shape}, feature dim:{features.dim()}") # if features.dim() == 3: # num_channels = min(features.shape[0], 16) # plt.figure(figsize=(15, 10)) # for i in range(num_channels): # plt.subplot(4, 4, i + 1) # feature_map = features[i].cpu().numpy() # feature_map = (feature_map - feature_map.min()) / (feature_map.max() - feature_map.min() + 1e-8) # plt.imshow(feature_map, cmap='viridis') # plt.title(f'Channel {i+1}') # plt.axis('off') # plt.suptitle(f'Feature Maps - {layer_name} (Class: {true_label})') # plt.tight_layout(rect=[0, 0, 1, 0.95]) # feature_map_path = f'/kaggle/working/visualizations/feature_maps_{layer_name}_{idx}.png' # plt.savefig(feature_map_path) # plt.close() # print(f"Saved {num_channels} feature maps for {layer_name} to: {feature_map_path}") # else: # values = features.flatten().cpu().numpy() # plt.figure(figsize=(10, 5)) # plt.bar(range(len(values)), values, color='blue') # Changed bar color to blue # plt.title(f'Feature Vector - {layer_name} (Class: {true_label})') # plt.xlabel('Index') # plt.ylabel('Value') # ## Use only if it,s looking good, the grid part # plt.grid(True, linestyle='--', alpha=0.6) # Added light grid for better readability # plt.tight_layout() # vector_path = f'/kaggle/working/visualizations/feature_vector_{layer_name}_{idx}.png' # plt.savefig(vector_path) # plt.close() # print(f"Saved feature vector with {len(values)} elements for {layer_name} to: {vector_path}") # plt.figure(figsize=(8, 6)) # bars = plt.bar(classes, probs_list[idx]) # Store the bar objects # plt.title(f'Classification Probabilities (True: {true_label})') # plt.xlabel('Classes') # plt.ylabel('Probability') # plt.xticks(rotation=45) # # Add value labels on top of each bar # for bar in bars: # height = bar.get_height() # plt.text(bar.get_x() + bar.get_width()/2., height, # f'{height:.3f}', # ha='center', va='bottom', # fontsize=10, fontweight='bold') # plt.tight_layout() # probs_path = f'/kaggle/working/visualizations/classification_probs_{idx}.png' # plt.savefig(probs_path) # plt.close() # print(f"Saved classification probabilities to: {probs_path}") # print("\nListing saved visualization files:") # os.system('ls /kaggle/working/visualizations/') def visualize_feature_extraction(model, dataloader, device, classes, num_samples_per_class=1): model.eval() feature_maps = {} layer_names = ['features', 'ccdgs', 'triplet_attention', 'se_block', 'fc1', 'fc2'] print("\nModel structure (named modules):") for name, module in model.named_modules(): print(f"Layer: {name}, Module: {type(module).__name__}") print("\nRegistering forward hooks for layers:", layer_names) def get_hook(name): def hook(module, input, output): feature_maps[name] = output.detach() print(f"Captured output for {name}, shape: {output.shape}") return hook hooks = [] for name in layer_names: module = getattr(model, name, None) if module: hooks.append(module.register_forward_hook(get_hook(name))) print(f"Hook registered for {name}") else: print(f"Warning: Layer {name} not found in model") # Collect samples from each class class_samples = {class_idx: [] for class_idx in range(len(classes))} with torch.no_grad(): for images, labels, features in dataloader: images, labels, features = images.to(device), labels.to(device), features.to(device) outputs = model(images, features) probs = F.softmax(outputs, dim=1) for i in range(len(images)): class_idx = labels[i].item() if len(class_samples[class_idx]) < num_samples_per_class: class_samples[class_idx].append(( images[i].cpu().numpy(), labels[i].cpu().numpy(), probs[i].cpu().numpy(), features[i].cpu().numpy() )) # Check if we have enough samples from each class if all(len(samples) >= num_samples_per_class for samples in class_samples.values()): break print(f"Removing {len(hooks)} hooks") for hook in hooks: hook.remove() print(f"Feature maps captured: {list(feature_maps.keys())}") # Process one sample from each class for class_idx in range(len(classes)): if not class_samples[class_idx]: print(f"No samples found for class {class_idx} ({classes[class_idx]})") continue # Take the first sample for this class img, label, prob, features = class_samples[class_idx][0] true_label = classes[label] # Process input image img = img.transpose(1, 2, 0) img = img * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406]) img = np.clip(img, 0, 1) plt.figure(figsize=(5, 5)) plt.imshow(img) plt.title(f'Input Image (Class: {true_label})') plt.axis('off') input_img_path = f'/kaggle/working/visualizations/class_{class_idx}_input_image.png' plt.savefig(input_img_path) plt.close() print(f"Saved input image for class {true_label} to: {input_img_path}") # Process feature maps for each layer for layer_name in layer_names: if layer_name not in feature_maps: print(f"No feature map for {layer_name}, skipping visualization") continue features = feature_maps[layer_name][class_idx] # Assuming feature maps are in order print(f"Visualizing {layer_name} for class {true_label}, feature shape: {features.shape}") if features.dim() == 3: num_channels = min(features.shape[0], 16) plt.figure(figsize=(15, 10)) for i in range(num_channels): plt.subplot(4, 4, i + 1) feature_map = features[i].cpu().numpy() feature_map = (feature_map - feature_map.min()) / (feature_map.max() - feature_map.min() + 1e-8) plt.imshow(feature_map, cmap='viridis') plt.title(f'Channel {i+1}') plt.axis('off') plt.suptitle(f'Feature Maps - {layer_name} (Class: {true_label})') plt.tight_layout(rect=[0, 0, 1, 0.95]) feature_map_path = f'/kaggle/working/visualizations/class_{class_idx}_feature_maps_{layer_name}.png' plt.savefig(feature_map_path) plt.close() print(f"Saved {num_channels} feature maps for {layer_name} to: {feature_map_path}") # else: # values = features.flatten().cpu().numpy() # plt.figure(figsize=(10, 5)) # plt.bar(range(len(values)), values, color='blue') # plt.title(f'Feature Vector - {layer_name} (Class: {true_label})') # plt.xlabel('Index') # plt.ylabel('Value') # plt.grid(True, linestyle='--', alpha=0.6) # plt.tight_layout() # vector_path = f'/kaggle/working/visualizations/class_{class_idx}_feature_vector_{layer_name}.png' # plt.savefig(vector_path) # plt.close() # print(f"Saved feature vector with {len(values)} elements for {layer_name} to: {vector_path}") else: values = features.flatten().cpu().numpy() num_features = len(values) # Adjust figure width based on number of features fig_width = max(20, num_features * 0.025) # 0.025 inches per bar (adjustable) plt.figure(figsize=(fig_width, 5)) # Wider for more bars # Plot bars with optimized width & spacing bars = plt.bar( range(num_features), values, color='#1f77b4', # Matplotlib default blue (better than 'blue') edgecolor='#1f77b4', linewidth=0.05, # Thinner border for dense plots width=0.9, # Slightly narrower to guarantee gaps align='center' ) # Hide x-axis labels if too many features if num_features > 100: ticks = list(range(0, num_features, 50)) + [num_features-1] # Add last feature plt.xticks(ticks) # Diagonal labels plt.title(f'Feature Vector - {layer_name} (Class: {true_label})') plt.xlabel('Feature') plt.ylabel('Activation Value') plt.grid(True, linestyle=':', alpha=0.5) plt.tight_layout() vector_path = f'/kaggle/working/visualizations/class_{class_idx}_feature_vector_{layer_name}.png' plt.savefig(vector_path, dpi=120, bbox_inches='tight', facecolor='white') plt.close() print(f"Saved feature vector with {len(values)} elements for {layer_name} to: {vector_path}") # Process classification probabilities plt.figure(figsize=(8, 6)) bars = plt.bar(classes, prob) plt.title(f'Classification Probabilities (True: {true_label})') plt.xlabel('Classes') plt.ylabel('Probability') plt.xticks(rotation=45) for bar in bars: height = bar.get_height() plt.text(bar.get_x() + bar.get_width()/2., height, f'{height:.3f}', ha='center', va='bottom', fontsize=10, fontweight='bold') plt.tight_layout() probs_path = f'/kaggle/working/visualizations/class_{class_idx}_classification_probs.png' plt.savefig(probs_path) plt.close() print(f"Saved classification probabilities to: {probs_path}") print("\nListing saved visualization files:") os.system('ls /kaggle/working/visualizations/') # Training Function def train_model(model, dataloader, criterion, optimizer, device, epochs=100, model_idx=0): history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []} best_val_loss = float('inf') best_val_acc = 0.0 best_train_acc = 0.0 patience = 10 counter = 0 scaler = torch.cuda.amp.GradScaler() scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, verbose=True) train_batches = len(dataloader['train']) val_batches = len(dataloader['val']) print(f"Training dataset: {train_batches} batches") print(f"Validation dataset: {val_batches} batches") for epoch in range(epochs): print(f"\n--- Epoch {epoch+1}/{epochs} ---") model.train() running_loss = 0.0 correct, total = 0, 0 for batch_idx, (images, labels, features) in enumerate(dataloader['train']): print(f"Training epoch {epoch+1}, batch {batch_idx+1}/{train_batches}") images, labels, features = images.to(device), labels.to(device), features.to(device) optimizer.zero_grad() with torch.cuda.amp.autocast(): outputs = model(images, features) loss = criterion(outputs, labels) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() running_loss += loss.item() _, predicted = outputs.max(1) total += labels.size(0) correct += predicted.eq(labels).sum().item() epoch_loss = running_loss / train_batches if train_batches > 0 else 0.0 epoch_acc = 100. * correct / total if total > 0 else 0.0 history['train_loss'].append(epoch_loss) history['train_acc'].append(epoch_acc) best_train_acc = max(best_train_acc, epoch_acc) model.eval() val_loss = 0.0 val_correct, val_total = 0, 0 with torch.no_grad(): for batch_idx, (val_images, val_labels, val_features) in enumerate(dataloader['val']): print(f"Validation epoch {epoch+1}, batch {batch_idx+1}/{val_batches}") val_images, val_labels, val_features = val_images.to(device), val_labels.to(device), val_features.to(device) with torch.cuda.amp.autocast(): val_outputs = model(val_images, val_features) loss = criterion(val_outputs, val_labels) val_loss += loss.item() _, predicted = val_outputs.max(1) val_total += val_labels.size(0) val_correct += predicted.eq(val_labels).sum().item() val_epoch_loss = val_loss / val_batches if val_batches > 0 else 0.0 val_epoch_acc = 100. * val_correct / val_total if val_total > 0 else 0.0 history['val_loss'].append(val_epoch_loss) history['val_acc'].append(val_epoch_acc) best_val_acc = max(best_val_acc, val_epoch_acc) scheduler.step(val_epoch_loss) if val_epoch_loss < best_val_loss: best_val_loss = val_epoch_loss counter = 0 torch.save(model.state_dict(), f'/kaggle/working/best_model_{model_idx}.pth') else: counter += 1 if counter >= patience: print("Early stopping triggered") break print(f"Training Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.2f}%") print(f"Validation Loss: {val_epoch_loss:.4f}, Accuracy: {val_epoch_acc:.2f}%") print(f"Best Training Accuracy (Model {model_idx}): {best_train_acc:.2f}%") print(f"Best Validation Accuracy (Model {model_idx}): {best_val_acc:.2f}%") return history, best_train_acc, best_val_acc # Function to Plot Training History def plot_training_history(history, epochs, model_idx=0): epochs_range = range(1, len(history['train_loss']) + 1) plt.figure(figsize=(12, 5)) plt.subplot(1, 2, 1) plt.plot(epochs_range, history['train_loss'], label='Training Loss') plt.plot(epochs_range, history['val_loss'], label='Validation Loss') plt.xlabel('Epochs') plt.ylabel('Loss') plt.title(f'Training and Validation Loss (Model {model_idx})') plt.legend() plt.grid(True) plt.subplot(1, 2, 2) plt.plot(epochs_range, history['train_acc'], label='Training Accuracy') plt.plot(epochs_range, history['val_acc'], label='Validation Accuracy') plt.xlabel('Epochs') plt.ylabel('Accuracy (%)') plt.title(f'Training and Validation Accuracy (Model {model_idx})') plt.legend() plt.grid(True) plt.tight_layout() plt.savefig(f'/kaggle/working/visualizations/training_history_model_{model_idx}.png') plt.close() # Function to Evaluate Model def evaluate_model(model, dataloader, device, split_name, classes, model_idx=None, use_tta=False): model.eval() correct = 0 total = 0 all_predictions = [] all_labels = [] all_probs = [] mean = torch.tensor([0.485, 0.456, 0.406], device=device).view(3, 1, 1) std = torch.tensor([0.229, 0.224, 0.225], device=device).view(3, 1, 1) tta_transforms = [ transforms.Compose([ transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]), transforms.Compose([ transforms.RandomHorizontalFlip(p=1.0), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]), transforms.Compose([ transforms.RandomRotation(10), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) ] with torch.no_grad(): for images, labels, features in dataloader: images, labels, features = images.to(device), labels.to(device), features.to(device) if use_tta: batch_probs = [] for transform in tta_transforms: denorm_images = images * std + mean denorm_images = denorm_images.clamp(0, 1) * 255 denorm_images = denorm_images.to(torch.uint8) tta_images = torch.stack([ transform(Image.fromarray(img.cpu().numpy().transpose(1, 2, 0))) for img in denorm_images ]).to(device) outputs = model(tta_images, features) batch_probs.append(F.softmax(outputs, dim=1)) avg_probs = torch.stack(batch_probs).mean(dim=0) _, predicted = torch.max(avg_probs, 1) all_probs.extend(avg_probs.cpu().numpy()) else: outputs = model(images, features) _, predicted = torch.max(outputs.data, 1) all_probs.extend(F.softmax(outputs, dim=1).cpu().numpy()) all_predictions.extend(predicted.cpu().numpy()) all_labels.extend(labels.cpu().numpy()) total += labels.size(0) correct += (predicted == labels).sum().item() accuracy = 100 * correct / total if total > 0 else 0.0 # cm = confusion_matrix(all_labels, all_predictions) # plt.figure(figsize=(10, 8)) # sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=classes, yticklabels=classes) # plt.title(f'Confusion Matrix - {split_name}' + (f' (Model {model_idx})' if model_idx is not None else '')) # plt.xlabel('Predicted') # plt.ylabel('True') cm = confusion_matrix(all_labels, all_predictions) plt.figure(figsize=(10, 8)) # Create heatmap with custom annotation formatting sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=classes, yticklabels=classes, annot_kws={'size': 12, 'weight': 'bold'}) # Larger bold annotations # Make title and axis labels bold plt.title(f'Confusion Matrix - {split_name}' + (f' (Model {model_idx})' if model_idx is not None else ''), fontsize=14, fontweight='bold') # Bold title with larger font plt.xlabel('Predicted', fontsize=12, fontweight='bold') # Bold x-label plt.ylabel('True', fontsize=12, fontweight='bold') # Bold y-label plt.tight_layout() filename = f'/kaggle/working/visualizations/cm_{split_name.lower()}' + (f'_model_{model_idx}.png' if model_idx is not None else '_ensemble.png') plt.savefig(filename) plt.close() report = classification_report(all_labels, all_predictions, target_names=classes, output_dict=True) report_df = pd.DataFrame(report).transpose() report_filename = f'/kaggle/working/classification_report_{split_name.lower()}' + (f'_model_{model_idx}.csv' if model_idx is not None else '_ensemble.csv') report_df.to_csv(report_filename) all_probs = np.array(all_probs) plot_roc_curves(all_labels, all_probs, split_name, classes, model_idx) return accuracy, all_predictions, all_labels, all_probs, report_df # Ensemble Voting Function def ensemble_voting(models, dataloader, device, split_name, classes): all_predictions = [] all_labels = [] all_probs = [] for model in models: model.eval() with torch.no_grad(): for images, labels, features in dataloader: images, labels, features = images.to(device), labels.to(device), features.to(device) votes = [] probs = [] for model in models: outputs = model(images, features) _, predicted = torch.max(outputs.data, 1) votes.append(predicted.cpu().numpy()) probs.append(F.softmax(outputs, dim=1).cpu().numpy()) votes = np.array(votes) final_predictions = np.apply_along_axis(lambda x: np.bincount(x).argmax(), axis=0, arr=votes) avg_probs = np.mean(probs, axis=0) all_predictions.extend(final_predictions) all_labels.extend(labels.cpu().numpy()) all_probs.extend(avg_probs) accuracy = 100 * sum(np.array(all_predictions) == np.array(all_labels)) / len(all_labels) # cm = confusion_matrix(all_labels, all_predictions) # plt.figure(figsize=(10, 8)) # sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=classes, yticklabels=classes) # plt.title(f'Confusion Matrix - {split_name} (Ensemble)') # plt.xlabel('Predicted') # plt.ylabel('True') cm = confusion_matrix(all_labels, all_predictions) plt.figure(figsize=(10, 8)) # Create heatmap with custom annotation formatting sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=classes, yticklabels=classes, annot_kws={'size': 12, 'weight': 'bold'}) # Larger bold annotations plt.title(f'Confusion Matrix - {split_name} (Ensemble)', fontsize=14, fontweight='bold') plt.xlabel('Predicted', fontsize=12, fontweight='bold') # Bold x-label plt.ylabel('True', fontsize=12, fontweight='bold') # Bold y-label plt.tight_layout() plt.savefig(f'/kaggle/working/visualizations/cm_{split_name.lower()}_ensemble.png') plt.close() report = classification_report(all_labels, all_predictions, target_names=classes, output_dict=True) report_df = pd.DataFrame(report).transpose() report_df.to_csv(f'/kaggle/working/classification_report_{split_name.lower()}_ensemble.csv') all_probs = np.array(all_probs) plot_roc_curves(all_labels, all_probs, f'{split_name} (Ensemble)', classes) return accuracy, all_predictions, all_labels, all_probs, report_df # Function to visualize voting process def visualize_voting_process(models, dataloader, device, classes, num_samples=5): model_predictions = [] true_labels = [] images_list = [] with torch.no_grad(): for images, labels, features in dataloader: images, labels, features = images.to(device), labels.to(device), features.to(device) batch_preds = [] for model in models: model.eval() outputs = model(images, features) _, predicted = torch.max(outputs.data, 1) batch_preds.append(predicted.cpu().numpy()) model_predictions.extend(np.array(batch_preds).T) true_labels.extend(labels.cpu().numpy()) images_list.extend(images.cpu().numpy()) if len(true_labels) >= num_samples: break model_predictions = model_predictions[:num_samples] true_labels = true_labels[:num_samples] images_list = images_list[:num_samples] plt.figure(figsize=(15, num_samples * 3)) for i in range(num_samples): img = images_list[i].transpose(1, 2, 0) img = img * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406]) img = np.clip(img, 0, 1) plt.subplot(num_samples, 1, i + 1) plt.imshow(img) preds = [CLASSES[p] for p in model_predictions[i]] ensemble_pred = CLASSES[np.bincount(model_predictions[i]).argmax()] title = f'True: {CLASSES[true_labels[i]]}\n' + \ f'Model 1: {preds[0]}, Model 2: {preds[1]}, Model 3: {preds[2]}\n' + \ f'Ensemble: {ensemble_pred}' plt.title(title) plt.axis('off') plt.tight_layout() plt.savefig('/kaggle/working/visualizations/voting_process.png') plt.close() # Function to visualize predictions per class # def visualize_predictions_per_class(model, dataloader, device, classes, split_name, model_idx=None, num_samples=4): # model.eval() # class_images = {i: [] for i in range(len(classes))} # class_preds = {i: [] for i in range(len(classes))} # class_labels = {i: [] for i in range(len(classes))} # with torch.no_grad(): # for images, labels, features in dataloader: # images, labels, features = images.to(device), labels.to(device), features.to(device) # outputs = model(images, features) # _, predicted = torch.max(outputs.data, 1) # for img, pred, label in zip(images.cpu().numpy(), predicted.cpu().numpy(), labels.cpu().numpy()): # if len(class_images[label]) < num_samples: # class_images[label].append(img) # class_preds[label].append(pred) # class_labels[label].append(label) # if all(len(class_images[i]) >= num_samples for i in range(len(classes))): # break # for class_idx, class_name in enumerate(classes): # plt.figure(figsize=(15, 5)) # for i in range(min(num_samples, len(class_images[class_idx]))): # img = class_images[class_idx][i].transpose(1, 2, 0) # img = img * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406]) # img = np.clip(img, 0, 1) # plt.subplot(1, num_samples, i + 1) # plt.imshow(img) # plt.title(f'True: {class_name}\nPred: {CLASSES[class_preds[class_idx][i]]}') # plt.axis('off') # plt.suptitle(f'Predictions for {class_name} ({split_name})') # plt.tight_layout(rect=[0, 0, 1, 0.95]) # filename = f'/kaggle/working/visualizations/predictions_{class_name}_{split_name.lower()}' + (f'_model_{model_idx}.png' if model_idx is not None else '_ensemble.png') # plt.savefig(filename) # plt.close() def visualize_predictions_per_class(model, dataloader, device, classes, split_name, model_idx=None, num_samples=4): model.eval() class_images = {i: [] for i in range(len(classes))} class_preds = {i: [] for i in range(len(classes))} class_labels = {i: [] for i in range(len(classes))} with torch.no_grad(): for images, labels, features in dataloader: images, labels, features = images.to(device), labels.to(device), features.to(device) outputs = model(images, features) _, predicted = torch.max(outputs.data, 1) for img, pred, label in zip(images.cpu().numpy(), predicted.cpu().numpy(), labels.cpu().numpy()): if len(class_images[label]) < num_samples: class_images[label].append(img) class_preds[label].append(pred) class_labels[label].append(label) if all(len(class_images[i]) >= num_samples for i in range(len(classes))): break for class_idx, class_name in enumerate(classes): plt.figure(figsize=(15, 5)) for i in range(min(num_samples, len(class_images[class_idx]))): img = class_images[class_idx][i].transpose(1, 2, 0) img = img * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406]) img = np.clip(img, 0, 1) plt.subplot(1, num_samples, i + 1) plt.imshow(img) plt.title(f'True: {class_name}\nPred: {CLASSES[class_preds[class_idx][i]]}', fontweight='bold') # Bold title plt.axis('off') # Make suptitle bold and adjust font properties plt.suptitle(f'Predictions for {class_name} ({split_name})', fontweight='bold', fontsize=12) # Optional: slightly larger font plt.tight_layout(rect=[0, 0, 1, 0.95]) filename = f'/kaggle/working/visualizations/predictions_{class_name}_{split_name.lower()}' + (f'_model_{model_idx}.png' if model_idx is not None else '_ensemble.png') plt.savefig(filename) plt.close() # # Function to visualize predictions combined per class # def visualize_predictions_grid_per_class(model, dataloader, device, classes, split_name, model_idx=None, num_samples=2): # import os # os.makedirs("/kaggle/working/visualizations", exist_ok=True) # model.eval() # num_classes = len(classes) # # Collect samples # class_images = {i: [] for i in range(num_classes)} # class_preds = {i: [] for i in range(num_classes)} # class_labels = {i: [] for i in range(num_classes)} # with torch.no_grad(): # for images, labels, features in dataloader: # images, labels, features = images.to(device), labels.to(device), features.to(device) # outputs = model(images, features) # _, predicted = torch.max(outputs.data, 1) # for img, pred, label in zip(images.cpu().numpy(), predicted.cpu().numpy(), labels.cpu().numpy()): # if len(class_images[label]) < num_samples: # class_images[label].append(img) # class_preds[label].append(pred) # class_labels[label].append(label) # if all(len(class_images[i]) >= num_samples for i in range(num_classes)): # break # # Plot: Grid of num_samples rows × num_classes columns # plt.figure(figsize=(4 * num_classes, 4 * num_samples)) # for row in range(num_samples): # for class_idx in range(num_classes): # if row >= len(class_images[class_idx]): # continue # img = class_images[class_idx][row].transpose(1, 2, 0) # img = img * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406]) # unnormalize # img = np.clip(img, 0, 1) # ax_idx = row * num_classes + class_idx + 1 # plt.subplot(num_samples, num_classes, ax_idx) # true_label = classes[class_labels[class_idx][row]] # pred_label = classes[class_preds[class_idx][row]] # plt.imshow(img) # plt.title(f'True: {true_label}\nPred: {pred_label}') # plt.axis('off') # plt.suptitle(f'{split_name} Predictions Grid ({num_samples}×{num_classes})') # filename = f'/kaggle/working/visualizations/predictions_grid_{split_name.lower()}' + (f'_model_{model_idx}.png' if model_idx is not None else '_ensemble.png') # plt.tight_layout(rect=[0, 0, 1, 0.95]) # plt.savefig(filename) # plt.close() # Main Execution if __name__ == "__main__": # Set device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") # Debug dataset directory print("Checking dataset directory structure:") for dirname, _, filenames in os.walk(DATASET_PATH): print(f"Directory: {dirname}, Files: {len(filenames)}") for filename in filenames[:5]: print(f" - {os.path.join(dirname, filename)}") # Check for CSV file csv_path = os.path.join(DATASET_PATH, "labels.csv") use_csv = os.path.exists(csv_path) if use_csv: print("Detected labels.csv, will load dataset from CSV") else: print("No labels.csv found, assuming directory-based structure") # Load dataset train_path = os.path.join(DATASET_PATH, "TRAIN") val_path = os.path.join(DATASET_PATH, "VALIDATION") test_path = os.path.join(DATASET_PATH, "TEST") train_images, train_labels, train_features = load_images(train_path, CLASSES, use_csv) val_images, val_labels, val_features = load_images(val_path, CLASSES, use_csv) test_images, test_labels, test_features = load_images(test_path, CLASSES, use_csv) # Check if datasets are empty if not train_images: raise ValueError("Training dataset is empty. Please check the dataset path, class names, image files, or CSV structure.") if not val_images: print("Warning: Validation dataset is empty. Creating validation split from training data.") train_images, val_images, train_labels, val_labels, train_features, val_features = train_test_split( train_images, train_labels, train_features, test_size=0.2, stratify=train_labels, random_state=42 ) if not test_images: print("Warning: Test dataset is empty.") test_images, test_labels, test_features = [], [], [] # Visualize dataset visualize_data_distribution(train_path, "Train", CLASSES) visualize_data_distribution(val_path, "Validation", CLASSES) visualize_data_distribution(test_path, "Test", CLASSES) display_sample_images(train_images, train_labels, "Train", CLASSES) display_sample_images(val_images, val_labels, "Validation", CLASSES) display_sample_images(test_images, test_labels, "Test", CLASSES) visualize_handcrafted_features(train_images, train_labels, CLASSES) # Create datasets train_dataset = FootUlcerDataset(train_images, train_labels, train_features) val_dataset = FootUlcerDataset(val_images, val_labels, val_features) test_dataset = FootUlcerDataset(test_images, test_labels, test_features) # Create WeightedRandomSampler train_labels_np = np.array(train_labels) class_counts = np.array([sum(train_labels_np == i) for i in range(len(CLASSES))]) print(f"Class counts: {dict(zip(CLASSES, class_counts))}") if np.any(class_counts == 0): print("Warning: Some classes have zero samples in the training set.") class_weights = 1.0 / (class_counts + 1e-6) sample_weights = class_weights[train_labels_np] sampler = WeightedRandomSampler(sample_weights, len(train_labels), replacement=True) # Create DataLoaders batch_size = 32 dataloader = { 'train': DataLoader(train_dataset, batch_size=batch_size, sampler=sampler), 'val': DataLoader(val_dataset, batch_size=batch_size, shuffle=False), 'test': DataLoader(test_dataset, batch_size=batch_size, shuffle=False) } # Train and evaluate models num_models = 3 models_list = [] test_accuracies = [] best_train_accuracies = [] best_val_accuracies = [] for i in range(num_models): print(f"\nTraining Model {i+1}/{num_models}") model = DenseShuffleGCANet(num_classes=len(CLASSES), handcrafted_feature_dim=41).to(device) print(f"\nModel {i+1} Summary:") print_model_summary(model, input_size=(3, 224, 224), handcrafted_feature_dim=41) criterion = FocalLoss(gamma=3.0, alpha=0.5) optimizer = optim.Adam(model.parameters(), lr=0.00005, weight_decay=0.001) history, best_train_acc, best_val_acc = train_model(model, dataloader, criterion, optimizer, device, epochs=100, model_idx=i) plot_training_history(history, len(history['train_loss']), model_idx=i) best_train_accuracies.append(best_train_acc) best_val_accuracies.append(best_val_acc) print(f"\nEvaluating Model {i+1} on Training Set") train_acc, train_preds, train_labels, _, _ = evaluate_model(model, dataloader['train'], device, 'Train', CLASSES, i) print(f"Model {i+1} Train Accuracy: {train_acc:.2f}%") print(f"\nEvaluating Model {i+1} on Validation Set") val_acc, val_preds, val_labels, _, _ = evaluate_model(model, dataloader['val'], device, 'Validation', CLASSES, i) print(f"Model {i+1} Validation Accuracy: {val_acc:.2f}%") print(f"\nEvaluating Model {i+1} on Test Set") test_acc, test_preds, test_labels, _, _ = evaluate_model(model, dataloader['test'], device, 'Test', CLASSES, i) print(f"Model {i+1} Test Accuracy: {test_acc:.2f}%") test_accuracies.append(test_acc) visualize_predictions_per_class(model, dataloader['test'], device, CLASSES, 'Test', model_idx=i) if i == 0: print(f"\nVisualizing Feature Extraction for Model {i+1}") visualize_feature_extraction(model, dataloader['test'], device, CLASSES, num_samples_per_class=1) models_list.append(model) # Evaluate ensemble print("\nEvaluating Ensemble on Training Set") ensemble_train_acc, _, _, _, _ = ensemble_voting(models_list, dataloader['train'], device, 'Train', CLASSES) print(f"Ensemble Train Accuracy: {ensemble_train_acc:.2f}%") print("\nEvaluating Ensemble on Validation Set") ensemble_val_acc, _, _, _, _ = ensemble_voting(models_list, dataloader['val'], device, 'Validation', CLASSES) print(f"Ensemble Validation Accuracy: {ensemble_val_acc:.2f}%") print("\nEvaluating Ensemble on Test Set") ensemble_test_acc, ensemble_test_preds, ensemble_test_labels, _, _ = ensemble_voting(models_list, dataloader['test'], device, 'Test', CLASSES) print(f"Ensemble Test Accuracy: {ensemble_test_acc:.2f}%") visualize_predictions_per_class(models_list[0], dataloader['test'], device, CLASSES, 'Test', model_idx=None) visualize_voting_process(models_list, dataloader['test'], device, CLASSES) # Evaluate TTA print("\nEvaluating Best Model with TTA on Test Set") tta_acc, _, _, _, _ = evaluate_model(models_list[0], dataloader['test'], device, 'Test_TTA', CLASSES, model_idx=0, use_tta=True) print(f"TTA Test Accuracy: {tta_acc:.2f}%") # Statistical Analysis print("\nStatistical Analysis of Model Performance:") print(f"Mean Test Accuracy: {np.mean(test_accuracies):.2f}% ± {np.std(test_accuracies):.2f}%") print(f"Best Training Accuracies: {[f'{acc:.2f}%' for acc in best_train_accuracies]}") print(f"Best Validation Accuracies: {[f'{acc:.2f}%' for acc in best_val_accuracies]}") # Save predictions predictions_df = pd.DataFrame({ 'True_Label': [CLASSES[label] for label in ensemble_test_labels], 'Predicted_Label': [CLASSES[pred] for pred in ensemble_test_preds] }) predictions_df.to_csv('/kaggle/working/predictions/test_predictions_ensemble.csv', index=False) print("Predictions saved to /kaggle/working/predictions/test_predictions_ensemble.csv") # Save summary report summary_report = { 'Model': [f'Model {i+1}' for i in range(num_models)] + ['Ensemble', 'TTA'], 'Test_Accuracy': test_accuracies + [ensemble_test_acc, tta_acc], 'Best_Train_Accuracy': best_train_accuracies + [None, None], 'Best_Val_Accuracy': best_val_accuracies + [None, None] } summary_df = pd.DataFrame(summary_report) summary_df.to_csv('/kaggle/working/summary_report.csv', index=False) print("\nSummary Report:") print(summary_df)