ErdemAtak commited on
Commit
86f4754
·
verified ·
1 Parent(s): c8e5239

Upload 4 files

Browse files
Files changed (4) hide show
  1. art_trainer-mixup.py +234 -0
  2. model_evaluator.py +325 -0
  3. model_evaluator_kfold.py +379 -0
  4. trainer.py +556 -0
art_trainer-mixup.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ import os
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.optim as optim
8
+ from torch.utils.data import DataLoader, Dataset
9
+ from torchvision import transforms as T
10
+ from torchvision.transforms import v2
11
+ from PIL import Image
12
+ from pathlib import Path
13
+ from tqdm.auto import tqdm
14
+ import random
15
+ import numpy as np
16
+
17
+ # A.1. Check device availability and setup MPS optimizations
18
+ device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
19
+ torch.set_float32_matmul_precision('high') # MPS performance optimization
20
+
21
+ # Hyperparameters (Tested optimal values)
22
+ CFG = {
23
+ 'img_size': 224,
24
+ 'batch_size': 32,
25
+ 'lr': 3e-5, # Lower learning rate
26
+ 'weight_decay': 0.05, # Stronger L2 regularization
27
+ 'dropout': 0.5, # Increased dropout
28
+ 'epochs': 30,
29
+ 'mixup_alpha': 0.4,
30
+ 'cutmix_prob': 0.3,
31
+ 'label_smoothing': 0.15,
32
+ 'patience': 5 # For early stopping
33
+ }
34
+
35
+ # A.2.4. Define data transformations with advanced augmentation pipeline
36
+ def create_transforms():
37
+ return {
38
+ 'train': v2.Compose([
39
+ # A word on presizing:
40
+ # 1. Increase the size (item by item)
41
+ v2.RandomResizedCrop(CFG['img_size'], scale=(0.6, 1.0)),
42
+ # 2. Apply augmentation (batch by batch)
43
+ v2.RandomHorizontalFlip(p=0.7),
44
+ v2.RandomVerticalFlip(p=0.3),
45
+ v2.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.3),
46
+ v2.RandomRotation(35),
47
+ v2.RandomAffine(degrees=0, translate=(0.2, 0.2)),
48
+ v2.RandomPerspective(distortion_scale=0.4, p=0.6),
49
+ v2.GaussianBlur(kernel_size=(5, 9)),
50
+ v2.RandomSolarize(threshold=0.3, p=0.2),
51
+ v2.ToTensor(),
52
+ # 3. Decrease the size (batch by batch)
53
+ v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
54
+ v2.RandomErasing(p=0.5, scale=(0.02, 0.2), value='random')
55
+ ]),
56
+ 'val': v2.Compose([
57
+ v2.Resize(CFG['img_size'] + 32),
58
+ v2.CenterCrop(CFG['img_size']),
59
+ v2.ToTensor(),
60
+ v2.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
61
+ ])
62
+ }
63
+
64
+ # A.2.2. Define the means of getting data into DataBlock
65
+ class ArtDataset(Dataset):
66
+ def __init__(self, data_dir, transform=None):
67
+ self.classes = sorted([d.name for d in Path(data_dir).iterdir() if d.is_dir()])
68
+ self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)}
69
+ self.samples = []
70
+ for cls in self.classes:
71
+ cls_dir = Path(data_dir) / cls
72
+ for img_path in cls_dir.glob('*'):
73
+ self.samples.append((img_path, self.class_to_idx[cls]))
74
+ self.transform = transform
75
+
76
+ def __len__(self):
77
+ return len(self.samples)
78
+
79
+ def __getitem__(self, idx):
80
+ img_path, label = self.samples[idx]
81
+ img = Image.open(img_path).convert('RGB')
82
+ if self.transform:
83
+ img = self.transform(img)
84
+ return img, label
85
+
86
+ # B.4. Implement mixup data augmentation - part of discriminative learning rates
87
+ def mixup_data(x, y, alpha=1.0):
88
+ if alpha > 0:
89
+ lam = np.random.beta(alpha, alpha)
90
+ else:
91
+ lam = 1
92
+ batch_size = x.size()[0]
93
+ index = torch.randperm(batch_size).to(device)
94
+ mixed_x = lam * x + (1 - lam) * x[index, :]
95
+ y_a, y_b = y, y[index]
96
+ return mixed_x, y_a, y_b, lam
97
+
98
+ # A.4. Define training step
99
+ def train_step(model, data_loader, criterion, optimizer):
100
+ model.train()
101
+ total_loss = 0
102
+ correct = 0
103
+
104
+ for inputs, targets in tqdm(data_loader, desc='Training', leave=False):
105
+ inputs, targets = inputs.to(device), targets.to(device)
106
+
107
+ # B.4. Advanced Mixup - part of discriminative learning rates
108
+ inputs, targets_a, targets_b, lam = mixup_data(inputs, targets, CFG['mixup_alpha'])
109
+
110
+ optimizer.zero_grad()
111
+ outputs = model(inputs)
112
+ loss = criterion(outputs, targets_a) * lam + criterion(outputs, targets_b) * (1 - lam)
113
+
114
+ loss.backward()
115
+ nn.utils.clip_grad_norm_(model.parameters(), 1.0) # Gradient clipping
116
+ optimizer.step()
117
+
118
+ total_loss += loss.item()
119
+ _, predicted = outputs.max(1)
120
+ correct += (lam * predicted.eq(targets_a).sum().item() +
121
+ (1 - lam) * predicted.eq(targets_b).sum().item())
122
+
123
+ acc = 100. * correct / len(data_loader.dataset)
124
+ avg_loss = total_loss / len(data_loader)
125
+ return avg_loss, acc
126
+
127
+ # A.3. Define validation step to inspect the DataBlock
128
+ def validate(model, data_loader, criterion):
129
+ model.eval()
130
+ total_loss = 0
131
+ correct = 0
132
+
133
+ with torch.no_grad():
134
+ for inputs, targets in tqdm(data_loader, desc='Validation', leave=False):
135
+ inputs, targets = inputs.to(device), targets.to(device)
136
+ outputs = model(inputs)
137
+ loss = criterion(outputs, targets)
138
+
139
+ total_loss += loss.item()
140
+ _, predicted = outputs.max(1)
141
+ correct += predicted.eq(targets).sum().item()
142
+
143
+ acc = 100. * correct / len(data_loader.dataset)
144
+ avg_loss = total_loss / len(data_loader)
145
+ return avg_loss, acc
146
+
147
+ def main():
148
+ # A.1. Load data
149
+ transforms = create_transforms()
150
+
151
+ # Set directory paths according to your structure
152
+ art_dataset_dir = 'Art Dataset'
153
+
154
+ # A.2.1. Define the blocks (dataset creation)
155
+ train_dataset = ArtDataset(art_dataset_dir, transform=transforms['train'])
156
+ val_dataset = ArtDataset(art_dataset_dir, transform=transforms['val'])
157
+
158
+ # A.2.2. Create data loaders
159
+ train_loader = DataLoader(train_dataset, batch_size=CFG['batch_size'],
160
+ shuffle=True, num_workers=4, pin_memory=True)
161
+ val_loader = DataLoader(val_dataset, batch_size=CFG['batch_size'],
162
+ num_workers=4, pin_memory=True)
163
+
164
+ # B.3. Transfer Learning - Load model
165
+ model_path = 'models/model_final.pth'
166
+
167
+ # Load model state dictionary
168
+ state_dict = torch.load(model_path)
169
+
170
+ # Create ResNet34 model
171
+ from torchvision import models
172
+ model = models.resnet34(weights=None)
173
+
174
+ # Number of classes
175
+ num_classes = len(train_dataset.classes)
176
+
177
+ # B.3. Update the final fully-connected layer
178
+ model.fc = nn.Linear(512, num_classes)
179
+
180
+ # Load state dictionary
181
+ model.load_state_dict(state_dict)
182
+ model = model.to(device)
183
+
184
+ # B.6. Model Capacity - Measures to prevent overfitting
185
+ for name, module in model.named_modules():
186
+ if isinstance(module, nn.Dropout):
187
+ module.p = CFG['dropout'] # Increase dropout rate
188
+
189
+ # B.1. Learning Rate Finder - Optimizer and Loss setup
190
+ optimizer = optim.AdamW(model.parameters(), lr=CFG['lr'],
191
+ weight_decay=CFG['weight_decay'])
192
+ criterion = nn.CrossEntropyLoss(label_smoothing=CFG['label_smoothing'])
193
+ scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer,
194
+ T_0=10, T_mult=2)
195
+
196
+ # Create results directory
197
+ results_dir = 'results'
198
+ os.makedirs(results_dir, exist_ok=True)
199
+
200
+ # B.5. Early Stopping - Deciding the Number of Training Epochs
201
+ best_val_acc = 0
202
+ patience_counter = 0
203
+
204
+ # A.4. Train a simple model
205
+ for epoch in range(CFG['epochs']):
206
+ print(f"\nEpoch {epoch+1}/{CFG['epochs']}")
207
+
208
+ # Training
209
+ train_loss, train_acc = train_step(model, train_loader, criterion, optimizer)
210
+ # Validation
211
+ val_loss, val_acc = validate(model, val_loader, criterion)
212
+
213
+ # Learning rate update
214
+ scheduler.step()
215
+
216
+ # Monitor results
217
+ print(f"Train Loss: {train_loss:.4f} | Acc: {train_acc:.2f}%")
218
+ print(f"Val Loss: {val_loss:.4f} | Acc: {val_acc:.2f}%")
219
+
220
+ # B.5. Early stopping check
221
+ if val_acc > best_val_acc:
222
+ best_val_acc = val_acc
223
+ patience_counter = 0
224
+ best_model_path = os.path.join(results_dir, 'best_model.pth')
225
+ torch.save(model.state_dict(), best_model_path)
226
+ print(f"New best model saved ({val_acc:.2f}%)")
227
+ else:
228
+ patience_counter += 1
229
+ if patience_counter >= CFG['patience']:
230
+ print(f"Early stopping! No improvement for {CFG['patience']} epochs.")
231
+ break
232
+
233
+ if __name__ == "__main__":
234
+ main()
model_evaluator.py ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.utils.data import Dataset, DataLoader, random_split
6
+ from torchvision import models, transforms
7
+ from pathlib import Path
8
+ from PIL import Image
9
+ import matplotlib.pyplot as plt
10
+ import seaborn as sns
11
+ from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, classification_report
12
+ from tqdm import tqdm
13
+ import pandas as pd
14
+ import random
15
+ from collections import defaultdict
16
+
17
+ # MPS (Metal Performance Shaders) check - Apple GPU
18
+ if torch.backends.mps.is_available():
19
+ DEVICE = torch.device("mps")
20
+ print(f"Using Metal GPU: {DEVICE}")
21
+ else:
22
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
+ print(f"Metal GPU not found, using device: {DEVICE}")
24
+
25
+ # Constants
26
+ IMG_SIZE = 224
27
+ BATCH_SIZE = 64 # Batch size increased for GPU
28
+ NUM_WORKERS = 6 # Number of threads increased
29
+ MAX_SAMPLES_PER_CLASS = 30 # Maximum number of samples per class (for quick testing)
30
+
31
+ # Transformation for test dataset
32
+ test_transform = transforms.Compose([
33
+ transforms.Resize(IMG_SIZE + 32),
34
+ transforms.CenterCrop(IMG_SIZE),
35
+ transforms.ToTensor(),
36
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
37
+ ])
38
+
39
+ class ArtDataset(Dataset):
40
+ def __init__(self, samples, transform=None, class_to_idx=None):
41
+ self.samples = samples
42
+ self.transform = transform
43
+
44
+ if class_to_idx is None:
45
+ # Extract classes from samples
46
+ classes = set([Path(str(s[0])).parent.name for s in samples])
47
+ self.classes = sorted(list(classes))
48
+ self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)}
49
+ else:
50
+ self.class_to_idx = class_to_idx
51
+ self.classes = sorted(class_to_idx.keys(), key=lambda x: class_to_idx[x])
52
+
53
+ def __len__(self):
54
+ return len(self.samples)
55
+
56
+ def __getitem__(self, idx):
57
+ img_path, class_name = self.samples[idx]
58
+ label = self.class_to_idx[class_name]
59
+ img = Image.open(img_path).convert('RGB')
60
+ if self.transform:
61
+ img = self.transform(img)
62
+ return img, label
63
+
64
+ def create_test_set(data_dir, test_ratio=0.2, max_per_class=None):
65
+ """Create test set by taking a certain percentage of samples from each class"""
66
+ class_samples = defaultdict(list)
67
+
68
+ # Collect all examples by their classes
69
+ for class_dir in Path(data_dir).iterdir():
70
+ if class_dir.is_dir():
71
+ class_name = class_dir.name
72
+ for img_path in class_dir.glob('*'):
73
+ class_samples[class_name].append((img_path, class_name))
74
+
75
+ # Select a certain percentage and maximum number of examples from each class
76
+ test_samples = []
77
+ for class_name, samples in class_samples.items():
78
+ random.shuffle(samples)
79
+ n_test = max(1, int(len(samples) * test_ratio))
80
+
81
+ # Limit the maximum number of examples
82
+ if max_per_class and n_test > max_per_class:
83
+ n_test = max_per_class
84
+
85
+ test_samples.extend(samples[:n_test])
86
+
87
+ print(f"Total of {len(test_samples)} test samples selected from {len(class_samples)} different art movements.")
88
+
89
+ # Create class-index mapping
90
+ classes = sorted(class_samples.keys())
91
+ class_to_idx = {cls: i for i, cls in enumerate(classes)}
92
+
93
+ return test_samples, class_to_idx
94
+
95
+ def load_model(model_path, num_classes):
96
+ """Load model file"""
97
+ print(f"Loading model: {model_path}")
98
+ # Create ResNet34 model
99
+ model = models.resnet34(weights=None)
100
+ # Update the last fully-connected layer
101
+ model.fc = nn.Linear(512, num_classes)
102
+
103
+ # Special loading for Metal GPU availability check
104
+ state_dict = torch.load(model_path, map_location=DEVICE)
105
+ model.load_state_dict(state_dict)
106
+ model = model.to(DEVICE)
107
+ model.eval()
108
+
109
+ return model
110
+
111
+ def evaluate_model(model, test_loader, classes):
112
+ """Evaluate model and return metrics"""
113
+ all_preds = []
114
+ all_labels = []
115
+
116
+ with torch.no_grad():
117
+ for inputs, labels in tqdm(test_loader, desc="Evaluation"):
118
+ inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
119
+
120
+ # Run directly on MPS device (without using autocast)
121
+ outputs = model(inputs)
122
+
123
+ _, preds = torch.max(outputs, 1)
124
+
125
+ # Move results to CPU
126
+ all_preds.extend(preds.cpu().numpy())
127
+ all_labels.extend(labels.cpu().numpy())
128
+
129
+ # Calculate metrics
130
+ accuracy = accuracy_score(all_labels, all_preds)
131
+ f1 = f1_score(all_labels, all_preds, average='weighted')
132
+ precision = precision_score(all_labels, all_preds, average='weighted')
133
+ recall = recall_score(all_labels, all_preds, average='weighted')
134
+
135
+ # Class-based accuracy
136
+ class_accuracy = {}
137
+ conf_matrix = confusion_matrix(all_labels, all_preds)
138
+
139
+ for i, class_name in enumerate(classes):
140
+ class_samples = np.sum(np.array(all_labels) == i)
141
+ class_correct = conf_matrix[i, i]
142
+ if class_samples > 0:
143
+ class_accuracy[class_name] = class_correct / class_samples
144
+
145
+ results = {
146
+ 'accuracy': accuracy,
147
+ 'f1_score': f1,
148
+ 'precision': precision,
149
+ 'recall': recall,
150
+ 'class_accuracy': class_accuracy,
151
+ 'confusion_matrix': conf_matrix,
152
+ 'predictions': all_preds,
153
+ 'ground_truth': all_labels
154
+ }
155
+
156
+ return results
157
+
158
+ def plot_confusion_matrix(conf_matrix, classes, model_name, save_dir):
159
+ """Plot confusion matrix graph"""
160
+ plt.figure(figsize=(12, 10))
161
+ sns.heatmap(conf_matrix, annot=False, fmt='d', cmap='Blues',
162
+ xticklabels=classes, yticklabels=classes)
163
+ plt.xlabel('Predicted Class')
164
+ plt.ylabel('True Class')
165
+ plt.title(f'Confusion Matrix - {model_name}')
166
+ plt.tight_layout()
167
+
168
+ # Save the graph
169
+ save_path = Path(save_dir) / f"conf_matrix_{Path(model_name).stem}.png"
170
+ plt.savefig(save_path, dpi=300)
171
+ plt.close()
172
+
173
+ def plot_class_accuracy(class_acc, model_name, save_dir):
174
+ """Plot class-based accuracy graph"""
175
+ plt.figure(figsize=(14, 8))
176
+
177
+ # Sort classes by accuracy value
178
+ sorted_items = sorted(class_acc.items(), key=lambda x: x[1], reverse=True)
179
+ classes = [item[0] for item in sorted_items]
180
+ accuracies = [item[1] for item in sorted_items]
181
+
182
+ bars = plt.bar(classes, accuracies)
183
+ plt.xlabel('Art Movement')
184
+ plt.ylabel('Accuracy')
185
+ plt.title(f'Class-Based Accuracy - {model_name}')
186
+ plt.xticks(rotation=90)
187
+ plt.ylim(0, 1.0)
188
+
189
+ # Add values on top of bars
190
+ for bar in bars:
191
+ height = bar.get_height()
192
+ plt.text(bar.get_x() + bar.get_width()/2., height,
193
+ f'{height:.2f}', ha='center', va='bottom', rotation=0)
194
+
195
+ plt.tight_layout()
196
+
197
+ # Save the graph
198
+ save_path = Path(save_dir) / f"class_accuracy_{Path(model_name).stem}.png"
199
+ plt.savefig(save_path, dpi=300)
200
+ plt.close()
201
+
202
+ def plot_model_comparison(all_results, save_dir):
203
+ """Plot model comparison graph"""
204
+ model_names = list(all_results.keys())
205
+ metrics = ['accuracy', 'f1_score', 'precision', 'recall']
206
+
207
+ # Collect metrics
208
+ metric_data = {metric: [all_results[model][metric] for model in model_names] for metric in metrics}
209
+
210
+ # Compare metrics
211
+ plt.figure(figsize=(12, 7))
212
+ x = np.arange(len(model_names))
213
+ width = 0.2
214
+ multiplier = 0
215
+
216
+ for metric, values in metric_data.items():
217
+ offset = width * multiplier
218
+ bars = plt.bar(x + offset, values, width, label=metric)
219
+
220
+ # Add values on top of bars
221
+ for bar in bars:
222
+ height = bar.get_height()
223
+ plt.annotate(f'{height:.3f}',
224
+ xy=(bar.get_x() + bar.get_width() / 2, height),
225
+ xytext=(0, 3), # 3 points vertical offset
226
+ textcoords="offset points",
227
+ ha='center', va='bottom')
228
+
229
+ multiplier += 1
230
+
231
+ plt.xlabel('Model')
232
+ plt.ylabel('Score')
233
+ plt.title('Model Performance Comparison')
234
+ plt.xticks(x + width, model_names)
235
+ plt.legend(loc='lower right')
236
+ plt.ylim(0, 1.0)
237
+
238
+ plt.tight_layout()
239
+
240
+ # Save the graph
241
+ save_path = Path(save_dir) / "model_comparison.png"
242
+ plt.savefig(save_path, dpi=300)
243
+ plt.close()
244
+
245
+ def main():
246
+ # Data directory and results directory
247
+ art_dataset_dir = 'Art Dataset'
248
+ models_dir = 'models'
249
+ results_dir = 'evaluation_results'
250
+
251
+ # Create results directory
252
+ os.makedirs(results_dir, exist_ok=True)
253
+
254
+ # Create test data - limit maximum number of examples from each class
255
+ test_samples, class_to_idx = create_test_set(art_dataset_dir, test_ratio=0.2, max_per_class=MAX_SAMPLES_PER_CLASS)
256
+ test_dataset = ArtDataset(test_samples, transform=test_transform, class_to_idx=class_to_idx)
257
+ test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, pin_memory=True)
258
+
259
+ classes = test_dataset.classes
260
+ num_classes = len(classes)
261
+ print(f"Art classes: {len(classes)}")
262
+
263
+ # Find model files (exclude files like .DS_Store)
264
+ model_paths = [os.path.join(models_dir, f) for f in os.listdir(models_dir)
265
+ if f.endswith('.pth') and not f.startswith('.')]
266
+
267
+ # Dictionary to store results
268
+ all_results = {}
269
+
270
+ # Evaluate each model
271
+ for model_path in model_paths:
272
+ model_name = Path(model_path).name
273
+ print(f"\nEvaluating {model_name}...")
274
+
275
+ # Load model
276
+ model = load_model(model_path, num_classes)
277
+
278
+ # Evaluate model
279
+ results = evaluate_model(model, test_loader, classes)
280
+ all_results[model_name] = results
281
+
282
+ print(f"Accuracy: {results['accuracy']:.4f}")
283
+ print(f"F1 Score: {results['f1_score']:.4f}")
284
+ print(f"Precision: {results['precision']:.4f}")
285
+ print(f"Recall: {results['recall']:.4f}")
286
+
287
+ # Plot confusion matrix graph
288
+ plot_confusion_matrix(results['confusion_matrix'], classes, model_name, results_dir)
289
+
290
+ # Plot class-based accuracy graph
291
+ plot_class_accuracy(results['class_accuracy'], model_name, results_dir)
292
+
293
+ # Save detailed class report
294
+ report = classification_report(results['ground_truth'], results['predictions'],
295
+ target_names=classes, output_dict=True)
296
+ report_df = pd.DataFrame(report).transpose()
297
+ report_df.to_csv(f"{results_dir}/classification_report_{Path(model_name).stem}.csv")
298
+
299
+ # Compare models
300
+ if len(all_results) > 1:
301
+ plot_model_comparison(all_results, results_dir)
302
+
303
+ # Save results to CSV file
304
+ results_summary = []
305
+ for model_name, results in all_results.items():
306
+ row = {
307
+ 'model': model_name,
308
+ 'accuracy': results['accuracy'],
309
+ 'f1_score': results['f1_score'],
310
+ 'precision': results['precision'],
311
+ 'recall': results['recall']
312
+ }
313
+ results_summary.append(row)
314
+
315
+ summary_df = pd.DataFrame(results_summary)
316
+ summary_df.to_csv(f"{results_dir}/model_comparison_summary.csv", index=False)
317
+
318
+ print(f"\nEvaluation completed. Results are in '{results_dir}' directory.")
319
+
320
+ if __name__ == "__main__":
321
+ # Set seed for reproducibility
322
+ random.seed(42)
323
+ np.random.seed(42)
324
+ torch.manual_seed(42)
325
+ main()
model_evaluator_kfold.py ADDED
@@ -0,0 +1,379 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ from torch.utils.data import Dataset, DataLoader, Subset
6
+ from torchvision import models, transforms
7
+ from pathlib import Path
8
+ from PIL import Image
9
+ import matplotlib.pyplot as plt
10
+ import seaborn as sns
11
+ from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, classification_report
12
+ from sklearn.model_selection import KFold
13
+ from tqdm import tqdm
14
+ import pandas as pd
15
+ import random
16
+ from collections import defaultdict
17
+
18
+ # MPS (Metal Performance Shaders) kontrolü - Apple GPU
19
+ if torch.backends.mps.is_available():
20
+ DEVICE = torch.device("mps")
21
+ print(f"Metal GPU kullanılıyor: {DEVICE}")
22
+ else:
23
+ DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
+ print(f"Metal GPU bulunamadı, şu cihaz kullanılıyor: {DEVICE}")
25
+
26
+ # Sabit değerler
27
+ IMG_SIZE = 224
28
+ BATCH_SIZE = 64
29
+ NUM_WORKERS = 6
30
+ MAX_SAMPLES_PER_CLASS = 20 # Her sınıftan maksimum örnek sayısı (hızlı test için)
31
+ K_FOLDS = 5 # 5-fold cross validation
32
+
33
+ # Test veri seti için dönüşüm
34
+ test_transform = transforms.Compose([
35
+ transforms.Resize(IMG_SIZE + 32),
36
+ transforms.CenterCrop(IMG_SIZE),
37
+ transforms.ToTensor(),
38
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
39
+ ])
40
+
41
+ class ArtDataset(Dataset):
42
+ def __init__(self, samples, transform=None, class_to_idx=None):
43
+ self.samples = samples
44
+ self.transform = transform
45
+
46
+ if class_to_idx is None:
47
+ # Sınıfları örneklerden çıkar
48
+ classes = set([Path(str(s[0])).parent.name for s in samples])
49
+ self.classes = sorted(list(classes))
50
+ self.class_to_idx = {cls: i for i, cls in enumerate(self.classes)}
51
+ else:
52
+ self.class_to_idx = class_to_idx
53
+ self.classes = sorted(class_to_idx.keys(), key=lambda x: class_to_idx[x])
54
+
55
+ def __len__(self):
56
+ return len(self.samples)
57
+
58
+ def __getitem__(self, idx):
59
+ img_path, class_name = self.samples[idx]
60
+ label = self.class_to_idx[class_name]
61
+ img = Image.open(img_path).convert('RGB')
62
+ if self.transform:
63
+ img = self.transform(img)
64
+ return img, label
65
+
66
+ def create_balanced_dataset(data_dir, max_per_class=None):
67
+ """Her sınıftan eşit sayıda örnek içeren dengeli bir veri seti oluştur"""
68
+ class_samples = defaultdict(list)
69
+
70
+ # Tüm örnekleri sınıflarına göre topla
71
+ for class_dir in Path(data_dir).iterdir():
72
+ if class_dir.is_dir():
73
+ class_name = class_dir.name
74
+ for img_path in class_dir.glob('*'):
75
+ class_samples[class_name].append((img_path, class_name))
76
+
77
+ # Her sınıftan maksimum sayıda örnek seç
78
+ balanced_samples = []
79
+ for class_name, samples in class_samples.items():
80
+ random.shuffle(samples)
81
+
82
+ # Maksimum örnek sayısını sınırla
83
+ if max_per_class and len(samples) > max_per_class:
84
+ samples = samples[:max_per_class]
85
+
86
+ balanced_samples.extend(samples)
87
+
88
+ print(f"Toplam {len(balanced_samples)} örnek, {len(class_samples)} farklı sanat akımından seçildi.")
89
+
90
+ # Sınıf-indeks eşleştirmesini oluştur
91
+ classes = sorted(class_samples.keys())
92
+ class_to_idx = {cls: i for i, cls in enumerate(classes)}
93
+
94
+ return balanced_samples, class_to_idx
95
+
96
+ def load_model(model_path, num_classes):
97
+ """Model dosyasını yükle"""
98
+ print(f"Model yükleniyor: {model_path}")
99
+ # ResNet34 modelini oluştur
100
+ model = models.resnet34(weights=None)
101
+ # Son fully-connected katmanını güncelle
102
+ model.fc = nn.Linear(512, num_classes)
103
+
104
+ # Metal GPU kullanılabilirliği kontrolü için özel yükleme
105
+ state_dict = torch.load(model_path, map_location=DEVICE)
106
+ model.load_state_dict(state_dict)
107
+ model = model.to(DEVICE)
108
+ model.eval()
109
+
110
+ return model
111
+
112
+ def evaluate_model(model, test_loader, classes):
113
+ """Modeli değerlendir ve metrikleri döndür"""
114
+ all_preds = []
115
+ all_labels = []
116
+
117
+ with torch.no_grad():
118
+ for inputs, labels in tqdm(test_loader, desc="Değerlendirme", leave=False):
119
+ inputs, labels = inputs.to(DEVICE), labels.to(DEVICE)
120
+
121
+ # MPS cihazında çalıştır
122
+ outputs = model(inputs)
123
+
124
+ _, preds = torch.max(outputs, 1)
125
+
126
+ # Sonuçları CPU'ya taşı
127
+ all_preds.extend(preds.cpu().numpy())
128
+ all_labels.extend(labels.cpu().numpy())
129
+
130
+ # Temel metrikleri hesapla - uyarıları engellemek için zero_division=1 parametresi eklendi
131
+ accuracy = accuracy_score(all_labels, all_preds)
132
+ f1 = f1_score(all_labels, all_preds, average='weighted', zero_division=1)
133
+ precision = precision_score(all_labels, all_preds, average='weighted', zero_division=1)
134
+ recall = recall_score(all_labels, all_preds, average='weighted', zero_division=1)
135
+
136
+ # Sınıf bazında doğruluk
137
+ class_accuracy = {}
138
+ conf_matrix = confusion_matrix(all_labels, all_preds)
139
+
140
+ for i, class_name in enumerate(classes):
141
+ class_samples = np.sum(np.array(all_labels) == i)
142
+ class_correct = conf_matrix[i, i] if i < len(conf_matrix) else 0
143
+ if class_samples > 0:
144
+ class_accuracy[class_name] = class_correct / class_samples
145
+
146
+ results = {
147
+ 'accuracy': accuracy,
148
+ 'f1_score': f1,
149
+ 'precision': precision,
150
+ 'recall': recall,
151
+ 'class_accuracy': class_accuracy,
152
+ 'confusion_matrix': conf_matrix,
153
+ 'predictions': all_preds,
154
+ 'ground_truth': all_labels
155
+ }
156
+
157
+ return results
158
+
159
+ def k_fold_cross_validation(dataset, model_paths, num_classes, k=5):
160
+ """K-fold cross validation ile modelleri değerlendir"""
161
+
162
+ # K-fold nesnesi oluştur
163
+ kfold = KFold(n_splits=k, shuffle=True, random_state=42)
164
+
165
+ # Her model için sonuçları sakla
166
+ all_model_results = {}
167
+ for model_path in model_paths:
168
+ model_name = Path(model_path).name
169
+ all_model_results[model_name] = {
170
+ 'fold_results': [],
171
+ 'accuracy': [],
172
+ 'f1_score': [],
173
+ 'precision': [],
174
+ 'recall': []
175
+ }
176
+
177
+ # K-fold cross validation
178
+ for fold, (_, test_indices) in enumerate(kfold.split(dataset)):
179
+ print(f"\nFold {fold+1}/{k} değerlendiriliyor...")
180
+
181
+ # Test veri setini oluştur
182
+ test_subset = Subset(dataset, test_indices)
183
+ test_loader = DataLoader(test_subset, batch_size=BATCH_SIZE, num_workers=NUM_WORKERS, pin_memory=True)
184
+
185
+ # Her model için değerlendirme yap
186
+ for model_path in model_paths:
187
+ model_name = Path(model_path).name
188
+ print(f" {model_name} değerlendiriliyor...")
189
+
190
+ # Modeli yükle
191
+ model = load_model(model_path, num_classes)
192
+
193
+ # Modeli değerlendir
194
+ results = evaluate_model(model, test_loader, dataset.classes)
195
+
196
+ # Sonuçları kaydet
197
+ all_model_results[model_name]['fold_results'].append(results)
198
+ all_model_results[model_name]['accuracy'].append(results['accuracy'])
199
+ all_model_results[model_name]['f1_score'].append(results['f1_score'])
200
+ all_model_results[model_name]['precision'].append(results['precision'])
201
+ all_model_results[model_name]['recall'].append(results['recall'])
202
+
203
+ print(f" Fold {fold+1} - Doğruluk: {results['accuracy']:.4f}, F1: {results['f1_score']:.4f}")
204
+
205
+ # Her model için ortalama sonuçları hesapla
206
+ summary_results = {}
207
+ for model_name, results in all_model_results.items():
208
+ summary_results[model_name] = {
209
+ 'mean_accuracy': np.mean(results['accuracy']),
210
+ 'std_accuracy': np.std(results['accuracy']),
211
+ 'mean_f1': np.mean(results['f1_score']),
212
+ 'std_f1': np.std(results['f1_score']),
213
+ 'mean_precision': np.mean(results['precision']),
214
+ 'std_precision': np.std(results['precision']),
215
+ 'mean_recall': np.mean(results['recall']),
216
+ 'std_recall': np.std(results['recall']),
217
+ 'fold_accuracy': results['accuracy'],
218
+ 'fold_f1': results['f1_score']
219
+ }
220
+
221
+ return summary_results
222
+
223
+ def plot_kfold_results(summary_results, save_dir):
224
+ """K-fold cross validation sonuçlarını gösteren grafikler oluştur"""
225
+
226
+ # Accuracy ve F1 için ortalama değerleri çiz
227
+ plt.figure(figsize=(14, 7))
228
+
229
+ # Model isimlerini ve ortalama değerleri çıkart
230
+ model_names = list(summary_results.keys())
231
+ model_names = [Path(name).stem for name in model_names] # .pth uzantısını kaldır
232
+
233
+ # Doğruluk ve F1 skorları
234
+ mean_accuracy = [summary_results[model]['mean_accuracy'] for model in summary_results]
235
+ std_accuracy = [summary_results[model]['std_accuracy'] for model in summary_results]
236
+ mean_f1 = [summary_results[model]['mean_f1'] for model in summary_results]
237
+ std_f1 = [summary_results[model]['std_f1'] for model in summary_results]
238
+
239
+ # X ekseni konumları
240
+ x = np.arange(len(model_names))
241
+ width = 0.35
242
+
243
+ # Çubuk grafikleri
244
+ fig, ax = plt.subplots(figsize=(12, 8))
245
+ rects1 = ax.bar(x - width/2, mean_accuracy, width, yerr=std_accuracy,
246
+ label='Accuracy', capsize=5, color='cornflowerblue')
247
+ rects2 = ax.bar(x + width/2, mean_f1, width, yerr=std_f1,
248
+ label='F1 Score', capsize=5, color='lightcoral')
249
+
250
+ # Grafik özellikleri
251
+ ax.set_ylabel('Skor')
252
+ ax.set_title('5-Fold Cross Validation Ortalama Performans (Ortalama ± Std)')
253
+ ax.set_xticks(x)
254
+ ax.set_xticklabels(model_names)
255
+ ax.legend()
256
+ ax.set_ylim(0, 1.0)
257
+
258
+ # Çubukların üstüne değerleri ekle
259
+ def add_labels(rects):
260
+ for rect in rects:
261
+ height = rect.get_height()
262
+ ax.annotate(f'{height:.3f}',
263
+ xy=(rect.get_x() + rect.get_width() / 2, height),
264
+ xytext=(0, 3), # 3 points vertical offset
265
+ textcoords="offset points",
266
+ ha='center', va='bottom')
267
+
268
+ add_labels(rects1)
269
+ add_labels(rects2)
270
+
271
+ plt.tight_layout()
272
+
273
+ # Grafiği kaydet
274
+ save_path = Path(save_dir) / "kfold_mean_performance.png"
275
+ plt.savefig(save_path, dpi=300)
276
+ plt.close()
277
+
278
+ # Her bir fold için performansı çiz
279
+ plt.figure(figsize=(18, 12))
280
+
281
+ # Accuracy için
282
+ plt.subplot(2, 1, 1)
283
+ for model_name in summary_results:
284
+ model_stem = Path(model_name).stem
285
+ plt.plot(range(1, K_FOLDS + 1), summary_results[model_name]['fold_accuracy'],
286
+ marker='o', linestyle='-', label=model_stem)
287
+
288
+ plt.title('Her Fold için Accuracy Değerleri')
289
+ plt.xlabel('Fold')
290
+ plt.ylabel('Accuracy')
291
+ plt.xticks(range(1, K_FOLDS + 1))
292
+ plt.ylim(0, 1.0)
293
+ plt.grid(True, linestyle='--', alpha=0.7)
294
+ plt.legend()
295
+
296
+ # F1 Skor için
297
+ plt.subplot(2, 1, 2)
298
+ for model_name in summary_results:
299
+ model_stem = Path(model_name).stem
300
+ plt.plot(range(1, K_FOLDS + 1), summary_results[model_name]['fold_f1'],
301
+ marker='o', linestyle='-', label=model_stem)
302
+
303
+ plt.title('Her Fold için F1 Değerleri')
304
+ plt.xlabel('Fold')
305
+ plt.ylabel('F1 Score')
306
+ plt.xticks(range(1, K_FOLDS + 1))
307
+ plt.ylim(0, 1.0)
308
+ plt.grid(True, linestyle='--', alpha=0.7)
309
+ plt.legend()
310
+
311
+ plt.tight_layout()
312
+
313
+ # Grafiği kaydet
314
+ save_path = Path(save_dir) / "kfold_all_folds_performance.png"
315
+ plt.savefig(save_path, dpi=300)
316
+ plt.close()
317
+
318
+ def main():
319
+ # Veri dizini ve sonuç dizini
320
+ art_dataset_dir = 'Art Dataset'
321
+ models_dir = 'models'
322
+ results_dir = 'kfold_evaluation_results'
323
+
324
+ # Sonuç dizinini oluştur
325
+ os.makedirs(results_dir, exist_ok=True)
326
+
327
+ # Dengeli veri setini oluştur - her sınıftan maksimum örnek sayısını sınırla
328
+ samples, class_to_idx = create_balanced_dataset(art_dataset_dir, max_per_class=MAX_SAMPLES_PER_CLASS)
329
+ dataset = ArtDataset(samples, transform=test_transform, class_to_idx=class_to_idx)
330
+
331
+ num_classes = len(dataset.classes)
332
+ print(f"Sanat sınıfları: {len(dataset.classes)}")
333
+
334
+ # Model dosyalarını bul (.DS_Store gibi dosyaları hariç tut)
335
+ model_paths = [os.path.join(models_dir, f) for f in os.listdir(models_dir)
336
+ if f.endswith('.pth') and not f.startswith('.')]
337
+
338
+ # K-fold cross validation ile modelleri değerlendir
339
+ summary_results = k_fold_cross_validation(dataset, model_paths, num_classes, k=K_FOLDS)
340
+
341
+ # Sonuçları görselleştir
342
+ plot_kfold_results(summary_results, results_dir)
343
+
344
+ # Sonuçları yazdır
345
+ print("\n5-Fold Cross Validation Sonuçları:")
346
+ for model_name, results in summary_results.items():
347
+ print(f"\n{model_name}:")
348
+ print(f" Ortalama Accuracy: {results['mean_accuracy']:.4f} ± {results['std_accuracy']:.4f}")
349
+ print(f" Ortalama F1 Score: {results['mean_f1']:.4f} ± {results['std_f1']:.4f}")
350
+ print(f" Ortalama Precision: {results['mean_precision']:.4f} ± {results['std_precision']:.4f}")
351
+ print(f" Ortalama Recall: {results['mean_recall']:.4f} ± {results['std_recall']:.4f}")
352
+
353
+ # Sonuçları CSV dosyasına kaydet
354
+ results_summary = []
355
+ for model_name, results in summary_results.items():
356
+ row = {
357
+ 'model': model_name,
358
+ 'mean_accuracy': results['mean_accuracy'],
359
+ 'std_accuracy': results['std_accuracy'],
360
+ 'mean_f1': results['mean_f1'],
361
+ 'std_f1': results['std_f1'],
362
+ 'mean_precision': results['mean_precision'],
363
+ 'std_precision': results['std_precision'],
364
+ 'mean_recall': results['mean_recall'],
365
+ 'std_recall': results['std_recall']
366
+ }
367
+ results_summary.append(row)
368
+
369
+ summary_df = pd.DataFrame(results_summary)
370
+ summary_df.to_csv(f"{results_dir}/kfold_model_comparison_summary.csv", index=False)
371
+
372
+ print(f"\nDeğerlendirme tamamlandı. Sonuçlar '{results_dir}' dizininde.")
373
+
374
+ if __name__ == "__main__":
375
+ # Tekrar üretilebilirlik için seed ayarla
376
+ random.seed(42)
377
+ np.random.seed(42)
378
+ torch.manual_seed(42)
379
+ main()
trainer.py ADDED
@@ -0,0 +1,556 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ import os
5
+ import random
6
+ import numpy as np
7
+ import pandas as pd
8
+ import matplotlib.pyplot as plt
9
+ import time
10
+ from tqdm.auto import tqdm
11
+ from pathlib import Path
12
+ from collections import Counter
13
+ from PIL import Image
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ from torch.utils.data import DataLoader, Dataset
18
+ import torchvision
19
+ import torchvision.transforms as T
20
+ from torchvision.datasets import ImageFolder
21
+ from torchvision.models import resnet34, ResNet34_Weights
22
+
23
+ # A.1. Enable CPU fallback for MPS device
24
+ os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
25
+
26
+ # Enable MPS optimizations for PyTorch 2.2+
27
+ if hasattr(torch.backends.mps, 'enable_workflow_compiling'):
28
+ print("Enabling MPS workflow compiling...")
29
+ torch.backends.mps.enable_workflow_compiling = True
30
+
31
+ # A.1. Check Metal 3 / MPS support
32
+ def setup_device():
33
+ """Checks Metal 3 / MPS support and returns appropriate device"""
34
+ print("PyTorch version:", torch.__version__)
35
+
36
+ if torch.backends.mps.is_available() and torch.backends.mps.is_built():
37
+ print("Metal Performance Shaders (MPS) available.")
38
+ print("PYTORCH_ENABLE_MPS_FALLBACK=1 set - CPU will be used for unsupported operations.")
39
+ device = torch.device("mps")
40
+
41
+ # Force GPU usage
42
+ dummy_tensor = torch.ones(1, device=device)
43
+ result = dummy_tensor + 1
44
+ is_mps_working = (result.device.type == 'mps')
45
+
46
+ if is_mps_working:
47
+ print(f"MPS successfully tested: {result}")
48
+ print(f"Training device: {device}")
49
+ return device
50
+ else:
51
+ print("MPS is available but simple operation failed, using CPU.")
52
+ return torch.device("cpu")
53
+ else:
54
+ print("MPS not available, using CPU.")
55
+ device = torch.device("cpu")
56
+ print(f"Training device: {device}")
57
+ return device
58
+
59
+ # A.1.1. Dataset analysis
60
+ def analyze_dataset(data_path):
61
+ """Analyzes the dataset and calculates the number of samples per class"""
62
+ data_path = Path(data_path)
63
+ classes = [d.name for d in data_path.iterdir() if d.is_dir()]
64
+ class_counts = {}
65
+
66
+ # Calculate the number of samples in each class
67
+ for cls in tqdm(classes, desc="Analyzing classes"):
68
+ class_path = data_path / cls
69
+ class_counts[cls] = len(list(class_path.glob('*.jpg')))
70
+
71
+ # Display results
72
+ df = pd.DataFrame({'Class': list(class_counts.keys()),
73
+ 'Number of Samples': list(class_counts.values())})
74
+ df = df.sort_values('Number of Samples', ascending=False).reset_index(drop=True)
75
+
76
+ # Calculate statistics
77
+ total_samples = df['Number of Samples'].sum()
78
+ mean_samples = df['Number of Samples'].mean()
79
+ min_samples = df['Number of Samples'].min()
80
+ max_samples = df['Number of Samples'].max()
81
+
82
+ print(f"Total number of samples: {total_samples}")
83
+ print(f"Average number of samples: {mean_samples:.1f}")
84
+ print(f"Minimum number of samples: {min_samples} ({df.iloc[-1]['Class']})")
85
+ print(f"Maximum number of samples: {max_samples} ({df.iloc[0]['Class']})")
86
+
87
+ # Visualize class distribution
88
+ plt.figure(figsize=(14, 8))
89
+ plt.bar(df['Class'], df['Number of Samples'])
90
+ plt.xticks(rotation=90)
91
+ plt.title('Art Styles - Sample Distribution')
92
+ plt.xlabel('Class')
93
+ plt.ylabel('Number of Samples')
94
+ plt.tight_layout()
95
+ plt.savefig('results/class_distribution.png')
96
+ plt.close()
97
+
98
+ return df, classes
99
+
100
+ # A.2.2. Custom dataset class - Performs data augmentation on CPU
101
+ class ArtStyleDataset(Dataset):
102
+ def __init__(self, root_dir, transform=None, target_transform=None, train=True, valid_pct=0.2, seed=42):
103
+ self.root_dir = Path(root_dir)
104
+ self.transform = transform
105
+ self.target_transform = target_transform
106
+ self.train = train
107
+
108
+ # Get all images and labels
109
+ all_imgs = []
110
+ class_names = [d.name for d in self.root_dir.iterdir() if d.is_dir()]
111
+ self.class_to_idx = {cls_name: i for i, cls_name in enumerate(sorted(class_names))}
112
+
113
+ # Collect images and labels for each class
114
+ for cls_name in class_names:
115
+ cls_path = self.root_dir / cls_name
116
+ cls_idx = self.class_to_idx[cls_name]
117
+ for img_path in cls_path.glob('*.jpg'):
118
+ all_imgs.append((str(img_path), cls_idx))
119
+
120
+ # Shuffle data
121
+ random.seed(seed)
122
+ random.shuffle(all_imgs)
123
+
124
+ # Split into training and validation sets
125
+ n_valid = int(len(all_imgs) * valid_pct)
126
+ if train:
127
+ self.imgs = all_imgs[n_valid:]
128
+ else:
129
+ self.imgs = all_imgs[:n_valid]
130
+
131
+ self.classes = sorted(class_names)
132
+
133
+ def __len__(self):
134
+ return len(self.imgs)
135
+
136
+ def __getitem__(self, idx):
137
+ img_path, label = self.imgs[idx]
138
+ img = Image.open(img_path).convert('RGB')
139
+
140
+ if self.transform:
141
+ img = self.transform(img)
142
+
143
+ if self.target_transform:
144
+ label = self.target_transform(label)
145
+
146
+ return img, label
147
+
148
+ # A.2. Creating DataLoaders using PyTorch native structures
149
+ def create_dataloaders(data_path, batch_size=32, img_size=224, augment=True,
150
+ balance_method='weighted', valid_pct=0.2, seed=42):
151
+ """Creates PyTorch DataLoaders"""
152
+
153
+ # A.2.4. Define data transformations
154
+ # Transformations to run on CPU
155
+ if augment:
156
+ # A word on presizing:
157
+ # 1. Increase the size (item by item) - done by RandomResizedCrop
158
+ # 2. Apply augmentation (batch by batch) - done by various transforms
159
+ # 3. Decrease the size (batch by batch) - handled by normalization
160
+ # 4. Presizing avoids artifacts when applying augmentations (e.g., rotation)
161
+ train_transforms = T.Compose([
162
+ T.RandomResizedCrop(img_size, scale=(0.8, 1.0)), # Increase size item by item
163
+ T.RandomHorizontalFlip(),
164
+ T.RandomRotation(10), # Apply augmentation batch by batch
165
+ T.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
166
+ T.ToTensor(),
167
+ T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Decrease size batch by batch
168
+ ])
169
+ else:
170
+ train_transforms = T.Compose([
171
+ T.Resize(int(img_size*1.14)),
172
+ T.CenterCrop(img_size),
173
+ T.ToTensor(),
174
+ T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
175
+ ])
176
+
177
+ valid_transforms = T.Compose([
178
+ T.Resize(int(img_size*1.14)),
179
+ T.CenterCrop(img_size),
180
+ T.ToTensor(),
181
+ T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
182
+ ])
183
+
184
+ # A.2.1. Define the blocks (dataset creation)
185
+ train_dataset = ArtStyleDataset(data_path, transform=train_transforms, train=True, valid_pct=valid_pct, seed=seed)
186
+ valid_dataset = ArtStyleDataset(data_path, transform=valid_transforms, train=False, valid_pct=valid_pct, seed=seed)
187
+
188
+ # A.2.2. Define the means of getting data into DataBlock
189
+ # Calculate weights for weighted sampling
190
+ if balance_method == 'weighted' and train_dataset:
191
+ # Count classes
192
+ class_counts = Counter([label for _, label in train_dataset.imgs])
193
+ total = sum(class_counts.values())
194
+
195
+ # Calculate weights (classes with fewer examples will get higher weights)
196
+ weights = [total / class_counts[train_dataset.imgs[i][1]] for i in range(len(train_dataset))]
197
+ sampler = torch.utils.data.WeightedRandomSampler(weights, len(weights))
198
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=sampler, num_workers=2, pin_memory=True)
199
+ else:
200
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)
201
+
202
+ valid_loader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)
203
+
204
+ class_names = train_dataset.classes
205
+
206
+ # Display data loader summary
207
+ print(f"Training dataset: {len(train_dataset)} images")
208
+ print(f"Validation dataset: {len(valid_dataset)} images")
209
+ print(f"Classes: {len(class_names)}")
210
+
211
+ # Return the data loaders
212
+ return train_loader, valid_loader, class_names
213
+
214
+ # PyTorch native training loop
215
+ def train_epoch(model, dataloader, criterion, optimizer, device):
216
+ model.train()
217
+ running_loss = 0.0
218
+ correct = 0
219
+ total = 0
220
+ batch_times = []
221
+
222
+ # Show progress with tqdm
223
+ progress_bar = tqdm(dataloader, desc="Training", leave=False)
224
+
225
+ # Monitor MPS memory usage
226
+ if device.type == 'mps':
227
+ print(f"MPS memory usage (start): {torch.mps.current_allocated_memory() / 1024**2:.2f} MB")
228
+
229
+ start_time = time.time()
230
+ for inputs, labels in progress_bar:
231
+ batch_start = time.time()
232
+
233
+ # Move data to device
234
+ inputs, labels = inputs.to(device), labels.to(device)
235
+
236
+ # Verify training device
237
+ if total == 0:
238
+ print(f"Training tensor device: {inputs.device}, Model device: {next(model.parameters()).device}")
239
+
240
+ # Zero gradients
241
+ optimizer.zero_grad()
242
+
243
+ # Forward pass
244
+ outputs = model(inputs)
245
+ loss = criterion(outputs, labels)
246
+
247
+ # Backward propagation
248
+ loss.backward()
249
+ optimizer.step()
250
+
251
+ # Measure processing time
252
+ batch_end = time.time()
253
+ batch_time = batch_end - batch_start
254
+ batch_times.append(batch_time)
255
+
256
+ # Update statistics
257
+ running_loss += loss.item() * inputs.size(0)
258
+ _, predicted = outputs.max(1)
259
+ total += labels.size(0)
260
+ correct += predicted.eq(labels).sum().item()
261
+
262
+ # Update progress bar
263
+ progress_bar.set_postfix({'loss': loss.item(), 'acc': 100 * correct / total})
264
+
265
+ # Calculate final statistics
266
+ avg_loss = running_loss / len(dataloader.dataset)
267
+ avg_acc = 100 * correct / total
268
+ avg_time = sum(batch_times) / len(batch_times)
269
+ total_time = time.time() - start_time
270
+
271
+ # Monitoring memory usage
272
+ if device.type == 'mps':
273
+ print(f"MPS memory usage (end): {torch.mps.current_allocated_memory() / 1024**2:.2f} MB")
274
+
275
+ # Print statistics
276
+ print(f"Training - Loss: {avg_loss:.4f}, Acc: {avg_acc:.2f}%, Time: {total_time:.1f}s, Avg batch: {avg_time:.3f}s")
277
+
278
+ return avg_loss, avg_acc
279
+
280
+ # A.3. Inspect the DataBlock via dataloader
281
+ def validate_epoch(model, dataloader, criterion, device):
282
+ # Set model to evaluation mode
283
+ model.eval()
284
+ running_loss = 0.0
285
+ correct = 0
286
+ total = 0
287
+
288
+ # Disable gradient calculation
289
+ with torch.no_grad():
290
+ progress_bar = tqdm(dataloader, desc="Validation", leave=False)
291
+
292
+ for inputs, labels in progress_bar:
293
+ # Move data to device
294
+ inputs, labels = inputs.to(device), labels.to(device)
295
+
296
+ # Forward pass
297
+ outputs = model(inputs)
298
+ loss = criterion(outputs, labels)
299
+
300
+ # Update statistics
301
+ running_loss += loss.item() * inputs.size(0)
302
+ _, predicted = outputs.max(1)
303
+ total += labels.size(0)
304
+ correct += predicted.eq(labels).sum().item()
305
+
306
+ # Update progress bar
307
+ progress_bar.set_postfix({'loss': loss.item(), 'acc': 100 * correct / total})
308
+
309
+ # Calculate final statistics
310
+ avg_loss = running_loss / len(dataloader.dataset)
311
+ avg_acc = 100 * correct / total
312
+
313
+ # Print statistics
314
+ print(f"Validation - Loss: {avg_loss:.4f}, Acc: {avg_acc:.2f}%")
315
+
316
+ return avg_loss, avg_acc
317
+
318
+ # A.4. Train a simple model
319
+ def train_model(train_loader, valid_loader, class_names, device,
320
+ model_name="resnet34", lr=1e-3, epochs=10,
321
+ freeze_epochs=3, unfreeze_epochs=7):
322
+ """Trains a model using transfer learning with discriminative learning rates"""
323
+ print(f"\nTraining {model_name} model for {epochs} epochs (freeze: {freeze_epochs}, unfreeze: {unfreeze_epochs})")
324
+
325
+ # B.3. Transfer Learning setup
326
+ # Create ResNet34 model with pretrained weights
327
+ if model_name == "resnet34":
328
+ model = resnet34(weights=ResNet34_Weights.DEFAULT)
329
+
330
+ # Replace the final layer with a new one for our classes
331
+ num_classes = len(class_names)
332
+ model.fc = nn.Linear(512, num_classes)
333
+ else:
334
+ raise ValueError(f"Unsupported model: {model_name}")
335
+
336
+ # Move model to device
337
+ model = model.to(device)
338
+
339
+ # B.3. Freeze all weights except the final layer
340
+ for param in model.parameters():
341
+ param.requires_grad = False
342
+ for param in model.fc.parameters():
343
+ param.requires_grad = True
344
+
345
+ # Set up loss function
346
+ criterion = nn.CrossEntropyLoss()
347
+
348
+ # Training history for plotting
349
+ history = {
350
+ 'train_loss': [],
351
+ 'train_acc': [],
352
+ 'val_loss': [],
353
+ 'val_acc': []
354
+ }
355
+
356
+ # Training in two phases: first frozen, then unfrozen
357
+ total_start_time = time.time()
358
+
359
+ # Phase 1: Train with frozen layers
360
+ if freeze_epochs > 0:
361
+ print("\n=== Phase 1: Training with frozen feature extractor ===")
362
+ optimizer = torch.optim.Adam(model.fc.parameters(), lr=lr)
363
+
364
+ for epoch in range(freeze_epochs):
365
+ print(f"\nEpoch {epoch+1}/{freeze_epochs}")
366
+ train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
367
+ val_loss, val_acc = validate_epoch(model, valid_loader, criterion, device)
368
+
369
+ # Record history
370
+ history['train_loss'].append(train_loss)
371
+ history['train_acc'].append(train_acc)
372
+ history['val_loss'].append(val_loss)
373
+ history['val_acc'].append(val_acc)
374
+
375
+ # Phase 2: Unfreeze and train with discriminative learning rates
376
+ if unfreeze_epochs > 0:
377
+ print("\n=== Phase 2: Fine-tuning with discriminative learning rates ===")
378
+
379
+ # B.3. Unfreeze all weights for fine-tuning
380
+ for param in model.parameters():
381
+ param.requires_grad = True
382
+
383
+ # B.4. Discriminative learning rates
384
+ # Group parameters by layer to apply different learning rates
385
+ # Earlier layers get smaller learning rates (already well-trained)
386
+ # Later layers get higher learning rates (need more adaptation)
387
+ layer_params = [
388
+ {'params': model.layer1.parameters(), 'lr': lr/9}, # Earlier layers - smaller learning rate
389
+ {'params': model.layer2.parameters(), 'lr': lr/3},
390
+ {'params': model.layer3.parameters(), 'lr': lr/3},
391
+ {'params': model.layer4.parameters(), 'lr': lr}, # Later layers - higher learning rate
392
+ {'params': model.fc.parameters(), 'lr': lr*3} # New classification layer - highest learning rate
393
+ ]
394
+
395
+ optimizer = torch.optim.Adam(layer_params, lr=lr)
396
+ scheduler = torch.optim.lr_scheduler.OneCycleLR(
397
+ optimizer, max_lr=lr*3, total_steps=unfreeze_epochs * len(train_loader)
398
+ )
399
+
400
+ for epoch in range(unfreeze_epochs):
401
+ print(f"\nEpoch {freeze_epochs+epoch+1}/{epochs}")
402
+ train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
403
+ val_loss, val_acc = validate_epoch(model, valid_loader, criterion, device)
404
+
405
+ # Record history
406
+ history['train_loss'].append(train_loss)
407
+ history['train_acc'].append(train_acc)
408
+ history['val_loss'].append(val_loss)
409
+ history['val_acc'].append(val_acc)
410
+
411
+ total_time = time.time() - total_start_time
412
+ print(f"\nTotal training time: {total_time:.1f} seconds ({total_time/60:.1f} minutes)")
413
+
414
+ # Save model
415
+ os.makedirs('models', exist_ok=True)
416
+ torch.save(model.state_dict(), f'models/model_final.pth')
417
+ print(f"Model saved to models/model_final.pth")
418
+
419
+ # A.4.2. Visualize training history
420
+ plt.figure(figsize=(12, 5))
421
+ plt.subplot(1, 2, 1)
422
+ plt.plot(history['train_loss'], label='Train')
423
+ plt.plot(history['val_loss'], label='Validation')
424
+ plt.title('Loss')
425
+ plt.xlabel('Epoch')
426
+ plt.legend()
427
+
428
+ plt.subplot(1, 2, 2)
429
+ plt.plot(history['train_acc'], label='Train')
430
+ plt.plot(history['val_acc'], label='Validation')
431
+ plt.title('Accuracy')
432
+ plt.xlabel('Epoch')
433
+ plt.legend()
434
+
435
+ plt.tight_layout()
436
+ plt.savefig('results/training_history.png')
437
+ plt.close()
438
+
439
+ # A.4.3. Create confusion matrix
440
+ model.eval()
441
+ all_preds = []
442
+ all_labels = []
443
+
444
+ with torch.no_grad():
445
+ for inputs, labels in tqdm(valid_loader, desc="Creating confusion matrix"):
446
+ inputs, labels = inputs.to(device), labels.to(device)
447
+ outputs = model(inputs)
448
+ _, preds = outputs.max(1)
449
+
450
+ all_preds.extend(preds.cpu().numpy())
451
+ all_labels.extend(labels.cpu().numpy())
452
+
453
+ # Create and plot confusion matrix
454
+ from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay
455
+ cm = confusion_matrix(all_labels, all_preds)
456
+
457
+ plt.figure(figsize=(20, 20))
458
+ disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=class_names)
459
+ disp.plot(cmap='Blues', values_format='d')
460
+ plt.title('Confusion Matrix')
461
+ plt.xticks(rotation=90)
462
+ plt.tight_layout()
463
+ plt.savefig('results/confusion_matrix.png')
464
+ plt.close()
465
+
466
+ return model, history
467
+
468
+ def main():
469
+ # Setup environment
470
+ device = setup_device()
471
+
472
+ # A.1. Download and analyze the data
473
+ data_path = "Art Dataset"
474
+ os.makedirs('results', exist_ok=True)
475
+
476
+ # A.1.1. Inspect the data layout
477
+ print("\n===== A.1.1. Inspecting data layout =====")
478
+ df, classes = analyze_dataset(data_path)
479
+
480
+ # A.2. Create the DataBlock and dataloaders
481
+ print("\n===== A.2. Creating DataLoaders =====")
482
+ train_loader, valid_loader, class_names = create_dataloaders(
483
+ data_path, batch_size=32, img_size=224, augment=True,
484
+ balance_method='weighted', valid_pct=0.2
485
+ )
486
+
487
+ # A.3. Inspect the DataBlock via dataloader
488
+ print("\n===== A.3. Inspecting DataBlock =====")
489
+
490
+ # A.3.1. Show batch
491
+ def visualize_batch(dataloader, num_images=16):
492
+ """Display a batch of images from the dataloader"""
493
+ # Get a batch
494
+ images, labels = next(iter(dataloader))
495
+ images = images[:num_images]
496
+ labels = labels[:num_images]
497
+
498
+ # Convert tensors back to images
499
+ # (unnormalize first)
500
+ mean = torch.tensor([0.485, 0.456, 0.406])
501
+ std = torch.tensor([0.229, 0.224, 0.225])
502
+
503
+ # Create a grid of images
504
+ fig, axes = plt.subplots(nrows=4, ncols=4, figsize=(12, 12))
505
+ for i, (img, label) in enumerate(zip(images, labels)):
506
+ # Unnormalize
507
+ img = img.cpu() * std[:, None, None] + mean[:, None, None]
508
+ # Convert to numpy
509
+ img = img.permute(1, 2, 0).numpy()
510
+ # Clip values to valid range
511
+ img = np.clip(img, 0, 1)
512
+
513
+ # Get class name
514
+ class_name = class_names[label]
515
+ class_name = class_name.replace('_', ' ')
516
+
517
+ # Plot
518
+ row, col = i // 4, i % 4
519
+ axes[row, col].imshow(img)
520
+ axes[row, col].set_title(class_name)
521
+ axes[row, col].axis('off')
522
+
523
+ plt.tight_layout()
524
+ plt.savefig('results/batch_preview.png')
525
+ plt.close()
526
+ print("Batch preview saved to results/batch_preview.png")
527
+
528
+ # A.3.1. Show batch: dataloader.show_batch()
529
+ print("\n===== A.3.1. Showing batch =====")
530
+ visualize_batch(train_loader)
531
+
532
+ # A.3.2. Check the labels
533
+ print("\n===== A.3.2. Checking labels =====")
534
+ print(f"Class names: {class_names}")
535
+
536
+ # A.3.3. Summarize the DataBlock
537
+ print("\n===== A.3.3. Summarizing DataBlock =====")
538
+ print(f"Number of classes: {len(class_names)}")
539
+ print(f"Training batches: {len(train_loader)}")
540
+ print(f"Validation batches: {len(valid_loader)}")
541
+ print(f"Batch size: {train_loader.batch_size}")
542
+ print(f"Total training samples: {len(train_loader.dataset)}")
543
+ print(f"Total validation samples: {len(valid_loader.dataset)}")
544
+
545
+ # A.4. Train a simple model
546
+ print("\n===== A.4. Training a simple model =====")
547
+ model, history = train_model(
548
+ train_loader, valid_loader, class_names, device,
549
+ model_name="resnet34", lr=1e-3,
550
+ epochs=10, freeze_epochs=3, unfreeze_epochs=7
551
+ )
552
+
553
+ print("\nTraining complete!")
554
+
555
+ if __name__ == "__main__":
556
+ main()