lijuanliao commited on
Commit
2cbeb83
·
verified ·
1 Parent(s): 7c337ab

Upload rop_patient_grouped.py

Browse files
Files changed (1) hide show
  1. rop_patient_grouped.py +1747 -0
rop_patient_grouped.py ADDED
@@ -0,0 +1,1747 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import math
4
+ import os
5
+ import random
6
+ from contextlib import nullcontext
7
+ from pathlib import Path
8
+ from typing import Dict, List, Optional, Tuple, Any
9
+
10
+ import cv2
11
+ import numpy as np
12
+ import pandas as pd
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+ from PIL import Image
17
+ from sklearn.metrics import (
18
+ accuracy_score,
19
+ balanced_accuracy_score,
20
+ classification_report,
21
+ f1_score,
22
+ precision_score,
23
+ recall_score,
24
+ )
25
+ from sklearn.model_selection import GroupKFold, StratifiedKFold
26
+ from torch.utils.data import DataLoader, Dataset, WeightedRandomSampler
27
+ from torchvision import models, transforms
28
+ from torchvision.transforms import InterpolationMode
29
+ from tqdm import tqdm
30
+
31
+ try:
32
+ from sklearn.model_selection import StratifiedGroupKFold
33
+ HAS_STRATIFIED_GROUP_KFOLD = True
34
+ except Exception:
35
+ StratifiedGroupKFold = None
36
+ HAS_STRATIFIED_GROUP_KFOLD = False
37
+
38
+ try:
39
+ import timm
40
+ HAS_TIMM = True
41
+ except ImportError:
42
+ HAS_TIMM = False
43
+ print("Warning: timm is not installed. timm-based models will be skipped.")
44
+
45
+ PRIMARY_METRIC = "macro_f1"
46
+ DEFAULT_INPUT_SIZE = 512
47
+
48
+ # Utilities
49
+
50
+ def seed_everything(seed: int = 42, deterministic: bool = False) -> None:
51
+ random.seed(seed)
52
+ np.random.seed(seed)
53
+ torch.manual_seed(seed)
54
+ torch.cuda.manual_seed_all(seed)
55
+ os.environ["PYTHONHASHSEED"] = str(seed)
56
+
57
+ if deterministic:
58
+ torch.backends.cudnn.benchmark = False
59
+ torch.backends.cudnn.deterministic = True
60
+ try:
61
+ torch.use_deterministic_algorithms(True, warn_only=True)
62
+ except Exception:
63
+ pass
64
+ else:
65
+ torch.backends.cudnn.benchmark = True
66
+ torch.backends.cudnn.deterministic = False
67
+
68
+ def ensure_dir(path: Path) -> None:
69
+ path.mkdir(parents=True, exist_ok=True)
70
+
71
+ def to_jsonable(obj: Any):
72
+ if isinstance(obj, dict):
73
+ return {k: to_jsonable(v) for k, v in obj.items()}
74
+ if isinstance(obj, list):
75
+ return [to_jsonable(v) for v in obj]
76
+ if isinstance(obj, tuple):
77
+ return [to_jsonable(v) for v in obj]
78
+ if isinstance(obj, (np.integer, np.floating)):
79
+ return obj.item()
80
+ return obj
81
+
82
+ def name_matches_keywords(name: str, keywords: List[str]) -> bool:
83
+ if not name:
84
+ return False
85
+ for kw in keywords:
86
+ plain_kw = kw.rstrip(".")
87
+ if kw in name or name == plain_kw or name.startswith(plain_kw + "."):
88
+ return True
89
+ return False
90
+
91
+
92
+ # Device
93
+
94
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
95
+ print(f"Using device: {device}")
96
+ if device.type == "cuda":
97
+ print(f"GPU: {torch.cuda.get_device_name(0)}")
98
+ print(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1024 ** 3:.1f} GB")
99
+ else:
100
+ print("Warning: CUDA is not available. Training will be much slower on CPU.")
101
+
102
+
103
+ # Model
104
+
105
+ _VIT_KEYWORDS = [
106
+ "ViT", "Swin", "Transformer", "DeiT", "MaxViT", "CoAtNet",
107
+ "EfficientFormer", "FastViT", "CaFormer",
108
+ ]
109
+
110
+ def _is_vit_family(model_name: str) -> bool:
111
+ return any(kw.lower() in model_name.lower() for kw in _VIT_KEYWORDS)
112
+
113
+ def _is_timm_model(model: nn.Module) -> bool:
114
+ return hasattr(model, "get_classifier") and hasattr(model, "num_features")
115
+
116
+ MODEL_INPUT_SIZES: Dict[str, int] = {
117
+ "inception_v3": 299,
118
+ }
119
+
120
+ def get_model_input_size(model_name: str) -> int:
121
+ return MODEL_INPUT_SIZES.get(model_name, DEFAULT_INPUT_SIZE)
122
+
123
+
124
+ # Metrics / IO
125
+ def compute_metrics(
126
+ y_true: List[int],
127
+ y_pred: List[int],
128
+ num_classes: int,
129
+ class_names: List[str],
130
+ ) -> Tuple[Dict, Dict]:
131
+ labels = list(range(num_classes))
132
+ report = classification_report(
133
+ y_true,
134
+ y_pred,
135
+ labels=labels,
136
+ target_names=class_names,
137
+ output_dict=True,
138
+ zero_division=0,
139
+ )
140
+ metrics = {
141
+ "accuracy": 100.0 * accuracy_score(y_true, y_pred),
142
+ "balanced_accuracy": 100.0 * balanced_accuracy_score(y_true, y_pred),
143
+ "macro_f1": 100.0 * f1_score(y_true, y_pred, labels=labels, average="macro", zero_division=0),
144
+ "macro_precision": 100.0 * precision_score(y_true, y_pred, labels=labels, average="macro", zero_division=0),
145
+ "macro_recall": 100.0 * recall_score(y_true, y_pred, labels=labels, average="macro", zero_division=0),
146
+ "weighted_f1": 100.0 * f1_score(y_true, y_pred, labels=labels, average="weighted", zero_division=0),
147
+ }
148
+ return metrics, report
149
+
150
+ def save_fold_results(results: Dict, save_dir: Path, tag: str = "best") -> None:
151
+ ensure_dir(save_dir)
152
+
153
+ report_df = pd.DataFrame(results["classification_report"]).transpose()
154
+ with open(save_dir / f"test_report_{tag}.txt", "w", encoding="utf-8") as f:
155
+ f.write(f"Primary Metric ({PRIMARY_METRIC}): {results['metrics'][PRIMARY_METRIC]:.4f}\n")
156
+ f.write(f"Accuracy: {results['metrics']['accuracy']:.4f}\n")
157
+ f.write(f"Balanced Accuracy: {results['metrics']['balanced_accuracy']:.4f}\n")
158
+ f.write(f"Macro F1: {results['metrics']['macro_f1']:.4f}\n")
159
+ f.write(f"Macro Recall: {results['metrics']['macro_recall']:.4f}\n")
160
+ f.write(f"Macro Precision: {results['metrics']['macro_precision']:.4f}\n\n")
161
+ f.write("Classification Report:\n")
162
+ f.write(report_df.to_string())
163
+
164
+ pred_df = pd.DataFrame({
165
+ "patient": results["patients"],
166
+ "image_name": results["image_names"],
167
+ "True": results["targets"],
168
+ "Predicted": results["predictions"],
169
+ "path": results["image_path"],
170
+ })
171
+ for c in range(results["num_classes"]):
172
+ pred_df[f"prob_class{c}"] = [row[c] for row in results["probabilities"]]
173
+ pred_df.to_csv(save_dir / f"predictions_{tag}.csv", index=False)
174
+
175
+ payload = {
176
+ "best_epoch": results["best_epoch"],
177
+ "primary_metric": PRIMARY_METRIC,
178
+ "metrics": results["metrics"],
179
+ "per_class": [
180
+ results["classification_report"].get(
181
+ f"class{i}", {"precision": 0, "recall": 0, "f1-score": 0}
182
+ )
183
+ for i in range(results["num_classes"])
184
+ ],
185
+ }
186
+ with open(save_dir / f"{tag}_metrics.json", "w", encoding="utf-8") as f:
187
+ json.dump(to_jsonable(payload), f, indent=2, ensure_ascii=False)
188
+
189
+ def save_kfold_summary(
190
+ model_name: str,
191
+ fold_results: List[Dict],
192
+ num_classes: int,
193
+ save_dir: Path,
194
+ ) -> Tuple[float, float]:
195
+ ensure_dir(save_dir)
196
+
197
+ metric_names = [
198
+ "accuracy",
199
+ "balanced_accuracy",
200
+ "macro_f1",
201
+ "macro_recall",
202
+ "macro_precision",
203
+ "weighted_f1",
204
+ ]
205
+ summary = {}
206
+ for name in metric_names:
207
+ values = [r["metrics"][name] for r in fold_results]
208
+ summary[name] = {
209
+ "mean": float(np.mean(values)),
210
+ "std": float(np.std(values)),
211
+ }
212
+
213
+ lines = [
214
+ "=" * 70,
215
+ f"Model: {model_name}",
216
+ "5-Fold Cross-Validation Summary",
217
+ f"Primary Metric: {PRIMARY_METRIC}",
218
+ "=" * 70,
219
+ "",
220
+ ]
221
+ for i, r in enumerate(fold_results, 1):
222
+ lines.append(
223
+ f"Fold {i}: Macro-F1={r['metrics']['macro_f1']:.2f}% | "
224
+ f"BA={r['metrics']['balanced_accuracy']:.2f}% | "
225
+ f"Acc={r['metrics']['accuracy']:.2f}% | "
226
+ f"BestEpoch={r['best_epoch']}"
227
+ )
228
+ lines.append("")
229
+ for name in metric_names:
230
+ lines.append(f"{name}: {summary[name]['mean']:.2f}% +/- {summary[name]['std']:.2f}%")
231
+
232
+ lines.append("")
233
+ lines.append("Per-class metrics (mean +/- std)")
234
+ lines.append(f"{'class':<10} {'precision':>18} {'recall':>18} {'f1-score':>18}")
235
+
236
+ per_class_summary = {}
237
+ for c in range(num_classes):
238
+ ps = [r["per_class"][c]["precision"] for r in fold_results]
239
+ rs = [r["per_class"][c]["recall"] for r in fold_results]
240
+ fs = [r["per_class"][c]["f1-score"] for r in fold_results]
241
+ per_class_summary[c] = {
242
+ "precision_mean": float(np.mean(ps)),
243
+ "precision_std": float(np.std(ps)),
244
+ "recall_mean": float(np.mean(rs)),
245
+ "recall_std": float(np.std(rs)),
246
+ "f1_mean": float(np.mean(fs)),
247
+ "f1_std": float(np.std(fs)),
248
+ }
249
+ lines.append(
250
+ f"class{c:<5} "
251
+ f"{np.mean(ps):.4f}+/-{np.std(ps):.4f}"
252
+ f"{np.mean(rs):>18.4f}+/-{np.std(rs):.4f}"
253
+ f"{np.mean(fs):>18.4f}+/-{np.std(fs):.4f}"
254
+ )
255
+
256
+ text = "\n".join(lines)
257
+ print(text)
258
+ with open(save_dir / "kfold_summary.txt", "w", encoding="utf-8") as f:
259
+ f.write(text)
260
+
261
+ with open(save_dir / "kfold_summary.json", "w", encoding="utf-8") as f:
262
+ json.dump(
263
+ to_jsonable({
264
+ "model": model_name,
265
+ "primary_metric": PRIMARY_METRIC,
266
+ "summary": summary,
267
+ "per_class": per_class_summary,
268
+ }),
269
+ f,
270
+ indent=2,
271
+ ensure_ascii=False,
272
+ )
273
+
274
+ all_targets, all_predictions, all_paths = [], [], []
275
+ all_patients, all_image_names = [], []
276
+ all_probabilities = []
277
+
278
+ pooled_ready = all(
279
+ "targets" in r and "predictions" in r and "image_path" in r and "probabilities" in r
280
+ for r in fold_results
281
+ )
282
+ if pooled_ready:
283
+ for r in fold_results:
284
+ all_targets.extend(r["targets"])
285
+ all_predictions.extend(r["predictions"])
286
+ all_paths.extend(r["image_path"])
287
+ all_patients.extend(r["patients"])
288
+ all_image_names.extend(r["image_names"])
289
+ all_probabilities.extend(r["probabilities"])
290
+
291
+ class_names = [f"class{i}" for i in range(num_classes)]
292
+ pooled_metrics, pooled_report = compute_metrics(
293
+ all_targets,
294
+ all_predictions,
295
+ num_classes,
296
+ class_names,
297
+ )
298
+
299
+ with open(save_dir / "oof_report.txt", "w", encoding="utf-8") as f:
300
+ f.write("Pooled out-of-fold metrics\n")
301
+ f.write(f"Primary Metric ({PRIMARY_METRIC}): {pooled_metrics[PRIMARY_METRIC]:.4f}\n")
302
+ for k, v in pooled_metrics.items():
303
+ f.write(f"{k}: {v:.4f}\n")
304
+ f.write("\nClassification Report:\n")
305
+ f.write(pd.DataFrame(pooled_report).transpose().to_string())
306
+
307
+ oof_df = pd.DataFrame({
308
+ "patient": all_patients,
309
+ "image_name": all_image_names,
310
+ "True": all_targets,
311
+ "Predicted": all_predictions,
312
+ "path": all_paths,
313
+ })
314
+ for c in range(num_classes):
315
+ oof_df[f"prob_class{c}"] = [row[c] for row in all_probabilities]
316
+ oof_df.to_csv(save_dir / "oof_predictions.csv", index=False)
317
+
318
+ return summary[PRIMARY_METRIC]["mean"], summary[PRIMARY_METRIC]["std"]
319
+
320
+
321
+ # 去掉黑边裁切,在 CLAHE + 绿色增强后增加眼底区域蒙版
322
+
323
+ class BlackBorderCrop:
324
+ """Crop black borders and obvious invalid background around the fundus."""
325
+ def __init__(self, threshold: int = 10, margin_ratio: float = 0.02):
326
+ self.threshold = threshold
327
+ self.margin_ratio = margin_ratio
328
+
329
+ def __call__(self, pil_img: Image.Image) -> Image.Image:
330
+ img = np.array(pil_img.convert("RGB"))
331
+ gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
332
+ mask = gray > self.threshold
333
+
334
+ if mask.sum() < 64:
335
+ return pil_img.convert("RGB")
336
+
337
+ ys, xs = np.where(mask)
338
+ y1, y2 = ys.min(), ys.max()
339
+ x1, x2 = xs.min(), xs.max()
340
+
341
+ margin_y = int((y2 - y1 + 1) * self.margin_ratio)
342
+ margin_x = int((x2 - x1 + 1) * self.margin_ratio)
343
+
344
+ y1 = max(0, y1 - margin_y)
345
+ y2 = min(img.shape[0], y2 + margin_y + 1)
346
+ x1 = max(0, x1 - margin_x)
347
+ x2 = min(img.shape[1], x2 + margin_x + 1)
348
+
349
+ cropped = img[y1:y2, x1:x2]
350
+ return Image.fromarray(cropped)
351
+
352
+ class FundusCircularCrop:
353
+
354
+ def __init__(self, threshold: int = 8, radius_pad_ratio: float = 0.03):
355
+ self.threshold = threshold
356
+ self.radius_pad_ratio = radius_pad_ratio
357
+
358
+ def __call__(self, pil_img: Image.Image) -> Image.Image:
359
+ img = np.array(pil_img.convert("RGB"))
360
+ gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
361
+
362
+ mask = (gray > self.threshold).astype(np.uint8) * 255
363
+ kernel = np.ones((5, 5), np.uint8)
364
+ mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
365
+ mask = cv2.medianBlur(mask, 5)
366
+
367
+ contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
368
+
369
+ if not contours:
370
+ return Image.fromarray(img)
371
+
372
+ largest = max(contours, key=cv2.contourArea)
373
+ (cx, cy), radius = cv2.minEnclosingCircle(largest)
374
+
375
+ if radius < 10:
376
+ return Image.fromarray(img)
377
+
378
+ radius = radius * (1.0 + self.radius_pad_ratio)
379
+ cx, cy, radius = float(cx), float(cy), float(radius)
380
+
381
+ x1 = max(0, int(cx - radius))
382
+ y1 = max(0, int(cy - radius))
383
+ x2 = min(img.shape[1], int(cx + radius))
384
+ y2 = min(img.shape[0], int(cy + radius))
385
+
386
+ cropped = img[y1:y2, x1:x2]
387
+ h, w = cropped.shape[:2]
388
+ if h < 2 or w < 2:
389
+ return Image.fromarray(img)
390
+
391
+ local_cx = cx - x1
392
+ local_cy = cy - y1
393
+ rr = max(1, min(int(radius), min(h, w) // 2))
394
+
395
+ yy, xx = np.ogrid[:h, :w]
396
+ circle_mask = ((xx - local_cx) ** 2 + (yy - local_cy) ** 2) <= (rr ** 2)
397
+
398
+ out = np.zeros_like(cropped)
399
+ out[circle_mask] = cropped[circle_mask]
400
+ return Image.fromarray(out)
401
+
402
+ class ResizeToSquare:
403
+ def __init__(self, size: int):
404
+ self.size = size
405
+
406
+ def __call__(self, pil_img: Image.Image) -> Image.Image:
407
+ return pil_img.resize((self.size, self.size), resample=Image.BILINEAR)
408
+
409
+ class LightCLAHE:
410
+
411
+ def __init__(self, clip_limit: float = 2.0, grid: Tuple[int, int] = (8, 8)):
412
+ self.clahe = cv2.createCLAHE(clipLimit=clip_limit, tileGridSize=grid)
413
+
414
+ def __call__(self, pil_img: Image.Image) -> Image.Image:
415
+ img = np.array(pil_img.convert("RGB"))
416
+ lab = cv2.cvtColor(img, cv2.COLOR_RGB2LAB)
417
+ l, a, b = cv2.split(lab)
418
+ l = self.clahe.apply(l)
419
+ out = cv2.cvtColor(cv2.merge([l, a, b]), cv2.COLOR_LAB2RGB)
420
+ return Image.fromarray(out)
421
+
422
+ class GreenChannelEnhancement:
423
+
424
+ def __init__(
425
+ self,
426
+ clip_limit: float = 2.5,
427
+ grid: Tuple[int, int] = (8, 8),
428
+ blend_alpha: float = 0.30,
429
+ ):
430
+ self.clahe = cv2.createCLAHE(clipLimit=clip_limit, tileGridSize=grid)
431
+ self.blend_alpha = blend_alpha
432
+
433
+ def __call__(self, pil_img: Image.Image) -> Image.Image:
434
+ img = np.array(pil_img.convert("RGB"))
435
+ r, g, b = cv2.split(img)
436
+ g_eq = self.clahe.apply(g)
437
+ g_new = cv2.addWeighted(g, 1.0 - self.blend_alpha, g_eq, self.blend_alpha, 0.0)
438
+ out = cv2.merge([r, g_new, b])
439
+ return Image.fromarray(out)
440
+
441
+
442
+ class FundusEyeMask:
443
+
444
+ def __init__(
445
+ self,
446
+ threshold: int = 8,
447
+ radius_pad_ratio: float = 0.03,
448
+ morph_kernel: int = 7,
449
+ blur_kernel: int = 5,
450
+ ):
451
+ self.threshold = threshold
452
+ self.radius_pad_ratio = radius_pad_ratio
453
+ self.morph_kernel = morph_kernel
454
+ self.blur_kernel = blur_kernel
455
+
456
+ def __call__(self, pil_img: Image.Image) -> Image.Image:
457
+ img = np.array(pil_img.convert("RGB"))
458
+ gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
459
+
460
+ # Robust threshold against dark background after CLAHE + green enhancement
461
+ _, mask = cv2.threshold(gray, self.threshold, 255, cv2.THRESH_BINARY)
462
+
463
+ kernel = np.ones((self.morph_kernel, self.morph_kernel), np.uint8)
464
+ mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
465
+ mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
466
+
467
+ contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
468
+ if not contours:
469
+ return Image.fromarray(img)
470
+
471
+ largest = max(contours, key=cv2.contourArea)
472
+ (cx, cy), radius = cv2.minEnclosingCircle(largest)
473
+ if radius < 10:
474
+ return Image.fromarray(img)
475
+
476
+ radius = radius * (1.0 + self.radius_pad_ratio)
477
+ yy, xx = np.ogrid[:img.shape[0], :img.shape[1]]
478
+ circle_mask = (((xx - cx) ** 2 + (yy - cy) ** 2) <= (radius ** 2)).astype(np.uint8) * 255
479
+
480
+ if self.blur_kernel and self.blur_kernel > 1:
481
+ k = self.blur_kernel if self.blur_kernel % 2 == 1 else self.blur_kernel + 1
482
+ circle_mask = cv2.GaussianBlur(circle_mask, (k, k), 0)
483
+
484
+ circle_mask_f = (circle_mask.astype(np.float32) / 255.0)[..., None]
485
+ out = (img.astype(np.float32) * circle_mask_f).clip(0, 255).astype(np.uint8)
486
+ return Image.fromarray(out)
487
+
488
+ _light_clahe = LightCLAHE()
489
+ _green_enhance = GreenChannelEnhancement()
490
+ _eye_mask = FundusEyeMask()
491
+ _transform_cache: Dict[int, Tuple[transforms.Compose, transforms.Compose]] = {}
492
+
493
+ def build_transforms(input_size: int = DEFAULT_INPUT_SIZE):
494
+ """
495
+ 预处理流程:
496
+ - 不再使用 BlackBorderCrop
497
+ - 缩放到 input_size → CLAHE → 绿色通道增强 → 眼底区域蒙版
498
+ - 蒙版仅保留眼睛区域,屏蔽眼底边缘外的无关像素
499
+ 训练增强:
500
+ - 水平翻转 + 垂直翻转
501
+ - 小角度随机旋转 (±15°) + 轻微平移 + 尺度扰动 (0.85~1.15)
502
+ - 适度 ColorJitter
503
+ - 轻微高斯模糊
504
+ """
505
+ if input_size in _transform_cache:
506
+ return _transform_cache[input_size]
507
+
508
+ preprocess = [
509
+ ResizeToSquare(input_size),
510
+ _light_clahe,
511
+ _green_enhance,
512
+ _eye_mask,
513
+ ]
514
+
515
+ train_tf = transforms.Compose(
516
+ preprocess
517
+ + [
518
+ transforms.RandomHorizontalFlip(p=0.5),
519
+ transforms.RandomVerticalFlip(p=0.5),
520
+ transforms.RandomAffine(
521
+ degrees=15,
522
+ translate=(0.05, 0.05),
523
+ scale=(0.85, 1.15),
524
+ interpolation=InterpolationMode.BILINEAR,
525
+ fill=0,
526
+ ),
527
+ transforms.ColorJitter(
528
+ brightness=0.20,
529
+ contrast=0.20,
530
+ saturation=0.10,
531
+ hue=0.02,
532
+ ),
533
+ transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 0.8)),
534
+ transforms.ToTensor(),
535
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
536
+ ]
537
+ )
538
+
539
+ val_tf = transforms.Compose(
540
+ preprocess
541
+ + [
542
+ transforms.ToTensor(),
543
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
544
+ ]
545
+ )
546
+
547
+ _transform_cache[input_size] = (train_tf, val_tf)
548
+ return train_tf, val_tf
549
+
550
+
551
+ # TTA (Test-Time Augmentation)
552
+ # 增:4 路 TTA — 原图 / 水平翻转 / 垂直翻转 / 双向翻转
553
+
554
+ def predict_with_tta(
555
+ model: nn.Module,
556
+ inputs: torch.Tensor,
557
+ amp_enabled: bool = False,
558
+ ) -> torch.Tensor:
559
+
560
+ amp_ctx = torch.cuda.amp.autocast if amp_enabled else nullcontext
561
+ aug_variants = [
562
+ inputs, # 原图
563
+ inputs.flip(-1), # 水平翻转
564
+ inputs.flip(-2), # 垂直翻转
565
+ inputs.flip(-1).flip(-2), # 双向翻转
566
+ ]
567
+ probs_list = []
568
+ for aug in aug_variants:
569
+ with amp_ctx():
570
+ out = model(aug)
571
+ logits = _extract_logits(out)
572
+ probs_list.append(torch.softmax(logits, dim=1))
573
+
574
+ return torch.stack(probs_list, dim=0).mean(dim=0)
575
+
576
+ # Dataset
577
+
578
+ class ImageDataset(Dataset):
579
+ def __init__(self, df: pd.DataFrame, transform=None):
580
+ self.df = df.reset_index(drop=True).copy()
581
+ self.transform = transform
582
+
583
+ self.paths = self.df["path"].astype(str).tolist()
584
+ self.labels = self.df["label"].astype(int).tolist()
585
+ self.patients = self.df["patient"].astype(str).tolist()
586
+ if "image_name" in self.df.columns:
587
+ self.image_names = self.df["image_name"].astype(str).tolist()
588
+ else:
589
+ self.image_names = [Path(p).name for p in self.paths]
590
+
591
+ def __len__(self) -> int:
592
+ return len(self.paths)
593
+
594
+ def __getitem__(self, idx: int):
595
+ img_path = self.paths[idx]
596
+ label = self.labels[idx]
597
+
598
+ try:
599
+ image = Image.open(img_path).convert("RGB")
600
+ except Exception as exc:
601
+ raise RuntimeError(f"Failed to open image: {img_path}") from exc
602
+
603
+ if self.transform is not None:
604
+ image = self.transform(image)
605
+
606
+ meta = {
607
+ "path": img_path,
608
+ "patient": self.patients[idx],
609
+ "image_name": self.image_names[idx],
610
+ }
611
+ return image, torch.tensor(label, dtype=torch.long), meta
612
+
613
+
614
+ # Data loading / grouped splitting
615
+
616
+ def validate_image_paths(df: pd.DataFrame, path_col: str = "path") -> pd.DataFrame:
617
+ total = len(df)
618
+ mask = df[path_col].apply(os.path.isfile)
619
+ missing = total - int(mask.sum())
620
+ if missing > 0:
621
+ print(f"Warning: {missing}/{total} paths do not exist and will be removed.")
622
+ df = df.loc[mask].reset_index(drop=True)
623
+ else:
624
+ print(f"All {total} image paths are valid.")
625
+ return df
626
+
627
+ def load_and_prepare_data(excel_path: str, group_col: str = "patient") -> pd.DataFrame:
628
+ df = pd.read_excel(excel_path, engine="openpyxl")
629
+
630
+ required_cols = {"path", "label", group_col}
631
+ missing_cols = required_cols - set(df.columns)
632
+ if missing_cols:
633
+ raise KeyError(f"Missing required columns in Excel: {sorted(missing_cols)}")
634
+
635
+ df = df.copy()
636
+ df[group_col] = df[group_col].astype(str).str.strip()
637
+ if df[group_col].isin(["", "nan", "None"]).any():
638
+ bad_rows = int(df[group_col].isin(["", "nan", "None"]).sum())
639
+ raise ValueError(f"Found {bad_rows} rows with invalid patient/group identifiers in column '{group_col}'.")
640
+
641
+ df["label"] = df["label"].replace({"AROP": 5})
642
+ df["label"] = pd.to_numeric(df["label"], errors="raise").astype(int)
643
+
644
+ if df["label"].min() == 1:
645
+ df["label"] = df["label"] - 1
646
+
647
+ # Merge old labels 4 and 5 into class 3 -> final 4-class setup
648
+ df["label"] = df["label"].replace({4: 3, 5: 3})
649
+
650
+ df = validate_image_paths(df, path_col="path")
651
+
652
+ if "patient" != group_col:
653
+ df["patient"] = df[group_col].astype(str)
654
+
655
+ unique_labels = sorted(df["label"].unique().tolist())
656
+ print(f"Dataset size: {len(df)} images")
657
+ print(f"Unique patients: {df[group_col].nunique()}")
658
+ print(f"Class distribution: {dict(df['label'].value_counts().sort_index())}")
659
+ print(f"Observed labels: {unique_labels}")
660
+ return df
661
+
662
+ def _approximate_group_stratified_splits(
663
+ df: pd.DataFrame,
664
+ n_folds: int,
665
+ random_seed: int,
666
+ group_col: str,
667
+ ):
668
+
669
+ group_df = (
670
+ df.groupby(group_col)["label"]
671
+ .agg(lambda x: x.value_counts().index[0])
672
+ .reset_index()
673
+ )
674
+ if group_df[group_col].nunique() < n_folds:
675
+ raise ValueError(
676
+ f"Number of unique groups ({group_df[group_col].nunique()}) is smaller than n_folds={n_folds}."
677
+ )
678
+
679
+ skf = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=random_seed)
680
+ splits = []
681
+ group_ids = group_df[group_col].values
682
+ group_labels = group_df["label"].values
683
+
684
+ for group_train_idx, group_val_idx in skf.split(group_ids, group_labels):
685
+ train_groups = set(group_ids[group_train_idx])
686
+ val_groups = set(group_ids[group_val_idx])
687
+
688
+ train_idx = df.index[df[group_col].isin(train_groups)].to_numpy()
689
+ val_idx = df.index[df[group_col].isin(val_groups)].to_numpy()
690
+ splits.append((train_idx, val_idx))
691
+
692
+ return splits
693
+
694
+ def build_fold_splits(
695
+ df: pd.DataFrame,
696
+ n_folds: int,
697
+ random_seed: int,
698
+ group_col: str = "patient",
699
+ ):
700
+ groups = df[group_col].astype(str).values
701
+ labels = df["label"].values
702
+
703
+ if len(np.unique(groups)) < n_folds:
704
+ raise ValueError(
705
+ f"Unique groups in '{group_col}' = {len(np.unique(groups))}, which is smaller than n_folds={n_folds}."
706
+ )
707
+
708
+ if HAS_STRATIFIED_GROUP_KFOLD:
709
+ print(
710
+ f"Using StratifiedGroupKFold with group_col='{group_col}', n_folds={n_folds}, seed={random_seed}."
711
+ )
712
+ try:
713
+ splitter = StratifiedGroupKFold(
714
+ n_splits=n_folds,
715
+ shuffle=True,
716
+ random_state=random_seed,
717
+ )
718
+ splits = list(splitter.split(df, y=labels, groups=groups))
719
+ except ValueError as exc:
720
+ print(f"StratifiedGroupKFold failed: {exc}")
721
+ print("Falling back to approximate grouped stratification using patient-majority labels.")
722
+ splits = _approximate_group_stratified_splits(df, n_folds, random_seed, group_col)
723
+ else:
724
+ print("StratifiedGroupKFold is unavailable. Falling back to approximate grouped stratification.")
725
+ splits = _approximate_group_stratified_splits(df, n_folds, random_seed, group_col)
726
+
727
+ for fold_id, (train_idx, val_idx) in enumerate(splits, 1):
728
+ train_groups = set(df.iloc[train_idx][group_col].astype(str).tolist())
729
+ val_groups = set(df.iloc[val_idx][group_col].astype(str).tolist())
730
+ overlap = train_groups & val_groups
731
+ if overlap:
732
+ raise RuntimeError(
733
+ f"Data leakage detected in fold {fold_id}: {len(overlap)} overlapping groups."
734
+ )
735
+ return splits
736
+
737
+ def _compute_class_weights(train_df: pd.DataFrame, num_classes: int) -> torch.Tensor:
738
+ counts = train_df["label"].value_counts().sort_index()
739
+ total = len(train_df)
740
+ weights = [total / (num_classes * counts.get(c, 1)) for c in range(num_classes)]
741
+ return torch.tensor(weights, dtype=torch.float32, device=device)
742
+
743
+ def _make_weighted_sampler(train_df: pd.DataFrame) -> WeightedRandomSampler:
744
+ counts = train_df["label"].value_counts().to_dict()
745
+ sample_weights = train_df["label"].map(lambda x: 1.0 / counts[x]).astype(float).values
746
+ sample_weights = torch.as_tensor(sample_weights, dtype=torch.double)
747
+ return WeightedRandomSampler(sample_weights, num_samples=len(sample_weights), replacement=True)
748
+
749
+ def create_fold_loaders(
750
+ train_df: pd.DataFrame,
751
+ val_df: pd.DataFrame,
752
+ input_size: int = DEFAULT_INPUT_SIZE,
753
+ batch_size: int = 8,
754
+ num_classes: int = 4,
755
+ balance_mode: str = "loss",
756
+ num_workers: int = 4,
757
+ ):
758
+ train_tf, val_tf = build_transforms(input_size)
759
+
760
+ sampler = None
761
+ class_weights = None
762
+
763
+ if balance_mode == "sampler":+
764
+ sampler = _make_weighted_sampler(train_df)
765
+ print("Training loader uses WeightedRandomSampler for class balancing.")
766
+ elif balance_mode == "loss":
767
+ class_weights = _compute_class_weights(train_df, num_classes)
768
+ print(f"Training loss uses class weights: {class_weights.detach().cpu().numpy().tolist()}")
769
+ else:
770
+ print("No imbalance correction is used.")
771
+
772
+ drop_last = (batch_size > 1) and (len(train_df) % batch_size == 1)
773
+ if drop_last:
774
+ print(
775
+ f"Training loader will drop the last singleton batch "
776
+ f"(train_size={len(train_df)}, batch_size={batch_size}) to avoid BatchNorm issues."
777
+ )
778
+
779
+ pin_memory = device.type == "cuda"
780
+
781
+ train_loader = DataLoader(
782
+ ImageDataset(train_df, train_tf),
783
+ batch_size=batch_size,
784
+ shuffle=(sampler is None),
785
+ sampler=sampler,
786
+ num_workers=num_workers,
787
+ pin_memory=pin_memory,
788
+ drop_last=drop_last,
789
+ persistent_workers=(num_workers > 0),
790
+ )
791
+ val_loader = DataLoader(
792
+ ImageDataset(val_df, val_tf),
793
+ batch_size=batch_size,
794
+ shuffle=False,
795
+ num_workers=num_workers,
796
+ pin_memory=pin_memory,
797
+ persistent_workers=(num_workers > 0),
798
+ )
799
+ return train_loader, val_loader, class_weights
800
+
801
+ # ViT positional embedding interpolation
802
+
803
+ def patch_vit_for_large_input(
804
+ model: nn.Module,
805
+ model_name: str,
806
+ input_size: int,
807
+ ) -> nn.Module:
808
+ if "ViT" not in model_name:
809
+ return model
810
+
811
+ if not (hasattr(model, "encoder") and hasattr(model.encoder, "pos_embedding")):
812
+ print(f"Warning: cannot find pos_embedding for {model_name}, skip interpolation.")
813
+ return model
814
+
815
+ patch_size = model.patch_size
816
+ expected_patches = (input_size // patch_size) ** 2
817
+ pos_embed = model.encoder.pos_embedding
818
+ current_patches = pos_embed.shape[1] - 1
819
+
820
+ if current_patches == expected_patches:
821
+ print(f"[ViT] pos_embedding already matches input_size={input_size}, no interpolation needed.")
822
+ return model
823
+
824
+ print(
825
+ f"[ViT] Interpolating pos_embedding: {current_patches} -> {expected_patches} patches "
826
+ f"for input_size={input_size}."
827
+ )
828
+
829
+ cls_token = pos_embed[:, :1, :]
830
+ patch_tokens = pos_embed[:, 1:, :]
831
+ dim = patch_tokens.shape[-1]
832
+
833
+ h_old = w_old = int(math.sqrt(current_patches))
834
+ h_new = w_new = int(math.sqrt(expected_patches))
835
+
836
+ patch_tokens = (
837
+ patch_tokens
838
+ .reshape(1, h_old, w_old, dim)
839
+ .permute(0, 3, 1, 2)
840
+ .float()
841
+ )
842
+ patch_tokens = F.interpolate(
843
+ patch_tokens,
844
+ size=(h_new, w_new),
845
+ mode="bicubic",
846
+ align_corners=False,
847
+ )
848
+ patch_tokens = patch_tokens.permute(0, 2, 3, 1).reshape(1, expected_patches, dim)
849
+
850
+ model.encoder.pos_embedding = nn.Parameter(torch.cat([cls_token, patch_tokens], dim=1))
851
+
852
+ if hasattr(model, "image_size"):
853
+ model.image_size = input_size
854
+
855
+ return model
856
+
857
+ # Classifier replacement
858
+
859
+ def _find_last_linear(module: nn.Module):
860
+ if isinstance(module, nn.Linear):
861
+ return module
862
+ if isinstance(module, nn.Sequential):
863
+ for child in reversed(list(module.children())):
864
+ result = _find_last_linear(child)
865
+ if result is not None:
866
+ return result
867
+ if hasattr(module, "head") and isinstance(module.head, (nn.Linear, nn.Sequential)):
868
+ return _find_last_linear(module.head)
869
+ return None
870
+
871
+ def _verify_classifier(model: nn.Module, model_name: str, expected_classes: int) -> None:
872
+ for attr_name in ["fc", "head", "classifier", "heads"]:
873
+ if not hasattr(model, attr_name):
874
+ continue
875
+ layer = getattr(model, attr_name)
876
+ last_linear = _find_last_linear(layer)
877
+ if last_linear is not None:
878
+ if last_linear.out_features != expected_classes:
879
+ raise RuntimeError(
880
+ f"Classifier replacement failed for {model_name}: "
881
+ f"out_features={last_linear.out_features}, expected={expected_classes}"
882
+ )
883
+ print(f"Verified {model_name}: classifier -> {expected_classes} classes (in={last_linear.in_features})")
884
+ return
885
+ print(f"Warning: failed to automatically verify classifier for {model_name}")
886
+
887
+ def replace_classifier(
888
+ model_name: str,
889
+ model: nn.Module,
890
+ num_classes: int,
891
+ dropout: float = 0.3,
892
+ ) -> nn.Module:
893
+ if _is_timm_model(model):
894
+ in_feat = model.num_features
895
+ orig_classifier = model.get_classifier()
896
+ print(f"[timm] {model_name}: original classifier={type(orig_classifier).__name__}, num_features={in_feat}")
897
+
898
+ model.reset_classifier(num_classes)
899
+ new_fc = model.get_classifier()
900
+
901
+ wrapped = False
902
+ if isinstance(new_fc, nn.Linear):
903
+ for parent_attr, child_attr in [
904
+ ("head", "fc"),
905
+ ("head", "head"),
906
+ (None, "head"),
907
+ (None, "classifier"),
908
+ (None, "fc"),
909
+ ]:
910
+ try:
911
+ parent = getattr(model, parent_attr) if parent_attr else model
912
+ child = getattr(parent, child_attr)
913
+ if child is new_fc:
914
+ setattr(
915
+ parent,
916
+ child_attr,
917
+ nn.Sequential(nn.Dropout(dropout), nn.Linear(in_feat, num_classes)),
918
+ )
919
+ wrapped = True
920
+ break
921
+ except AttributeError:
922
+ continue
923
+
924
+ if not wrapped:
925
+ print(f"[timm] {model_name}: reset_classifier({num_classes}) applied (no Dropout wrapper).")
926
+
927
+ _verify_classifier(model, model_name, num_classes)
928
+ return model
929
+
930
+ n = model_name
931
+
932
+ if "VGG" in n:
933
+ in_feat = model.classifier[6].in_features
934
+ model.classifier[6] = nn.Sequential(nn.Dropout(dropout), nn.Linear(in_feat, num_classes))
935
+
936
+ elif n == "inception_v3":
937
+ aux_in = model.AuxLogits.fc.in_features
938
+ model.AuxLogits.fc = nn.Sequential(nn.Dropout(dropout), nn.Linear(aux_in, num_classes))
939
+ fc_in = model.fc.in_features
940
+ model.fc = nn.Sequential(nn.Dropout(dropout), nn.Linear(fc_in, num_classes))
941
+
942
+ elif "GoogLeNet" in n:
943
+ in_feat = model.fc.in_features
944
+ model.fc = nn.Sequential(nn.Dropout(dropout), nn.Linear(in_feat, num_classes))
945
+ if hasattr(model, "aux1") and model.aux1 is not None and hasattr(model.aux1, "fc2"):
946
+ aux1_in = model.aux1.fc2.in_features
947
+ model.aux1.fc2 = nn.Sequential(nn.Dropout(dropout), nn.Linear(aux1_in, num_classes))
948
+ if hasattr(model, "aux2") and model.aux2 is not None and hasattr(model.aux2, "fc2"):
949
+ aux2_in = model.aux2.fc2.in_features
950
+ model.aux2.fc2 = nn.Sequential(nn.Dropout(dropout), nn.Linear(aux2_in, num_classes))
951
+
952
+ elif "ResNe" in n:
953
+ in_feat = model.fc.in_features
954
+ model.fc = nn.Sequential(nn.Dropout(dropout), nn.Linear(in_feat, num_classes))
955
+
956
+ elif "DenseNet" in n:
957
+ in_feat = model.classifier.in_features
958
+ model.classifier = nn.Sequential(nn.Dropout(dropout), nn.Linear(in_feat, num_classes))
959
+
960
+ elif "MobileNet" in n:
961
+ in_feat = model.classifier[-1].in_features
962
+ model.classifier[-1] = nn.Sequential(nn.Dropout(dropout), nn.Linear(in_feat, num_classes))
963
+
964
+ elif "MnasNet" in n:
965
+ in_feat = model.classifier[-1].in_features
966
+ model.classifier[-1] = nn.Sequential(nn.Dropout(dropout), nn.Linear(in_feat, num_classes))
967
+
968
+ elif "EfficientNet" in n:
969
+ in_feat = model.classifier[-1].in_features
970
+ model.classifier[-1] = nn.Sequential(nn.Dropout(dropout), nn.Linear(in_feat, num_classes))
971
+
972
+ elif "ConvNeXt" in n:
973
+ in_feat = model.classifier[-1].in_features
974
+ model.classifier[-1] = nn.Sequential(nn.Dropout(dropout), nn.Linear(in_feat, num_classes))
975
+
976
+ elif "RegNet" in n or "ShuffleNet" in n:
977
+ in_feat = model.fc.in_features
978
+ model.fc = nn.Sequential(nn.Dropout(dropout), nn.Linear(in_feat, num_classes))
979
+
980
+ elif "ViT" in n:
981
+ if hasattr(model, "heads") and hasattr(model.heads, "head"):
982
+ in_feat = model.heads.head.in_features
983
+ model.heads = nn.Sequential(nn.Dropout(dropout), nn.Linear(in_feat, num_classes))
984
+ elif hasattr(model, "head") and isinstance(model.head, nn.Linear):
985
+ in_feat = model.head.in_features
986
+ model.head = nn.Sequential(nn.Dropout(dropout), nn.Linear(in_feat, num_classes))
987
+ else:
988
+ raise ValueError(f"Cannot find classifier head for {n}")
989
+
990
+ elif "Swin" in n:
991
+ if hasattr(model, "head") and isinstance(model.head, nn.Linear):
992
+ in_feat = model.head.in_features
993
+ model.head = nn.Sequential(nn.Dropout(dropout), nn.Linear(in_feat, num_classes))
994
+ elif hasattr(model, "heads") and hasattr(model.heads, "head"):
995
+ in_feat = model.heads.head.in_features
996
+ model.heads = nn.Sequential(nn.Dropout(dropout), nn.Linear(in_feat, num_classes))
997
+ else:
998
+ raise ValueError(f"Cannot find classifier head for {n}")
999
+
1000
+ elif _is_vit_family(n):
1001
+ replaced = False
1002
+ for attr in ["heads.head", "head", "classifier"]:
1003
+ parts = attr.split(".")
1004
+ obj = model
1005
+ try:
1006
+ for p in parts:
1007
+ obj = getattr(obj, p)
1008
+ if isinstance(obj, nn.Linear):
1009
+ in_feat = obj.in_features
1010
+ parent = model
1011
+ for p in parts[:-1]:
1012
+ parent = getattr(parent, p)
1013
+ setattr(parent, parts[-1], nn.Sequential(nn.Dropout(dropout), nn.Linear(in_feat, num_classes)))
1014
+ replaced = True
1015
+ break
1016
+ except AttributeError:
1017
+ continue
1018
+ if not replaced:
1019
+ raise ValueError(f"Cannot find classifier head for {n}")
1020
+
1021
+ else:
1022
+ replaced = False
1023
+ for attr_name in ["fc", "head", "classifier"]:
1024
+ if not hasattr(model, attr_name):
1025
+ continue
1026
+ layer = getattr(model, attr_name)
1027
+ if isinstance(layer, nn.Linear):
1028
+ in_feat = layer.in_features
1029
+ setattr(model, attr_name, nn.Sequential(nn.Dropout(dropout), nn.Linear(in_feat, num_classes)))
1030
+ replaced = True
1031
+ break
1032
+ if isinstance(layer, nn.Sequential) and len(layer) > 0 and isinstance(layer[-1], nn.Linear):
1033
+ in_feat = layer[-1].in_features
1034
+ layer[-1] = nn.Sequential(nn.Dropout(dropout), nn.Linear(in_feat, num_classes))
1035
+ replaced = True
1036
+ break
1037
+ if not replaced:
1038
+ raise ValueError(f"Cannot automatically replace classifier for {n}")
1039
+
1040
+ _verify_classifier(model, model_name, num_classes)
1041
+ return model
1042
+
1043
+
1044
+ # Optimizer groups / freezing
1045
+
1046
+ def _get_head_keywords(model_name: str) -> List[str]:
1047
+ n = model_name
1048
+ if "VGG" in n:
1049
+ return ["classifier.6"]
1050
+ if n == "inception_v3":
1051
+ return ["fc.", "AuxLogits.fc"]
1052
+ if "GoogLeNet" in n:
1053
+ return ["fc.", "aux1.fc2", "aux2.fc2"]
1054
+ if "ResNe" in n:
1055
+ return ["fc."]
1056
+ if "DenseNet" in n:
1057
+ return ["classifier."]
1058
+ if "MobileNet" in n:
1059
+ return ["classifier.3", "classifier.2", "classifier."]
1060
+ if "MnasNet" in n:
1061
+ return ["classifier.1", "classifier."]
1062
+ if "EfficientNet" in n or "ConvNeXt" in n:
1063
+ return ["classifier.", "head.fc"]
1064
+ if "RegNet" in n or "ShuffleNet" in n:
1065
+ return ["fc."]
1066
+ if "ViT" in n or _is_vit_family(n):
1067
+ return ["heads.", "head.", "classifier."]
1068
+ return ["fc.", "classifier.", "head.", "heads."]
1069
+
1070
+ def get_parameter_groups(
1071
+ model_name: str,
1072
+ model: nn.Module,
1073
+ backbone_lr: float = 3e-5,
1074
+ head_lr: float = 1e-3,
1075
+ ):
1076
+ head_kw = _get_head_keywords(model_name)
1077
+ head_p, back_p = [], []
1078
+ for name, param in model.named_parameters():
1079
+ if name_matches_keywords(name, head_kw):
1080
+ head_p.append(param)
1081
+ else:
1082
+ back_p.append(param)
1083
+
1084
+ if not head_p:
1085
+ print(f"Warning: no head parameters matched for {model_name}; all params use head_lr.")
1086
+ return [{"params": list(model.parameters()), "lr": head_lr}]
1087
+
1088
+ print(
1089
+ f"Parameter groups | backbone: {sum(p.numel() for p in back_p):,} (lr={backbone_lr}) | "
1090
+ f"head: {sum(p.numel() for p in head_p):,} (lr={head_lr})"
1091
+ )
1092
+ return [{"params": back_p, "lr": backbone_lr}, {"params": head_p, "lr": head_lr}]
1093
+
1094
+ def set_backbone_trainable(model_name: str, model: nn.Module, train_backbone: bool) -> None:
1095
+ head_kw = _get_head_keywords(model_name)
1096
+ for name, param in model.named_parameters():
1097
+ is_head = name_matches_keywords(name, head_kw)
1098
+ param.requires_grad = train_backbone or is_head
1099
+
1100
+ def set_frozen_backbone_bn_eval(model_name: str, model: nn.Module) -> None:
1101
+ head_kw = _get_head_keywords(model_name)
1102
+ bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm)
1103
+
1104
+ for name, module in model.named_modules():
1105
+ if isinstance(module, bn_types) and not name_matches_keywords(name, head_kw):
1106
+ module.eval()
1107
+ for param in module.parameters():
1108
+ param.requires_grad = False
1109
+
1110
+ def configure_small_batch_behavior(model_name: str, model: nn.Module, batch_size: int) -> nn.Module:
1111
+ if batch_size >= 2:
1112
+ return model
1113
+
1114
+ if model_name == "inception_v3":
1115
+ print("batch_size=1 detected: disabling Inception auxiliary classifier.")
1116
+ if hasattr(model, "aux_logits"):
1117
+ model.aux_logits = False
1118
+ if hasattr(model, "AuxLogits"):
1119
+ model.AuxLogits = None
1120
+
1121
+ elif "GoogLeNet" in model_name:
1122
+ print("batch_size=1 detected: disabling GoogLeNet auxiliary classifiers.")
1123
+ if hasattr(model, "aux_logits"):
1124
+ model.aux_logits = False
1125
+ if hasattr(model, "aux1"):
1126
+ model.aux1 = None
1127
+ if hasattr(model, "aux2"):
1128
+ model.aux2 = None
1129
+
1130
+ return model
1131
+
1132
+ # Forward helpers
1133
+
1134
+ def _extract_logits(output):
1135
+ if torch.is_tensor(output):
1136
+ return output
1137
+ if hasattr(output, "logits") and torch.is_tensor(output.logits):
1138
+ return output.logits
1139
+ if isinstance(output, (tuple, list)) and len(output) > 0 and torch.is_tensor(output[0]):
1140
+ return output[0]
1141
+ raise TypeError("Unable to extract logits from model output.")
1142
+
1143
+ def _extract_aux_outputs(output):
1144
+ aux_outputs = []
1145
+ if isinstance(output, (tuple, list)):
1146
+ aux_outputs.extend([o for o in output[1:] if torch.is_tensor(o)])
1147
+ else:
1148
+ for attr in ["aux_logits", "aux_logits2", "aux_logits1"]:
1149
+ if hasattr(output, attr):
1150
+ aux = getattr(output, attr)
1151
+ if torch.is_tensor(aux):
1152
+ aux_outputs.append(aux)
1153
+ return aux_outputs
1154
+
1155
+ def forward_with_loss(
1156
+ model: nn.Module,
1157
+ inputs: torch.Tensor,
1158
+ labels: torch.Tensor,
1159
+ criterion,
1160
+ aux_weight: float = 0.3,
1161
+ ):
1162
+ output = model(inputs)
1163
+ logits = _extract_logits(output)
1164
+ aux_outputs = _extract_aux_outputs(output)
1165
+
1166
+ loss = criterion(logits, labels)
1167
+ if model.training and aux_outputs:
1168
+ for aux in aux_outputs:
1169
+ loss = loss + aux_weight * criterion(aux, labels)
1170
+ return logits, loss
1171
+
1172
+ # Losses
1173
+
1174
+ class FocalLoss(nn.Module):
1175
+
1176
+ def __init__(self, alpha: Optional[torch.Tensor] = None, gamma: float = 2.0, reduction: str = "mean"):
1177
+ super().__init__()
1178
+ self.alpha = alpha
1179
+ self.gamma = gamma
1180
+ self.reduction = reduction
1181
+
1182
+ def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
1183
+ log_probs = F.log_softmax(inputs, dim=1)
1184
+ log_pt = log_probs.gather(1, targets.unsqueeze(1)).squeeze(1)
1185
+ pt = log_pt.exp()
1186
+
1187
+ loss = -((1.0 - pt) ** self.gamma) * log_pt
1188
+
1189
+ if self.alpha is not None:
1190
+ alpha_t = self.alpha.to(inputs.device)[targets]
1191
+ loss = alpha_t * loss
1192
+
1193
+ if self.reduction == "mean":
1194
+ return loss.mean()
1195
+ if self.reduction == "sum":
1196
+ return loss.sum()
1197
+ return loss
1198
+
1199
+ def build_criterion(
1200
+ loss_type: str,
1201
+ class_weights: Optional[torch.Tensor] = None,
1202
+ focal_gamma: float = 2.0,
1203
+ label_smoothing: float = 0.0,
1204
+ ):
1205
+ loss_type = loss_type.lower()
1206
+ if loss_type == "focal":
1207
+ return FocalLoss(alpha=class_weights, gamma=focal_gamma, reduction="mean")
1208
+ if loss_type == "weighted_ce":
1209
+ return nn.CrossEntropyLoss(weight=class_weights, label_smoothing=label_smoothing)
1210
+ if loss_type == "ce":
1211
+ return nn.CrossEntropyLoss(label_smoothing=label_smoothing)
1212
+ raise ValueError(f"Unsupported loss_type: {loss_type}")
1213
+
1214
+
1215
+ # Training
1216
+ # - epochs 90
1217
+ # - freeze_backbone_epochs
1218
+ # - warmup_ep = freeze_backbone_epochs
1219
+ # - 验证循环使用 TTA
1220
+ def train_one_fold(
1221
+ model_name: str,
1222
+ model: nn.Module,
1223
+ train_loader,
1224
+ val_loader,
1225
+ epochs: int = 90,
1226
+ num_classes: int = 4,
1227
+ backbone_lr: float = 3e-5,
1228
+ head_lr: float = 1e-3,
1229
+ class_weights: Optional[torch.Tensor] = None,
1230
+ fold_id: int = 1,
1231
+ save_dir: Optional[Path] = None,
1232
+ freeze_backbone_epochs: int = 8,
1233
+ max_grad_norm: float = 1.0,
1234
+ primary_metric: str = PRIMARY_METRIC,
1235
+ loss_type: str = "weighted_ce",
1236
+ focal_gamma: float = 2.0,
1237
+ label_smoothing: float = 0.0,
1238
+ use_tta: bool = True,
1239
+ ):
1240
+ if save_dir is None:
1241
+ save_dir = Path(model_name)
1242
+ else:
1243
+ save_dir = Path(save_dir)
1244
+ ensure_dir(save_dir)
1245
+
1246
+ criterion = build_criterion(
1247
+ loss_type=loss_type,
1248
+ class_weights=class_weights,
1249
+ focal_gamma=focal_gamma,
1250
+ label_smoothing=label_smoothing,
1251
+ )
1252
+ print(
1253
+ f"Fold {fold_id}: Using loss_type='{loss_type}'"
1254
+ f"{' with class weights' if class_weights is not None else ''}."
1255
+ )
1256
+ print(
1257
+ f"Fold {fold_id}: backbone_lr={backbone_lr}, head_lr={head_lr}, "
1258
+ f"freeze_backbone_epochs={freeze_backbone_epochs}, "
1259
+ f"epochs={epochs}, use_tta={use_tta}."
1260
+ )
1261
+
1262
+ param_groups = get_parameter_groups(model_name, model, backbone_lr, head_lr)
1263
+ optimizer = torch.optim.AdamW(param_groups, betas=(0.9, 0.999), weight_decay=5e-4)
1264
+
1265
+
1266
+ warmup_ep = freeze_backbone_epochs
1267
+ sched_main = torch.optim.lr_scheduler.CosineAnnealingLR(
1268
+ optimizer,
1269
+ T_max=max(1, epochs - warmup_ep),
1270
+ eta_min=1e-7,
1271
+ )
1272
+ sched_warm = torch.optim.lr_scheduler.LinearLR(
1273
+ optimizer,
1274
+ start_factor=0.1,
1275
+ end_factor=1.0,
1276
+ total_iters=warmup_ep,
1277
+ )
1278
+ scheduler = torch.optim.lr_scheduler.SequentialLR(
1279
+ optimizer,
1280
+ schedulers=[sched_warm, sched_main],
1281
+ milestones=[warmup_ep],
1282
+ )
1283
+
1284
+ amp_enabled = device.type == "cuda"
1285
+ scaler = torch.cuda.amp.GradScaler(enabled=amp_enabled)
1286
+ class_names = [f"class{i}" for i in range(num_classes)]
1287
+
1288
+ best_monitor = -float("inf")
1289
+ best_results = None
1290
+ was_backbone_trainable = None
1291
+ start_epoch = 0
1292
+
1293
+ ckpt_path = save_dir / f"fold{fold_id}_checkpoint.pth"
1294
+ if ckpt_path.is_file():
1295
+ print(f"Fold {fold_id}: found epoch-level checkpoint, attempting to resume...")
1296
+ try:
1297
+ ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)
1298
+ model.load_state_dict(ckpt["model_state_dict"])
1299
+ optimizer.load_state_dict(ckpt["optimizer_state_dict"])
1300
+ scheduler.load_state_dict(ckpt["scheduler_state_dict"])
1301
+ scaler.load_state_dict(ckpt["scaler_state_dict"])
1302
+ start_epoch = ckpt["epoch"] + 1
1303
+ best_monitor = ckpt["best_monitor"]
1304
+ best_results = ckpt.get("best_results", None)
1305
+ print(
1306
+ f"Fold {fold_id}: resumed from epoch {start_epoch}/{epochs} "
1307
+ f"(best {primary_metric}={best_monitor:.2f})."
1308
+ )
1309
+ except Exception as exc:
1310
+ print(f"Fold {fold_id}: failed to load checkpoint ({exc}), training from scratch.")
1311
+ start_epoch = 0
1312
+ best_monitor = -float("inf")
1313
+ best_results = None
1314
+
1315
+ for epoch in tqdm(
1316
+ range(start_epoch, epochs),
1317
+ desc=f"Fold {fold_id}",
1318
+ leave=False,
1319
+ initial=start_epoch,
1320
+ total=epochs,
1321
+ ):
1322
+ train_backbone = epoch >= freeze_backbone_epochs
1323
+ if was_backbone_trainable is None or was_backbone_trainable != train_backbone:
1324
+ set_backbone_trainable(model_name, model, train_backbone=train_backbone)
1325
+ stage = "unfrozen" if train_backbone else "frozen"
1326
+ print(f"Fold {fold_id}: backbone is now {stage} (epoch {epoch + 1}).")
1327
+ was_backbone_trainable = train_backbone
1328
+
1329
+ model.train()
1330
+ if not train_backbone:
1331
+ set_frozen_backbone_bn_eval(model_name, model)
1332
+
1333
+ run_loss = 0.0
1334
+
1335
+ for inputs, labels, _meta in train_loader:
1336
+ inputs = inputs.to(device, non_blocking=(device.type == "cuda"))
1337
+ labels = labels.to(device, non_blocking=(device.type == "cuda"))
1338
+
1339
+ optimizer.zero_grad(set_to_none=True)
1340
+ amp_ctx = torch.cuda.amp.autocast if amp_enabled else nullcontext
1341
+ with amp_ctx():
1342
+ logits, loss = forward_with_loss(model, inputs, labels, criterion, aux_weight=0.3)
1343
+
1344
+ scaler.scale(loss).backward()
1345
+ if max_grad_norm is not None and max_grad_norm > 0:
1346
+ scaler.unscale_(optimizer)
1347
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
1348
+ scaler.step(optimizer)
1349
+ scaler.update()
1350
+
1351
+ run_loss += loss.item() * inputs.size(0)
1352
+
1353
+ scheduler.step()
1354
+ ep_loss = run_loss / len(train_loader.dataset)
1355
+
1356
+ # ---- 验证阶段:可选 TTA ----
1357
+ model.eval()
1358
+ all_t, all_p, all_paths, all_probs = [], [], [], []
1359
+ all_patients, all_image_names = [], []
1360
+
1361
+ with torch.no_grad():
1362
+ for inputs, labels, meta in val_loader:
1363
+ inputs = inputs.to(device, non_blocking=(device.type == "cuda"))
1364
+ labels = labels.to(device, non_blocking=(device.type == "cuda"))
1365
+
1366
+ if use_tta:
1367
+ # V7: 使用 TTA 推断
1368
+ probs = predict_with_tta(model, inputs, amp_enabled=amp_enabled)
1369
+ else:
1370
+ amp_ctx = torch.cuda.amp.autocast if amp_enabled else nullcontext
1371
+ with amp_ctx():
1372
+ output = model(inputs)
1373
+ logits = _extract_logits(output)
1374
+ probs = torch.softmax(logits, dim=1)
1375
+
1376
+ pred = probs.argmax(dim=1)
1377
+
1378
+ all_t.extend(labels.cpu().numpy().tolist())
1379
+ all_p.extend(pred.cpu().numpy().tolist())
1380
+ all_probs.extend(probs.cpu().numpy().tolist())
1381
+ all_paths.extend(list(meta["path"]))
1382
+ all_patients.extend(list(meta["patient"]))
1383
+ all_image_names.extend(list(meta["image_name"]))
1384
+
1385
+ metrics, report = compute_metrics(all_t, all_p, num_classes, class_names)
1386
+ monitor = metrics[primary_metric]
1387
+
1388
+ if (epoch + 1) % 5 == 0 or epoch == epochs - 1 or epoch == start_epoch:
1389
+ print(
1390
+ f"F{fold_id} E{epoch + 1}/{epochs} "
1391
+ f"Loss={ep_loss:.4f} "
1392
+ f"Macro-F1={metrics['macro_f1']:.2f}% "
1393
+ f"BA={metrics['balanced_accuracy']:.2f}% "
1394
+ f"Acc={metrics['accuracy']:.2f}%"
1395
+ f"{' [TTA]' if use_tta else ''}"
1396
+ )
1397
+
1398
+ improved = (monitor > best_monitor) or (
1399
+ np.isclose(monitor, best_monitor)
1400
+ and best_results is not None
1401
+ and metrics["balanced_accuracy"] > best_results["metrics"]["balanced_accuracy"]
1402
+ )
1403
+
1404
+ if improved:
1405
+ best_monitor = monitor
1406
+ best_results = {
1407
+ "best_epoch": epoch + 1,
1408
+ "metrics": metrics,
1409
+ "classification_report": report,
1410
+ "predictions": all_p,
1411
+ "targets": all_t,
1412
+ "image_path": all_paths,
1413
+ "patients": all_patients,
1414
+ "image_names": all_image_names,
1415
+ "probabilities": all_probs,
1416
+ "num_classes": num_classes,
1417
+ "per_class": [
1418
+ report.get(f"class{i}", {"precision": 0, "recall": 0, "f1-score": 0})
1419
+ for i in range(num_classes)
1420
+ ],
1421
+ }
1422
+ save_fold_results(best_results, save_dir, tag=f"fold{fold_id}_best")
1423
+ torch.save(model.state_dict(), save_dir / f"fold{fold_id}_best.pth")
1424
+ torch.save({
1425
+ "epoch": epoch,
1426
+ "model_state_dict": model.state_dict(),
1427
+ "optimizer_state_dict": optimizer.state_dict(),
1428
+ "scheduler_state_dict": scheduler.state_dict(),
1429
+ "scaler_state_dict": scaler.state_dict(),
1430
+ "best_monitor": best_monitor,
1431
+ "best_results": best_results,
1432
+ }, ckpt_path)
1433
+
1434
+ if ckpt_path.is_file():
1435
+ ckpt_path.unlink()
1436
+ print(f"Fold {fold_id}: removed epoch-level checkpoint (training complete).")
1437
+
1438
+ if best_results is None:
1439
+ raise RuntimeError(f"Fold {fold_id}: no valid result was produced.")
1440
+
1441
+ return best_results
1442
+
1443
+
1444
+
1445
+ def build_model_registry():
1446
+ reg = {}
1447
+
1448
+ reg["DenseNet161"] = lambda: models.densenet161(weights=models.DenseNet161_Weights.DEFAULT)
1449
+ reg["ConvNeXt_Tiny"] = lambda: models.convnext_tiny(weights=models.ConvNeXt_Tiny_Weights.DEFAULT)
1450
+ reg["ViT_B_16"] = lambda: models.vit_b_16(weights=models.ViT_B_16_Weights.DEFAULT)
1451
+
1452
+ if HAS_TIMM:
1453
+ reg["SwinV2_T"] = lambda: timm.create_model(
1454
+ "swinv2_tiny_window8_256",
1455
+ pretrained=True,
1456
+ img_size=512,
1457
+ )
1458
+ reg["DeiT3_S"] = lambda: timm.create_model(
1459
+ "deit3_small_patch16_224",
1460
+ pretrained=True,
1461
+ img_size=512,
1462
+ )
1463
+ else:
1464
+ print("Skipping timm models because timm is not installed.")
1465
+
1466
+ return reg
1467
+
1468
+
1469
+ def parse_args():
1470
+ parser = argparse.ArgumentParser(description="ROP benchmark training with patient-grouped 5-fold CV (v7).")
1471
+
1472
+ boolean_action = getattr(argparse, "BooleanOptionalAction", None)
1473
+
1474
+ parser.add_argument(
1475
+ "--excel_path",
1476
+ type=str,
1477
+ default="/media/fang/9fc99a7b-15d6-4e22-ab05-fe46e6058c39/felicia/Downloads/医生审核之后第一版12-17/部分公开数据集/公开数据集训练表_调整数据1.xlsx",
1478
+ help="Path to Excel with at least columns: patient, path, label.",
1479
+ )
1480
+ parser.add_argument("--group_col", type=str, default="patient", help="Grouping column for leakage-free split.")
1481
+ parser.add_argument("--num_classes", type=int, default=4)
1482
+ # V7: epochs 90
1483
+ parser.add_argument("--epochs", type=int, default=90)
1484
+ parser.add_argument("--n_folds", type=int, default=5)
1485
+ parser.add_argument("--batch_size", type=int, default=8)
1486
+ # V7: backbone_lr 提升至 3e-5
1487
+ parser.add_argument("--backbone_lr", type=float, default=3e-5)
1488
+ # V7: head_lr 提升至 1e-3
1489
+ parser.add_argument("--head_lr", type=float, default=1e-3)
1490
+ parser.add_argument("--random_seed", type=int, default=42)
1491
+ parser.add_argument("--num_workers", type=int, default=min(8, os.cpu_count() or 2))
1492
+
1493
+ parser.add_argument(
1494
+ "--balance_mode",
1495
+ type=str,
1496
+ default="loss",
1497
+ choices=["none", "loss", "sampler"],
1498
+ help="Imbalance handling. 'loss' computes class weights; 'sampler' uses WeightedRandomSampler.",
1499
+ )
1500
+ parser.add_argument(
1501
+ "--loss_type",
1502
+ type=str,
1503
+ default="weighted_ce",
1504
+ choices=["weighted_ce", "focal", "ce"],
1505
+ help="weighted_ce is the recommended default.",
1506
+ )
1507
+ parser.add_argument("--focal_gamma", type=float, default=2.0)
1508
+ parser.add_argument("--label_smoothing", type=float, default=0.0)
1509
+ # V7: freeze_backbone_epochs 提升至 8
1510
+ parser.add_argument("--freeze_backbone_epochs", type=int, default=8)
1511
+ parser.add_argument("--max_grad_norm", type=float, default=1.0)
1512
+ parser.add_argument("--output_root", type=str, default="runs_rop_V7_old")
1513
+
1514
+ if boolean_action is not None:
1515
+ parser.add_argument("--use_tta", action=boolean_action, default=True,
1516
+ help="Enable 4-way TTA (flip) during validation.")
1517
+ parser.add_argument("--deterministic", action=boolean_action, default=False)
1518
+ else:
1519
+ parser.add_argument("--use_tta", action="store_true", default=True,
1520
+ help="Enable 4-way TTA (flip) during validation.")
1521
+ parser.add_argument("--no_tta", dest="use_tta", action="store_false")
1522
+ parser.add_argument("--deterministic", action="store_true", default=False)
1523
+
1524
+ parser.add_argument(
1525
+ "--models",
1526
+ nargs="*",
1527
+ default=None,
1528
+ help="Optional subset of model names to train (e.g. --models DenseNet161 ViT_B_16).",
1529
+ )
1530
+
1531
+ return parser.parse_args()
1532
+
1533
+
1534
+ def main():
1535
+ args = parse_args()
1536
+ seed_everything(args.random_seed, deterministic=args.deterministic)
1537
+
1538
+ print("\nLoading data...")
1539
+ df = load_and_prepare_data(args.excel_path, group_col=args.group_col)
1540
+
1541
+ observed_num_classes = int(df["label"].nunique())
1542
+ if observed_num_classes != args.num_classes:
1543
+ raise ValueError(
1544
+ f"num_classes mismatch: args.num_classes={args.num_classes}, "
1545
+ f"but observed labels in Excel imply {observed_num_classes} classes after remapping."
1546
+ )
1547
+
1548
+ fold_splits = build_fold_splits(
1549
+ df=df,
1550
+ n_folds=args.n_folds,
1551
+ random_seed=args.random_seed,
1552
+ group_col=args.group_col,
1553
+ )
1554
+
1555
+ model_registry = build_model_registry()
1556
+ if args.models:
1557
+ selected = {k: v for k, v in model_registry.items() if k in set(args.models)}
1558
+ missing = [m for m in args.models if m not in model_registry]
1559
+ if missing:
1560
+ print(f"Warning: these models were not found and will be ignored: {missing}")
1561
+ model_registry = selected
1562
+
1563
+ print(f"\nTotal models to train: {len(model_registry)}")
1564
+ for i, name in enumerate(model_registry, 1):
1565
+ print(f"{i:2d}. {name}")
1566
+
1567
+ output_root = Path(args.output_root)
1568
+ ensure_dir(output_root)
1569
+
1570
+ global_results = {}
1571
+
1572
+ for model_idx, (model_name, model_fn) in enumerate(model_registry.items(), 1):
1573
+ print("\n" + "=" * 70)
1574
+ print(f"[{model_idx}/{len(model_registry)}] Model: {model_name}")
1575
+ print("=" * 70)
1576
+
1577
+ model_dir = output_root / model_name
1578
+ ensure_dir(model_dir)
1579
+
1580
+ summary_path = model_dir / "kfold_summary.json"
1581
+ if summary_path.is_file():
1582
+ try:
1583
+ with open(summary_path, "r", encoding="utf-8") as f:
1584
+ old = json.load(f)
1585
+ old_summary = old.get("summary", {})
1586
+ if old_summary:
1587
+ mean_primary = old_summary[PRIMARY_METRIC]["mean"]
1588
+ std_primary = old_summary[PRIMARY_METRIC]["std"]
1589
+ print(
1590
+ f"[Skip] Found existing {args.n_folds}-fold summary: "
1591
+ f"{PRIMARY_METRIC}={mean_primary:.2f}% +/- {std_primary:.2f}%"
1592
+ )
1593
+ global_results[model_name] = (mean_primary, std_primary)
1594
+ continue
1595
+ except Exception:
1596
+ pass
1597
+
1598
+ input_size = get_model_input_size(model_name)
1599
+ print(f"Input size: {input_size}x{input_size}")
1600
+
1601
+ fold_results = []
1602
+
1603
+ for fold_idx in range(args.n_folds):
1604
+ fold_id = fold_idx + 1
1605
+ print(f"\n-- Fold {fold_id}/{args.n_folds} --")
1606
+
1607
+ metrics_json = model_dir / f"fold{fold_id}_best_metrics.json"
1608
+ weight_path = model_dir / f"fold{fold_id}_best.pth"
1609
+ if metrics_json.is_file() and weight_path.is_file():
1610
+ try:
1611
+ with open(metrics_json, "r", encoding="utf-8") as f:
1612
+ cached = json.load(f)
1613
+ fold_results.append({
1614
+ "best_epoch": cached["best_epoch"],
1615
+ "metrics": cached["metrics"],
1616
+ "per_class": cached["per_class"],
1617
+ })
1618
+ print(
1619
+ f"Fold {fold_id}: cached result found "
1620
+ f"(Macro-F1={cached['metrics']['macro_f1']:.2f}%, "
1621
+ f"BA={cached['metrics']['balanced_accuracy']:.2f}%), skipped."
1622
+ )
1623
+ continue
1624
+ except Exception:
1625
+ pass
1626
+
1627
+ train_idx, val_idx = fold_splits[fold_idx]
1628
+ train_df = df.iloc[train_idx].reset_index(drop=True)
1629
+ val_df = df.iloc[val_idx].reset_index(drop=True)
1630
+
1631
+ train_patients = set(train_df[args.group_col].astype(str).tolist())
1632
+ val_patients = set(val_df[args.group_col].astype(str).tolist())
1633
+ overlap = train_patients & val_patients
1634
+ if overlap:
1635
+ raise RuntimeError(
1636
+ f"Leakage detected in fold {fold_id}: {len(overlap)} overlapping patients/groups."
1637
+ )
1638
+
1639
+ print(f"Train: {len(train_df)} | Validation: {len(val_df)}")
1640
+ print(
1641
+ f"Train patients: {train_df[args.group_col].nunique()} | "
1642
+ f"Validation patients: {val_df[args.group_col].nunique()}"
1643
+ )
1644
+ print(
1645
+ f"Train class dist: {dict(train_df['label'].value_counts().sort_index())} | "
1646
+ f"Val class dist: {dict(val_df['label'].value_counts().sort_index())}"
1647
+ )
1648
+
1649
+ train_loader, val_loader, class_weights = create_fold_loaders(
1650
+ train_df=train_df,
1651
+ val_df=val_df,
1652
+ input_size=input_size,
1653
+ batch_size=args.batch_size,
1654
+ num_classes=args.num_classes,
1655
+ balance_mode=args.balance_mode,
1656
+ num_workers=args.num_workers,
1657
+ )
1658
+
1659
+ try:
1660
+ model = model_fn()
1661
+ except Exception as exc:
1662
+ print(f"Model creation failed for {model_name}: {exc}")
1663
+ break
1664
+
1665
+ model = replace_classifier(model_name, model, args.num_classes)
1666
+ model = patch_vit_for_large_input(model, model_name, input_size)
1667
+ model = configure_small_batch_behavior(model_name, model, args.batch_size)
1668
+ model = model.to(device)
1669
+
1670
+ dummy = torch.randn(1, 3, input_size, input_size, device=device)
1671
+ model.eval()
1672
+ with torch.no_grad():
1673
+ out = model(dummy)
1674
+ out = _extract_logits(out)
1675
+ out_dim = out.shape[-1]
1676
+ if out_dim != args.num_classes:
1677
+ raise RuntimeError(
1678
+ f"Fatal: classifier replacement failed for {model_name}. "
1679
+ f"Output dim={out_dim}, expected={args.num_classes}."
1680
+ )
1681
+ print(f"Forward sanity check passed: output dim={out_dim}")
1682
+ del dummy, out
1683
+ if device.type == "cuda":
1684
+ torch.cuda.empty_cache()
1685
+
1686
+ result = train_one_fold(
1687
+ model_name=model_name,
1688
+ model=model,
1689
+ train_loader=train_loader,
1690
+ val_loader=val_loader,
1691
+ epochs=args.epochs,
1692
+ num_classes=args.num_classes,
1693
+ backbone_lr=args.backbone_lr,
1694
+ head_lr=args.head_lr,
1695
+ class_weights=class_weights,
1696
+ fold_id=fold_id,
1697
+ save_dir=model_dir,
1698
+ freeze_backbone_epochs=args.freeze_backbone_epochs,
1699
+ max_grad_norm=args.max_grad_norm,
1700
+ primary_metric=PRIMARY_METRIC,
1701
+ loss_type=args.loss_type,
1702
+ focal_gamma=args.focal_gamma,
1703
+ label_smoothing=args.label_smoothing,
1704
+ use_tta=args.use_tta,
1705
+ )
1706
+ fold_results.append(result)
1707
+
1708
+ del model
1709
+ if device.type == "cuda":
1710
+ torch.cuda.empty_cache()
1711
+
1712
+ if len(fold_results) == args.n_folds:
1713
+ mean_primary, std_primary = save_kfold_summary(
1714
+ model_name,
1715
+ fold_results,
1716
+ args.num_classes,
1717
+ model_dir,
1718
+ )
1719
+ global_results[model_name] = (mean_primary, std_primary)
1720
+ else:
1721
+ print(f"Warning: {model_name} completed only {len(fold_results)}/{args.n_folds} folds.")
1722
+
1723
+ print("\n" + "=" * 70)
1724
+ print(f"Global leaderboard ({args.n_folds}-Fold CV)")
1725
+ print(f"Sorted by: {PRIMARY_METRIC}")
1726
+ print("=" * 70)
1727
+
1728
+ sorted_results = sorted(global_results.items(), key=lambda x: x[1][0], reverse=True)
1729
+ print(f"{'Rank':<6} {'Model':<25} {PRIMARY_METRIC:>12} {'Std':>10}")
1730
+ print("-" * 62)
1731
+ for rank, (name, (mean_primary, std_primary)) in enumerate(sorted_results, 1):
1732
+ print(f"{rank:<6} {name:<25} {mean_primary:>11.2f}% {std_primary:>9.2f}%")
1733
+
1734
+ leaderboard_path = output_root / f"global_leaderboard_{PRIMARY_METRIC}.csv"
1735
+ pd.DataFrame([
1736
+ {
1737
+ "rank": idx + 1,
1738
+ "model": name,
1739
+ f"mean_{PRIMARY_METRIC}": mean_primary,
1740
+ f"std_{PRIMARY_METRIC}": std_primary,
1741
+ }
1742
+ for idx, (name, (mean_primary, std_primary)) in enumerate(sorted_results)
1743
+ ]).to_csv(leaderboard_path, index=False)
1744
+ print(f"\nLeaderboard saved to: {leaderboard_path}")
1745
+
1746
+ if __name__ == "__main__":
1747
+ main()