Spaces:
Running
Running
Upload 4 files
Browse files- art_trainer-mixup.py +234 -0
- model_evaluator.py +325 -0
- model_evaluator_kfold.py +379 -0
- trainer.py +556 -0
art_trainer-mixup.py
ADDED
|
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.optim as optim
|
| 8 |
+
from torch.utils.data import DataLoader, Dataset
|
| 9 |
+
from torchvision import transforms as T
|
| 10 |
+
from torchvision.transforms import v2
|
| 11 |
+
from PIL import Image
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
from tqdm.auto import tqdm
|
| 14 |
+
import random
|
| 15 |
+
import numpy as np
|
| 16 |
+
|
| 17 |
+
# A.1. Check device availability and setup MPS optimizations
|
| 18 |
+
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
|
| 19 |
+
torch.set_float32_matmul_precision('high') # MPS performance optimization
|
| 20 |
+
|
| 21 |
+
# Hyperparameters (Tested optimal values)
|
| 22 |
+
CFG = {
|
| 23 |
+
'img_size': 224,
|
| 24 |
+
'batch_size': 32,
|
| 25 |
+
'lr': 3e-5, # Lower learning rate
|
| 26 |
+
'weight_decay': 0.05, # Stronger L2 regularization
|
| 27 |
+
'dropout': 0.5, # Increased dropout
|
| 28 |
+
'epochs': 30,
|
| 29 |
+
'mixup_alpha': 0.4,
|
| 30 |
+
'cutmix_prob': 0.3,
|
| 31 |
+
'label_smoothing': 0.15,
|
| 32 |
+
'patience': 5 # For early stopping
|
| 33 |
+
}
|
| 34 |
+
|
| 35 |
+
# A.2.4. Define data transformations with advanced augmentation pipeline
|
| 36 |
+
def create_transforms():
|
| 37 |
+
return {
|
| 38 |
+
'train': v2.Compose([
|
| 39 |
+
# A word on presizing:
|
| 40 |
+
# 1. Increase the size (item by item)
|
| 41 |
+
v2.RandomResizedCrop(CFG['img_size'], scale=(0.6, 1.0)),
|
| 42 |
+
# 2. Apply augmentation (batch by batch)
|
| 43 |
+
v2.RandomHorizontalFlip(p=0.7),
|
| 44 |
+
v2.RandomVerticalFlip(p=0.3),
|
| 45 |
+
v2.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.3),
|
| 46 |
+
v2.RandomRotation(35),
|
| 47 |
+
v2.RandomAffine(degrees=0, translate=(0.2, 0.2)),
|
| 48 |
+
v2.RandomPerspective(distortion_scale=0.4, p=0.6),
|
| 49 |
+
v2.GaussianBlur(kernel_size=(5, 9)),
|
| 50 |
+
v2.RandomSolarize(threshold=0.3, p=0.2),
|
| 51 |
+
v2.ToTensor(),
|
| 52 |
+
# 3. Decrease the size (batch by batch)
|
| 53 |
+
v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
| 54 |
+
v2.RandomErasing(p=0.5, scale=(0.02, 0.2), value='random')
|
| 55 |
+
]),
|
| 56 |
+
'val': v2.Compose([
|
| 57 |
+
v2.Resize(CFG['img_size'] + 32),
|
| 58 |
+
v2.CenterCrop(CFG['img_size']),
|
| 59 |
+
v2.ToTensor(),
|
| 60 |
+
v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 61 |
+
])
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
# A.2.2. Define the means of getting data into DataBlock
|
| 65 |
+
class ArtDataset(Dataset):
|
| 66 |
+
def __init__(self, data_dir, transform=None):
|
| 67 |
+
self.classes = sorted([d.name for d in Path(data_dir).iterdir() if d.is_dir()])
|
| 68 |
+
self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)}
|
| 69 |
+
self.samples = []
|
| 70 |
+
for cls in self.classes:
|
| 71 |
+
cls_dir = Path(data_dir) / cls
|
| 72 |
+
for img_path in cls_dir.glob('*'):
|
| 73 |
+
self.samples.append((img_path, self.class_to_idx[cls]))
|
| 74 |
+
self.transform = transform
|
| 75 |
+
|
| 76 |
+
def __len__(self):
|
| 77 |
+
return len(self.samples)
|
| 78 |
+
|
| 79 |
+
def __getitem__(self, idx):
|
| 80 |
+
img_path, label = self.samples[idx]
|
| 81 |
+
img = Image.open(img_path).convert('RGB')
|
| 82 |
+
if self.transform:
|
| 83 |
+
img = self.transform(img)
|
| 84 |
+
return img, label
|
| 85 |
+
|
| 86 |
+
# B.4. Implement mixup data augmentation - part of discriminative learning rates
|
| 87 |
+
def mixup_data(x, y, alpha=1.0):
|
| 88 |
+
if alpha > 0:
|
| 89 |
+
lam = np.random.beta(alpha, alpha)
|
| 90 |
+
else:
|
| 91 |
+
lam = 1
|
| 92 |
+
batch_size = x.size()[0]
|
| 93 |
+
index = torch.randperm(batch_size).to(device)
|
| 94 |
+
mixed_x = lam * x + (1 - lam) * x[index, :]
|
| 95 |
+
y_a, y_b = y, y[index]
|
| 96 |
+
return mixed_x, y_a, y_b, lam
|
| 97 |
+
|
| 98 |
+
# A.4. Define training step
|
| 99 |
+
def train_step(model, data_loader, criterion, optimizer):
|
| 100 |
+
model.train()
|
| 101 |
+
total_loss = 0
|
| 102 |
+
correct = 0
|
| 103 |
+
|
| 104 |
+
for inputs, targets in tqdm(data_loader, desc='Training', leave=False):
|
| 105 |
+
inputs, targets = inputs.to(device), targets.to(device)
|
| 106 |
+
|
| 107 |
+
# B.4. Advanced Mixup - part of discriminative learning rates
|
| 108 |
+
inputs, targets_a, targets_b, lam = mixup_data(inputs, targets, CFG['mixup_alpha'])
|
| 109 |
+
|
| 110 |
+
optimizer.zero_grad()
|
| 111 |
+
outputs = model(inputs)
|
| 112 |
+
loss = criterion(outputs, targets_a) * lam + criterion(outputs, targets_b) * (1 - lam)
|
| 113 |
+
|
| 114 |
+
loss.backward()
|
| 115 |
+
nn.utils.clip_grad_norm_(model.parameters(), 1.0) # Gradient clipping
|
| 116 |
+
optimizer.step()
|
| 117 |
+
|
| 118 |
+
total_loss += loss.item()
|
| 119 |
+
_, predicted = outputs.max(1)
|
| 120 |
+
correct += (lam * predicted.eq(targets_a).sum().item() +
|
| 121 |
+
(1 - lam) * predicted.eq(targets_b).sum().item())
|
| 122 |
+
|
| 123 |
+
acc = 100. * correct / len(data_loader.dataset)
|
| 124 |
+
avg_loss = total_loss / len(data_loader)
|
| 125 |
+
return avg_loss, acc
|
| 126 |
+
|
| 127 |
+
# A.3. Define validation step to inspect the DataBlock
|
| 128 |
+
def validate(model, data_loader, criterion):
|
| 129 |
+
model.eval()
|
| 130 |
+
total_loss = 0
|
| 131 |
+
correct = 0
|
| 132 |
+
|
| 133 |
+
with torch.no_grad():
|
| 134 |
+
for inputs, targets in tqdm(data_loader, desc='Validation', leave=False):
|
| 135 |
+
inputs, targets = inputs.to(device), targets.to(device)
|
| 136 |
+
outputs = model(inputs)
|
| 137 |
+
loss = criterion(outputs, targets)
|
| 138 |
+
|
| 139 |
+
total_loss += loss.item()
|
| 140 |
+
_, predicted = outputs.max(1)
|
| 141 |
+
correct += predicted.eq(targets).sum().item()
|
| 142 |
+
|
| 143 |
+
acc = 100. * correct / len(data_loader.dataset)
|
| 144 |
+
avg_loss = total_loss / len(data_loader)
|
| 145 |
+
return avg_loss, acc
|
| 146 |
+
|
| 147 |
+
def main():
|
| 148 |
+
# A.1. Load data
|
| 149 |
+
transforms = create_transforms()
|
| 150 |
+
|
| 151 |
+
# Set directory paths according to your structure
|
| 152 |
+
art_dataset_dir = 'Art Dataset'
|
| 153 |
+
|
| 154 |
+
# A.2.1. Define the blocks (dataset creation)
|
| 155 |
+
train_dataset = ArtDataset(art_dataset_dir, transform=transforms['train'])
|
| 156 |
+
val_dataset = ArtDataset(art_dataset_dir, transform=transforms['val'])
|
| 157 |
+
|
| 158 |
+
# A.2.2. Create data loaders
|
| 159 |
+
train_loader = DataLoader(train_dataset, batch_size=CFG['batch_size'],
|
| 160 |
+
shuffle=True, num_workers=4, pin_memory=True)
|
| 161 |
+
val_loader = DataLoader(val_dataset, batch_size=CFG['batch_size'],
|
| 162 |
+
num_workers=4, pin_memory=True)
|
| 163 |
+
|
| 164 |
+
# B.3. Transfer Learning - Load model
|
| 165 |
+
model_path = 'models/model_final.pth'
|
| 166 |
+
|
| 167 |
+
# Load model state dictionary
|
| 168 |
+
state_dict = torch.load(model_path)
|
| 169 |
+
|
| 170 |
+
# Create ResNet34 model
|
| 171 |
+
from torchvision import models
|
| 172 |
+
model = models.resnet34(weights=None)
|
| 173 |
+
|
| 174 |
+
# Number of classes
|
| 175 |
+
num_classes = len(train_dataset.classes)
|
| 176 |
+
|
| 177 |
+
# B.3. Update the final fully-connected layer
|
| 178 |
+
model.fc = nn.Linear(512, num_classes)
|
| 179 |
+
|
| 180 |
+
# Load state dictionary
|
| 181 |
+
model.load_state_dict(state_dict)
|
| 182 |
+
model = model.to(device)
|
| 183 |
+
|
| 184 |
+
# B.6. Model Capacity - Measures to prevent overfitting
|
| 185 |
+
for name, module in model.named_modules():
|
| 186 |
+
if isinstance(module, nn.Dropout):
|
| 187 |
+
module.p = CFG['dropout'] # Increase dropout rate
|
| 188 |
+
|
| 189 |
+
# B.1. Learning Rate Finder - Optimizer and Loss setup
|
| 190 |
+
optimizer = optim.AdamW(model.parameters(), lr=CFG['lr'],
|
| 191 |
+
weight_decay=CFG['weight_decay'])
|
| 192 |
+
criterion = nn.CrossEntropyLoss(label_smoothing=CFG['label_smoothing'])
|
| 193 |
+
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer,
|
| 194 |
+
T_0=10, T_mult=2)
|
| 195 |
+
|
| 196 |
+
# Create results directory
|
| 197 |
+
results_dir = 'results'
|
| 198 |
+
os.makedirs(results_dir, exist_ok=True)
|
| 199 |
+
|
| 200 |
+
# B.5. Early Stopping - Deciding the Number of Training Epochs
|
| 201 |
+
best_val_acc = 0
|
| 202 |
+
patience_counter = 0
|
| 203 |
+
|
| 204 |
+
# A.4. Train a simple model
|
| 205 |
+
for epoch in range(CFG['epochs']):
|
| 206 |
+
print(f"\nEpoch {epoch+1}/{CFG['epochs']}")
|
| 207 |
+
|
| 208 |
+
# Training
|
| 209 |
+
train_loss, train_acc = train_step(model, train_loader, criterion, optimizer)
|
| 210 |
+
# Validation
|
| 211 |
+
val_loss, val_acc = validate(model, val_loader, criterion)
|
| 212 |
+
|
| 213 |
+
# Learning rate update
|
| 214 |
+
scheduler.step()
|
| 215 |
+
|
| 216 |
+
# Monitor results
|
| 217 |
+
print(f"Train Loss: {train_loss:.4f} | Acc: {train_acc:.2f}%")
|
| 218 |
+
print(f"Val Loss: {val_loss:.4f} | Acc: {val_acc:.2f}%")
|
| 219 |
+
|
| 220 |
+
# B.5. Early stopping check
|
| 221 |
+
if val_acc > best_val_acc:
|
| 222 |
+
best_val_acc = val_acc
|
| 223 |
+
patience_counter = 0
|
| 224 |
+
best_model_path = os.path.join(results_dir, 'best_model.pth')
|
| 225 |
+
torch.save(model.state_dict(), best_model_path)
|
| 226 |
+
print(f"New best model saved ({val_acc:.2f}%)")
|
| 227 |
+
else:
|
| 228 |
+
patience_counter += 1
|
| 229 |
+
if patience_counter >= CFG['patience']:
|
| 230 |
+
print(f"Early stopping! No improvement for {CFG['patience']} epochs.")
|
| 231 |
+
break
|
| 232 |
+
|
| 233 |
+
if __name__ == "__main__":
|
| 234 |
+
main()
|
model_evaluator.py
ADDED
|
@@ -0,0 +1,325 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from torch.utils.data import Dataset, DataLoader, random_split
|
| 6 |
+
from torchvision import models, transforms
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from PIL import Image
|
| 9 |
+
import matplotlib.pyplot as plt
|
| 10 |
+
import seaborn as sns
|
| 11 |
+
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, classification_report
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
import pandas as pd
|
| 14 |
+
import random
|
| 15 |
+
from collections import defaultdict
|
| 16 |
+
|
| 17 |
+
# MPS (Metal Performance Shaders) check - Apple GPU
|
| 18 |
+
if torch.backends.mps.is_available():
|
| 19 |
+
DEVICE = torch.device("mps")
|
| 20 |
+
print(f"Using Metal GPU: {DEVICE}")
|
| 21 |
+
else:
|
| 22 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 23 |
+
print(f"Metal GPU not found, using device: {DEVICE}")
|
| 24 |
+
|
| 25 |
+
# Constants
|
| 26 |
+
IMG_SIZE = 224
|
| 27 |
+
BATCH_SIZE = 64 # Batch size increased for GPU
|
| 28 |
+
NUM_WORKERS = 6 # Number of threads increased
|
| 29 |
+
MAX_SAMPLES_PER_CLASS = 30 # Maximum number of samples per class (for quick testing)
|
| 30 |
+
|
| 31 |
+
# Transformation for test dataset
|
| 32 |
+
test_transform = transforms.Compose([
|
| 33 |
+
transforms.Resize(IMG_SIZE + 32),
|
| 34 |
+
transforms.CenterCrop(IMG_SIZE),
|
| 35 |
+
transforms.ToTensor(),
|
| 36 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 37 |
+
])
|
| 38 |
+
|
| 39 |
+
class ArtDataset(Dataset):
|
| 40 |
+
def __init__(self, samples, transform=None, class_to_idx=None):
|
| 41 |
+
self.samples = samples
|
| 42 |
+
self.transform = transform
|
| 43 |
+
|
| 44 |
+
if class_to_idx is None:
|
| 45 |
+
# Extract classes from samples
|
| 46 |
+
classes = set([Path(str(s[0])).parent.name for s in samples])
|
| 47 |
+
self.classes = sorted(list(classes))
|
| 48 |
+
self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)}
|
| 49 |
+
else:
|
| 50 |
+
self.class_to_idx = class_to_idx
|
| 51 |
+
self.classes = sorted(class_to_idx.keys(), key=lambda x: class_to_idx[x])
|
| 52 |
+
|
| 53 |
+
def __len__(self):
|
| 54 |
+
return len(self.samples)
|
| 55 |
+
|
| 56 |
+
def __getitem__(self, idx):
|
| 57 |
+
img_path, class_name = self.samples[idx]
|
| 58 |
+
label = self.class_to_idx[class_name]
|
| 59 |
+
img = Image.open(img_path).convert('RGB')
|
| 60 |
+
if self.transform:
|
| 61 |
+
img = self.transform(img)
|
| 62 |
+
return img, label
|
| 63 |
+
|
| 64 |
+
def create_test_set(data_dir, test_ratio=0.2, max_per_class=None):
|
| 65 |
+
"""Create test set by taking a certain percentage of samples from each class"""
|
| 66 |
+
class_samples = defaultdict(list)
|
| 67 |
+
|
| 68 |
+
# Collect all examples by their classes
|
| 69 |
+
for class_dir in Path(data_dir).iterdir():
|
| 70 |
+
if class_dir.is_dir():
|
| 71 |
+
class_name = class_dir.name
|
| 72 |
+
for img_path in class_dir.glob('*'):
|
| 73 |
+
class_samples[class_name].append((img_path, class_name))
|
| 74 |
+
|
| 75 |
+
# Select a certain percentage and maximum number of examples from each class
|
| 76 |
+
test_samples = []
|
| 77 |
+
for class_name, samples in class_samples.items():
|
| 78 |
+
random.shuffle(samples)
|
| 79 |
+
n_test = max(1, int(len(samples) * test_ratio))
|
| 80 |
+
|
| 81 |
+
# Limit the maximum number of examples
|
| 82 |
+
if max_per_class and n_test > max_per_class:
|
| 83 |
+
n_test = max_per_class
|
| 84 |
+
|
| 85 |
+
test_samples.extend(samples[:n_test])
|
| 86 |
+
|
| 87 |
+
print(f"Total of {len(test_samples)} test samples selected from {len(class_samples)} different art movements.")
|
| 88 |
+
|
| 89 |
+
# Create class-index mapping
|
| 90 |
+
classes = sorted(class_samples.keys())
|
| 91 |
+
class_to_idx = {cls: i for i, cls in enumerate(classes)}
|
| 92 |
+
|
| 93 |
+
return test_samples, class_to_idx
|
| 94 |
+
|
| 95 |
+
def load_model(model_path, num_classes):
|
| 96 |
+
"""Load model file"""
|
| 97 |
+
print(f"Loading model: {model_path}")
|
| 98 |
+
# Create ResNet34 model
|
| 99 |
+
model = models.resnet34(weights=None)
|
| 100 |
+
# Update the last fully-connected layer
|
| 101 |
+
model.fc = nn.Linear(512, num_classes)
|
| 102 |
+
|
| 103 |
+
# Special loading for Metal GPU availability check
|
| 104 |
+
state_dict = torch.load(model_path, map_location=DEVICE)
|
| 105 |
+
model.load_state_dict(state_dict)
|
| 106 |
+
model = model.to(DEVICE)
|
| 107 |
+
model.eval()
|
| 108 |
+
|
| 109 |
+
return model
|
| 110 |
+
|
| 111 |
+
def evaluate_model(model, test_loader, classes):
|
| 112 |
+
"""Evaluate model and return metrics"""
|
| 113 |
+
all_preds = []
|
| 114 |
+
all_labels = []
|
| 115 |
+
|
| 116 |
+
with torch.no_grad():
|
| 117 |
+
for inputs, labels in tqdm(test_loader, desc="Evaluation"):
|
| 118 |
+
inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
|
| 119 |
+
|
| 120 |
+
# Run directly on MPS device (without using autocast)
|
| 121 |
+
outputs = model(inputs)
|
| 122 |
+
|
| 123 |
+
_, preds = torch.max(outputs, 1)
|
| 124 |
+
|
| 125 |
+
# Move results to CPU
|
| 126 |
+
all_preds.extend(preds.cpu().numpy())
|
| 127 |
+
all_labels.extend(labels.cpu().numpy())
|
| 128 |
+
|
| 129 |
+
# Calculate metrics
|
| 130 |
+
accuracy = accuracy_score(all_labels, all_preds)
|
| 131 |
+
f1 = f1_score(all_labels, all_preds, average='weighted')
|
| 132 |
+
precision = precision_score(all_labels, all_preds, average='weighted')
|
| 133 |
+
recall = recall_score(all_labels, all_preds, average='weighted')
|
| 134 |
+
|
| 135 |
+
# Class-based accuracy
|
| 136 |
+
class_accuracy = {}
|
| 137 |
+
conf_matrix = confusion_matrix(all_labels, all_preds)
|
| 138 |
+
|
| 139 |
+
for i, class_name in enumerate(classes):
|
| 140 |
+
class_samples = np.sum(np.array(all_labels) == i)
|
| 141 |
+
class_correct = conf_matrix[i, i]
|
| 142 |
+
if class_samples > 0:
|
| 143 |
+
class_accuracy[class_name] = class_correct / class_samples
|
| 144 |
+
|
| 145 |
+
results = {
|
| 146 |
+
'accuracy': accuracy,
|
| 147 |
+
'f1_score': f1,
|
| 148 |
+
'precision': precision,
|
| 149 |
+
'recall': recall,
|
| 150 |
+
'class_accuracy': class_accuracy,
|
| 151 |
+
'confusion_matrix': conf_matrix,
|
| 152 |
+
'predictions': all_preds,
|
| 153 |
+
'ground_truth': all_labels
|
| 154 |
+
}
|
| 155 |
+
|
| 156 |
+
return results
|
| 157 |
+
|
| 158 |
+
def plot_confusion_matrix(conf_matrix, classes, model_name, save_dir):
|
| 159 |
+
"""Plot confusion matrix graph"""
|
| 160 |
+
plt.figure(figsize=(12, 10))
|
| 161 |
+
sns.heatmap(conf_matrix, annot=False, fmt='d', cmap='Blues',
|
| 162 |
+
xticklabels=classes, yticklabels=classes)
|
| 163 |
+
plt.xlabel('Predicted Class')
|
| 164 |
+
plt.ylabel('True Class')
|
| 165 |
+
plt.title(f'Confusion Matrix - {model_name}')
|
| 166 |
+
plt.tight_layout()
|
| 167 |
+
|
| 168 |
+
# Save the graph
|
| 169 |
+
save_path = Path(save_dir) / f"conf_matrix_{Path(model_name).stem}.png"
|
| 170 |
+
plt.savefig(save_path, dpi=300)
|
| 171 |
+
plt.close()
|
| 172 |
+
|
| 173 |
+
def plot_class_accuracy(class_acc, model_name, save_dir):
|
| 174 |
+
"""Plot class-based accuracy graph"""
|
| 175 |
+
plt.figure(figsize=(14, 8))
|
| 176 |
+
|
| 177 |
+
# Sort classes by accuracy value
|
| 178 |
+
sorted_items = sorted(class_acc.items(), key=lambda x: x[1], reverse=True)
|
| 179 |
+
classes = [item[0] for item in sorted_items]
|
| 180 |
+
accuracies = [item[1] for item in sorted_items]
|
| 181 |
+
|
| 182 |
+
bars = plt.bar(classes, accuracies)
|
| 183 |
+
plt.xlabel('Art Movement')
|
| 184 |
+
plt.ylabel('Accuracy')
|
| 185 |
+
plt.title(f'Class-Based Accuracy - {model_name}')
|
| 186 |
+
plt.xticks(rotation=90)
|
| 187 |
+
plt.ylim(0, 1.0)
|
| 188 |
+
|
| 189 |
+
# Add values on top of bars
|
| 190 |
+
for bar in bars:
|
| 191 |
+
height = bar.get_height()
|
| 192 |
+
plt.text(bar.get_x() + bar.get_width()/2., height,
|
| 193 |
+
f'{height:.2f}', ha='center', va='bottom', rotation=0)
|
| 194 |
+
|
| 195 |
+
plt.tight_layout()
|
| 196 |
+
|
| 197 |
+
# Save the graph
|
| 198 |
+
save_path = Path(save_dir) / f"class_accuracy_{Path(model_name).stem}.png"
|
| 199 |
+
plt.savefig(save_path, dpi=300)
|
| 200 |
+
plt.close()
|
| 201 |
+
|
| 202 |
+
def plot_model_comparison(all_results, save_dir):
|
| 203 |
+
"""Plot model comparison graph"""
|
| 204 |
+
model_names = list(all_results.keys())
|
| 205 |
+
metrics = ['accuracy', 'f1_score', 'precision', 'recall']
|
| 206 |
+
|
| 207 |
+
# Collect metrics
|
| 208 |
+
metric_data = {metric: [all_results[model][metric] for model in model_names] for metric in metrics}
|
| 209 |
+
|
| 210 |
+
# Compare metrics
|
| 211 |
+
plt.figure(figsize=(12, 7))
|
| 212 |
+
x = np.arange(len(model_names))
|
| 213 |
+
width = 0.2
|
| 214 |
+
multiplier = 0
|
| 215 |
+
|
| 216 |
+
for metric, values in metric_data.items():
|
| 217 |
+
offset = width * multiplier
|
| 218 |
+
bars = plt.bar(x + offset, values, width, label=metric)
|
| 219 |
+
|
| 220 |
+
# Add values on top of bars
|
| 221 |
+
for bar in bars:
|
| 222 |
+
height = bar.get_height()
|
| 223 |
+
plt.annotate(f'{height:.3f}',
|
| 224 |
+
xy=(bar.get_x() + bar.get_width() / 2, height),
|
| 225 |
+
xytext=(0, 3), # 3 points vertical offset
|
| 226 |
+
textcoords="offset points",
|
| 227 |
+
ha='center', va='bottom')
|
| 228 |
+
|
| 229 |
+
multiplier += 1
|
| 230 |
+
|
| 231 |
+
plt.xlabel('Model')
|
| 232 |
+
plt.ylabel('Score')
|
| 233 |
+
plt.title('Model Performance Comparison')
|
| 234 |
+
plt.xticks(x + width, model_names)
|
| 235 |
+
plt.legend(loc='lower right')
|
| 236 |
+
plt.ylim(0, 1.0)
|
| 237 |
+
|
| 238 |
+
plt.tight_layout()
|
| 239 |
+
|
| 240 |
+
# Save the graph
|
| 241 |
+
save_path = Path(save_dir) / "model_comparison.png"
|
| 242 |
+
plt.savefig(save_path, dpi=300)
|
| 243 |
+
plt.close()
|
| 244 |
+
|
| 245 |
+
def main():
|
| 246 |
+
# Data directory and results directory
|
| 247 |
+
art_dataset_dir = 'Art Dataset'
|
| 248 |
+
models_dir = 'models'
|
| 249 |
+
results_dir = 'evaluation_results'
|
| 250 |
+
|
| 251 |
+
# Create results directory
|
| 252 |
+
os.makedirs(results_dir, exist_ok=True)
|
| 253 |
+
|
| 254 |
+
# Create test data - limit maximum number of examples from each class
|
| 255 |
+
test_samples, class_to_idx = create_test_set(art_dataset_dir, test_ratio=0.2, max_per_class=MAX_SAMPLES_PER_CLASS)
|
| 256 |
+
test_dataset = ArtDataset(test_samples, transform=test_transform, class_to_idx=class_to_idx)
|
| 257 |
+
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, pin_memory=True)
|
| 258 |
+
|
| 259 |
+
classes = test_dataset.classes
|
| 260 |
+
num_classes = len(classes)
|
| 261 |
+
print(f"Art classes: {len(classes)}")
|
| 262 |
+
|
| 263 |
+
# Find model files (exclude files like .DS_Store)
|
| 264 |
+
model_paths = [os.path.join(models_dir, f) for f in os.listdir(models_dir)
|
| 265 |
+
if f.endswith('.pth') and not f.startswith('.')]
|
| 266 |
+
|
| 267 |
+
# Dictionary to store results
|
| 268 |
+
all_results = {}
|
| 269 |
+
|
| 270 |
+
# Evaluate each model
|
| 271 |
+
for model_path in model_paths:
|
| 272 |
+
model_name = Path(model_path).name
|
| 273 |
+
print(f"\nEvaluating {model_name}...")
|
| 274 |
+
|
| 275 |
+
# Load model
|
| 276 |
+
model = load_model(model_path, num_classes)
|
| 277 |
+
|
| 278 |
+
# Evaluate model
|
| 279 |
+
results = evaluate_model(model, test_loader, classes)
|
| 280 |
+
all_results[model_name] = results
|
| 281 |
+
|
| 282 |
+
print(f"Accuracy: {results['accuracy']:.4f}")
|
| 283 |
+
print(f"F1 Score: {results['f1_score']:.4f}")
|
| 284 |
+
print(f"Precision: {results['precision']:.4f}")
|
| 285 |
+
print(f"Recall: {results['recall']:.4f}")
|
| 286 |
+
|
| 287 |
+
# Plot confusion matrix graph
|
| 288 |
+
plot_confusion_matrix(results['confusion_matrix'], classes, model_name, results_dir)
|
| 289 |
+
|
| 290 |
+
# Plot class-based accuracy graph
|
| 291 |
+
plot_class_accuracy(results['class_accuracy'], model_name, results_dir)
|
| 292 |
+
|
| 293 |
+
# Save detailed class report
|
| 294 |
+
report = classification_report(results['ground_truth'], results['predictions'],
|
| 295 |
+
target_names=classes, output_dict=True)
|
| 296 |
+
report_df = pd.DataFrame(report).transpose()
|
| 297 |
+
report_df.to_csv(f"{results_dir}/classification_report_{Path(model_name).stem}.csv")
|
| 298 |
+
|
| 299 |
+
# Compare models
|
| 300 |
+
if len(all_results) > 1:
|
| 301 |
+
plot_model_comparison(all_results, results_dir)
|
| 302 |
+
|
| 303 |
+
# Save results to CSV file
|
| 304 |
+
results_summary = []
|
| 305 |
+
for model_name, results in all_results.items():
|
| 306 |
+
row = {
|
| 307 |
+
'model': model_name,
|
| 308 |
+
'accuracy': results['accuracy'],
|
| 309 |
+
'f1_score': results['f1_score'],
|
| 310 |
+
'precision': results['precision'],
|
| 311 |
+
'recall': results['recall']
|
| 312 |
+
}
|
| 313 |
+
results_summary.append(row)
|
| 314 |
+
|
| 315 |
+
summary_df = pd.DataFrame(results_summary)
|
| 316 |
+
summary_df.to_csv(f"{results_dir}/model_comparison_summary.csv", index=False)
|
| 317 |
+
|
| 318 |
+
print(f"\nEvaluation completed. Results are in '{results_dir}' directory.")
|
| 319 |
+
|
| 320 |
+
if __name__ == "__main__":
|
| 321 |
+
# Set seed for reproducibility
|
| 322 |
+
random.seed(42)
|
| 323 |
+
np.random.seed(42)
|
| 324 |
+
torch.manual_seed(42)
|
| 325 |
+
main()
|
model_evaluator_kfold.py
ADDED
|
@@ -0,0 +1,379 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
from torch.utils.data import Dataset, DataLoader, Subset
|
| 6 |
+
from torchvision import models, transforms
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from PIL import Image
|
| 9 |
+
import matplotlib.pyplot as plt
|
| 10 |
+
import seaborn as sns
|
| 11 |
+
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, classification_report
|
| 12 |
+
from sklearn.model_selection import KFold
|
| 13 |
+
from tqdm import tqdm
|
| 14 |
+
import pandas as pd
|
| 15 |
+
import random
|
| 16 |
+
from collections import defaultdict
|
| 17 |
+
|
| 18 |
+
# MPS (Metal Performance Shaders) kontrolü - Apple GPU
|
| 19 |
+
if torch.backends.mps.is_available():
|
| 20 |
+
DEVICE = torch.device("mps")
|
| 21 |
+
print(f"Metal GPU kullanılıyor: {DEVICE}")
|
| 22 |
+
else:
|
| 23 |
+
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 24 |
+
print(f"Metal GPU bulunamadı, şu cihaz kullanılıyor: {DEVICE}")
|
| 25 |
+
|
| 26 |
+
# Sabit değerler
|
| 27 |
+
IMG_SIZE = 224
|
| 28 |
+
BATCH_SIZE = 64
|
| 29 |
+
NUM_WORKERS = 6
|
| 30 |
+
MAX_SAMPLES_PER_CLASS = 20 # Her sınıftan maksimum örnek sayısı (hızlı test için)
|
| 31 |
+
K_FOLDS = 5 # 5-fold cross validation
|
| 32 |
+
|
| 33 |
+
# Test veri seti için dönüşüm
|
| 34 |
+
test_transform = transforms.Compose([
|
| 35 |
+
transforms.Resize(IMG_SIZE + 32),
|
| 36 |
+
transforms.CenterCrop(IMG_SIZE),
|
| 37 |
+
transforms.ToTensor(),
|
| 38 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 39 |
+
])
|
| 40 |
+
|
| 41 |
+
class ArtDataset(Dataset):
|
| 42 |
+
def __init__(self, samples, transform=None, class_to_idx=None):
|
| 43 |
+
self.samples = samples
|
| 44 |
+
self.transform = transform
|
| 45 |
+
|
| 46 |
+
if class_to_idx is None:
|
| 47 |
+
# Sınıfları örneklerden çıkar
|
| 48 |
+
classes = set([Path(str(s[0])).parent.name for s in samples])
|
| 49 |
+
self.classes = sorted(list(classes))
|
| 50 |
+
self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)}
|
| 51 |
+
else:
|
| 52 |
+
self.class_to_idx = class_to_idx
|
| 53 |
+
self.classes = sorted(class_to_idx.keys(), key=lambda x: class_to_idx[x])
|
| 54 |
+
|
| 55 |
+
def __len__(self):
|
| 56 |
+
return len(self.samples)
|
| 57 |
+
|
| 58 |
+
def __getitem__(self, idx):
|
| 59 |
+
img_path, class_name = self.samples[idx]
|
| 60 |
+
label = self.class_to_idx[class_name]
|
| 61 |
+
img = Image.open(img_path).convert('RGB')
|
| 62 |
+
if self.transform:
|
| 63 |
+
img = self.transform(img)
|
| 64 |
+
return img, label
|
| 65 |
+
|
| 66 |
+
def create_balanced_dataset(data_dir, max_per_class=None):
|
| 67 |
+
"""Her sınıftan eşit sayıda örnek içeren dengeli bir veri seti oluştur"""
|
| 68 |
+
class_samples = defaultdict(list)
|
| 69 |
+
|
| 70 |
+
# Tüm örnekleri sınıflarına göre topla
|
| 71 |
+
for class_dir in Path(data_dir).iterdir():
|
| 72 |
+
if class_dir.is_dir():
|
| 73 |
+
class_name = class_dir.name
|
| 74 |
+
for img_path in class_dir.glob('*'):
|
| 75 |
+
class_samples[class_name].append((img_path, class_name))
|
| 76 |
+
|
| 77 |
+
# Her sınıftan maksimum sayıda örnek seç
|
| 78 |
+
balanced_samples = []
|
| 79 |
+
for class_name, samples in class_samples.items():
|
| 80 |
+
random.shuffle(samples)
|
| 81 |
+
|
| 82 |
+
# Maksimum örnek sayısını sınırla
|
| 83 |
+
if max_per_class and len(samples) > max_per_class:
|
| 84 |
+
samples = samples[:max_per_class]
|
| 85 |
+
|
| 86 |
+
balanced_samples.extend(samples)
|
| 87 |
+
|
| 88 |
+
print(f"Toplam {len(balanced_samples)} örnek, {len(class_samples)} farklı sanat akımından seçildi.")
|
| 89 |
+
|
| 90 |
+
# Sınıf-indeks eşleştirmesini oluştur
|
| 91 |
+
classes = sorted(class_samples.keys())
|
| 92 |
+
class_to_idx = {cls: i for i, cls in enumerate(classes)}
|
| 93 |
+
|
| 94 |
+
return balanced_samples, class_to_idx
|
| 95 |
+
|
| 96 |
+
def load_model(model_path, num_classes):
|
| 97 |
+
"""Model dosyasını yükle"""
|
| 98 |
+
print(f"Model yükleniyor: {model_path}")
|
| 99 |
+
# ResNet34 modelini oluştur
|
| 100 |
+
model = models.resnet34(weights=None)
|
| 101 |
+
# Son fully-connected katmanını güncelle
|
| 102 |
+
model.fc = nn.Linear(512, num_classes)
|
| 103 |
+
|
| 104 |
+
# Metal GPU kullanılabilirliği kontrolü için özel yükleme
|
| 105 |
+
state_dict = torch.load(model_path, map_location=DEVICE)
|
| 106 |
+
model.load_state_dict(state_dict)
|
| 107 |
+
model = model.to(DEVICE)
|
| 108 |
+
model.eval()
|
| 109 |
+
|
| 110 |
+
return model
|
| 111 |
+
|
| 112 |
+
def evaluate_model(model, test_loader, classes):
|
| 113 |
+
"""Modeli değerlendir ve metrikleri döndür"""
|
| 114 |
+
all_preds = []
|
| 115 |
+
all_labels = []
|
| 116 |
+
|
| 117 |
+
with torch.no_grad():
|
| 118 |
+
for inputs, labels in tqdm(test_loader, desc="Değerlendirme", leave=False):
|
| 119 |
+
inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
|
| 120 |
+
|
| 121 |
+
# MPS cihazında çalıştır
|
| 122 |
+
outputs = model(inputs)
|
| 123 |
+
|
| 124 |
+
_, preds = torch.max(outputs, 1)
|
| 125 |
+
|
| 126 |
+
# Sonuçları CPU'ya taşı
|
| 127 |
+
all_preds.extend(preds.cpu().numpy())
|
| 128 |
+
all_labels.extend(labels.cpu().numpy())
|
| 129 |
+
|
| 130 |
+
# Temel metrikleri hesapla - uyarıları engellemek için zero_division=1 parametresi eklendi
|
| 131 |
+
accuracy = accuracy_score(all_labels, all_preds)
|
| 132 |
+
f1 = f1_score(all_labels, all_preds, average='weighted', zero_division=1)
|
| 133 |
+
precision = precision_score(all_labels, all_preds, average='weighted', zero_division=1)
|
| 134 |
+
recall = recall_score(all_labels, all_preds, average='weighted', zero_division=1)
|
| 135 |
+
|
| 136 |
+
# Sınıf bazında doğruluk
|
| 137 |
+
class_accuracy = {}
|
| 138 |
+
conf_matrix = confusion_matrix(all_labels, all_preds)
|
| 139 |
+
|
| 140 |
+
for i, class_name in enumerate(classes):
|
| 141 |
+
class_samples = np.sum(np.array(all_labels) == i)
|
| 142 |
+
class_correct = conf_matrix[i, i] if i < len(conf_matrix) else 0
|
| 143 |
+
if class_samples > 0:
|
| 144 |
+
class_accuracy[class_name] = class_correct / class_samples
|
| 145 |
+
|
| 146 |
+
results = {
|
| 147 |
+
'accuracy': accuracy,
|
| 148 |
+
'f1_score': f1,
|
| 149 |
+
'precision': precision,
|
| 150 |
+
'recall': recall,
|
| 151 |
+
'class_accuracy': class_accuracy,
|
| 152 |
+
'confusion_matrix': conf_matrix,
|
| 153 |
+
'predictions': all_preds,
|
| 154 |
+
'ground_truth': all_labels
|
| 155 |
+
}
|
| 156 |
+
|
| 157 |
+
return results
|
| 158 |
+
|
| 159 |
+
def k_fold_cross_validation(dataset, model_paths, num_classes, k=5):
|
| 160 |
+
"""K-fold cross validation ile modelleri değerlendir"""
|
| 161 |
+
|
| 162 |
+
# K-fold nesnesi oluştur
|
| 163 |
+
kfold = KFold(n_splits=k, shuffle=True, random_state=42)
|
| 164 |
+
|
| 165 |
+
# Her model için sonuçları sakla
|
| 166 |
+
all_model_results = {}
|
| 167 |
+
for model_path in model_paths:
|
| 168 |
+
model_name = Path(model_path).name
|
| 169 |
+
all_model_results[model_name] = {
|
| 170 |
+
'fold_results': [],
|
| 171 |
+
'accuracy': [],
|
| 172 |
+
'f1_score': [],
|
| 173 |
+
'precision': [],
|
| 174 |
+
'recall': []
|
| 175 |
+
}
|
| 176 |
+
|
| 177 |
+
# K-fold cross validation
|
| 178 |
+
for fold, (_, test_indices) in enumerate(kfold.split(dataset)):
|
| 179 |
+
print(f"\nFold {fold+1}/{k} değerlendiriliyor...")
|
| 180 |
+
|
| 181 |
+
# Test veri setini oluştur
|
| 182 |
+
test_subset = Subset(dataset, test_indices)
|
| 183 |
+
test_loader = DataLoader(test_subset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, pin_memory=True)
|
| 184 |
+
|
| 185 |
+
# Her model için değerlendirme yap
|
| 186 |
+
for model_path in model_paths:
|
| 187 |
+
model_name = Path(model_path).name
|
| 188 |
+
print(f" {model_name} değerlendiriliyor...")
|
| 189 |
+
|
| 190 |
+
# Modeli yükle
|
| 191 |
+
model = load_model(model_path, num_classes)
|
| 192 |
+
|
| 193 |
+
# Modeli değerlendir
|
| 194 |
+
results = evaluate_model(model, test_loader, dataset.classes)
|
| 195 |
+
|
| 196 |
+
# Sonuçları kaydet
|
| 197 |
+
all_model_results[model_name]['fold_results'].append(results)
|
| 198 |
+
all_model_results[model_name]['accuracy'].append(results['accuracy'])
|
| 199 |
+
all_model_results[model_name]['f1_score'].append(results['f1_score'])
|
| 200 |
+
all_model_results[model_name]['precision'].append(results['precision'])
|
| 201 |
+
all_model_results[model_name]['recall'].append(results['recall'])
|
| 202 |
+
|
| 203 |
+
print(f" Fold {fold+1} - Doğruluk: {results['accuracy']:.4f}, F1: {results['f1_score']:.4f}")
|
| 204 |
+
|
| 205 |
+
# Her model için ortalama sonuçları hesapla
|
| 206 |
+
summary_results = {}
|
| 207 |
+
for model_name, results in all_model_results.items():
|
| 208 |
+
summary_results[model_name] = {
|
| 209 |
+
'mean_accuracy': np.mean(results['accuracy']),
|
| 210 |
+
'std_accuracy': np.std(results['accuracy']),
|
| 211 |
+
'mean_f1': np.mean(results['f1_score']),
|
| 212 |
+
'std_f1': np.std(results['f1_score']),
|
| 213 |
+
'mean_precision': np.mean(results['precision']),
|
| 214 |
+
'std_precision': np.std(results['precision']),
|
| 215 |
+
'mean_recall': np.mean(results['recall']),
|
| 216 |
+
'std_recall': np.std(results['recall']),
|
| 217 |
+
'fold_accuracy': results['accuracy'],
|
| 218 |
+
'fold_f1': results['f1_score']
|
| 219 |
+
}
|
| 220 |
+
|
| 221 |
+
return summary_results
|
| 222 |
+
|
| 223 |
+
def plot_kfold_results(summary_results, save_dir):
|
| 224 |
+
"""K-fold cross validation sonuçlarını gösteren grafikler oluştur"""
|
| 225 |
+
|
| 226 |
+
# Accuracy ve F1 için ortalama değerleri çiz
|
| 227 |
+
plt.figure(figsize=(14, 7))
|
| 228 |
+
|
| 229 |
+
# Model isimlerini ve ortalama değerleri çıkart
|
| 230 |
+
model_names = list(summary_results.keys())
|
| 231 |
+
model_names = [Path(name).stem for name in model_names] # .pth uzantısını kaldır
|
| 232 |
+
|
| 233 |
+
# Doğruluk ve F1 skorları
|
| 234 |
+
mean_accuracy = [summary_results[model]['mean_accuracy'] for model in summary_results]
|
| 235 |
+
std_accuracy = [summary_results[model]['std_accuracy'] for model in summary_results]
|
| 236 |
+
mean_f1 = [summary_results[model]['mean_f1'] for model in summary_results]
|
| 237 |
+
std_f1 = [summary_results[model]['std_f1'] for model in summary_results]
|
| 238 |
+
|
| 239 |
+
# X ekseni konumları
|
| 240 |
+
x = np.arange(len(model_names))
|
| 241 |
+
width = 0.35
|
| 242 |
+
|
| 243 |
+
# Çubuk grafikleri
|
| 244 |
+
fig, ax = plt.subplots(figsize=(12, 8))
|
| 245 |
+
rects1 = ax.bar(x - width/2, mean_accuracy, width, yerr=std_accuracy,
|
| 246 |
+
label='Accuracy', capsize=5, color='cornflowerblue')
|
| 247 |
+
rects2 = ax.bar(x + width/2, mean_f1, width, yerr=std_f1,
|
| 248 |
+
label='F1 Score', capsize=5, color='lightcoral')
|
| 249 |
+
|
| 250 |
+
# Grafik özellikleri
|
| 251 |
+
ax.set_ylabel('Skor')
|
| 252 |
+
ax.set_title('5-Fold Cross Validation Ortalama Performans (Ortalama ± Std)')
|
| 253 |
+
ax.set_xticks(x)
|
| 254 |
+
ax.set_xticklabels(model_names)
|
| 255 |
+
ax.legend()
|
| 256 |
+
ax.set_ylim(0, 1.0)
|
| 257 |
+
|
| 258 |
+
# Çubukların üstüne değerleri ekle
|
| 259 |
+
def add_labels(rects):
|
| 260 |
+
for rect in rects:
|
| 261 |
+
height = rect.get_height()
|
| 262 |
+
ax.annotate(f'{height:.3f}',
|
| 263 |
+
xy=(rect.get_x() + rect.get_width() / 2, height),
|
| 264 |
+
xytext=(0, 3), # 3 points vertical offset
|
| 265 |
+
textcoords="offset points",
|
| 266 |
+
ha='center', va='bottom')
|
| 267 |
+
|
| 268 |
+
add_labels(rects1)
|
| 269 |
+
add_labels(rects2)
|
| 270 |
+
|
| 271 |
+
plt.tight_layout()
|
| 272 |
+
|
| 273 |
+
# Grafiği kaydet
|
| 274 |
+
save_path = Path(save_dir) / "kfold_mean_performance.png"
|
| 275 |
+
plt.savefig(save_path, dpi=300)
|
| 276 |
+
plt.close()
|
| 277 |
+
|
| 278 |
+
# Her bir fold için performansı çiz
|
| 279 |
+
plt.figure(figsize=(18, 12))
|
| 280 |
+
|
| 281 |
+
# Accuracy için
|
| 282 |
+
plt.subplot(2, 1, 1)
|
| 283 |
+
for model_name in summary_results:
|
| 284 |
+
model_stem = Path(model_name).stem
|
| 285 |
+
plt.plot(range(1, K_FOLDS + 1), summary_results[model_name]['fold_accuracy'],
|
| 286 |
+
marker='o', linestyle='-', label=model_stem)
|
| 287 |
+
|
| 288 |
+
plt.title('Her Fold için Accuracy Değerleri')
|
| 289 |
+
plt.xlabel('Fold')
|
| 290 |
+
plt.ylabel('Accuracy')
|
| 291 |
+
plt.xticks(range(1, K_FOLDS + 1))
|
| 292 |
+
plt.ylim(0, 1.0)
|
| 293 |
+
plt.grid(True, linestyle='--', alpha=0.7)
|
| 294 |
+
plt.legend()
|
| 295 |
+
|
| 296 |
+
# F1 Skor için
|
| 297 |
+
plt.subplot(2, 1, 2)
|
| 298 |
+
for model_name in summary_results:
|
| 299 |
+
model_stem = Path(model_name).stem
|
| 300 |
+
plt.plot(range(1, K_FOLDS + 1), summary_results[model_name]['fold_f1'],
|
| 301 |
+
marker='o', linestyle='-', label=model_stem)
|
| 302 |
+
|
| 303 |
+
plt.title('Her Fold için F1 Değerleri')
|
| 304 |
+
plt.xlabel('Fold')
|
| 305 |
+
plt.ylabel('F1 Score')
|
| 306 |
+
plt.xticks(range(1, K_FOLDS + 1))
|
| 307 |
+
plt.ylim(0, 1.0)
|
| 308 |
+
plt.grid(True, linestyle='--', alpha=0.7)
|
| 309 |
+
plt.legend()
|
| 310 |
+
|
| 311 |
+
plt.tight_layout()
|
| 312 |
+
|
| 313 |
+
# Grafiği kaydet
|
| 314 |
+
save_path = Path(save_dir) / "kfold_all_folds_performance.png"
|
| 315 |
+
plt.savefig(save_path, dpi=300)
|
| 316 |
+
plt.close()
|
| 317 |
+
|
| 318 |
+
def main():
|
| 319 |
+
# Veri dizini ve sonuç dizini
|
| 320 |
+
art_dataset_dir = 'Art Dataset'
|
| 321 |
+
models_dir = 'models'
|
| 322 |
+
results_dir = 'kfold_evaluation_results'
|
| 323 |
+
|
| 324 |
+
# Sonuç dizinini oluştur
|
| 325 |
+
os.makedirs(results_dir, exist_ok=True)
|
| 326 |
+
|
| 327 |
+
# Dengeli veri setini oluştur - her sınıftan maksimum örnek sayısını sınırla
|
| 328 |
+
samples, class_to_idx = create_balanced_dataset(art_dataset_dir, max_per_class=MAX_SAMPLES_PER_CLASS)
|
| 329 |
+
dataset = ArtDataset(samples, transform=test_transform, class_to_idx=class_to_idx)
|
| 330 |
+
|
| 331 |
+
num_classes = len(dataset.classes)
|
| 332 |
+
print(f"Sanat sınıfları: {len(dataset.classes)}")
|
| 333 |
+
|
| 334 |
+
# Model dosyalarını bul (.DS_Store gibi dosyaları hariç tut)
|
| 335 |
+
model_paths = [os.path.join(models_dir, f) for f in os.listdir(models_dir)
|
| 336 |
+
if f.endswith('.pth') and not f.startswith('.')]
|
| 337 |
+
|
| 338 |
+
# K-fold cross validation ile modelleri değerlendir
|
| 339 |
+
summary_results = k_fold_cross_validation(dataset, model_paths, num_classes, k=K_FOLDS)
|
| 340 |
+
|
| 341 |
+
# Sonuçları görselleştir
|
| 342 |
+
plot_kfold_results(summary_results, results_dir)
|
| 343 |
+
|
| 344 |
+
# Sonuçları yazdır
|
| 345 |
+
print("\n5-Fold Cross Validation Sonuçları:")
|
| 346 |
+
for model_name, results in summary_results.items():
|
| 347 |
+
print(f"\n{model_name}:")
|
| 348 |
+
print(f" Ortalama Accuracy: {results['mean_accuracy']:.4f} ± {results['std_accuracy']:.4f}")
|
| 349 |
+
print(f" Ortalama F1 Score: {results['mean_f1']:.4f} ± {results['std_f1']:.4f}")
|
| 350 |
+
print(f" Ortalama Precision: {results['mean_precision']:.4f} ± {results['std_precision']:.4f}")
|
| 351 |
+
print(f" Ortalama Recall: {results['mean_recall']:.4f} ± {results['std_recall']:.4f}")
|
| 352 |
+
|
| 353 |
+
# Sonuçları CSV dosyasına kaydet
|
| 354 |
+
results_summary = []
|
| 355 |
+
for model_name, results in summary_results.items():
|
| 356 |
+
row = {
|
| 357 |
+
'model': model_name,
|
| 358 |
+
'mean_accuracy': results['mean_accuracy'],
|
| 359 |
+
'std_accuracy': results['std_accuracy'],
|
| 360 |
+
'mean_f1': results['mean_f1'],
|
| 361 |
+
'std_f1': results['std_f1'],
|
| 362 |
+
'mean_precision': results['mean_precision'],
|
| 363 |
+
'std_precision': results['std_precision'],
|
| 364 |
+
'mean_recall': results['mean_recall'],
|
| 365 |
+
'std_recall': results['std_recall']
|
| 366 |
+
}
|
| 367 |
+
results_summary.append(row)
|
| 368 |
+
|
| 369 |
+
summary_df = pd.DataFrame(results_summary)
|
| 370 |
+
summary_df.to_csv(f"{results_dir}/kfold_model_comparison_summary.csv", index=False)
|
| 371 |
+
|
| 372 |
+
print(f"\nDeğerlendirme tamamlandı. Sonuçlar '{results_dir}' dizininde.")
|
| 373 |
+
|
| 374 |
+
if __name__ == "__main__":
|
| 375 |
+
# Tekrar üretilebilirlik için seed ayarla
|
| 376 |
+
random.seed(42)
|
| 377 |
+
np.random.seed(42)
|
| 378 |
+
torch.manual_seed(42)
|
| 379 |
+
main()
|
trainer.py
ADDED
|
@@ -0,0 +1,556 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
# -*- coding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
import random
|
| 6 |
+
import numpy as np
|
| 7 |
+
import pandas as pd
|
| 8 |
+
import matplotlib.pyplot as plt
|
| 9 |
+
import time
|
| 10 |
+
from tqdm.auto import tqdm
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
from collections import Counter
|
| 13 |
+
from PIL import Image
|
| 14 |
+
|
| 15 |
+
import torch
|
| 16 |
+
import torch.nn as nn
|
| 17 |
+
from torch.utils.data import DataLoader, Dataset
|
| 18 |
+
import torchvision
|
| 19 |
+
import torchvision.transforms as T
|
| 20 |
+
from torchvision.datasets import ImageFolder
|
| 21 |
+
from torchvision.models import resnet34, ResNet34_Weights
|
| 22 |
+
|
| 23 |
+
# A.1. Enable CPU fallback for MPS device
|
| 24 |
+
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
|
| 25 |
+
|
| 26 |
+
# Enable MPS optimizations for PyTorch 2.2+
|
| 27 |
+
if hasattr(torch.backends.mps, 'enable_workflow_compiling'):
|
| 28 |
+
print("Enabling MPS workflow compiling...")
|
| 29 |
+
torch.backends.mps.enable_workflow_compiling = True
|
| 30 |
+
|
| 31 |
+
# A.1. Check Metal 3 / MPS support
|
| 32 |
+
def setup_device():
|
| 33 |
+
"""Checks Metal 3 / MPS support and returns appropriate device"""
|
| 34 |
+
print("PyTorch version:", torch.__version__)
|
| 35 |
+
|
| 36 |
+
if torch.backends.mps.is_available() and torch.backends.mps.is_built():
|
| 37 |
+
print("Metal Performance Shaders (MPS) available.")
|
| 38 |
+
print("PYTORCH_ENABLE_MPS_FALLBACK=1 set - CPU will be used for unsupported operations.")
|
| 39 |
+
device = torch.device("mps")
|
| 40 |
+
|
| 41 |
+
# Force GPU usage
|
| 42 |
+
dummy_tensor = torch.ones(1, device=device)
|
| 43 |
+
result = dummy_tensor + 1
|
| 44 |
+
is_mps_working = (result.device.type == 'mps')
|
| 45 |
+
|
| 46 |
+
if is_mps_working:
|
| 47 |
+
print(f"MPS successfully tested: {result}")
|
| 48 |
+
print(f"Training device: {device}")
|
| 49 |
+
return device
|
| 50 |
+
else:
|
| 51 |
+
print("MPS is available but simple operation failed, using CPU.")
|
| 52 |
+
return torch.device("cpu")
|
| 53 |
+
else:
|
| 54 |
+
print("MPS not available, using CPU.")
|
| 55 |
+
device = torch.device("cpu")
|
| 56 |
+
print(f"Training device: {device}")
|
| 57 |
+
return device
|
| 58 |
+
|
| 59 |
+
# A.1.1. Dataset analysis
|
| 60 |
+
def analyze_dataset(data_path):
|
| 61 |
+
"""Analyzes the dataset and calculates the number of samples per class"""
|
| 62 |
+
data_path = Path(data_path)
|
| 63 |
+
classes = [d.name for d in data_path.iterdir() if d.is_dir()]
|
| 64 |
+
class_counts = {}
|
| 65 |
+
|
| 66 |
+
# Calculate the number of samples in each class
|
| 67 |
+
for cls in tqdm(classes, desc="Analyzing classes"):
|
| 68 |
+
class_path = data_path / cls
|
| 69 |
+
class_counts[cls] = len(list(class_path.glob('*.jpg')))
|
| 70 |
+
|
| 71 |
+
# Display results
|
| 72 |
+
df = pd.DataFrame({'Class': list(class_counts.keys()),
|
| 73 |
+
'Number of Samples': list(class_counts.values())})
|
| 74 |
+
df = df.sort_values('Number of Samples', ascending=False).reset_index(drop=True)
|
| 75 |
+
|
| 76 |
+
# Calculate statistics
|
| 77 |
+
total_samples = df['Number of Samples'].sum()
|
| 78 |
+
mean_samples = df['Number of Samples'].mean()
|
| 79 |
+
min_samples = df['Number of Samples'].min()
|
| 80 |
+
max_samples = df['Number of Samples'].max()
|
| 81 |
+
|
| 82 |
+
print(f"Total number of samples: {total_samples}")
|
| 83 |
+
print(f"Average number of samples: {mean_samples:.1f}")
|
| 84 |
+
print(f"Minimum number of samples: {min_samples} ({df.iloc[-1]['Class']})")
|
| 85 |
+
print(f"Maximum number of samples: {max_samples} ({df.iloc[0]['Class']})")
|
| 86 |
+
|
| 87 |
+
# Visualize class distribution
|
| 88 |
+
plt.figure(figsize=(14, 8))
|
| 89 |
+
plt.bar(df['Class'], df['Number of Samples'])
|
| 90 |
+
plt.xticks(rotation=90)
|
| 91 |
+
plt.title('Art Styles - Sample Distribution')
|
| 92 |
+
plt.xlabel('Class')
|
| 93 |
+
plt.ylabel('Number of Samples')
|
| 94 |
+
plt.tight_layout()
|
| 95 |
+
plt.savefig('results/class_distribution.png')
|
| 96 |
+
plt.close()
|
| 97 |
+
|
| 98 |
+
return df, classes
|
| 99 |
+
|
| 100 |
+
# A.2.2. Custom dataset class - Performs data augmentation on CPU
|
| 101 |
+
class ArtStyleDataset(Dataset):
|
| 102 |
+
def __init__(self, root_dir, transform=None, target_transform=None, train=True, valid_pct=0.2, seed=42):
|
| 103 |
+
self.root_dir = Path(root_dir)
|
| 104 |
+
self.transform = transform
|
| 105 |
+
self.target_transform = target_transform
|
| 106 |
+
self.train = train
|
| 107 |
+
|
| 108 |
+
# Get all images and labels
|
| 109 |
+
all_imgs = []
|
| 110 |
+
class_names = [d.name for d in self.root_dir.iterdir() if d.is_dir()]
|
| 111 |
+
self.class_to_idx = {cls_name: i for i, cls_name in enumerate(sorted(class_names))}
|
| 112 |
+
|
| 113 |
+
# Collect images and labels for each class
|
| 114 |
+
for cls_name in class_names:
|
| 115 |
+
cls_path = self.root_dir / cls_name
|
| 116 |
+
cls_idx = self.class_to_idx[cls_name]
|
| 117 |
+
for img_path in cls_path.glob('*.jpg'):
|
| 118 |
+
all_imgs.append((str(img_path), cls_idx))
|
| 119 |
+
|
| 120 |
+
# Shuffle data
|
| 121 |
+
random.seed(seed)
|
| 122 |
+
random.shuffle(all_imgs)
|
| 123 |
+
|
| 124 |
+
# Split into training and validation sets
|
| 125 |
+
n_valid = int(len(all_imgs) * valid_pct)
|
| 126 |
+
if train:
|
| 127 |
+
self.imgs = all_imgs[n_valid:]
|
| 128 |
+
else:
|
| 129 |
+
self.imgs = all_imgs[:n_valid]
|
| 130 |
+
|
| 131 |
+
self.classes = sorted(class_names)
|
| 132 |
+
|
| 133 |
+
def __len__(self):
|
| 134 |
+
return len(self.imgs)
|
| 135 |
+
|
| 136 |
+
def __getitem__(self, idx):
|
| 137 |
+
img_path, label = self.imgs[idx]
|
| 138 |
+
img = Image.open(img_path).convert('RGB')
|
| 139 |
+
|
| 140 |
+
if self.transform:
|
| 141 |
+
img = self.transform(img)
|
| 142 |
+
|
| 143 |
+
if self.target_transform:
|
| 144 |
+
label = self.target_transform(label)
|
| 145 |
+
|
| 146 |
+
return img, label
|
| 147 |
+
|
| 148 |
+
# A.2. Creating DataLoaders using PyTorch native structures
|
| 149 |
+
def create_dataloaders(data_path, batch_size=32, img_size=224, augment=True,
|
| 150 |
+
balance_method='weighted', valid_pct=0.2, seed=42):
|
| 151 |
+
"""Creates PyTorch DataLoaders"""
|
| 152 |
+
|
| 153 |
+
# A.2.4. Define data transformations
|
| 154 |
+
# Transformations to run on CPU
|
| 155 |
+
if augment:
|
| 156 |
+
# A word on presizing:
|
| 157 |
+
# 1. Increase the size (item by item) - done by RandomResizedCrop
|
| 158 |
+
# 2. Apply augmentation (batch by batch) - done by various transforms
|
| 159 |
+
# 3. Decrease the size (batch by batch) - handled by normalization
|
| 160 |
+
# 4. Presizing avoids artifacts when applying augmentations (e.g., rotation)
|
| 161 |
+
train_transforms = T.Compose([
|
| 162 |
+
T.RandomResizedCrop(img_size, scale=(0.8, 1.0)), # Increase size item by item
|
| 163 |
+
T.RandomHorizontalFlip(),
|
| 164 |
+
T.RandomRotation(10), # Apply augmentation batch by batch
|
| 165 |
+
T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
|
| 166 |
+
T.ToTensor(),
|
| 167 |
+
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Decrease size batch by batch
|
| 168 |
+
])
|
| 169 |
+
else:
|
| 170 |
+
train_transforms = T.Compose([
|
| 171 |
+
T.Resize(int(img_size*1.14)),
|
| 172 |
+
T.CenterCrop(img_size),
|
| 173 |
+
T.ToTensor(),
|
| 174 |
+
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 175 |
+
])
|
| 176 |
+
|
| 177 |
+
valid_transforms = T.Compose([
|
| 178 |
+
T.Resize(int(img_size*1.14)),
|
| 179 |
+
T.CenterCrop(img_size),
|
| 180 |
+
T.ToTensor(),
|
| 181 |
+
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
|
| 182 |
+
])
|
| 183 |
+
|
| 184 |
+
# A.2.1. Define the blocks (dataset creation)
|
| 185 |
+
train_dataset = ArtStyleDataset(data_path, transform=train_transforms, train=True, valid_pct=valid_pct, seed=seed)
|
| 186 |
+
valid_dataset = ArtStyleDataset(data_path, transform=valid_transforms, train=False, valid_pct=valid_pct, seed=seed)
|
| 187 |
+
|
| 188 |
+
# A.2.2. Define the means of getting data into DataBlock
|
| 189 |
+
# Calculate weights for weighted sampling
|
| 190 |
+
if balance_method == 'weighted' and train_dataset:
|
| 191 |
+
# Count classes
|
| 192 |
+
class_counts = Counter([label for _, label in train_dataset.imgs])
|
| 193 |
+
total = sum(class_counts.values())
|
| 194 |
+
|
| 195 |
+
# Calculate weights (classes with fewer examples will get higher weights)
|
| 196 |
+
weights = [total / class_counts[train_dataset.imgs[i][1]] for i in range(len(train_dataset))]
|
| 197 |
+
sampler = torch.utils.data.WeightedRandomSampler(weights, len(weights))
|
| 198 |
+
train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=sampler, num_workers=2, pin_memory=True)
|
| 199 |
+
else:
|
| 200 |
+
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
|
| 201 |
+
|
| 202 |
+
valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)
|
| 203 |
+
|
| 204 |
+
class_names = train_dataset.classes
|
| 205 |
+
|
| 206 |
+
# Display data loader summary
|
| 207 |
+
print(f"Training dataset: {len(train_dataset)} images")
|
| 208 |
+
print(f"Validation dataset: {len(valid_dataset)} images")
|
| 209 |
+
print(f"Classes: {len(class_names)}")
|
| 210 |
+
|
| 211 |
+
# Return the data loaders
|
| 212 |
+
return train_loader, valid_loader, class_names
|
| 213 |
+
|
| 214 |
+
# PyTorch native training loop
|
| 215 |
+
def train_epoch(model, dataloader, criterion, optimizer, device):
|
| 216 |
+
model.train()
|
| 217 |
+
running_loss = 0.0
|
| 218 |
+
correct = 0
|
| 219 |
+
total = 0
|
| 220 |
+
batch_times = []
|
| 221 |
+
|
| 222 |
+
# Show progress with tqdm
|
| 223 |
+
progress_bar = tqdm(dataloader, desc="Training", leave=False)
|
| 224 |
+
|
| 225 |
+
# Monitor MPS memory usage
|
| 226 |
+
if device.type == 'mps':
|
| 227 |
+
print(f"MPS memory usage (start): {torch.mps.current_allocated_memory() / 1024**2:.2f} MB")
|
| 228 |
+
|
| 229 |
+
start_time = time.time()
|
| 230 |
+
for inputs, labels in progress_bar:
|
| 231 |
+
batch_start = time.time()
|
| 232 |
+
|
| 233 |
+
# Move data to device
|
| 234 |
+
inputs, labels = inputs.to(device), labels.to(device)
|
| 235 |
+
|
| 236 |
+
# Verify training device
|
| 237 |
+
if total == 0:
|
| 238 |
+
print(f"Training tensor device: {inputs.device}, Model device: {next(model.parameters()).device}")
|
| 239 |
+
|
| 240 |
+
# Zero gradients
|
| 241 |
+
optimizer.zero_grad()
|
| 242 |
+
|
| 243 |
+
# Forward pass
|
| 244 |
+
outputs = model(inputs)
|
| 245 |
+
loss = criterion(outputs, labels)
|
| 246 |
+
|
| 247 |
+
# Backward propagation
|
| 248 |
+
loss.backward()
|
| 249 |
+
optimizer.step()
|
| 250 |
+
|
| 251 |
+
# Measure processing time
|
| 252 |
+
batch_end = time.time()
|
| 253 |
+
batch_time = batch_end - batch_start
|
| 254 |
+
batch_times.append(batch_time)
|
| 255 |
+
|
| 256 |
+
# Update statistics
|
| 257 |
+
running_loss += loss.item() * inputs.size(0)
|
| 258 |
+
_, predicted = outputs.max(1)
|
| 259 |
+
total += labels.size(0)
|
| 260 |
+
correct += predicted.eq(labels).sum().item()
|
| 261 |
+
|
| 262 |
+
# Update progress bar
|
| 263 |
+
progress_bar.set_postfix({'loss': loss.item(), 'acc': 100 * correct / total})
|
| 264 |
+
|
| 265 |
+
# Calculate final statistics
|
| 266 |
+
avg_loss = running_loss / len(dataloader.dataset)
|
| 267 |
+
avg_acc = 100 * correct / total
|
| 268 |
+
avg_time = sum(batch_times) / len(batch_times)
|
| 269 |
+
total_time = time.time() - start_time
|
| 270 |
+
|
| 271 |
+
# Monitoring memory usage
|
| 272 |
+
if device.type == 'mps':
|
| 273 |
+
print(f"MPS memory usage (end): {torch.mps.current_allocated_memory() / 1024**2:.2f} MB")
|
| 274 |
+
|
| 275 |
+
# Print statistics
|
| 276 |
+
print(f"Training - Loss: {avg_loss:.4f}, Acc: {avg_acc:.2f}%, Time: {total_time:.1f}s, Avg batch: {avg_time:.3f}s")
|
| 277 |
+
|
| 278 |
+
return avg_loss, avg_acc
|
| 279 |
+
|
| 280 |
+
# A.3. Inspect the DataBlock via dataloader
|
| 281 |
+
def validate_epoch(model, dataloader, criterion, device):
|
| 282 |
+
# Set model to evaluation mode
|
| 283 |
+
model.eval()
|
| 284 |
+
running_loss = 0.0
|
| 285 |
+
correct = 0
|
| 286 |
+
total = 0
|
| 287 |
+
|
| 288 |
+
# Disable gradient calculation
|
| 289 |
+
with torch.no_grad():
|
| 290 |
+
progress_bar = tqdm(dataloader, desc="Validation", leave=False)
|
| 291 |
+
|
| 292 |
+
for inputs, labels in progress_bar:
|
| 293 |
+
# Move data to device
|
| 294 |
+
inputs, labels = inputs.to(device), labels.to(device)
|
| 295 |
+
|
| 296 |
+
# Forward pass
|
| 297 |
+
outputs = model(inputs)
|
| 298 |
+
loss = criterion(outputs, labels)
|
| 299 |
+
|
| 300 |
+
# Update statistics
|
| 301 |
+
running_loss += loss.item() * inputs.size(0)
|
| 302 |
+
_, predicted = outputs.max(1)
|
| 303 |
+
total += labels.size(0)
|
| 304 |
+
correct += predicted.eq(labels).sum().item()
|
| 305 |
+
|
| 306 |
+
# Update progress bar
|
| 307 |
+
progress_bar.set_postfix({'loss': loss.item(), 'acc': 100 * correct / total})
|
| 308 |
+
|
| 309 |
+
# Calculate final statistics
|
| 310 |
+
avg_loss = running_loss / len(dataloader.dataset)
|
| 311 |
+
avg_acc = 100 * correct / total
|
| 312 |
+
|
| 313 |
+
# Print statistics
|
| 314 |
+
print(f"Validation - Loss: {avg_loss:.4f}, Acc: {avg_acc:.2f}%")
|
| 315 |
+
|
| 316 |
+
return avg_loss, avg_acc
|
| 317 |
+
|
| 318 |
+
# A.4. Train a simple model
|
| 319 |
+
def train_model(train_loader, valid_loader, class_names, device,
|
| 320 |
+
model_name="resnet34", lr=1e-3, epochs=10,
|
| 321 |
+
freeze_epochs=3, unfreeze_epochs=7):
|
| 322 |
+
"""Trains a model using transfer learning with discriminative learning rates"""
|
| 323 |
+
print(f"\nTraining {model_name} model for {epochs} epochs (freeze: {freeze_epochs}, unfreeze: {unfreeze_epochs})")
|
| 324 |
+
|
| 325 |
+
# B.3. Transfer Learning setup
|
| 326 |
+
# Create ResNet34 model with pretrained weights
|
| 327 |
+
if model_name == "resnet34":
|
| 328 |
+
model = resnet34(weights=ResNet34_Weights.DEFAULT)
|
| 329 |
+
|
| 330 |
+
# Replace the final layer with a new one for our classes
|
| 331 |
+
num_classes = len(class_names)
|
| 332 |
+
model.fc = nn.Linear(512, num_classes)
|
| 333 |
+
else:
|
| 334 |
+
raise ValueError(f"Unsupported model: {model_name}")
|
| 335 |
+
|
| 336 |
+
# Move model to device
|
| 337 |
+
model = model.to(device)
|
| 338 |
+
|
| 339 |
+
# B.3. Freeze all weights except the final layer
|
| 340 |
+
for param in model.parameters():
|
| 341 |
+
param.requires_grad = False
|
| 342 |
+
for param in model.fc.parameters():
|
| 343 |
+
param.requires_grad = True
|
| 344 |
+
|
| 345 |
+
# Set up loss function
|
| 346 |
+
criterion = nn.CrossEntropyLoss()
|
| 347 |
+
|
| 348 |
+
# Training history for plotting
|
| 349 |
+
history = {
|
| 350 |
+
'train_loss': [],
|
| 351 |
+
'train_acc': [],
|
| 352 |
+
'val_loss': [],
|
| 353 |
+
'val_acc': []
|
| 354 |
+
}
|
| 355 |
+
|
| 356 |
+
# Training in two phases: first frozen, then unfrozen
|
| 357 |
+
total_start_time = time.time()
|
| 358 |
+
|
| 359 |
+
# Phase 1: Train with frozen layers
|
| 360 |
+
if freeze_epochs > 0:
|
| 361 |
+
print("\n=== Phase 1: Training with frozen feature extractor ===")
|
| 362 |
+
optimizer = torch.optim.Adam(model.fc.parameters(), lr=lr)
|
| 363 |
+
|
| 364 |
+
for epoch in range(freeze_epochs):
|
| 365 |
+
print(f"\nEpoch {epoch+1}/{freeze_epochs}")
|
| 366 |
+
train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
|
| 367 |
+
val_loss, val_acc = validate_epoch(model, valid_loader, criterion, device)
|
| 368 |
+
|
| 369 |
+
# Record history
|
| 370 |
+
history['train_loss'].append(train_loss)
|
| 371 |
+
history['train_acc'].append(train_acc)
|
| 372 |
+
history['val_loss'].append(val_loss)
|
| 373 |
+
history['val_acc'].append(val_acc)
|
| 374 |
+
|
| 375 |
+
# Phase 2: Unfreeze and train with discriminative learning rates
|
| 376 |
+
if unfreeze_epochs > 0:
|
| 377 |
+
print("\n=== Phase 2: Fine-tuning with discriminative learning rates ===")
|
| 378 |
+
|
| 379 |
+
# B.3. Unfreeze all weights for fine-tuning
|
| 380 |
+
for param in model.parameters():
|
| 381 |
+
param.requires_grad = True
|
| 382 |
+
|
| 383 |
+
# B.4. Discriminative learning rates
|
| 384 |
+
# Group parameters by layer to apply different learning rates
|
| 385 |
+
# Earlier layers get smaller learning rates (already well-trained)
|
| 386 |
+
# Later layers get higher learning rates (need more adaptation)
|
| 387 |
+
layer_params = [
|
| 388 |
+
{'params': model.layer1.parameters(), 'lr': lr/9}, # Earlier layers - smaller learning rate
|
| 389 |
+
{'params': model.layer2.parameters(), 'lr': lr/3},
|
| 390 |
+
{'params': model.layer3.parameters(), 'lr': lr/3},
|
| 391 |
+
{'params': model.layer4.parameters(), 'lr': lr}, # Later layers - higher learning rate
|
| 392 |
+
{'params': model.fc.parameters(), 'lr': lr*3} # New classification layer - highest learning rate
|
| 393 |
+
]
|
| 394 |
+
|
| 395 |
+
optimizer = torch.optim.Adam(layer_params, lr=lr)
|
| 396 |
+
scheduler = torch.optim.lr_scheduler.OneCycleLR(
|
| 397 |
+
optimizer, max_lr=lr*3, total_steps=unfreeze_epochs * len(train_loader)
|
| 398 |
+
)
|
| 399 |
+
|
| 400 |
+
for epoch in range(unfreeze_epochs):
|
| 401 |
+
print(f"\nEpoch {freeze_epochs+epoch+1}/{epochs}")
|
| 402 |
+
train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
|
| 403 |
+
val_loss, val_acc = validate_epoch(model, valid_loader, criterion, device)
|
| 404 |
+
|
| 405 |
+
# Record history
|
| 406 |
+
history['train_loss'].append(train_loss)
|
| 407 |
+
history['train_acc'].append(train_acc)
|
| 408 |
+
history['val_loss'].append(val_loss)
|
| 409 |
+
history['val_acc'].append(val_acc)
|
| 410 |
+
|
| 411 |
+
total_time = time.time() - total_start_time
|
| 412 |
+
print(f"\nTotal training time: {total_time:.1f} seconds ({total_time/60:.1f} minutes)")
|
| 413 |
+
|
| 414 |
+
# Save model
|
| 415 |
+
os.makedirs('models', exist_ok=True)
|
| 416 |
+
torch.save(model.state_dict(), f'models/model_final.pth')
|
| 417 |
+
print(f"Model saved to models/model_final.pth")
|
| 418 |
+
|
| 419 |
+
# A.4.2. Visualize training history
|
| 420 |
+
plt.figure(figsize=(12, 5))
|
| 421 |
+
plt.subplot(1, 2, 1)
|
| 422 |
+
plt.plot(history['train_loss'], label='Train')
|
| 423 |
+
plt.plot(history['val_loss'], label='Validation')
|
| 424 |
+
plt.title('Loss')
|
| 425 |
+
plt.xlabel('Epoch')
|
| 426 |
+
plt.legend()
|
| 427 |
+
|
| 428 |
+
plt.subplot(1, 2, 2)
|
| 429 |
+
plt.plot(history['train_acc'], label='Train')
|
| 430 |
+
plt.plot(history['val_acc'], label='Validation')
|
| 431 |
+
plt.title('Accuracy')
|
| 432 |
+
plt.xlabel('Epoch')
|
| 433 |
+
plt.legend()
|
| 434 |
+
|
| 435 |
+
plt.tight_layout()
|
| 436 |
+
plt.savefig('results/training_history.png')
|
| 437 |
+
plt.close()
|
| 438 |
+
|
| 439 |
+
# A.4.3. Create confusion matrix
|
| 440 |
+
model.eval()
|
| 441 |
+
all_preds = []
|
| 442 |
+
all_labels = []
|
| 443 |
+
|
| 444 |
+
with torch.no_grad():
|
| 445 |
+
for inputs, labels in tqdm(valid_loader, desc="Creating confusion matrix"):
|
| 446 |
+
inputs, labels = inputs.to(device), labels.to(device)
|
| 447 |
+
outputs = model(inputs)
|
| 448 |
+
_, preds = outputs.max(1)
|
| 449 |
+
|
| 450 |
+
all_preds.extend(preds.cpu().numpy())
|
| 451 |
+
all_labels.extend(labels.cpu().numpy())
|
| 452 |
+
|
| 453 |
+
# Create and plot confusion matrix
|
| 454 |
+
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
|
| 455 |
+
cm = confusion_matrix(all_labels, all_preds)
|
| 456 |
+
|
| 457 |
+
plt.figure(figsize=(20, 20))
|
| 458 |
+
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=class_names)
|
| 459 |
+
disp.plot(cmap='Blues', values_format='d')
|
| 460 |
+
plt.title('Confusion Matrix')
|
| 461 |
+
plt.xticks(rotation=90)
|
| 462 |
+
plt.tight_layout()
|
| 463 |
+
plt.savefig('results/confusion_matrix.png')
|
| 464 |
+
plt.close()
|
| 465 |
+
|
| 466 |
+
return model, history
|
| 467 |
+
|
| 468 |
+
def main():
|
| 469 |
+
# Setup environment
|
| 470 |
+
device = setup_device()
|
| 471 |
+
|
| 472 |
+
# A.1. Download and analyze the data
|
| 473 |
+
data_path = "Art Dataset"
|
| 474 |
+
os.makedirs('results', exist_ok=True)
|
| 475 |
+
|
| 476 |
+
# A.1.1. Inspect the data layout
|
| 477 |
+
print("\n===== A.1.1. Inspecting data layout =====")
|
| 478 |
+
df, classes = analyze_dataset(data_path)
|
| 479 |
+
|
| 480 |
+
# A.2. Create the DataBlock and dataloaders
|
| 481 |
+
print("\n===== A.2. Creating DataLoaders =====")
|
| 482 |
+
train_loader, valid_loader, class_names = create_dataloaders(
|
| 483 |
+
data_path, batch_size=32, img_size=224, augment=True,
|
| 484 |
+
balance_method='weighted', valid_pct=0.2
|
| 485 |
+
)
|
| 486 |
+
|
| 487 |
+
# A.3. Inspect the DataBlock via dataloader
|
| 488 |
+
print("\n===== A.3. Inspecting DataBlock =====")
|
| 489 |
+
|
| 490 |
+
# A.3.1. Show batch
|
| 491 |
+
def visualize_batch(dataloader, num_images=16):
|
| 492 |
+
"""Display a batch of images from the dataloader"""
|
| 493 |
+
# Get a batch
|
| 494 |
+
images, labels = next(iter(dataloader))
|
| 495 |
+
images = images[:num_images]
|
| 496 |
+
labels = labels[:num_images]
|
| 497 |
+
|
| 498 |
+
# Convert tensors back to images
|
| 499 |
+
# (unnormalize first)
|
| 500 |
+
mean = torch.tensor([0.485, 0.456, 0.406])
|
| 501 |
+
std = torch.tensor([0.229, 0.224, 0.225])
|
| 502 |
+
|
| 503 |
+
# Create a grid of images
|
| 504 |
+
fig, axes = plt.subplots(nrows=4, ncols=4, figsize=(12, 12))
|
| 505 |
+
for i, (img, label) in enumerate(zip(images, labels)):
|
| 506 |
+
# Unnormalize
|
| 507 |
+
img = img.cpu() * std[:, None, None] + mean[:, None, None]
|
| 508 |
+
# Convert to numpy
|
| 509 |
+
img = img.permute(1, 2, 0).numpy()
|
| 510 |
+
# Clip values to valid range
|
| 511 |
+
img = np.clip(img, 0, 1)
|
| 512 |
+
|
| 513 |
+
# Get class name
|
| 514 |
+
class_name = class_names[label]
|
| 515 |
+
class_name = class_name.replace('_', ' ')
|
| 516 |
+
|
| 517 |
+
# Plot
|
| 518 |
+
row, col = i // 4, i % 4
|
| 519 |
+
axes[row, col].imshow(img)
|
| 520 |
+
axes[row, col].set_title(class_name)
|
| 521 |
+
axes[row, col].axis('off')
|
| 522 |
+
|
| 523 |
+
plt.tight_layout()
|
| 524 |
+
plt.savefig('results/batch_preview.png')
|
| 525 |
+
plt.close()
|
| 526 |
+
print("Batch preview saved to results/batch_preview.png")
|
| 527 |
+
|
| 528 |
+
# A.3.1. Show batch: dataloader.show_batch()
|
| 529 |
+
print("\n===== A.3.1. Showing batch =====")
|
| 530 |
+
visualize_batch(train_loader)
|
| 531 |
+
|
| 532 |
+
# A.3.2. Check the labels
|
| 533 |
+
print("\n===== A.3.2. Checking labels =====")
|
| 534 |
+
print(f"Class names: {class_names}")
|
| 535 |
+
|
| 536 |
+
# A.3.3. Summarize the DataBlock
|
| 537 |
+
print("\n===== A.3.3. Summarizing DataBlock =====")
|
| 538 |
+
print(f"Number of classes: {len(class_names)}")
|
| 539 |
+
print(f"Training batches: {len(train_loader)}")
|
| 540 |
+
print(f"Validation batches: {len(valid_loader)}")
|
| 541 |
+
print(f"Batch size: {train_loader.batch_size}")
|
| 542 |
+
print(f"Total training samples: {len(train_loader.dataset)}")
|
| 543 |
+
print(f"Total validation samples: {len(valid_loader.dataset)}")
|
| 544 |
+
|
| 545 |
+
# A.4. Train a simple model
|
| 546 |
+
print("\n===== A.4. Training a simple model =====")
|
| 547 |
+
model, history = train_model(
|
| 548 |
+
train_loader, valid_loader, class_names, device,
|
| 549 |
+
model_name="resnet34", lr=1e-3,
|
| 550 |
+
epochs=10, freeze_epochs=3, unfreeze_epochs=7
|
| 551 |
+
)
|
| 552 |
+
|
| 553 |
+
print("\nTraining complete!")
|
| 554 |
+
|
| 555 |
+
if __name__ == "__main__":
|
| 556 |
+
main()
|