|
|
|
|
|
import os |
|
|
import cv2 |
|
|
import numpy as np |
|
|
import matplotlib.pyplot as plt |
|
|
from glob import glob |
|
|
from PIL import Image |
|
|
from sklearn.model_selection import train_test_split |
|
|
|
|
|
|
|
|
import pandas as pd |
|
|
import seaborn as sns |
|
|
sns.set_style('darkgrid') |
|
|
from sklearn.metrics import confusion_matrix, classification_report, roc_curve, auc |
|
|
from sklearn.utils import shuffle |
|
|
from torch.utils.data import WeightedRandomSampler |
|
|
from skimage.feature import local_binary_pattern |
|
|
|
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
import torch.optim as optim |
|
|
import torchvision.transforms as transforms |
|
|
import torchvision.models as models |
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
|
|
|
|
|
|
DATASET_PATH = "/kaggle/input/ms-dfu/DFU_CLASSES(4)" |
|
|
CLASSES = ["NONE", "INFECTION", "ISCHAEMIA", "BOTH"] |
|
|
|
|
|
|
|
|
os.makedirs("/kaggle/working/logs", exist_ok=True) |
|
|
os.makedirs("/kaggle/working/predictions", exist_ok=True) |
|
|
os.makedirs("/kaggle/working/visualizations", exist_ok=True) |
|
|
|
|
|
|
|
|
class SEBlock(nn.Module): |
|
|
def __init__(self, in_channels, reduction=16): |
|
|
super(SEBlock, self).__init__() |
|
|
self.fc1 = nn.Linear(in_channels, in_channels // reduction, bias=False) |
|
|
self.fc2 = nn.Linear(in_channels // reduction, in_channels, bias=False) |
|
|
self.global_pool = nn.AdaptiveAvgPool2d(1) |
|
|
|
|
|
def forward(self, x): |
|
|
batch, channels, _, _ = x.size() |
|
|
y = self.global_pool(x).view(batch, channels) |
|
|
y = F.relu(self.fc1(y)) |
|
|
y = torch.sigmoid(self.fc2(y)).view(batch, channels, 1, 1) |
|
|
return x * y |
|
|
|
|
|
|
|
|
class FocalLoss(nn.Module): |
|
|
def __init__(self, gamma=3.0, alpha=0.5): |
|
|
super(FocalLoss, self).__init__() |
|
|
self.gamma = gamma |
|
|
self.alpha = alpha |
|
|
|
|
|
def forward(self, inputs, targets): |
|
|
ce_loss = F.cross_entropy(inputs, targets, reduction='none') |
|
|
pt = torch.exp(-ce_loss) |
|
|
focal_loss = self.alpha * (1 - pt) ** self.gamma * ce_loss |
|
|
return focal_loss.mean() |
|
|
|
|
|
|
|
|
class CCDGSBlock(nn.Module): |
|
|
def __init__(self, in_channels, group_size=4): |
|
|
super(CCDGSBlock, self).__init__() |
|
|
self.group_size = group_size |
|
|
self.group_conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, groups=group_size, bias=False) |
|
|
self.bn1 = nn.BatchNorm2d(in_channels) |
|
|
self.depth_conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, padding=1, groups=in_channels, bias=False) |
|
|
self.bn2 = nn.BatchNorm2d(in_channels) |
|
|
self.point_conv = nn.Conv2d(in_channels, in_channels, kernel_size=1, bias=False) |
|
|
self.bn3 = nn.BatchNorm2d(in_channels) |
|
|
self.global_pool = nn.AdaptiveAvgPool2d(1) |
|
|
|
|
|
def channel_shuffle(self, x, groups): |
|
|
batchsize, num_channels, height, width = x.size() |
|
|
channels_per_group = num_channels // groups |
|
|
x = x.view(batchsize, groups, channels_per_group, height, width) |
|
|
x = torch.transpose(x, 1, 2).contiguous() |
|
|
x = x.view(batchsize, -1, height, width) |
|
|
return x |
|
|
|
|
|
def forward(self, x): |
|
|
out = self.group_conv(x) |
|
|
out = self.bn1(out) |
|
|
out = F.relu(out) |
|
|
out = self.channel_shuffle(out, self.group_size) |
|
|
out = self.depth_conv(out) |
|
|
out = self.bn2(out) |
|
|
out = F.relu(out) |
|
|
out = self.point_conv(out) |
|
|
out = self.bn3(out) |
|
|
out = F.relu(out) |
|
|
out = self.global_pool(out) |
|
|
return out |
|
|
|
|
|
|
|
|
class TripletAttention(nn.Module): |
|
|
def __init__(self, in_channels, kernel_size=7): |
|
|
super(TripletAttention, self).__init__() |
|
|
self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False) |
|
|
self.conv2 = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False) |
|
|
self.conv3 = nn.Conv2d(2, 1, kernel_size, padding=kernel_size//2, bias=False) |
|
|
|
|
|
def z_pool(self, x): |
|
|
max_pool = torch.max(x, dim=1, keepdim=True)[0] |
|
|
avg_pool = torch.mean(x, dim=1, keepdim=True) |
|
|
return torch.cat([max_pool, avg_pool], dim=1) |
|
|
|
|
|
def forward(self, x): |
|
|
x1 = torch.rot90(x, 1, [2, 3]) |
|
|
x1 = self.z_pool(x1) |
|
|
x1 = self.conv1(x1) |
|
|
x1 = torch.sigmoid(x1) |
|
|
x1 = torch.rot90(x1, -1, [2, 3]) |
|
|
y1 = x * x1 |
|
|
x2 = torch.rot90(x, 1, [1, 3]) |
|
|
x2 = self.z_pool(x2) |
|
|
x2 = self.conv2(x2) |
|
|
x2 = torch.sigmoid(x2) |
|
|
x2 = torch.rot90(x2, -1, [1, 3]) |
|
|
y2 = x * x2 |
|
|
x3 = self.z_pool(x) |
|
|
x3 = self.conv3(x3) |
|
|
x3 = torch.sigmoid(x3) |
|
|
y3 = x * x3 |
|
|
out = (y1 + y2 + y3) / 3.0 |
|
|
return out |
|
|
|
|
|
|
|
|
class DenseShuffleGCANet(nn.Module): |
|
|
def __init__(self, num_classes=4, handcrafted_feature_dim=41): |
|
|
super(DenseShuffleGCANet, self).__init__() |
|
|
densenet = models.densenet169(weights='IMAGENET1K_V1') |
|
|
self.features = densenet.features |
|
|
self.ccdgs = CCDGSBlock(in_channels=1664, group_size=4) |
|
|
self.triplet_attention = TripletAttention(in_channels=1664) |
|
|
self.se_block = SEBlock(in_channels=1664) |
|
|
self.global_pool = nn.AdaptiveAvgPool2d(1) |
|
|
self.flatten = nn.Flatten() |
|
|
self.dropout = nn.Dropout(0.6) |
|
|
self.fc1 = nn.Linear(1664 + handcrafted_feature_dim, 512) |
|
|
self.fc2 = nn.Linear(512, num_classes) |
|
|
|
|
|
def forward(self, x, handcrafted_features=None): |
|
|
x = self.features(x) |
|
|
x = self.ccdgs(x) |
|
|
x = self.triplet_attention(x) |
|
|
x = self.se_block(x) |
|
|
x = self.global_pool(x) |
|
|
x = self.flatten(x) |
|
|
if handcrafted_features is not None: |
|
|
x = torch.cat([x, handcrafted_features], dim=1) |
|
|
x = self.dropout(x) |
|
|
x = F.relu(self.fc1(x)) |
|
|
x = self.dropout(x) |
|
|
x = self.fc2(x) |
|
|
return x |
|
|
|
|
|
|
|
|
def display_sample_images(images, labels, split_name, classes, num_samples=4): |
|
|
plt.figure(figsize=(15, 10)) |
|
|
for class_idx, class_name in enumerate(classes): |
|
|
class_indices = [i for i, label in enumerate(labels) if label == class_idx] |
|
|
if not class_indices: |
|
|
continue |
|
|
selected_indices = class_indices[:num_samples] |
|
|
for i, idx in enumerate(selected_indices): |
|
|
img = cv2.cvtColor(images[idx], cv2.COLOR_BGR2RGB) |
|
|
plt.subplot(len(classes), num_samples, class_idx * num_samples + i + 1) |
|
|
plt.imshow(img) |
|
|
plt.title(f'{class_name}') |
|
|
plt.axis('off') |
|
|
plt.suptitle(f'{split_name} Sample Images') |
|
|
plt.tight_layout(rect=[0, 0, 1, 0.95]) |
|
|
plt.savefig(f'/kaggle/working/visualizations/{split_name.lower()}_samples.png') |
|
|
plt.close() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def visualize_handcrafted_features(images, labels, classes, num_samples=1): |
|
|
|
|
|
main_dir = '/kaggle/working/visualizations/handcrafted_features' |
|
|
os.makedirs(main_dir, exist_ok=True) |
|
|
|
|
|
for class_idx, class_name in enumerate(classes): |
|
|
|
|
|
class_dir = os.path.join(main_dir, f"class_{class_idx}_{class_name}") |
|
|
os.makedirs(class_dir, exist_ok=True) |
|
|
|
|
|
class_indices = [i for i, label in enumerate(labels) if label == class_idx] |
|
|
if not class_indices: |
|
|
continue |
|
|
|
|
|
selected_indices = class_indices[:num_samples] |
|
|
for idx in selected_indices: |
|
|
img = cv2.cvtColor(images[idx], cv2.COLOR_BGR2RGB) |
|
|
gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) |
|
|
|
|
|
|
|
|
gray_blur = cv2.GaussianBlur(gray, (5, 5), 0) |
|
|
gray_eq = cv2.equalizeHist(gray_blur) |
|
|
|
|
|
|
|
|
median = np.median(gray_eq) |
|
|
lower_threshold = int(max(0, 0.66 * median)) |
|
|
upper_threshold = int(min(255, 1.33 * median)) |
|
|
edges = cv2.Canny(gray_eq, lower_threshold, upper_threshold) |
|
|
print(f"Visualization - Class: {class_name}, Image {idx}, Edge pixels: {np.sum(edges > 0)}") |
|
|
|
|
|
|
|
|
edge_hist, _ = np.histogram(edges.ravel(), bins=8, range=(0, 256), density=True) |
|
|
lbp = local_binary_pattern(gray, P=8, R=1, method='uniform') |
|
|
lbp_hist, _ = np.histogram(lbp.ravel(), bins=np.arange(0, 10), density=True) |
|
|
color_hist = [] |
|
|
for channel in range(img.shape[2]): |
|
|
hist, _ = np.histogram(img[:, :, channel], bins=8, range=(0, 256), density=True) |
|
|
color_hist.extend(hist) |
|
|
|
|
|
|
|
|
plt.figure(figsize=(5, 5)) |
|
|
plt.imshow(img) |
|
|
plt.title(f'Original ({class_name})') |
|
|
plt.axis('off') |
|
|
plt.savefig(os.path.join(class_dir, f'sample_{idx}_original.png'), |
|
|
dpi=120, bbox_inches='tight') |
|
|
plt.close() |
|
|
|
|
|
|
|
|
plt.figure(figsize=(5, 5)) |
|
|
plt.imshow(edges, cmap='gray') |
|
|
plt.title('Canny Edge Map') |
|
|
plt.axis('off') |
|
|
plt.savefig(os.path.join(class_dir, f'sample_{idx}_edges.png'), |
|
|
dpi=120, bbox_inches='tight') |
|
|
plt.close() |
|
|
|
|
|
|
|
|
plt.figure(figsize=(8, 4)) |
|
|
bars = plt.bar(range(len(lbp_hist)), lbp_hist) |
|
|
plt.title('LBP Histogram') |
|
|
for bar in bars: |
|
|
height = bar.get_height() |
|
|
plt.text(bar.get_x() + bar.get_width()/2., height, |
|
|
f'{height:.2f}', |
|
|
ha='center', va='bottom', |
|
|
fontsize=8, fontweight='bold') |
|
|
plt.savefig(os.path.join(class_dir, f'sample_{idx}_lbp_hist.png'), |
|
|
dpi=120, bbox_inches='tight') |
|
|
plt.close() |
|
|
|
|
|
|
|
|
plt.figure(figsize=(10, 4)) |
|
|
bars = plt.bar(range(len(color_hist)), color_hist) |
|
|
plt.title('Color Histogram') |
|
|
for bar in bars: |
|
|
height = bar.get_height() |
|
|
plt.text(bar.get_x() + bar.get_width()/2., height, |
|
|
f'{height:.2f}', |
|
|
ha='center', va='bottom', |
|
|
fontsize=8, fontweight='bold') |
|
|
plt.savefig(os.path.join(class_dir, f'sample_{idx}_color_hist.png'), |
|
|
dpi=120, bbox_inches='tight') |
|
|
plt.close() |
|
|
|
|
|
|
|
|
plt.figure(figsize=(8, 4)) |
|
|
bars = plt.bar(range(len(edge_hist)), edge_hist) |
|
|
plt.title('Edge Histogram') |
|
|
for bar in bars: |
|
|
height = bar.get_height() |
|
|
plt.text(bar.get_x() + bar.get_width()/2., height, |
|
|
f'{height:.2f}', |
|
|
ha='center', va='bottom', |
|
|
fontsize=8, fontweight='bold') |
|
|
plt.savefig(os.path.join(class_dir, f'sample_{idx}_edge_hist.png'), |
|
|
dpi=120, bbox_inches='tight') |
|
|
plt.close() |
|
|
|
|
|
print(f"Saved separate visualizations for class {class_name} sample {idx} in: {class_dir}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def extract_handcrafted_features(image): |
|
|
if isinstance(image, torch.Tensor): |
|
|
image = image.cpu().numpy() |
|
|
image = np.transpose(image, (1, 2, 0)) |
|
|
image = (image * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406])) * 255 |
|
|
image = image.astype(np.uint8) |
|
|
|
|
|
gray = cv2.cvtColor(image, cv2.COLOR_RGB2GRAY) |
|
|
|
|
|
gray_blur = cv2.GaussianBlur(gray, (5, 5), 0) |
|
|
gray_eq = cv2.equalizeHist(gray_blur) |
|
|
|
|
|
median = np.median(gray_eq) |
|
|
lower_threshold = int(max(0, 0.66 * median)) |
|
|
upper_threshold = int(min(255, 1.33 * median)) |
|
|
|
|
|
edges = cv2.Canny(gray_eq, lower_threshold, upper_threshold) |
|
|
print(f"Edge detection stats - Lower threshold: {lower_threshold}, Upper threshold: {upper_threshold}, Edge pixels: {np.sum(edges > 0)}") |
|
|
|
|
|
edge_hist, _ = np.histogram(edges.ravel(), bins=8, range=(0, 256), density=True) |
|
|
|
|
|
lbp = local_binary_pattern(gray, P=8, R=1, method='uniform') |
|
|
lbp_hist, _ = np.histogram(lbp.ravel(), bins=np.arange(0, 10), density=True) |
|
|
|
|
|
color_hist = [] |
|
|
for channel in range(image.shape[2]): |
|
|
hist, _ = np.histogram(image[:, :, channel], bins=8, range=(0, 256), density=True) |
|
|
color_hist.extend(hist) |
|
|
|
|
|
features = np.concatenate([lbp_hist, color_hist, edge_hist]) |
|
|
return torch.tensor(features, dtype=torch.float32) |
|
|
|
|
|
|
|
|
def load_images(split_path, classes, use_csv=False): |
|
|
image_data = [] |
|
|
labels = [] |
|
|
handcrafted_features = [] |
|
|
print(f"Loading images from: {split_path}") |
|
|
csv_path = os.path.join(DATASET_PATH, "labels.csv") |
|
|
if use_csv and os.path.exists(csv_path): |
|
|
print("Found labels.csv, loading dataset from CSV") |
|
|
df = pd.read_csv(csv_path) |
|
|
print("CSV columns:", df.columns) |
|
|
for idx, row in df.iterrows(): |
|
|
img_path = os.path.join(DATASET_PATH, row['image_path']) |
|
|
label_name = row['label'] |
|
|
if label_name not in CLASSES: |
|
|
print(f"Warning: Label {label_name} not in CLASSES, skipping") |
|
|
continue |
|
|
label = CLASSES.index(label_name) |
|
|
try: |
|
|
img = Image.open(img_path).convert('RGB') |
|
|
img_array = np.array(img) |
|
|
img_array = cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR) |
|
|
features = extract_handcrafted_features(img_array) |
|
|
except Exception as e: |
|
|
print(f"Warning: Failed to load image {img_path}: {e}") |
|
|
continue |
|
|
image_data.append(img_array) |
|
|
labels.append(label) |
|
|
handcrafted_features.append(features) |
|
|
else: |
|
|
if not os.path.exists(split_path): |
|
|
print(f"Error: Directory {split_path} does not exist") |
|
|
return image_data, labels, handcrafted_features |
|
|
for class_idx, class_name in enumerate(classes): |
|
|
class_path = os.path.join(split_path, class_name) |
|
|
print(f"Checking class: {class_name} at {class_path}") |
|
|
if not os.path.exists(class_path): |
|
|
print(f"Warning: Class directory {class_path} does not exist") |
|
|
continue |
|
|
all_files = glob(os.path.join(class_path, '*')) |
|
|
print(f"All files in {class_path}: {all_files[:5]}") |
|
|
image_paths = glob(os.path.join(class_path, '*.[jJ][pP][gG]')) + \ |
|
|
glob(os.path.join(class_path, '*.[jJ][pP][eE][gG]')) + \ |
|
|
glob(os.path.join(class_path, '*.png')) |
|
|
print(f"Found {len(image_paths)} images for class {class_name}") |
|
|
for img_path in image_paths: |
|
|
try: |
|
|
img = Image.open(img_path).convert('RGB') |
|
|
img_array = np.array(img) |
|
|
img_array = cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR) |
|
|
features = extract_handcrafted_features(img_array) |
|
|
except Exception as e: |
|
|
print(f"Warning: Failed to load image {img_path}: {e}") |
|
|
continue |
|
|
image_data.append(img_array) |
|
|
labels.append(class_idx) |
|
|
handcrafted_features.append(features) |
|
|
print(f"Total images loaded: {len(image_data)}") |
|
|
return image_data, labels, handcrafted_features |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def visualize_data_distribution(split_path, split_name, classes): |
|
|
class_counts = [] |
|
|
for class_name in classes: |
|
|
class_path = os.path.join(split_path, class_name) |
|
|
image_paths = glob(os.path.join(class_path, '*.[jJ][pP][gG]')) + \ |
|
|
glob(os.path.join(class_path, '*.[jJ][pP][eE][gG]')) + \ |
|
|
glob(os.path.join(class_path, '*.png')) |
|
|
class_counts.append(len(image_paths)) |
|
|
print(f"Split: {split_name}, Class: {class_name}, Number of images: {len(image_paths)}") |
|
|
|
|
|
plt.figure(figsize=(10, 6)) |
|
|
bars = plt.bar(classes, class_counts) |
|
|
plt.title(f'{split_name} Dataset Distribution') |
|
|
plt.xlabel('Classes') |
|
|
plt.ylabel('Number of Images') |
|
|
plt.xticks(rotation=45, ha='right') |
|
|
|
|
|
|
|
|
for bar in bars: |
|
|
height = bar.get_height() |
|
|
plt.text(bar.get_x() + bar.get_width()/2., height, |
|
|
f'{height}', |
|
|
ha='center', va='bottom', |
|
|
fontsize=10, fontweight='bold') |
|
|
|
|
|
plt.tight_layout() |
|
|
plt.savefig(f"/kaggle/working/visualizations/{split_name.lower()}_distribution.png") |
|
|
plt.close() |
|
|
|
|
|
|
|
|
|
|
|
class FootUlcerDataset(Dataset): |
|
|
def __init__(self, images, labels, handcrafted_features): |
|
|
self.images = images |
|
|
self.labels = labels |
|
|
self.handcrafted_features = handcrafted_features |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.images) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
image = self.images[idx] |
|
|
label = self.labels[idx] |
|
|
features = self.handcrafted_features[idx] |
|
|
image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) |
|
|
if label in [2, 3]: |
|
|
transform = transforms.Compose([ |
|
|
transforms.RandomHorizontalFlip(p=0.5), |
|
|
transforms.RandomRotation(30), |
|
|
transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2), |
|
|
transforms.RandomResizedCrop(224, scale=(0.8, 1.0)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
|
]) |
|
|
else: |
|
|
transform = transforms.Compose([ |
|
|
transforms.RandomHorizontalFlip(p=0.5), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
|
]) |
|
|
image = transform(image) |
|
|
return image, label, features |
|
|
|
|
|
|
|
|
def print_model_summary(model, input_size=(3, 224, 224), handcrafted_feature_dim=41): |
|
|
device = next(model.parameters()).device |
|
|
model.eval() |
|
|
print("\nModel Summary:") |
|
|
print("=" * 80) |
|
|
print(f"{'Layer':<30} {'Output Shape':<25} {'Param #':<15}") |
|
|
print("-" * 80) |
|
|
|
|
|
total_params = 0 |
|
|
x = torch.randn(1, *input_size).to(device) |
|
|
handcrafted_features = torch.randn(1, handcrafted_feature_dim).to(device) |
|
|
|
|
|
def register_hook(module, input, output): |
|
|
nonlocal total_params |
|
|
class_name = str(module.__class__.__name__) |
|
|
param_count = sum(p.numel() for p in module.parameters()) |
|
|
total_params += param_count |
|
|
output_shape = list(output.shape) if isinstance(output, torch.Tensor) else "N/A" |
|
|
print(f"{class_name:<30} {str(output_shape):<25} {param_count:<15}") |
|
|
|
|
|
hooks = [] |
|
|
for name, module in model.named_modules(): |
|
|
if module != model: |
|
|
hooks.append(module.register_forward_hook(register_hook)) |
|
|
|
|
|
with torch.no_grad(): |
|
|
model(x, handcrafted_features) |
|
|
|
|
|
for hook in hooks: |
|
|
hook.remove() |
|
|
|
|
|
print("-" * 80) |
|
|
print(f"Total Parameters: {total_params:,}") |
|
|
print("=" * 80) |
|
|
|
|
|
|
|
|
def plot_roc_curves(labels, probabilities, split_name, classes, model_idx=None): |
|
|
plt.figure(figsize=(10, 8)) |
|
|
for i, class_name in enumerate(classes): |
|
|
fpr, tpr, _ = roc_curve(np.array(labels) == i, probabilities[:, i]) |
|
|
roc_auc = auc(fpr, tpr) |
|
|
plt.plot(fpr, tpr, label=f'{class_name} (AUC = {roc_auc:.2f})') |
|
|
plt.plot([0, 1], [0, 1], 'k--') |
|
|
plt.xlim([0.0, 1.0]) |
|
|
plt.ylim([0.0, 1.05]) |
|
|
plt.xlabel('False Positive Rate') |
|
|
plt.ylabel('True Positive Rate') |
|
|
plt.title(f'ROC Curves - {split_name}' + (f' (Model {model_idx})' if model_idx is not None else '')) |
|
|
plt.legend(loc='lower right') |
|
|
plt.grid(True) |
|
|
filename = f'/kaggle/working/visualizations/roc_{split_name.lower()}' + (f'_model_{model_idx}.png' if model_idx is not None else '_ensemble.png') |
|
|
plt.savefig(filename) |
|
|
plt.close() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def visualize_feature_extraction(model, dataloader, device, classes, num_samples_per_class=1): |
|
|
model.eval() |
|
|
feature_maps = {} |
|
|
layer_names = ['features', 'ccdgs', 'triplet_attention', 'se_block', 'fc1', 'fc2'] |
|
|
|
|
|
print("\nModel structure (named modules):") |
|
|
for name, module in model.named_modules(): |
|
|
print(f"Layer: {name}, Module: {type(module).__name__}") |
|
|
|
|
|
print("\nRegistering forward hooks for layers:", layer_names) |
|
|
|
|
|
def get_hook(name): |
|
|
def hook(module, input, output): |
|
|
feature_maps[name] = output.detach() |
|
|
print(f"Captured output for {name}, shape: {output.shape}") |
|
|
return hook |
|
|
|
|
|
hooks = [] |
|
|
for name in layer_names: |
|
|
module = getattr(model, name, None) |
|
|
if module: |
|
|
hooks.append(module.register_forward_hook(get_hook(name))) |
|
|
print(f"Hook registered for {name}") |
|
|
else: |
|
|
print(f"Warning: Layer {name} not found in model") |
|
|
|
|
|
|
|
|
class_samples = {class_idx: [] for class_idx in range(len(classes))} |
|
|
with torch.no_grad(): |
|
|
for images, labels, features in dataloader: |
|
|
images, labels, features = images.to(device), labels.to(device), features.to(device) |
|
|
outputs = model(images, features) |
|
|
probs = F.softmax(outputs, dim=1) |
|
|
|
|
|
for i in range(len(images)): |
|
|
class_idx = labels[i].item() |
|
|
if len(class_samples[class_idx]) < num_samples_per_class: |
|
|
class_samples[class_idx].append(( |
|
|
images[i].cpu().numpy(), |
|
|
labels[i].cpu().numpy(), |
|
|
probs[i].cpu().numpy(), |
|
|
features[i].cpu().numpy() |
|
|
)) |
|
|
|
|
|
|
|
|
if all(len(samples) >= num_samples_per_class for samples in class_samples.values()): |
|
|
break |
|
|
|
|
|
print(f"Removing {len(hooks)} hooks") |
|
|
for hook in hooks: |
|
|
hook.remove() |
|
|
|
|
|
print(f"Feature maps captured: {list(feature_maps.keys())}") |
|
|
|
|
|
|
|
|
for class_idx in range(len(classes)): |
|
|
if not class_samples[class_idx]: |
|
|
print(f"No samples found for class {class_idx} ({classes[class_idx]})") |
|
|
continue |
|
|
|
|
|
|
|
|
img, label, prob, features = class_samples[class_idx][0] |
|
|
true_label = classes[label] |
|
|
|
|
|
|
|
|
img = img.transpose(1, 2, 0) |
|
|
img = img * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406]) |
|
|
img = np.clip(img, 0, 1) |
|
|
|
|
|
plt.figure(figsize=(5, 5)) |
|
|
plt.imshow(img) |
|
|
plt.title(f'Input Image (Class: {true_label})') |
|
|
plt.axis('off') |
|
|
input_img_path = f'/kaggle/working/visualizations/class_{class_idx}_input_image.png' |
|
|
plt.savefig(input_img_path) |
|
|
plt.close() |
|
|
print(f"Saved input image for class {true_label} to: {input_img_path}") |
|
|
|
|
|
|
|
|
for layer_name in layer_names: |
|
|
if layer_name not in feature_maps: |
|
|
print(f"No feature map for {layer_name}, skipping visualization") |
|
|
continue |
|
|
|
|
|
features = feature_maps[layer_name][class_idx] |
|
|
print(f"Visualizing {layer_name} for class {true_label}, feature shape: {features.shape}") |
|
|
|
|
|
if features.dim() == 3: |
|
|
num_channels = min(features.shape[0], 16) |
|
|
plt.figure(figsize=(15, 10)) |
|
|
for i in range(num_channels): |
|
|
plt.subplot(4, 4, i + 1) |
|
|
feature_map = features[i].cpu().numpy() |
|
|
feature_map = (feature_map - feature_map.min()) / (feature_map.max() - feature_map.min() + 1e-8) |
|
|
plt.imshow(feature_map, cmap='viridis') |
|
|
plt.title(f'Channel {i+1}') |
|
|
plt.axis('off') |
|
|
plt.suptitle(f'Feature Maps - {layer_name} (Class: {true_label})') |
|
|
plt.tight_layout(rect=[0, 0, 1, 0.95]) |
|
|
feature_map_path = f'/kaggle/working/visualizations/class_{class_idx}_feature_maps_{layer_name}.png' |
|
|
plt.savefig(feature_map_path) |
|
|
plt.close() |
|
|
print(f"Saved {num_channels} feature maps for {layer_name} to: {feature_map_path}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
else: |
|
|
values = features.flatten().cpu().numpy() |
|
|
num_features = len(values) |
|
|
|
|
|
|
|
|
fig_width = max(20, num_features * 0.025) |
|
|
plt.figure(figsize=(fig_width, 5)) |
|
|
|
|
|
|
|
|
bars = plt.bar( |
|
|
range(num_features), |
|
|
values, |
|
|
color='#1f77b4', |
|
|
edgecolor='#1f77b4', |
|
|
linewidth=0.05, |
|
|
width=0.9, |
|
|
align='center' |
|
|
) |
|
|
|
|
|
|
|
|
if num_features > 100: |
|
|
ticks = list(range(0, num_features, 50)) + [num_features-1] |
|
|
plt.xticks(ticks) |
|
|
|
|
|
plt.title(f'Feature Vector - {layer_name} (Class: {true_label})') |
|
|
plt.xlabel('Feature') |
|
|
plt.ylabel('Activation Value') |
|
|
plt.grid(True, linestyle=':', alpha=0.5) |
|
|
plt.tight_layout() |
|
|
vector_path = f'/kaggle/working/visualizations/class_{class_idx}_feature_vector_{layer_name}.png' |
|
|
plt.savefig(vector_path, dpi=120, bbox_inches='tight', facecolor='white') |
|
|
plt.close() |
|
|
print(f"Saved feature vector with {len(values)} elements for {layer_name} to: {vector_path}") |
|
|
|
|
|
|
|
|
plt.figure(figsize=(8, 6)) |
|
|
bars = plt.bar(classes, prob) |
|
|
plt.title(f'Classification Probabilities (True: {true_label})') |
|
|
plt.xlabel('Classes') |
|
|
plt.ylabel('Probability') |
|
|
plt.xticks(rotation=45) |
|
|
|
|
|
for bar in bars: |
|
|
height = bar.get_height() |
|
|
plt.text(bar.get_x() + bar.get_width()/2., height, |
|
|
f'{height:.3f}', |
|
|
ha='center', va='bottom', |
|
|
fontsize=10, fontweight='bold') |
|
|
|
|
|
plt.tight_layout() |
|
|
probs_path = f'/kaggle/working/visualizations/class_{class_idx}_classification_probs.png' |
|
|
plt.savefig(probs_path) |
|
|
plt.close() |
|
|
print(f"Saved classification probabilities to: {probs_path}") |
|
|
|
|
|
print("\nListing saved visualization files:") |
|
|
os.system('ls /kaggle/working/visualizations/') |
|
|
|
|
|
|
|
|
def train_model(model, dataloader, criterion, optimizer, device, epochs=100, model_idx=0): |
|
|
history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []} |
|
|
best_val_loss = float('inf') |
|
|
best_val_acc = 0.0 |
|
|
best_train_acc = 0.0 |
|
|
patience = 10 |
|
|
counter = 0 |
|
|
scaler = torch.cuda.amp.GradScaler() |
|
|
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, verbose=True) |
|
|
train_batches = len(dataloader['train']) |
|
|
val_batches = len(dataloader['val']) |
|
|
print(f"Training dataset: {train_batches} batches") |
|
|
print(f"Validation dataset: {val_batches} batches") |
|
|
for epoch in range(epochs): |
|
|
print(f"\n--- Epoch {epoch+1}/{epochs} ---") |
|
|
model.train() |
|
|
running_loss = 0.0 |
|
|
correct, total = 0, 0 |
|
|
for batch_idx, (images, labels, features) in enumerate(dataloader['train']): |
|
|
print(f"Training epoch {epoch+1}, batch {batch_idx+1}/{train_batches}") |
|
|
images, labels, features = images.to(device), labels.to(device), features.to(device) |
|
|
optimizer.zero_grad() |
|
|
with torch.cuda.amp.autocast(): |
|
|
outputs = model(images, features) |
|
|
loss = criterion(outputs, labels) |
|
|
scaler.scale(loss).backward() |
|
|
scaler.step(optimizer) |
|
|
scaler.update() |
|
|
running_loss += loss.item() |
|
|
_, predicted = outputs.max(1) |
|
|
total += labels.size(0) |
|
|
correct += predicted.eq(labels).sum().item() |
|
|
epoch_loss = running_loss / train_batches if train_batches > 0 else 0.0 |
|
|
epoch_acc = 100. * correct / total if total > 0 else 0.0 |
|
|
history['train_loss'].append(epoch_loss) |
|
|
history['train_acc'].append(epoch_acc) |
|
|
best_train_acc = max(best_train_acc, epoch_acc) |
|
|
model.eval() |
|
|
val_loss = 0.0 |
|
|
val_correct, val_total = 0, 0 |
|
|
with torch.no_grad(): |
|
|
for batch_idx, (val_images, val_labels, val_features) in enumerate(dataloader['val']): |
|
|
print(f"Validation epoch {epoch+1}, batch {batch_idx+1}/{val_batches}") |
|
|
val_images, val_labels, val_features = val_images.to(device), val_labels.to(device), val_features.to(device) |
|
|
with torch.cuda.amp.autocast(): |
|
|
val_outputs = model(val_images, val_features) |
|
|
loss = criterion(val_outputs, val_labels) |
|
|
val_loss += loss.item() |
|
|
_, predicted = val_outputs.max(1) |
|
|
val_total += val_labels.size(0) |
|
|
val_correct += predicted.eq(val_labels).sum().item() |
|
|
val_epoch_loss = val_loss / val_batches if val_batches > 0 else 0.0 |
|
|
val_epoch_acc = 100. * val_correct / val_total if val_total > 0 else 0.0 |
|
|
history['val_loss'].append(val_epoch_loss) |
|
|
history['val_acc'].append(val_epoch_acc) |
|
|
best_val_acc = max(best_val_acc, val_epoch_acc) |
|
|
scheduler.step(val_epoch_loss) |
|
|
if val_epoch_loss < best_val_loss: |
|
|
best_val_loss = val_epoch_loss |
|
|
counter = 0 |
|
|
torch.save(model.state_dict(), f'/kaggle/working/best_model_{model_idx}.pth') |
|
|
else: |
|
|
counter += 1 |
|
|
if counter >= patience: |
|
|
print("Early stopping triggered") |
|
|
break |
|
|
print(f"Training Loss: {epoch_loss:.4f}, Accuracy: {epoch_acc:.2f}%") |
|
|
print(f"Validation Loss: {val_epoch_loss:.4f}, Accuracy: {val_epoch_acc:.2f}%") |
|
|
print(f"Best Training Accuracy (Model {model_idx}): {best_train_acc:.2f}%") |
|
|
print(f"Best Validation Accuracy (Model {model_idx}): {best_val_acc:.2f}%") |
|
|
return history, best_train_acc, best_val_acc |
|
|
|
|
|
|
|
|
def plot_training_history(history, epochs, model_idx=0): |
|
|
epochs_range = range(1, len(history['train_loss']) + 1) |
|
|
plt.figure(figsize=(12, 5)) |
|
|
plt.subplot(1, 2, 1) |
|
|
plt.plot(epochs_range, history['train_loss'], label='Training Loss') |
|
|
plt.plot(epochs_range, history['val_loss'], label='Validation Loss') |
|
|
plt.xlabel('Epochs') |
|
|
plt.ylabel('Loss') |
|
|
plt.title(f'Training and Validation Loss (Model {model_idx})') |
|
|
plt.legend() |
|
|
plt.grid(True) |
|
|
plt.subplot(1, 2, 2) |
|
|
plt.plot(epochs_range, history['train_acc'], label='Training Accuracy') |
|
|
plt.plot(epochs_range, history['val_acc'], label='Validation Accuracy') |
|
|
plt.xlabel('Epochs') |
|
|
plt.ylabel('Accuracy (%)') |
|
|
plt.title(f'Training and Validation Accuracy (Model {model_idx})') |
|
|
plt.legend() |
|
|
plt.grid(True) |
|
|
plt.tight_layout() |
|
|
plt.savefig(f'/kaggle/working/visualizations/training_history_model_{model_idx}.png') |
|
|
plt.close() |
|
|
|
|
|
|
|
|
def evaluate_model(model, dataloader, device, split_name, classes, model_idx=None, use_tta=False): |
|
|
model.eval() |
|
|
correct = 0 |
|
|
total = 0 |
|
|
all_predictions = [] |
|
|
all_labels = [] |
|
|
all_probs = [] |
|
|
mean = torch.tensor([0.485, 0.456, 0.406], device=device).view(3, 1, 1) |
|
|
std = torch.tensor([0.229, 0.224, 0.225], device=device).view(3, 1, 1) |
|
|
tta_transforms = [ |
|
|
transforms.Compose([ |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
|
]), |
|
|
transforms.Compose([ |
|
|
transforms.RandomHorizontalFlip(p=1.0), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
|
]), |
|
|
transforms.Compose([ |
|
|
transforms.RandomRotation(10), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
|
]) |
|
|
] |
|
|
with torch.no_grad(): |
|
|
for images, labels, features in dataloader: |
|
|
images, labels, features = images.to(device), labels.to(device), features.to(device) |
|
|
if use_tta: |
|
|
batch_probs = [] |
|
|
for transform in tta_transforms: |
|
|
denorm_images = images * std + mean |
|
|
denorm_images = denorm_images.clamp(0, 1) * 255 |
|
|
denorm_images = denorm_images.to(torch.uint8) |
|
|
tta_images = torch.stack([ |
|
|
transform(Image.fromarray(img.cpu().numpy().transpose(1, 2, 0))) |
|
|
for img in denorm_images |
|
|
]).to(device) |
|
|
outputs = model(tta_images, features) |
|
|
batch_probs.append(F.softmax(outputs, dim=1)) |
|
|
avg_probs = torch.stack(batch_probs).mean(dim=0) |
|
|
_, predicted = torch.max(avg_probs, 1) |
|
|
all_probs.extend(avg_probs.cpu().numpy()) |
|
|
else: |
|
|
outputs = model(images, features) |
|
|
_, predicted = torch.max(outputs.data, 1) |
|
|
all_probs.extend(F.softmax(outputs, dim=1).cpu().numpy()) |
|
|
all_predictions.extend(predicted.cpu().numpy()) |
|
|
all_labels.extend(labels.cpu().numpy()) |
|
|
total += labels.size(0) |
|
|
correct += (predicted == labels).sum().item() |
|
|
accuracy = 100 * correct / total if total > 0 else 0.0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cm = confusion_matrix(all_labels, all_predictions) |
|
|
plt.figure(figsize=(10, 8)) |
|
|
|
|
|
|
|
|
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', |
|
|
xticklabels=classes, yticklabels=classes, |
|
|
annot_kws={'size': 12, 'weight': 'bold'}) |
|
|
|
|
|
|
|
|
plt.title(f'Confusion Matrix - {split_name}' + |
|
|
(f' (Model {model_idx})' if model_idx is not None else ''), |
|
|
fontsize=14, fontweight='bold') |
|
|
|
|
|
plt.xlabel('Predicted', fontsize=12, fontweight='bold') |
|
|
plt.ylabel('True', fontsize=12, fontweight='bold') |
|
|
|
|
|
plt.tight_layout() |
|
|
filename = f'/kaggle/working/visualizations/cm_{split_name.lower()}' + (f'_model_{model_idx}.png' if model_idx is not None else '_ensemble.png') |
|
|
plt.savefig(filename) |
|
|
plt.close() |
|
|
report = classification_report(all_labels, all_predictions, target_names=classes, output_dict=True) |
|
|
report_df = pd.DataFrame(report).transpose() |
|
|
report_filename = f'/kaggle/working/classification_report_{split_name.lower()}' + (f'_model_{model_idx}.csv' if model_idx is not None else '_ensemble.csv') |
|
|
report_df.to_csv(report_filename) |
|
|
all_probs = np.array(all_probs) |
|
|
plot_roc_curves(all_labels, all_probs, split_name, classes, model_idx) |
|
|
return accuracy, all_predictions, all_labels, all_probs, report_df |
|
|
|
|
|
|
|
|
def ensemble_voting(models, dataloader, device, split_name, classes): |
|
|
all_predictions = [] |
|
|
all_labels = [] |
|
|
all_probs = [] |
|
|
for model in models: |
|
|
model.eval() |
|
|
with torch.no_grad(): |
|
|
for images, labels, features in dataloader: |
|
|
images, labels, features = images.to(device), labels.to(device), features.to(device) |
|
|
votes = [] |
|
|
probs = [] |
|
|
for model in models: |
|
|
outputs = model(images, features) |
|
|
_, predicted = torch.max(outputs.data, 1) |
|
|
votes.append(predicted.cpu().numpy()) |
|
|
probs.append(F.softmax(outputs, dim=1).cpu().numpy()) |
|
|
votes = np.array(votes) |
|
|
final_predictions = np.apply_along_axis(lambda x: np.bincount(x).argmax(), axis=0, arr=votes) |
|
|
avg_probs = np.mean(probs, axis=0) |
|
|
all_predictions.extend(final_predictions) |
|
|
all_labels.extend(labels.cpu().numpy()) |
|
|
all_probs.extend(avg_probs) |
|
|
accuracy = 100 * sum(np.array(all_predictions) == np.array(all_labels)) / len(all_labels) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
cm = confusion_matrix(all_labels, all_predictions) |
|
|
plt.figure(figsize=(10, 8)) |
|
|
|
|
|
|
|
|
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', |
|
|
xticklabels=classes, yticklabels=classes, |
|
|
annot_kws={'size': 12, 'weight': 'bold'}) |
|
|
|
|
|
plt.title(f'Confusion Matrix - {split_name} (Ensemble)', fontsize=14, fontweight='bold') |
|
|
|
|
|
plt.xlabel('Predicted', fontsize=12, fontweight='bold') |
|
|
plt.ylabel('True', fontsize=12, fontweight='bold') |
|
|
|
|
|
|
|
|
plt.tight_layout() |
|
|
plt.savefig(f'/kaggle/working/visualizations/cm_{split_name.lower()}_ensemble.png') |
|
|
plt.close() |
|
|
report = classification_report(all_labels, all_predictions, target_names=classes, output_dict=True) |
|
|
report_df = pd.DataFrame(report).transpose() |
|
|
report_df.to_csv(f'/kaggle/working/classification_report_{split_name.lower()}_ensemble.csv') |
|
|
all_probs = np.array(all_probs) |
|
|
plot_roc_curves(all_labels, all_probs, f'{split_name} (Ensemble)', classes) |
|
|
return accuracy, all_predictions, all_labels, all_probs, report_df |
|
|
|
|
|
|
|
|
def visualize_voting_process(models, dataloader, device, classes, num_samples=5): |
|
|
model_predictions = [] |
|
|
true_labels = [] |
|
|
images_list = [] |
|
|
with torch.no_grad(): |
|
|
for images, labels, features in dataloader: |
|
|
images, labels, features = images.to(device), labels.to(device), features.to(device) |
|
|
batch_preds = [] |
|
|
for model in models: |
|
|
model.eval() |
|
|
outputs = model(images, features) |
|
|
_, predicted = torch.max(outputs.data, 1) |
|
|
batch_preds.append(predicted.cpu().numpy()) |
|
|
model_predictions.extend(np.array(batch_preds).T) |
|
|
true_labels.extend(labels.cpu().numpy()) |
|
|
images_list.extend(images.cpu().numpy()) |
|
|
if len(true_labels) >= num_samples: |
|
|
break |
|
|
model_predictions = model_predictions[:num_samples] |
|
|
true_labels = true_labels[:num_samples] |
|
|
images_list = images_list[:num_samples] |
|
|
plt.figure(figsize=(15, num_samples * 3)) |
|
|
for i in range(num_samples): |
|
|
img = images_list[i].transpose(1, 2, 0) |
|
|
img = img * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406]) |
|
|
img = np.clip(img, 0, 1) |
|
|
plt.subplot(num_samples, 1, i + 1) |
|
|
plt.imshow(img) |
|
|
preds = [CLASSES[p] for p in model_predictions[i]] |
|
|
ensemble_pred = CLASSES[np.bincount(model_predictions[i]).argmax()] |
|
|
title = f'True: {CLASSES[true_labels[i]]}\n' + \ |
|
|
f'Model 1: {preds[0]}, Model 2: {preds[1]}, Model 3: {preds[2]}\n' + \ |
|
|
f'Ensemble: {ensemble_pred}' |
|
|
plt.title(title) |
|
|
plt.axis('off') |
|
|
plt.tight_layout() |
|
|
plt.savefig('/kaggle/working/visualizations/voting_process.png') |
|
|
plt.close() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def visualize_predictions_per_class(model, dataloader, device, classes, split_name, model_idx=None, num_samples=4): |
|
|
model.eval() |
|
|
class_images = {i: [] for i in range(len(classes))} |
|
|
class_preds = {i: [] for i in range(len(classes))} |
|
|
class_labels = {i: [] for i in range(len(classes))} |
|
|
with torch.no_grad(): |
|
|
for images, labels, features in dataloader: |
|
|
images, labels, features = images.to(device), labels.to(device), features.to(device) |
|
|
outputs = model(images, features) |
|
|
_, predicted = torch.max(outputs.data, 1) |
|
|
for img, pred, label in zip(images.cpu().numpy(), predicted.cpu().numpy(), labels.cpu().numpy()): |
|
|
if len(class_images[label]) < num_samples: |
|
|
class_images[label].append(img) |
|
|
class_preds[label].append(pred) |
|
|
class_labels[label].append(label) |
|
|
if all(len(class_images[i]) >= num_samples for i in range(len(classes))): |
|
|
break |
|
|
for class_idx, class_name in enumerate(classes): |
|
|
plt.figure(figsize=(15, 5)) |
|
|
for i in range(min(num_samples, len(class_images[class_idx]))): |
|
|
img = class_images[class_idx][i].transpose(1, 2, 0) |
|
|
img = img * np.array([0.229, 0.224, 0.225]) + np.array([0.485, 0.456, 0.406]) |
|
|
img = np.clip(img, 0, 1) |
|
|
plt.subplot(1, num_samples, i + 1) |
|
|
plt.imshow(img) |
|
|
plt.title(f'True: {class_name}\nPred: {CLASSES[class_preds[class_idx][i]]}', |
|
|
fontweight='bold') |
|
|
plt.axis('off') |
|
|
|
|
|
plt.suptitle(f'Predictions for {class_name} ({split_name})', |
|
|
fontweight='bold', |
|
|
fontsize=12) |
|
|
plt.tight_layout(rect=[0, 0, 1, 0.95]) |
|
|
filename = f'/kaggle/working/visualizations/predictions_{class_name}_{split_name.lower()}' + (f'_model_{model_idx}.png' if model_idx is not None else '_ensemble.png') |
|
|
plt.savefig(filename) |
|
|
plt.close() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
print(f"Using device: {device}") |
|
|
|
|
|
|
|
|
print("Checking dataset directory structure:") |
|
|
for dirname, _, filenames in os.walk(DATASET_PATH): |
|
|
print(f"Directory: {dirname}, Files: {len(filenames)}") |
|
|
for filename in filenames[:5]: |
|
|
print(f" - {os.path.join(dirname, filename)}") |
|
|
|
|
|
|
|
|
csv_path = os.path.join(DATASET_PATH, "labels.csv") |
|
|
use_csv = os.path.exists(csv_path) |
|
|
if use_csv: |
|
|
print("Detected labels.csv, will load dataset from CSV") |
|
|
else: |
|
|
print("No labels.csv found, assuming directory-based structure") |
|
|
|
|
|
|
|
|
train_path = os.path.join(DATASET_PATH, "TRAIN") |
|
|
val_path = os.path.join(DATASET_PATH, "VALIDATION") |
|
|
test_path = os.path.join(DATASET_PATH, "TEST") |
|
|
|
|
|
train_images, train_labels, train_features = load_images(train_path, CLASSES, use_csv) |
|
|
val_images, val_labels, val_features = load_images(val_path, CLASSES, use_csv) |
|
|
test_images, test_labels, test_features = load_images(test_path, CLASSES, use_csv) |
|
|
|
|
|
|
|
|
if not train_images: |
|
|
raise ValueError("Training dataset is empty. Please check the dataset path, class names, image files, or CSV structure.") |
|
|
if not val_images: |
|
|
print("Warning: Validation dataset is empty. Creating validation split from training data.") |
|
|
train_images, val_images, train_labels, val_labels, train_features, val_features = train_test_split( |
|
|
train_images, train_labels, train_features, test_size=0.2, stratify=train_labels, random_state=42 |
|
|
) |
|
|
if not test_images: |
|
|
print("Warning: Test dataset is empty.") |
|
|
test_images, test_labels, test_features = [], [], [] |
|
|
|
|
|
|
|
|
visualize_data_distribution(train_path, "Train", CLASSES) |
|
|
visualize_data_distribution(val_path, "Validation", CLASSES) |
|
|
visualize_data_distribution(test_path, "Test", CLASSES) |
|
|
display_sample_images(train_images, train_labels, "Train", CLASSES) |
|
|
display_sample_images(val_images, val_labels, "Validation", CLASSES) |
|
|
display_sample_images(test_images, test_labels, "Test", CLASSES) |
|
|
visualize_handcrafted_features(train_images, train_labels, CLASSES) |
|
|
|
|
|
|
|
|
train_dataset = FootUlcerDataset(train_images, train_labels, train_features) |
|
|
val_dataset = FootUlcerDataset(val_images, val_labels, val_features) |
|
|
test_dataset = FootUlcerDataset(test_images, test_labels, test_features) |
|
|
|
|
|
|
|
|
train_labels_np = np.array(train_labels) |
|
|
class_counts = np.array([sum(train_labels_np == i) for i in range(len(CLASSES))]) |
|
|
print(f"Class counts: {dict(zip(CLASSES, class_counts))}") |
|
|
if np.any(class_counts == 0): |
|
|
print("Warning: Some classes have zero samples in the training set.") |
|
|
class_weights = 1.0 / (class_counts + 1e-6) |
|
|
sample_weights = class_weights[train_labels_np] |
|
|
sampler = WeightedRandomSampler(sample_weights, len(train_labels), replacement=True) |
|
|
|
|
|
|
|
|
batch_size = 32 |
|
|
dataloader = { |
|
|
'train': DataLoader(train_dataset, batch_size=batch_size, sampler=sampler), |
|
|
'val': DataLoader(val_dataset, batch_size=batch_size, shuffle=False), |
|
|
'test': DataLoader(test_dataset, batch_size=batch_size, shuffle=False) |
|
|
} |
|
|
|
|
|
|
|
|
num_models = 3 |
|
|
models_list = [] |
|
|
test_accuracies = [] |
|
|
best_train_accuracies = [] |
|
|
best_val_accuracies = [] |
|
|
|
|
|
for i in range(num_models): |
|
|
print(f"\nTraining Model {i+1}/{num_models}") |
|
|
model = DenseShuffleGCANet(num_classes=len(CLASSES), handcrafted_feature_dim=41).to(device) |
|
|
print(f"\nModel {i+1} Summary:") |
|
|
print_model_summary(model, input_size=(3, 224, 224), handcrafted_feature_dim=41) |
|
|
|
|
|
criterion = FocalLoss(gamma=3.0, alpha=0.5) |
|
|
optimizer = optim.Adam(model.parameters(), lr=0.00005, weight_decay=0.001) |
|
|
history, best_train_acc, best_val_acc = train_model(model, dataloader, criterion, optimizer, device, epochs=100, model_idx=i) |
|
|
plot_training_history(history, len(history['train_loss']), model_idx=i) |
|
|
best_train_accuracies.append(best_train_acc) |
|
|
best_val_accuracies.append(best_val_acc) |
|
|
|
|
|
print(f"\nEvaluating Model {i+1} on Training Set") |
|
|
train_acc, train_preds, train_labels, _, _ = evaluate_model(model, dataloader['train'], device, 'Train', CLASSES, i) |
|
|
print(f"Model {i+1} Train Accuracy: {train_acc:.2f}%") |
|
|
|
|
|
print(f"\nEvaluating Model {i+1} on Validation Set") |
|
|
val_acc, val_preds, val_labels, _, _ = evaluate_model(model, dataloader['val'], device, 'Validation', CLASSES, i) |
|
|
print(f"Model {i+1} Validation Accuracy: {val_acc:.2f}%") |
|
|
|
|
|
print(f"\nEvaluating Model {i+1} on Test Set") |
|
|
test_acc, test_preds, test_labels, _, _ = evaluate_model(model, dataloader['test'], device, 'Test', CLASSES, i) |
|
|
print(f"Model {i+1} Test Accuracy: {test_acc:.2f}%") |
|
|
test_accuracies.append(test_acc) |
|
|
|
|
|
visualize_predictions_per_class(model, dataloader['test'], device, CLASSES, 'Test', model_idx=i) |
|
|
if i == 0: |
|
|
print(f"\nVisualizing Feature Extraction for Model {i+1}") |
|
|
visualize_feature_extraction(model, dataloader['test'], device, CLASSES, num_samples_per_class=1) |
|
|
models_list.append(model) |
|
|
|
|
|
|
|
|
print("\nEvaluating Ensemble on Training Set") |
|
|
ensemble_train_acc, _, _, _, _ = ensemble_voting(models_list, dataloader['train'], device, 'Train', CLASSES) |
|
|
print(f"Ensemble Train Accuracy: {ensemble_train_acc:.2f}%") |
|
|
|
|
|
print("\nEvaluating Ensemble on Validation Set") |
|
|
ensemble_val_acc, _, _, _, _ = ensemble_voting(models_list, dataloader['val'], device, 'Validation', CLASSES) |
|
|
print(f"Ensemble Validation Accuracy: {ensemble_val_acc:.2f}%") |
|
|
|
|
|
print("\nEvaluating Ensemble on Test Set") |
|
|
ensemble_test_acc, ensemble_test_preds, ensemble_test_labels, _, _ = ensemble_voting(models_list, dataloader['test'], device, 'Test', CLASSES) |
|
|
print(f"Ensemble Test Accuracy: {ensemble_test_acc:.2f}%") |
|
|
|
|
|
visualize_predictions_per_class(models_list[0], dataloader['test'], device, CLASSES, 'Test', model_idx=None) |
|
|
visualize_voting_process(models_list, dataloader['test'], device, CLASSES) |
|
|
|
|
|
|
|
|
print("\nEvaluating Best Model with TTA on Test Set") |
|
|
tta_acc, _, _, _, _ = evaluate_model(models_list[0], dataloader['test'], device, 'Test_TTA', CLASSES, model_idx=0, use_tta=True) |
|
|
print(f"TTA Test Accuracy: {tta_acc:.2f}%") |
|
|
|
|
|
|
|
|
print("\nStatistical Analysis of Model Performance:") |
|
|
print(f"Mean Test Accuracy: {np.mean(test_accuracies):.2f}% ± {np.std(test_accuracies):.2f}%") |
|
|
print(f"Best Training Accuracies: {[f'{acc:.2f}%' for acc in best_train_accuracies]}") |
|
|
print(f"Best Validation Accuracies: {[f'{acc:.2f}%' for acc in best_val_accuracies]}") |
|
|
|
|
|
|
|
|
predictions_df = pd.DataFrame({ |
|
|
'True_Label': [CLASSES[label] for label in ensemble_test_labels], |
|
|
'Predicted_Label': [CLASSES[pred] for pred in ensemble_test_preds] |
|
|
}) |
|
|
predictions_df.to_csv('/kaggle/working/predictions/test_predictions_ensemble.csv', index=False) |
|
|
print("Predictions saved to /kaggle/working/predictions/test_predictions_ensemble.csv") |
|
|
|
|
|
|
|
|
summary_report = { |
|
|
'Model': [f'Model {i+1}' for i in range(num_models)] + ['Ensemble', 'TTA'], |
|
|
'Test_Accuracy': test_accuracies + [ensemble_test_acc, tta_acc], |
|
|
'Best_Train_Accuracy': best_train_accuracies + [None, None], |
|
|
'Best_Val_Accuracy': best_val_accuracies + [None, None] |
|
|
} |
|
|
summary_df = pd.DataFrame(summary_report) |
|
|
summary_df.to_csv('/kaggle/working/summary_report.csv', index=False) |
|
|
print("\nSummary Report:") |
|
|
print(summary_df) |