cahsu commited on
Commit
fa50b6c
·
verified ·
1 Parent(s): c11dca9

Upload 5 files

Browse files
Files changed (5) hide show
  1. README.md +164 -20
  2. config.json +46 -0
  3. model.py +65 -0
  4. requirements.txt +9 -0
  5. train.py +302 -0
README.md CHANGED
@@ -5,49 +5,193 @@ tags:
5
  - ophthalmology
6
  - image-classification
7
  - explainable-ai
8
- - core-ml
9
  - grad-cam
 
 
 
10
  language:
11
  - zh
12
  - en
13
  metrics:
14
  - roc_auc
 
 
15
  ---
16
 
17
  # ELIAS — Eyelid Lesion Intelligent Analysis System
18
 
19
- **AUC 0.93 · iPhone 即時推論 < 1s · 2026 智慧創新大賞參賽作品**
 
 
 
 
20
 
21
  ## Model Description
22
- Clinician-guided deep learning classifier for epiblepharon detection
23
- from external eye photographs. Uses a frozen ImageNet-pretrained
24
- ResNet-18 backbone with a task-specific classification head.
25
 
26
- - **Architecture**: ResNet-18 (frozen) + Linear head
27
- - **Training**: 5-fold cross-validation, BCEWithLogitsLoss
28
- - **Explainability**: Native PyTorch Grad-CAM (layer4)
29
- - **Deployment**: Apple Core ML, on-device iOS inference < 1s
30
 
31
- ## Performance (5-Fold Cross-Validation)
32
- | Metric | Value |
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  |---|---|
34
- | AUC | **0.93** |
 
35
  | Sensitivity | High |
36
  | Specificity | Moderate |
37
  | F1 Score | High |
38
 
39
- ## Intended Use
40
- Research prototype for clinical decision support in
41
- epiblepharon screening. **Not a validated medical device.**
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
- ## Source Code
44
- GitHub: https://github.com/YOUR_USERNAME/ELIAS
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  ```
46
 
47
  ---
48
 
49
- ### Step 5 取得佐證連結
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
- 上傳完成後,你的 Hugging Face 頁面會是:
 
 
 
 
 
 
 
 
 
 
52
  ```
53
- https://huggingface.co/YOUR_HF_USERNAME/ELIAS-epiblepharon
 
 
 
 
 
 
5
  - ophthalmology
6
  - image-classification
7
  - explainable-ai
 
8
  - grad-cam
9
+ - core-ml
10
+ - resnet
11
+ - pytorch
12
  language:
13
  - zh
14
  - en
15
  metrics:
16
  - roc_auc
17
+ - f1
18
+ pipeline_tag: image-classification
19
  ---
20
 
21
  # ELIAS — Eyelid Lesion Intelligent Analysis System
22
 
23
+ **眼瞼疾病智慧分析系統**
24
+
25
+ > 🏆 2026 年經濟部智慧創新大賞(學生組)參賽作品
26
+
27
+ ---
28
 
29
  ## Model Description
 
 
 
30
 
31
+ ELIAS is a **clinician-guided deep learning classifier** for automated detection of **epiblepharon** (睫毛倒插) from external eye photographs.
 
 
 
32
 
33
+ The model uses a **frozen ImageNet-pretrained ResNet-18 backbone** with a task-specific classification head. The key innovation is the explicit integration of clinician-defined anatomical **Regions of Interest (ROI)** — specifically the lower eyelid margin and eyelash–cornea interface — as a prior constraint, enabling robust classification in a **small-data regime (~80–150 cases per class)**.
34
+
35
+ ### Architecture
36
+
37
+ ```
38
+ Input (224×224 RGB)
39
+
40
+
41
+ ResNet-18 backbone (frozen, ImageNet pretrained)
42
+ │ layer1 → layer2 → layer3 → layer4
43
+ │ Global Average Pooling → (512,)
44
+
45
+ Dropout(0.3) → Linear(512 → 2)
46
+
47
+
48
+ Softmax → [P(control), P(epiblepharon)]
49
+ ```
50
+
51
+ | Component | Detail |
52
+ |---|---|
53
+ | Backbone | ResNet-18 (ImageNet pretrained, **fully frozen**) |
54
+ | Classification head | `Dropout(0.3)` + `Linear(512 → 2)` |
55
+ | Loss function | `CrossEntropyLoss` |
56
+ | Optimizer | `Adam(lr=1e-3)`, head parameters only |
57
+ | Input size | 224 × 224 px, RGB (Grayscale → 3ch conversion applied) |
58
+ | Normalization | ImageNet mean/std `[0.485, 0.456, 0.406]` / `[0.229, 0.224, 0.225]` |
59
+
60
+ ---
61
+
62
+ ## Performance
63
+
64
+ Evaluated by **stratified 5-fold cross-validation** (`random_state=42`, 20 epochs/fold).
65
+
66
+ | Metric | Mean (5-fold) |
67
  |---|---|
68
+ | **AUC** | **0.93** |
69
+ | Accuracy | High |
70
  | Sensitivity | High |
71
  | Specificity | Moderate |
72
  | F1 Score | High |
73
 
74
+ - No fold collapse observed across all 5 folds
75
+ - Label-shuffling negative control confirmed genuine feature learning
76
+ - ROI ablation experiments validated lower eyelid margin as primary diagnostic signal
77
+
78
+ ### ROI Ablation Summary
79
+
80
+ | Condition | Performance vs Baseline |
81
+ |---|---|
82
+ | Full image (baseline) | ✅ Optimal |
83
+ | ROI ablated (lower eyelid blurred) | ❌ Significant drop |
84
+ | Non-ROI ablated (ROI preserved) | ✅ Near-baseline |
85
+
86
+ > Diagnostic features are **spatially localized** to the clinically defined lower eyelid margin — consistent with clinical examination principles for epiblepharon.
87
+
88
+ ---
89
+
90
+ ## Grad-CAM Explainability
91
+
92
+ Grad-CAM heatmaps were generated using native PyTorch hooks on `layer4` (no Captum dependency):
93
+
94
+ - **Epiblepharon cases**: Activation consistently focused on **lower eyelid margin and eyelash–cornea interface**
95
+ - **Control cases**: Diffuse, anatomically unfocused activation patterns
96
+
97
+ Heatmap overlay: α = 0.45, JET colormap, bilinear upsampling to 224×224.
98
+
99
+ ---
100
+
101
+ ## iOS On-Device Inference
102
 
103
+ The trained model has been converted to **Apple Core ML** format (`.mlpackage`):
104
+
105
+ | Metric | Value |
106
+ |---|---|
107
+ | Model size | < 50 MB |
108
+ | Inference latency | **< 1 second / image** |
109
+ | Device | iPhone 12+ (A14+ Neural Engine) |
110
+ | Network required | ❌ None — fully on-device |
111
+
112
+ Privacy: facial images never leave the device, consistent with PDPA / HIPAA principles.
113
+
114
+ ---
115
+
116
+ ## Training Data
117
+
118
+ - **Task**: Binary classification — epiblepharon vs. control
119
+ - **Image type**: External eye photographs
120
+ - **Dataset size**: ~80–150 cases per class (single-center, retrospective)
121
+ - **Preprocessing**: Resize 224×224, Grayscale→3ch, ColorJitter, RandomHorizontalFlip, ImageNet normalization
122
+
123
+ > ⚠️ Clinical images are **not distributed** in this repository due to patient privacy regulations (Personal Data Protection Act, IRB). For academic collaboration, please contact the corresponding author.
124
+
125
+ ---
126
+
127
+ ## Usage
128
+
129
+ ```python
130
+ import torch
131
+ from torchvision import models, transforms
132
+ from PIL import Image
133
+
134
+ # Load model
135
+ model = models.resnet18(weights=None)
136
+ for param in model.parameters():
137
+ param.requires_grad = False
138
+ model.fc = torch.nn.Linear(model.fc.in_features, 2)
139
+ model.load_state_dict(torch.load("pytorch_model.pt", map_location="cpu"))
140
+ model.eval()
141
+
142
+ # Preprocess
143
+ transform = transforms.Compose([
144
+ transforms.Resize((224, 224)),
145
+ transforms.Grayscale(num_output_channels=3),
146
+ transforms.ToTensor(),
147
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
148
+ ])
149
+
150
+ img = Image.open("eye_photo.jpg").convert("RGB")
151
+ x = transform(img).unsqueeze(0) # (1, 3, 224, 224)
152
+
153
+ with torch.no_grad():
154
+ logits = model(x)
155
+ prob = torch.softmax(logits, dim=1)[0, 1].item()
156
+ print(f"Epiblepharon probability: {prob:.3f}")
157
  ```
158
 
159
  ---
160
 
161
+ ## Files in This Repository
162
+
163
+ | File | Description |
164
+ |---|---|
165
+ | `README.md` | This model card |
166
+ | `model.py` | Model architecture definition |
167
+ | `train.py` | 5-fold cross-validation training script |
168
+ | `config.json` | Model configuration |
169
+ | `requirements.txt` | Python dependencies |
170
+ | `pytorch_model.pt` | *(Checkpoint — upload separately after training)* |
171
+
172
+ ---
173
+
174
+ ## Intended Use & Limitations
175
+
176
+ - **Intended use**: Research prototype for clinical decision support in epiblepharon screening
177
+ - **NOT** a validated medical device — prospective evaluation and regulatory assessment required before clinical deployment
178
+ - Single-center retrospective data — generalizability across imaging conditions and demographics requires multi-center validation
179
 
180
+ ---
181
+
182
+ ## Citation
183
+
184
+ ```bibtex
185
+ @misc{elias2026,
186
+ title = {ELIAS: Eyelid Lesion Intelligent Analysis System},
187
+ year = {2026},
188
+ note = {2026 MOEA Smart Innovation Award submission},
189
+ url = {https://huggingface.co/YOUR_HF_USERNAME/ELIAS-epiblepharon}
190
+ }
191
  ```
192
+
193
+ ---
194
+
195
+ ## License
196
+
197
+ [MIT License](LICENSE) — Source code only. Clinical data excluded.
config.json ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_name": "ELIAS-epiblepharon",
3
+ "model_type": "resnet18",
4
+ "architecture": "ResNet-18 (frozen ImageNet backbone + task-specific head)",
5
+ "task": "binary-image-classification",
6
+ "disease": "epiblepharon",
7
+ "num_classes": 2,
8
+ "id2label": {
9
+ "0": "control",
10
+ "1": "epiblepharon"
11
+ },
12
+ "label2id": {
13
+ "control": 0,
14
+ "epiblepharon": 1
15
+ },
16
+ "image_size": 224,
17
+ "input_channels": 3,
18
+ "preprocessing": {
19
+ "resize": [224, 224],
20
+ "grayscale_to_3ch": true,
21
+ "normalize_mean": [0.485, 0.456, 0.406],
22
+ "normalize_std": [0.229, 0.224, 0.225]
23
+ },
24
+ "training": {
25
+ "backbone_frozen": true,
26
+ "optimizer": "Adam",
27
+ "learning_rate": 0.001,
28
+ "epochs": 20,
29
+ "batch_size": 32,
30
+ "loss": "CrossEntropyLoss",
31
+ "validation": "StratifiedKFold(n_splits=5, random_state=42)"
32
+ },
33
+ "performance": {
34
+ "auc": 0.93,
35
+ "validation_strategy": "5-fold cross-validation",
36
+ "note": "No fold collapse observed"
37
+ },
38
+ "deployment": {
39
+ "ios_coreml": true,
40
+ "inference_latency_ms": "<1000",
41
+ "device": "iPhone 12+ (A14+ Neural Engine)",
42
+ "model_size_mb": "<50"
43
+ },
44
+ "framework": "pytorch",
45
+ "torch_version": ">=2.0.0"
46
+ }
model.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ELIAS — Eyelid Lesion Intelligent Analysis System
3
+ model.py
4
+
5
+ Frozen ResNet-18 classifier for epiblepharon detection.
6
+ Compatible with Hugging Face model loading.
7
+ """
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ from torchvision import models
12
+
13
+
14
+ def build_elias_model(num_classes: int = 2, freeze_backbone: bool = True) -> nn.Module:
15
+ """
16
+ Build ELIAS classifier.
17
+
18
+ Args:
19
+ num_classes: 2 for binary (CrossEntropyLoss)
20
+ freeze_backbone: Freeze all layers except the final FC head.
21
+
22
+ Returns:
23
+ ResNet-18 model with task-specific classification head.
24
+ """
25
+ model = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
26
+
27
+ if freeze_backbone:
28
+ for param in model.parameters():
29
+ param.requires_grad = False
30
+
31
+ # Replace final FC with task-specific head
32
+ in_features = model.fc.in_features # 512
33
+ model.fc = nn.Sequential(
34
+ nn.Dropout(p=0.3),
35
+ nn.Linear(in_features, num_classes),
36
+ )
37
+
38
+ return model
39
+
40
+
41
+ def load_elias_model(checkpoint_path: str, device: str = "cpu") -> nn.Module:
42
+ """
43
+ Load a trained ELIAS model from checkpoint.
44
+
45
+ Usage:
46
+ model = load_elias_model("pytorch_model.pt")
47
+ """
48
+ model = build_elias_model()
49
+ state_dict = torch.load(checkpoint_path, map_location=device)
50
+ model.load_state_dict(state_dict)
51
+ model.eval()
52
+ return model
53
+
54
+
55
+ if __name__ == "__main__":
56
+ model = build_elias_model()
57
+ trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
58
+ total = sum(p.numel() for p in model.parameters())
59
+ print(f"Trainable parameters: {trainable:,} / {total:,}")
60
+
61
+ # Sanity check
62
+ x = torch.randn(2, 3, 224, 224)
63
+ with torch.no_grad():
64
+ out = model(x)
65
+ print(f"Output shape: {out.shape}") # (2, 2)
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ torch>=2.0.0
2
+ torchvision>=0.15.0
3
+ numpy>=1.24.0
4
+ scikit-learn>=1.3.0
5
+ matplotlib>=3.7.0
6
+ seaborn>=0.12.0
7
+ pandas>=2.0.0
8
+ openpyxl>=3.1.0
9
+ Pillow>=9.5.0
train.py ADDED
@@ -0,0 +1,302 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ELIAS — Eyelid Lesion Intelligent Analysis System
3
+ train.py
4
+
5
+ Stratified 5-fold cross-validation training pipeline.
6
+ Extracted and refactored from gemini_crossval_masked.ipynb.
7
+
8
+ Usage:
9
+ python train.py --data_dir ./data/data --output_dir ./outputs
10
+
11
+ Data directory structure:
12
+ data/data/
13
+ ├── epiblepharon/ (positive class)
14
+ └── control/ (negative class)
15
+ """
16
+
17
+ import argparse
18
+ import os
19
+
20
+ import matplotlib.pyplot as plt
21
+ import numpy as np
22
+ import pandas as pd
23
+ import seaborn as sns
24
+ import torch
25
+ import torch.nn as nn
26
+ import torch.optim as optim
27
+ from sklearn.metrics import auc, confusion_matrix, f1_score, roc_curve
28
+ from sklearn.model_selection import StratifiedKFold
29
+ from torch.utils.data import DataLoader, Subset
30
+ from torchvision import datasets, models, transforms
31
+
32
+ from model import build_elias_model
33
+
34
+
35
+ # ── Hyperparameters ────────────────────────────────────────────────────────────
36
+ BATCH_SIZE = 32
37
+ EPOCHS = 20
38
+ LR = 1e-3
39
+ N_FOLDS = 5
40
+ RANDOM_STATE = 42
41
+ IMAGE_SIZE = 224
42
+
43
+
44
+ # ── Dataset Utilities ──────────────────────────────────────────────────────────
45
+
46
+ class ApplyTransform(torch.utils.data.Dataset):
47
+ """Wrapper to apply different transforms to train/val subsets."""
48
+ def __init__(self, subset, transform=None):
49
+ self.subset = subset
50
+ self.transform = transform
51
+
52
+ def __getitem__(self, index):
53
+ x, y = self.subset[index]
54
+ if self.transform:
55
+ x = self.transform(x)
56
+ return x, y
57
+
58
+ def __len__(self):
59
+ return len(self.subset)
60
+
61
+
62
+ def get_transforms():
63
+ """
64
+ Returns train and validation transform pipelines.
65
+
66
+ Note: Grayscale(num_output_channels=3) is applied to normalize
67
+ illumination variation across clinical photographs while maintaining
68
+ 3-channel input compatibility with ImageNet-pretrained ResNet-18.
69
+ """
70
+ train_tf = transforms.Compose([
71
+ transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
72
+ transforms.Grayscale(num_output_channels=3),
73
+ transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5),
74
+ transforms.RandomHorizontalFlip(),
75
+ transforms.ToTensor(),
76
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
77
+ ])
78
+ val_tf = transforms.Compose([
79
+ transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
80
+ transforms.Grayscale(num_output_channels=3),
81
+ transforms.ToTensor(),
82
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
83
+ ])
84
+ return train_tf, val_tf
85
+
86
+
87
+ # ── Training & Evaluation ──────────────────────────────────────────────────────
88
+
89
+ def train_one_epoch(model, loader, criterion, optimizer, device):
90
+ model.train()
91
+ running_loss = 0.0
92
+ for inputs, labels in loader:
93
+ inputs, labels = inputs.to(device), labels.to(device)
94
+ optimizer.zero_grad()
95
+ outputs = model(inputs)
96
+ loss = criterion(outputs, labels)
97
+ loss.backward()
98
+ optimizer.step()
99
+ running_loss += loss.item() * inputs.size(0)
100
+ return running_loss / len(loader.dataset)
101
+
102
+
103
+ @torch.no_grad()
104
+ def evaluate(model, loader, device):
105
+ model.eval()
106
+ y_true, y_probs, y_pred = [], [], []
107
+ correct = 0
108
+ for inputs, labels in loader:
109
+ inputs, labels = inputs.to(device), labels.to(device)
110
+ outputs = model(inputs)
111
+ probs = torch.softmax(outputs, dim=1)[:, 1]
112
+ preds = torch.argmax(outputs, dim=1)
113
+ correct += (preds == labels).sum().item()
114
+ y_true.extend(labels.cpu().numpy())
115
+ y_probs.extend(probs.cpu().numpy())
116
+ y_pred.extend(preds.cpu().numpy())
117
+ acc = correct / len(loader.dataset)
118
+ return acc, np.array(y_true), np.array(y_probs), np.array(y_pred)
119
+
120
+
121
+ def compute_fold_metrics(y_true, y_probs, y_pred, class_names):
122
+ """Compute sensitivity, specificity, F1, AUC from fold predictions."""
123
+ cm = confusion_matrix(y_true, y_pred)
124
+ tn, fp, fn, tp = cm.ravel()
125
+ sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0.0
126
+ specificity = tn / (tn + fp) if (tn + fp) > 0 else 0.0
127
+ f1 = f1_score(y_true, y_pred)
128
+ fpr, tpr, _ = roc_curve(y_true, y_probs)
129
+ fold_auc = auc(fpr, tpr)
130
+ return {
131
+ "sensitivity": sensitivity,
132
+ "specificity": specificity,
133
+ "f1": f1,
134
+ "auc": fold_auc,
135
+ "fpr": fpr,
136
+ "tpr": tpr,
137
+ "cm": cm,
138
+ }
139
+
140
+
141
+ # ── Plotting ───────────────────────────────────────────────────────────────────
142
+
143
+ def save_confusion_matrix(cm, class_names, fold_idx, output_dir):
144
+ plt.figure(figsize=(6, 5))
145
+ sns.heatmap(
146
+ cm, annot=True, fmt="d", cmap="Blues",
147
+ xticklabels=class_names, yticklabels=class_names,
148
+ )
149
+ plt.title(f"Confusion Matrix — Fold {fold_idx + 1}")
150
+ plt.ylabel("Actual"); plt.xlabel("Predicted")
151
+ path = os.path.join(output_dir, f"confusion_matrix_fold_{fold_idx + 1}.png")
152
+ plt.savefig(path, dpi=120, bbox_inches="tight")
153
+ plt.close()
154
+
155
+
156
+ def save_roc_curves(roc_data, output_dir):
157
+ plt.figure(figsize=(8, 6))
158
+ for fold_idx, (fpr, tpr, fold_auc) in enumerate(roc_data):
159
+ plt.plot(fpr, tpr, label=f"Fold {fold_idx + 1} (AUC = {fold_auc:.3f})")
160
+ plt.plot([0, 1], [0, 1], "k--", linewidth=1)
161
+ plt.xlabel("False Positive Rate"); plt.ylabel("True Positive Rate")
162
+ plt.title("ROC Curves — 5-Fold Cross-Validation")
163
+ plt.legend(loc="lower right")
164
+ path = os.path.join(output_dir, "roc_curves.png")
165
+ plt.savefig(path, dpi=120, bbox_inches="tight")
166
+ plt.close()
167
+ print(f"[ELIAS] ROC curve saved → {path}")
168
+
169
+
170
+ def save_learning_curves(all_train_loss, all_val_acc, output_dir):
171
+ fig, axes = plt.subplots(1, 2, figsize=(12, 4))
172
+ axes[0].plot(np.mean(all_train_loss, axis=0), linewidth=2)
173
+ axes[0].fill_between(
174
+ range(EPOCHS),
175
+ np.mean(all_train_loss, axis=0) - np.std(all_train_loss, axis=0),
176
+ np.mean(all_train_loss, axis=0) + np.std(all_train_loss, axis=0),
177
+ alpha=0.2,
178
+ )
179
+ axes[0].set_title("Mean Training Loss (±SD)"); axes[0].set_xlabel("Epoch")
180
+
181
+ axes[1].plot(np.mean(all_val_acc, axis=0), linewidth=2, color="tab:orange")
182
+ axes[1].fill_between(
183
+ range(EPOCHS),
184
+ np.mean(all_val_acc, axis=0) - np.std(all_val_acc, axis=0),
185
+ np.mean(all_val_acc, axis=0) + np.std(all_val_acc, axis=0),
186
+ alpha=0.2, color="tab:orange",
187
+ )
188
+ axes[1].set_title("Mean Validation Accuracy (±SD)"); axes[1].set_xlabel("Epoch")
189
+
190
+ plt.tight_layout()
191
+ path = os.path.join(output_dir, "learning_curves.png")
192
+ plt.savefig(path, dpi=120, bbox_inches="tight")
193
+ plt.close()
194
+ print(f"[ELIAS] Learning curves saved → {path}")
195
+
196
+
197
+ # ── Main ───────────────────────────────────────────────────────────────────────
198
+
199
+ def main():
200
+ parser = argparse.ArgumentParser(description="ELIAS 5-Fold Cross-Validation")
201
+ parser.add_argument("--data_dir", type=str, default="./data/data")
202
+ parser.add_argument("--output_dir", type=str, default="./outputs")
203
+ args = parser.parse_args()
204
+
205
+ os.makedirs(args.output_dir, exist_ok=True)
206
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
207
+ print(f"[ELIAS] Device: {device}")
208
+
209
+ # ── Load dataset ─────────────────────────────────────────────────────
210
+ full_dataset = datasets.ImageFolder(args.data_dir)
211
+ labels = np.array(full_dataset.targets)
212
+ class_names = full_dataset.classes
213
+ print(f"[ELIAS] Classes: {class_names}")
214
+ print(f"[ELIAS] Total samples: {len(full_dataset)}")
215
+
216
+ train_tf, val_tf = get_transforms()
217
+
218
+ # ── Cross-validation setup ────────────────────────────────────────────
219
+ skf = StratifiedKFold(n_splits=N_FOLDS, shuffle=True, random_state=RANDOM_STATE)
220
+
221
+ all_train_loss = np.zeros((N_FOLDS, EPOCHS))
222
+ all_val_acc = np.zeros((N_FOLDS, EPOCHS))
223
+ fold_results = []
224
+ roc_data = []
225
+
226
+ # ── Fold loop ─────────────────────────────────────────────────────────
227
+ for fold, (train_ids, val_ids) in enumerate(skf.split(np.zeros(len(labels)), labels)):
228
+ print(f"\n{'='*20} FOLD {fold + 1}/{N_FOLDS} {'='*20}")
229
+ print(f" Train: {len(train_ids)} | Val: {len(val_ids)}")
230
+
231
+ train_data = ApplyTransform(Subset(full_dataset, train_ids), transform=train_tf)
232
+ val_data = ApplyTransform(Subset(full_dataset, val_ids), transform=val_tf)
233
+ train_loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
234
+ val_loader = DataLoader(val_data, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)
235
+
236
+ model = build_elias_model(num_classes=2, freeze_backbone=True).to(device)
237
+ criterion = nn.CrossEntropyLoss()
238
+ optimizer = optim.Adam(model.fc.parameters(), lr=LR)
239
+
240
+ # Epoch loop
241
+ for epoch in range(EPOCHS):
242
+ train_loss = train_one_epoch(model, train_loader, criterion, optimizer, device)
243
+ val_acc, _, _, _ = evaluate(model, val_loader, device)
244
+ all_train_loss[fold, epoch] = train_loss
245
+ all_val_acc[fold, epoch] = val_acc
246
+ print(
247
+ f" Epoch {epoch + 1:02d}/{EPOCHS} "
248
+ f"loss={train_loss:.4f} val_acc={val_acc:.4f}"
249
+ )
250
+
251
+ # Final fold evaluation
252
+ val_acc, y_true, y_probs, y_pred = evaluate(model, val_loader, device)
253
+ metrics = compute_fold_metrics(y_true, y_probs, y_pred, class_names)
254
+
255
+ print(
256
+ f"\n ✅ Fold {fold + 1} | "
257
+ f"AUC={metrics['auc']:.4f} "
258
+ f"Sen={metrics['sensitivity']:.3f} "
259
+ f"Spe={metrics['specificity']:.3f} "
260
+ f"F1={metrics['f1']:.3f}"
261
+ )
262
+
263
+ fold_results.append({
264
+ "Fold": fold + 1,
265
+ "Accuracy": val_acc,
266
+ "Sensitivity": metrics["sensitivity"],
267
+ "Specificity": metrics["specificity"],
268
+ "F1 Score": metrics["f1"],
269
+ "AUC": metrics["auc"],
270
+ })
271
+ roc_data.append((metrics["fpr"], metrics["tpr"], metrics["auc"]))
272
+
273
+ # Save confusion matrix per fold
274
+ save_confusion_matrix(metrics["cm"], class_names, fold, args.output_dir)
275
+
276
+ # Save best model checkpoint (fold-specific)
277
+ ckpt_path = os.path.join(args.output_dir, f"pytorch_model_fold{fold + 1}.pt")
278
+ torch.save(model.state_dict(), ckpt_path)
279
+
280
+ # ── Aggregate results ─────────────────────────────────────────────────
281
+ results_df = pd.DataFrame(fold_results)
282
+ avg_row = results_df.mean(numeric_only=True).to_dict()
283
+ avg_row["Fold"] = "Average"
284
+ results_df = pd.concat([results_df, pd.DataFrame([avg_row])], ignore_index=True)
285
+
286
+ excel_path = os.path.join(args.output_dir, "model_performance_results.xlsx")
287
+ results_df.to_excel(excel_path, index=False)
288
+
289
+ print(f"\n{'='*60}")
290
+ print(" CROSS-VALIDATION SUMMARY")
291
+ print(f"{'='*60}")
292
+ print(results_df.to_string(index=False))
293
+
294
+ # ── Save plots ────────────────────────────────────────────────────────
295
+ save_roc_curves(roc_data, args.output_dir)
296
+ save_learning_curves(all_train_loss, all_val_acc, args.output_dir)
297
+
298
+ print(f"\n[ELIAS] All outputs saved to: {args.output_dir}")
299
+
300
+
301
+ if __name__ == "__main__":
302
+ main()