Multi-View-ROP / rop_patient_grouped.py
lijuanliao's picture
Upload rop_patient_grouped.py
2cbeb83 verified
import argparse
import json
import math
import os
import random
from contextlib import nullcontext
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Any
import cv2
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from sklearn.metrics import (
accuracy_score,
balanced_accuracy_score,
classification_report,
f1_score,
precision_score,
recall_score,
)
from sklearn.model_selection import GroupKFold, StratifiedKFold
from torch.utils.data import DataLoader, Dataset, WeightedRandomSampler
from torchvision import models, transforms
from torchvision.transforms import InterpolationMode
from tqdm import tqdm
try:
from sklearn.model_selection import StratifiedGroupKFold
HAS_STRATIFIED_GROUP_KFOLD = True
except Exception:
StratifiedGroupKFold = None
HAS_STRATIFIED_GROUP_KFOLD = False
try:
import timm
HAS_TIMM = True
except ImportError:
HAS_TIMM = False
print("Warning: timm is not installed. timm-based models will be skipped.")
PRIMARY_METRIC = "macro_f1"
DEFAULT_INPUT_SIZE = 512
# Utilities
def seed_everything(seed: int = 42, deterministic: bool = False) -> None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
os.environ["PYTHONHASHSEED"] = str(seed)
if deterministic:
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
try:
torch.use_deterministic_algorithms(True, warn_only=True)
except Exception:
pass
else:
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.deterministic = False
def ensure_dir(path: Path) -> None:
path.mkdir(parents=True, exist_ok=True)
def to_jsonable(obj: Any):
if isinstance(obj, dict):
return {k: to_jsonable(v) for k, v in obj.items()}
if isinstance(obj, list):
return [to_jsonable(v) for v in obj]
if isinstance(obj, tuple):
return [to_jsonable(v) for v in obj]
if isinstance(obj, (np.integer, np.floating)):
return obj.item()
return obj
def name_matches_keywords(name: str, keywords: List[str]) -> bool:
if not name:
return False
for kw in keywords:
plain_kw = kw.rstrip(".")
if kw in name or name == plain_kw or name.startswith(plain_kw + "."):
return True
return False
# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
if device.type == "cuda":
print(f"GPU: {torch.cuda.get_device_name(0)}")
print(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1024 ** 3:.1f} GB")
else:
print("Warning: CUDA is not available. Training will be much slower on CPU.")
# Model
_VIT_KEYWORDS = [
"ViT", "Swin", "Transformer", "DeiT", "MaxViT", "CoAtNet",
"EfficientFormer", "FastViT", "CaFormer",
]
def _is_vit_family(model_name: str) -> bool:
return any(kw.lower() in model_name.lower() for kw in _VIT_KEYWORDS)
def _is_timm_model(model: nn.Module) -> bool:
return hasattr(model, "get_classifier") and hasattr(model, "num_features")
MODEL_INPUT_SIZES: Dict[str, int] = {
"inception_v3": 299,
}
def get_model_input_size(model_name: str) -> int:
return MODEL_INPUT_SIZES.get(model_name, DEFAULT_INPUT_SIZE)
# Metrics / IO
def compute_metrics(
y_true: List[int],
y_pred: List[int],
num_classes: int,
class_names: List[str],
) -> Tuple[Dict, Dict]:
labels = list(range(num_classes))
report = classification_report(
y_true,
y_pred,
labels=labels,
target_names=class_names,
output_dict=True,
zero_division=0,
)
metrics = {
"accuracy": 100.0 * accuracy_score(y_true, y_pred),
"balanced_accuracy": 100.0 * balanced_accuracy_score(y_true, y_pred),
"macro_f1": 100.0 * f1_score(y_true, y_pred, labels=labels, average="macro", zero_division=0),
"macro_precision": 100.0 * precision_score(y_true, y_pred, labels=labels, average="macro", zero_division=0),
"macro_recall": 100.0 * recall_score(y_true, y_pred, labels=labels, average="macro", zero_division=0),
"weighted_f1": 100.0 * f1_score(y_true, y_pred, labels=labels, average="weighted", zero_division=0),
}
return metrics, report
def save_fold_results(results: Dict, save_dir: Path, tag: str = "best") -> None:
ensure_dir(save_dir)
report_df = pd.DataFrame(results["classification_report"]).transpose()
with open(save_dir / f"test_report_{tag}.txt", "w", encoding="utf-8") as f:
f.write(f"Primary Metric ({PRIMARY_METRIC}): {results['metrics'][PRIMARY_METRIC]:.4f}\n")
f.write(f"Accuracy: {results['metrics']['accuracy']:.4f}\n")
f.write(f"Balanced Accuracy: {results['metrics']['balanced_accuracy']:.4f}\n")
f.write(f"Macro F1: {results['metrics']['macro_f1']:.4f}\n")
f.write(f"Macro Recall: {results['metrics']['macro_recall']:.4f}\n")
f.write(f"Macro Precision: {results['metrics']['macro_precision']:.4f}\n\n")
f.write("Classification Report:\n")
f.write(report_df.to_string())
pred_df = pd.DataFrame({
"patient": results["patients"],
"image_name": results["image_names"],
"True": results["targets"],
"Predicted": results["predictions"],
"path": results["image_path"],
})
for c in range(results["num_classes"]):
pred_df[f"prob_class{c}"] = [row[c] for row in results["probabilities"]]
pred_df.to_csv(save_dir / f"predictions_{tag}.csv", index=False)
payload = {
"best_epoch": results["best_epoch"],
"primary_metric": PRIMARY_METRIC,
"metrics": results["metrics"],
"per_class": [
results["classification_report"].get(
f"class{i}", {"precision": 0, "recall": 0, "f1-score": 0}
)
for i in range(results["num_classes"])
],
}
with open(save_dir / f"{tag}_metrics.json", "w", encoding="utf-8") as f:
json.dump(to_jsonable(payload), f, indent=2, ensure_ascii=False)
def save_kfold_summary(
model_name: str,
fold_results: List[Dict],
num_classes: int,
save_dir: Path,
) -> Tuple[float, float]:
ensure_dir(save_dir)
metric_names = [
"accuracy",
"balanced_accuracy",
"macro_f1",
"macro_recall",
"macro_precision",
"weighted_f1",
]
summary = {}
for name in metric_names:
values = [r["metrics"][name] for r in fold_results]
summary[name] = {
"mean": float(np.mean(values)),
"std": float(np.std(values)),
}
lines = [
"=" * 70,
f"Model: {model_name}",
"5-Fold Cross-Validation Summary",
f"Primary Metric: {PRIMARY_METRIC}",
"=" * 70,
"",
]
for i, r in enumerate(fold_results, 1):
lines.append(
f"Fold {i}: Macro-F1={r['metrics']['macro_f1']:.2f}% | "
f"BA={r['metrics']['balanced_accuracy']:.2f}% | "
f"Acc={r['metrics']['accuracy']:.2f}% | "
f"BestEpoch={r['best_epoch']}"
)
lines.append("")
for name in metric_names:
lines.append(f"{name}: {summary[name]['mean']:.2f}% +/- {summary[name]['std']:.2f}%")
lines.append("")
lines.append("Per-class metrics (mean +/- std)")
lines.append(f"{'class':<10} {'precision':>18} {'recall':>18} {'f1-score':>18}")
per_class_summary = {}
for c in range(num_classes):
ps = [r["per_class"][c]["precision"] for r in fold_results]
rs = [r["per_class"][c]["recall"] for r in fold_results]
fs = [r["per_class"][c]["f1-score"] for r in fold_results]
per_class_summary[c] = {
"precision_mean": float(np.mean(ps)),
"precision_std": float(np.std(ps)),
"recall_mean": float(np.mean(rs)),
"recall_std": float(np.std(rs)),
"f1_mean": float(np.mean(fs)),
"f1_std": float(np.std(fs)),
}
lines.append(
f"class{c:<5} "
f"{np.mean(ps):.4f}+/-{np.std(ps):.4f}"
f"{np.mean(rs):>18.4f}+/-{np.std(rs):.4f}"
f"{np.mean(fs):>18.4f}+/-{np.std(fs):.4f}"
)
text = "\n".join(lines)
print(text)
with open(save_dir / "kfold_summary.txt", "w", encoding="utf-8") as f:
f.write(text)
with open(save_dir / "kfold_summary.json", "w", encoding="utf-8") as f:
json.dump(
to_jsonable({
"model": model_name,
"primary_metric": PRIMARY_METRIC,
"summary": summary,
"per_class": per_class_summary,
}),
f,
indent=2,
ensure_ascii=False,
)
all_targets, all_predictions, all_paths = [], [], []
all_patients, all_image_names = [], []
all_probabilities = []
pooled_ready = all(
"targets" in r and "predictions" in r and "image_path" in r and "probabilities" in r
for r in fold_results
)
if pooled_ready:
for r in fold_results:
all_targets.extend(r["targets"])
all_predictions.extend(r["predictions"])
all_paths.extend(r["image_path"])
all_patients.extend(r["patients"])
all_image_names.extend(r["image_names"])
all_probabilities.extend(r["probabilities"])
class_names = [f"class{i}" for i in range(num_classes)]
pooled_metrics, pooled_report = compute_metrics(
all_targets,
all_predictions,
num_classes,
class_names,
)
with open(save_dir / "oof_report.txt", "w", encoding="utf-8") as f:
f.write("Pooled out-of-fold metrics\n")
f.write(f"Primary Metric ({PRIMARY_METRIC}): {pooled_metrics[PRIMARY_METRIC]:.4f}\n")
for k, v in pooled_metrics.items():
f.write(f"{k}: {v:.4f}\n")
f.write("\nClassification Report:\n")
f.write(pd.DataFrame(pooled_report).transpose().to_string())
oof_df = pd.DataFrame({
"patient": all_patients,
"image_name": all_image_names,
"True": all_targets,
"Predicted": all_predictions,
"path": all_paths,
})
for c in range(num_classes):
oof_df[f"prob_class{c}"] = [row[c] for row in all_probabilities]
oof_df.to_csv(save_dir / "oof_predictions.csv", index=False)
return summary[PRIMARY_METRIC]["mean"], summary[PRIMARY_METRIC]["std"]
# 去掉黑边裁切,在 CLAHE + 绿色增强后增加眼底区域蒙版
class BlackBorderCrop:
"""Crop black borders and obvious invalid background around the fundus."""
def __init__(self, threshold: int = 10, margin_ratio: float = 0.02):
self.threshold = threshold
self.margin_ratio = margin_ratio
def __call__(self, pil_img: Image.Image) -> Image.Image:
img = np.array(pil_img.convert("RGB"))
gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
mask = gray > self.threshold
if mask.sum() < 64:
return pil_img.convert("RGB")
ys, xs = np.where(mask)
y1, y2 = ys.min(), ys.max()
x1, x2 = xs.min(), xs.max()
margin_y = int((y2 - y1 + 1) * self.margin_ratio)
margin_x = int((x2 - x1 + 1) * self.margin_ratio)
y1 = max(0, y1 - margin_y)
y2 = min(img.shape[0], y2 + margin_y + 1)
x1 = max(0, x1 - margin_x)
x2 = min(img.shape[1], x2 + margin_x + 1)
cropped = img[y1:y2, x1:x2]
return Image.fromarray(cropped)
class FundusCircularCrop:
def __init__(self, threshold: int = 8, radius_pad_ratio: float = 0.03):
self.threshold = threshold
self.radius_pad_ratio = radius_pad_ratio
def __call__(self, pil_img: Image.Image) -> Image.Image:
img = np.array(pil_img.convert("RGB"))
gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
mask = (gray > self.threshold).astype(np.uint8) * 255
kernel = np.ones((5, 5), np.uint8)
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
mask = cv2.medianBlur(mask, 5)
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if not contours:
return Image.fromarray(img)
largest = max(contours, key=cv2.contourArea)
(cx, cy), radius = cv2.minEnclosingCircle(largest)
if radius < 10:
return Image.fromarray(img)
radius = radius * (1.0 + self.radius_pad_ratio)
cx, cy, radius = float(cx), float(cy), float(radius)
x1 = max(0, int(cx - radius))
y1 = max(0, int(cy - radius))
x2 = min(img.shape[1], int(cx + radius))
y2 = min(img.shape[0], int(cy + radius))
cropped = img[y1:y2, x1:x2]
h, w = cropped.shape[:2]
if h < 2 or w < 2:
return Image.fromarray(img)
local_cx = cx - x1
local_cy = cy - y1
rr = max(1, min(int(radius), min(h, w) // 2))
yy, xx = np.ogrid[:h, :w]
circle_mask = ((xx - local_cx) ** 2 + (yy - local_cy) ** 2) <= (rr ** 2)
out = np.zeros_like(cropped)
out[circle_mask] = cropped[circle_mask]
return Image.fromarray(out)
class ResizeToSquare:
def __init__(self, size: int):
self.size = size
def __call__(self, pil_img: Image.Image) -> Image.Image:
return pil_img.resize((self.size, self.size), resample=Image.BILINEAR)
class LightCLAHE:
def __init__(self, clip_limit: float = 2.0, grid: Tuple[int, int] = (8, 8)):
self.clahe = cv2.createCLAHE(clipLimit=clip_limit, tileGridSize=grid)
def __call__(self, pil_img: Image.Image) -> Image.Image:
img = np.array(pil_img.convert("RGB"))
lab = cv2.cvtColor(img, cv2.COLOR_RGB2LAB)
l, a, b = cv2.split(lab)
l = self.clahe.apply(l)
out = cv2.cvtColor(cv2.merge([l, a, b]), cv2.COLOR_LAB2RGB)
return Image.fromarray(out)
class GreenChannelEnhancement:
def __init__(
self,
clip_limit: float = 2.5,
grid: Tuple[int, int] = (8, 8),
blend_alpha: float = 0.30,
):
self.clahe = cv2.createCLAHE(clipLimit=clip_limit, tileGridSize=grid)
self.blend_alpha = blend_alpha
def __call__(self, pil_img: Image.Image) -> Image.Image:
img = np.array(pil_img.convert("RGB"))
r, g, b = cv2.split(img)
g_eq = self.clahe.apply(g)
g_new = cv2.addWeighted(g, 1.0 - self.blend_alpha, g_eq, self.blend_alpha, 0.0)
out = cv2.merge([r, g_new, b])
return Image.fromarray(out)
class FundusEyeMask:
def __init__(
self,
threshold: int = 8,
radius_pad_ratio: float = 0.03,
morph_kernel: int = 7,
blur_kernel: int = 5,
):
self.threshold = threshold
self.radius_pad_ratio = radius_pad_ratio
self.morph_kernel = morph_kernel
self.blur_kernel = blur_kernel
def __call__(self, pil_img: Image.Image) -> Image.Image:
img = np.array(pil_img.convert("RGB"))
gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
# Robust threshold against dark background after CLAHE + green enhancement
_, mask = cv2.threshold(gray, self.threshold, 255, cv2.THRESH_BINARY)
kernel = np.ones((self.morph_kernel, self.morph_kernel), np.uint8)
mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if not contours:
return Image.fromarray(img)
largest = max(contours, key=cv2.contourArea)
(cx, cy), radius = cv2.minEnclosingCircle(largest)
if radius < 10:
return Image.fromarray(img)
radius = radius * (1.0 + self.radius_pad_ratio)
yy, xx = np.ogrid[:img.shape[0], :img.shape[1]]
circle_mask = (((xx - cx) ** 2 + (yy - cy) ** 2) <= (radius ** 2)).astype(np.uint8) * 255
if self.blur_kernel and self.blur_kernel > 1:
k = self.blur_kernel if self.blur_kernel % 2 == 1 else self.blur_kernel + 1
circle_mask = cv2.GaussianBlur(circle_mask, (k, k), 0)
circle_mask_f = (circle_mask.astype(np.float32) / 255.0)[..., None]
out = (img.astype(np.float32) * circle_mask_f).clip(0, 255).astype(np.uint8)
return Image.fromarray(out)
_light_clahe = LightCLAHE()
_green_enhance = GreenChannelEnhancement()
_eye_mask = FundusEyeMask()
_transform_cache: Dict[int, Tuple[transforms.Compose, transforms.Compose]] = {}
def build_transforms(input_size: int = DEFAULT_INPUT_SIZE):
"""
预处理流程:
- 不再使用 BlackBorderCrop
- 缩放到 input_size → CLAHE → 绿色通道增强 → 眼底区域蒙版
- 蒙版仅保留眼睛区域,屏蔽眼底边缘外的无关像素
训练增强:
- 水平翻转 + 垂直翻转
- 小角度随机旋转 (±15°) + 轻微平移 + 尺度扰动 (0.85~1.15)
- 适度 ColorJitter
- 轻微高斯模糊
"""
if input_size in _transform_cache:
return _transform_cache[input_size]
preprocess = [
ResizeToSquare(input_size),
_light_clahe,
_green_enhance,
_eye_mask,
]
train_tf = transforms.Compose(
preprocess
+ [
transforms.RandomHorizontalFlip(p=0.5),
transforms.RandomVerticalFlip(p=0.5),
transforms.RandomAffine(
degrees=15,
translate=(0.05, 0.05),
scale=(0.85, 1.15),
interpolation=InterpolationMode.BILINEAR,
fill=0,
),
transforms.ColorJitter(
brightness=0.20,
contrast=0.20,
saturation=0.10,
hue=0.02,
),
transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 0.8)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
val_tf = transforms.Compose(
preprocess
+ [
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
]
)
_transform_cache[input_size] = (train_tf, val_tf)
return train_tf, val_tf
# TTA (Test-Time Augmentation)
# 增:4 路 TTA — 原图 / 水平翻转 / 垂直翻转 / 双向翻转
def predict_with_tta(
model: nn.Module,
inputs: torch.Tensor,
amp_enabled: bool = False,
) -> torch.Tensor:
amp_ctx = torch.cuda.amp.autocast if amp_enabled else nullcontext
aug_variants = [
inputs, # 原图
inputs.flip(-1), # 水平翻转
inputs.flip(-2), # 垂直翻转
inputs.flip(-1).flip(-2), # 双向翻转
]
probs_list = []
for aug in aug_variants:
with amp_ctx():
out = model(aug)
logits = _extract_logits(out)
probs_list.append(torch.softmax(logits, dim=1))
return torch.stack(probs_list, dim=0).mean(dim=0)
# Dataset
class ImageDataset(Dataset):
def __init__(self, df: pd.DataFrame, transform=None):
self.df = df.reset_index(drop=True).copy()
self.transform = transform
self.paths = self.df["path"].astype(str).tolist()
self.labels = self.df["label"].astype(int).tolist()
self.patients = self.df["patient"].astype(str).tolist()
if "image_name" in self.df.columns:
self.image_names = self.df["image_name"].astype(str).tolist()
else:
self.image_names = [Path(p).name for p in self.paths]
def __len__(self) -> int:
return len(self.paths)
def __getitem__(self, idx: int):
img_path = self.paths[idx]
label = self.labels[idx]
try:
image = Image.open(img_path).convert("RGB")
except Exception as exc:
raise RuntimeError(f"Failed to open image: {img_path}") from exc
if self.transform is not None:
image = self.transform(image)
meta = {
"path": img_path,
"patient": self.patients[idx],
"image_name": self.image_names[idx],
}
return image, torch.tensor(label, dtype=torch.long), meta
# Data loading / grouped splitting
def validate_image_paths(df: pd.DataFrame, path_col: str = "path") -> pd.DataFrame:
total = len(df)
mask = df[path_col].apply(os.path.isfile)
missing = total - int(mask.sum())
if missing > 0:
print(f"Warning: {missing}/{total} paths do not exist and will be removed.")
df = df.loc[mask].reset_index(drop=True)
else:
print(f"All {total} image paths are valid.")
return df
def load_and_prepare_data(excel_path: str, group_col: str = "patient") -> pd.DataFrame:
df = pd.read_excel(excel_path, engine="openpyxl")
required_cols = {"path", "label", group_col}
missing_cols = required_cols - set(df.columns)
if missing_cols:
raise KeyError(f"Missing required columns in Excel: {sorted(missing_cols)}")
df = df.copy()
df[group_col] = df[group_col].astype(str).str.strip()
if df[group_col].isin(["", "nan", "None"]).any():
bad_rows = int(df[group_col].isin(["", "nan", "None"]).sum())
raise ValueError(f"Found {bad_rows} rows with invalid patient/group identifiers in column '{group_col}'.")
df["label"] = df["label"].replace({"AROP": 5})
df["label"] = pd.to_numeric(df["label"], errors="raise").astype(int)
if df["label"].min() == 1:
df["label"] = df["label"] - 1
# Merge old labels 4 and 5 into class 3 -> final 4-class setup
df["label"] = df["label"].replace({4: 3, 5: 3})
df = validate_image_paths(df, path_col="path")
if "patient" != group_col:
df["patient"] = df[group_col].astype(str)
unique_labels = sorted(df["label"].unique().tolist())
print(f"Dataset size: {len(df)} images")
print(f"Unique patients: {df[group_col].nunique()}")
print(f"Class distribution: {dict(df['label'].value_counts().sort_index())}")
print(f"Observed labels: {unique_labels}")
return df
def _approximate_group_stratified_splits(
df: pd.DataFrame,
n_folds: int,
random_seed: int,
group_col: str,
):
group_df = (
df.groupby(group_col)["label"]
.agg(lambda x: x.value_counts().index[0])
.reset_index()
)
if group_df[group_col].nunique() < n_folds:
raise ValueError(
f"Number of unique groups ({group_df[group_col].nunique()}) is smaller than n_folds={n_folds}."
)
skf = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=random_seed)
splits = []
group_ids = group_df[group_col].values
group_labels = group_df["label"].values
for group_train_idx, group_val_idx in skf.split(group_ids, group_labels):
train_groups = set(group_ids[group_train_idx])
val_groups = set(group_ids[group_val_idx])
train_idx = df.index[df[group_col].isin(train_groups)].to_numpy()
val_idx = df.index[df[group_col].isin(val_groups)].to_numpy()
splits.append((train_idx, val_idx))
return splits
def build_fold_splits(
df: pd.DataFrame,
n_folds: int,
random_seed: int,
group_col: str = "patient",
):
groups = df[group_col].astype(str).values
labels = df["label"].values
if len(np.unique(groups)) < n_folds:
raise ValueError(
f"Unique groups in '{group_col}' = {len(np.unique(groups))}, which is smaller than n_folds={n_folds}."
)
if HAS_STRATIFIED_GROUP_KFOLD:
print(
f"Using StratifiedGroupKFold with group_col='{group_col}', n_folds={n_folds}, seed={random_seed}."
)
try:
splitter = StratifiedGroupKFold(
n_splits=n_folds,
shuffle=True,
random_state=random_seed,
)
splits = list(splitter.split(df, y=labels, groups=groups))
except ValueError as exc:
print(f"StratifiedGroupKFold failed: {exc}")
print("Falling back to approximate grouped stratification using patient-majority labels.")
splits = _approximate_group_stratified_splits(df, n_folds, random_seed, group_col)
else:
print("StratifiedGroupKFold is unavailable. Falling back to approximate grouped stratification.")
splits = _approximate_group_stratified_splits(df, n_folds, random_seed, group_col)
for fold_id, (train_idx, val_idx) in enumerate(splits, 1):
train_groups = set(df.iloc[train_idx][group_col].astype(str).tolist())
val_groups = set(df.iloc[val_idx][group_col].astype(str).tolist())
overlap = train_groups & val_groups
if overlap:
raise RuntimeError(
f"Data leakage detected in fold {fold_id}: {len(overlap)} overlapping groups."
)
return splits
def _compute_class_weights(train_df: pd.DataFrame, num_classes: int) -> torch.Tensor:
counts = train_df["label"].value_counts().sort_index()
total = len(train_df)
weights = [total / (num_classes * counts.get(c, 1)) for c in range(num_classes)]
return torch.tensor(weights, dtype=torch.float32, device=device)
def _make_weighted_sampler(train_df: pd.DataFrame) -> WeightedRandomSampler:
counts = train_df["label"].value_counts().to_dict()
sample_weights = train_df["label"].map(lambda x: 1.0 / counts[x]).astype(float).values
sample_weights = torch.as_tensor(sample_weights, dtype=torch.double)
return WeightedRandomSampler(sample_weights, num_samples=len(sample_weights), replacement=True)
def create_fold_loaders(
train_df: pd.DataFrame,
val_df: pd.DataFrame,
input_size: int = DEFAULT_INPUT_SIZE,
batch_size: int = 8,
num_classes: int = 4,
balance_mode: str = "loss",
num_workers: int = 4,
):
train_tf, val_tf = build_transforms(input_size)
sampler = None
class_weights = None
if balance_mode == "sampler":+
sampler = _make_weighted_sampler(train_df)
print("Training loader uses WeightedRandomSampler for class balancing.")
elif balance_mode == "loss":
class_weights = _compute_class_weights(train_df, num_classes)
print(f"Training loss uses class weights: {class_weights.detach().cpu().numpy().tolist()}")
else:
print("No imbalance correction is used.")
drop_last = (batch_size > 1) and (len(train_df) % batch_size == 1)
if drop_last:
print(
f"Training loader will drop the last singleton batch "
f"(train_size={len(train_df)}, batch_size={batch_size}) to avoid BatchNorm issues."
)
pin_memory = device.type == "cuda"
train_loader = DataLoader(
ImageDataset(train_df, train_tf),
batch_size=batch_size,
shuffle=(sampler is None),
sampler=sampler,
num_workers=num_workers,
pin_memory=pin_memory,
drop_last=drop_last,
persistent_workers=(num_workers > 0),
)
val_loader = DataLoader(
ImageDataset(val_df, val_tf),
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=pin_memory,
persistent_workers=(num_workers > 0),
)
return train_loader, val_loader, class_weights
# ViT positional embedding interpolation
def patch_vit_for_large_input(
model: nn.Module,
model_name: str,
input_size: int,
) -> nn.Module:
if "ViT" not in model_name:
return model
if not (hasattr(model, "encoder") and hasattr(model.encoder, "pos_embedding")):
print(f"Warning: cannot find pos_embedding for {model_name}, skip interpolation.")
return model
patch_size = model.patch_size
expected_patches = (input_size // patch_size) ** 2
pos_embed = model.encoder.pos_embedding
current_patches = pos_embed.shape[1] - 1
if current_patches == expected_patches:
print(f"[ViT] pos_embedding already matches input_size={input_size}, no interpolation needed.")
return model
print(
f"[ViT] Interpolating pos_embedding: {current_patches} -> {expected_patches} patches "
f"for input_size={input_size}."
)
cls_token = pos_embed[:, :1, :]
patch_tokens = pos_embed[:, 1:, :]
dim = patch_tokens.shape[-1]
h_old = w_old = int(math.sqrt(current_patches))
h_new = w_new = int(math.sqrt(expected_patches))
patch_tokens = (
patch_tokens
.reshape(1, h_old, w_old, dim)
.permute(0, 3, 1, 2)
.float()
)
patch_tokens = F.interpolate(
patch_tokens,
size=(h_new, w_new),
mode="bicubic",
align_corners=False,
)
patch_tokens = patch_tokens.permute(0, 2, 3, 1).reshape(1, expected_patches, dim)
model.encoder.pos_embedding = nn.Parameter(torch.cat([cls_token, patch_tokens], dim=1))
if hasattr(model, "image_size"):
model.image_size = input_size
return model
# Classifier replacement
def _find_last_linear(module: nn.Module):
if isinstance(module, nn.Linear):
return module
if isinstance(module, nn.Sequential):
for child in reversed(list(module.children())):
result = _find_last_linear(child)
if result is not None:
return result
if hasattr(module, "head") and isinstance(module.head, (nn.Linear, nn.Sequential)):
return _find_last_linear(module.head)
return None
def _verify_classifier(model: nn.Module, model_name: str, expected_classes: int) -> None:
for attr_name in ["fc", "head", "classifier", "heads"]:
if not hasattr(model, attr_name):
continue
layer = getattr(model, attr_name)
last_linear = _find_last_linear(layer)
if last_linear is not None:
if last_linear.out_features != expected_classes:
raise RuntimeError(
f"Classifier replacement failed for {model_name}: "
f"out_features={last_linear.out_features}, expected={expected_classes}"
)
print(f"Verified {model_name}: classifier -> {expected_classes} classes (in={last_linear.in_features})")
return
print(f"Warning: failed to automatically verify classifier for {model_name}")
def replace_classifier(
model_name: str,
model: nn.Module,
num_classes: int,
dropout: float = 0.3,
) -> nn.Module:
if _is_timm_model(model):
in_feat = model.num_features
orig_classifier = model.get_classifier()
print(f"[timm] {model_name}: original classifier={type(orig_classifier).__name__}, num_features={in_feat}")
model.reset_classifier(num_classes)
new_fc = model.get_classifier()
wrapped = False
if isinstance(new_fc, nn.Linear):
for parent_attr, child_attr in [
("head", "fc"),
("head", "head"),
(None, "head"),
(None, "classifier"),
(None, "fc"),
]:
try:
parent = getattr(model, parent_attr) if parent_attr else model
child = getattr(parent, child_attr)
if child is new_fc:
setattr(
parent,
child_attr,
nn.Sequential(nn.Dropout(dropout), nn.Linear(in_feat, num_classes)),
)
wrapped = True
break
except AttributeError:
continue
if not wrapped:
print(f"[timm] {model_name}: reset_classifier({num_classes}) applied (no Dropout wrapper).")
_verify_classifier(model, model_name, num_classes)
return model
n = model_name
if "VGG" in n:
in_feat = model.classifier[6].in_features
model.classifier[6] = nn.Sequential(nn.Dropout(dropout), nn.Linear(in_feat, num_classes))
elif n == "inception_v3":
aux_in = model.AuxLogits.fc.in_features
model.AuxLogits.fc = nn.Sequential(nn.Dropout(dropout), nn.Linear(aux_in, num_classes))
fc_in = model.fc.in_features
model.fc = nn.Sequential(nn.Dropout(dropout), nn.Linear(fc_in, num_classes))
elif "GoogLeNet" in n:
in_feat = model.fc.in_features
model.fc = nn.Sequential(nn.Dropout(dropout), nn.Linear(in_feat, num_classes))
if hasattr(model, "aux1") and model.aux1 is not None and hasattr(model.aux1, "fc2"):
aux1_in = model.aux1.fc2.in_features
model.aux1.fc2 = nn.Sequential(nn.Dropout(dropout), nn.Linear(aux1_in, num_classes))
if hasattr(model, "aux2") and model.aux2 is not None and hasattr(model.aux2, "fc2"):
aux2_in = model.aux2.fc2.in_features
model.aux2.fc2 = nn.Sequential(nn.Dropout(dropout), nn.Linear(aux2_in, num_classes))
elif "ResNe" in n:
in_feat = model.fc.in_features
model.fc = nn.Sequential(nn.Dropout(dropout), nn.Linear(in_feat, num_classes))
elif "DenseNet" in n:
in_feat = model.classifier.in_features
model.classifier = nn.Sequential(nn.Dropout(dropout), nn.Linear(in_feat, num_classes))
elif "MobileNet" in n:
in_feat = model.classifier[-1].in_features
model.classifier[-1] = nn.Sequential(nn.Dropout(dropout), nn.Linear(in_feat, num_classes))
elif "MnasNet" in n:
in_feat = model.classifier[-1].in_features
model.classifier[-1] = nn.Sequential(nn.Dropout(dropout), nn.Linear(in_feat, num_classes))
elif "EfficientNet" in n:
in_feat = model.classifier[-1].in_features
model.classifier[-1] = nn.Sequential(nn.Dropout(dropout), nn.Linear(in_feat, num_classes))
elif "ConvNeXt" in n:
in_feat = model.classifier[-1].in_features
model.classifier[-1] = nn.Sequential(nn.Dropout(dropout), nn.Linear(in_feat, num_classes))
elif "RegNet" in n or "ShuffleNet" in n:
in_feat = model.fc.in_features
model.fc = nn.Sequential(nn.Dropout(dropout), nn.Linear(in_feat, num_classes))
elif "ViT" in n:
if hasattr(model, "heads") and hasattr(model.heads, "head"):
in_feat = model.heads.head.in_features
model.heads = nn.Sequential(nn.Dropout(dropout), nn.Linear(in_feat, num_classes))
elif hasattr(model, "head") and isinstance(model.head, nn.Linear):
in_feat = model.head.in_features
model.head = nn.Sequential(nn.Dropout(dropout), nn.Linear(in_feat, num_classes))
else:
raise ValueError(f"Cannot find classifier head for {n}")
elif "Swin" in n:
if hasattr(model, "head") and isinstance(model.head, nn.Linear):
in_feat = model.head.in_features
model.head = nn.Sequential(nn.Dropout(dropout), nn.Linear(in_feat, num_classes))
elif hasattr(model, "heads") and hasattr(model.heads, "head"):
in_feat = model.heads.head.in_features
model.heads = nn.Sequential(nn.Dropout(dropout), nn.Linear(in_feat, num_classes))
else:
raise ValueError(f"Cannot find classifier head for {n}")
elif _is_vit_family(n):
replaced = False
for attr in ["heads.head", "head", "classifier"]:
parts = attr.split(".")
obj = model
try:
for p in parts:
obj = getattr(obj, p)
if isinstance(obj, nn.Linear):
in_feat = obj.in_features
parent = model
for p in parts[:-1]:
parent = getattr(parent, p)
setattr(parent, parts[-1], nn.Sequential(nn.Dropout(dropout), nn.Linear(in_feat, num_classes)))
replaced = True
break
except AttributeError:
continue
if not replaced:
raise ValueError(f"Cannot find classifier head for {n}")
else:
replaced = False
for attr_name in ["fc", "head", "classifier"]:
if not hasattr(model, attr_name):
continue
layer = getattr(model, attr_name)
if isinstance(layer, nn.Linear):
in_feat = layer.in_features
setattr(model, attr_name, nn.Sequential(nn.Dropout(dropout), nn.Linear(in_feat, num_classes)))
replaced = True
break
if isinstance(layer, nn.Sequential) and len(layer) > 0 and isinstance(layer[-1], nn.Linear):
in_feat = layer[-1].in_features
layer[-1] = nn.Sequential(nn.Dropout(dropout), nn.Linear(in_feat, num_classes))
replaced = True
break
if not replaced:
raise ValueError(f"Cannot automatically replace classifier for {n}")
_verify_classifier(model, model_name, num_classes)
return model
# Optimizer groups / freezing
def _get_head_keywords(model_name: str) -> List[str]:
n = model_name
if "VGG" in n:
return ["classifier.6"]
if n == "inception_v3":
return ["fc.", "AuxLogits.fc"]
if "GoogLeNet" in n:
return ["fc.", "aux1.fc2", "aux2.fc2"]
if "ResNe" in n:
return ["fc."]
if "DenseNet" in n:
return ["classifier."]
if "MobileNet" in n:
return ["classifier.3", "classifier.2", "classifier."]
if "MnasNet" in n:
return ["classifier.1", "classifier."]
if "EfficientNet" in n or "ConvNeXt" in n:
return ["classifier.", "head.fc"]
if "RegNet" in n or "ShuffleNet" in n:
return ["fc."]
if "ViT" in n or _is_vit_family(n):
return ["heads.", "head.", "classifier."]
return ["fc.", "classifier.", "head.", "heads."]
def get_parameter_groups(
model_name: str,
model: nn.Module,
backbone_lr: float = 3e-5,
head_lr: float = 1e-3,
):
head_kw = _get_head_keywords(model_name)
head_p, back_p = [], []
for name, param in model.named_parameters():
if name_matches_keywords(name, head_kw):
head_p.append(param)
else:
back_p.append(param)
if not head_p:
print(f"Warning: no head parameters matched for {model_name}; all params use head_lr.")
return [{"params": list(model.parameters()), "lr": head_lr}]
print(
f"Parameter groups | backbone: {sum(p.numel() for p in back_p):,} (lr={backbone_lr}) | "
f"head: {sum(p.numel() for p in head_p):,} (lr={head_lr})"
)
return [{"params": back_p, "lr": backbone_lr}, {"params": head_p, "lr": head_lr}]
def set_backbone_trainable(model_name: str, model: nn.Module, train_backbone: bool) -> None:
head_kw = _get_head_keywords(model_name)
for name, param in model.named_parameters():
is_head = name_matches_keywords(name, head_kw)
param.requires_grad = train_backbone or is_head
def set_frozen_backbone_bn_eval(model_name: str, model: nn.Module) -> None:
head_kw = _get_head_keywords(model_name)
bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm)
for name, module in model.named_modules():
if isinstance(module, bn_types) and not name_matches_keywords(name, head_kw):
module.eval()
for param in module.parameters():
param.requires_grad = False
def configure_small_batch_behavior(model_name: str, model: nn.Module, batch_size: int) -> nn.Module:
if batch_size >= 2:
return model
if model_name == "inception_v3":
print("batch_size=1 detected: disabling Inception auxiliary classifier.")
if hasattr(model, "aux_logits"):
model.aux_logits = False
if hasattr(model, "AuxLogits"):
model.AuxLogits = None
elif "GoogLeNet" in model_name:
print("batch_size=1 detected: disabling GoogLeNet auxiliary classifiers.")
if hasattr(model, "aux_logits"):
model.aux_logits = False
if hasattr(model, "aux1"):
model.aux1 = None
if hasattr(model, "aux2"):
model.aux2 = None
return model
# Forward helpers
def _extract_logits(output):
if torch.is_tensor(output):
return output
if hasattr(output, "logits") and torch.is_tensor(output.logits):
return output.logits
if isinstance(output, (tuple, list)) and len(output) > 0 and torch.is_tensor(output[0]):
return output[0]
raise TypeError("Unable to extract logits from model output.")
def _extract_aux_outputs(output):
aux_outputs = []
if isinstance(output, (tuple, list)):
aux_outputs.extend([o for o in output[1:] if torch.is_tensor(o)])
else:
for attr in ["aux_logits", "aux_logits2", "aux_logits1"]:
if hasattr(output, attr):
aux = getattr(output, attr)
if torch.is_tensor(aux):
aux_outputs.append(aux)
return aux_outputs
def forward_with_loss(
model: nn.Module,
inputs: torch.Tensor,
labels: torch.Tensor,
criterion,
aux_weight: float = 0.3,
):
output = model(inputs)
logits = _extract_logits(output)
aux_outputs = _extract_aux_outputs(output)
loss = criterion(logits, labels)
if model.training and aux_outputs:
for aux in aux_outputs:
loss = loss + aux_weight * criterion(aux, labels)
return logits, loss
# Losses
class FocalLoss(nn.Module):
def __init__(self, alpha: Optional[torch.Tensor] = None, gamma: float = 2.0, reduction: str = "mean"):
super().__init__()
self.alpha = alpha
self.gamma = gamma
self.reduction = reduction
def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
log_probs = F.log_softmax(inputs, dim=1)
log_pt = log_probs.gather(1, targets.unsqueeze(1)).squeeze(1)
pt = log_pt.exp()
loss = -((1.0 - pt) ** self.gamma) * log_pt
if self.alpha is not None:
alpha_t = self.alpha.to(inputs.device)[targets]
loss = alpha_t * loss
if self.reduction == "mean":
return loss.mean()
if self.reduction == "sum":
return loss.sum()
return loss
def build_criterion(
loss_type: str,
class_weights: Optional[torch.Tensor] = None,
focal_gamma: float = 2.0,
label_smoothing: float = 0.0,
):
loss_type = loss_type.lower()
if loss_type == "focal":
return FocalLoss(alpha=class_weights, gamma=focal_gamma, reduction="mean")
if loss_type == "weighted_ce":
return nn.CrossEntropyLoss(weight=class_weights, label_smoothing=label_smoothing)
if loss_type == "ce":
return nn.CrossEntropyLoss(label_smoothing=label_smoothing)
raise ValueError(f"Unsupported loss_type: {loss_type}")
# Training
# - epochs 90
# - freeze_backbone_epochs
# - warmup_ep = freeze_backbone_epochs
# - 验证循环使用 TTA
def train_one_fold(
model_name: str,
model: nn.Module,
train_loader,
val_loader,
epochs: int = 90,
num_classes: int = 4,
backbone_lr: float = 3e-5,
head_lr: float = 1e-3,
class_weights: Optional[torch.Tensor] = None,
fold_id: int = 1,
save_dir: Optional[Path] = None,
freeze_backbone_epochs: int = 8,
max_grad_norm: float = 1.0,
primary_metric: str = PRIMARY_METRIC,
loss_type: str = "weighted_ce",
focal_gamma: float = 2.0,
label_smoothing: float = 0.0,
use_tta: bool = True,
):
if save_dir is None:
save_dir = Path(model_name)
else:
save_dir = Path(save_dir)
ensure_dir(save_dir)
criterion = build_criterion(
loss_type=loss_type,
class_weights=class_weights,
focal_gamma=focal_gamma,
label_smoothing=label_smoothing,
)
print(
f"Fold {fold_id}: Using loss_type='{loss_type}'"
f"{' with class weights' if class_weights is not None else ''}."
)
print(
f"Fold {fold_id}: backbone_lr={backbone_lr}, head_lr={head_lr}, "
f"freeze_backbone_epochs={freeze_backbone_epochs}, "
f"epochs={epochs}, use_tta={use_tta}."
)
param_groups = get_parameter_groups(model_name, model, backbone_lr, head_lr)
optimizer = torch.optim.AdamW(param_groups, betas=(0.9, 0.999), weight_decay=5e-4)
warmup_ep = freeze_backbone_epochs
sched_main = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer,
T_max=max(1, epochs - warmup_ep),
eta_min=1e-7,
)
sched_warm = torch.optim.lr_scheduler.LinearLR(
optimizer,
start_factor=0.1,
end_factor=1.0,
total_iters=warmup_ep,
)
scheduler = torch.optim.lr_scheduler.SequentialLR(
optimizer,
schedulers=[sched_warm, sched_main],
milestones=[warmup_ep],
)
amp_enabled = device.type == "cuda"
scaler = torch.cuda.amp.GradScaler(enabled=amp_enabled)
class_names = [f"class{i}" for i in range(num_classes)]
best_monitor = -float("inf")
best_results = None
was_backbone_trainable = None
start_epoch = 0
ckpt_path = save_dir / f"fold{fold_id}_checkpoint.pth"
if ckpt_path.is_file():
print(f"Fold {fold_id}: found epoch-level checkpoint, attempting to resume...")
try:
ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)
model.load_state_dict(ckpt["model_state_dict"])
optimizer.load_state_dict(ckpt["optimizer_state_dict"])
scheduler.load_state_dict(ckpt["scheduler_state_dict"])
scaler.load_state_dict(ckpt["scaler_state_dict"])
start_epoch = ckpt["epoch"] + 1
best_monitor = ckpt["best_monitor"]
best_results = ckpt.get("best_results", None)
print(
f"Fold {fold_id}: resumed from epoch {start_epoch}/{epochs} "
f"(best {primary_metric}={best_monitor:.2f})."
)
except Exception as exc:
print(f"Fold {fold_id}: failed to load checkpoint ({exc}), training from scratch.")
start_epoch = 0
best_monitor = -float("inf")
best_results = None
for epoch in tqdm(
range(start_epoch, epochs),
desc=f"Fold {fold_id}",
leave=False,
initial=start_epoch,
total=epochs,
):
train_backbone = epoch >= freeze_backbone_epochs
if was_backbone_trainable is None or was_backbone_trainable != train_backbone:
set_backbone_trainable(model_name, model, train_backbone=train_backbone)
stage = "unfrozen" if train_backbone else "frozen"
print(f"Fold {fold_id}: backbone is now {stage} (epoch {epoch + 1}).")
was_backbone_trainable = train_backbone
model.train()
if not train_backbone:
set_frozen_backbone_bn_eval(model_name, model)
run_loss = 0.0
for inputs, labels, _meta in train_loader:
inputs = inputs.to(device, non_blocking=(device.type == "cuda"))
labels = labels.to(device, non_blocking=(device.type == "cuda"))
optimizer.zero_grad(set_to_none=True)
amp_ctx = torch.cuda.amp.autocast if amp_enabled else nullcontext
with amp_ctx():
logits, loss = forward_with_loss(model, inputs, labels, criterion, aux_weight=0.3)
scaler.scale(loss).backward()
if max_grad_norm is not None and max_grad_norm > 0:
scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm)
scaler.step(optimizer)
scaler.update()
run_loss += loss.item() * inputs.size(0)
scheduler.step()
ep_loss = run_loss / len(train_loader.dataset)
# ---- 验证阶段:可选 TTA ----
model.eval()
all_t, all_p, all_paths, all_probs = [], [], [], []
all_patients, all_image_names = [], []
with torch.no_grad():
for inputs, labels, meta in val_loader:
inputs = inputs.to(device, non_blocking=(device.type == "cuda"))
labels = labels.to(device, non_blocking=(device.type == "cuda"))
if use_tta:
# V7: 使用 TTA 推断
probs = predict_with_tta(model, inputs, amp_enabled=amp_enabled)
else:
amp_ctx = torch.cuda.amp.autocast if amp_enabled else nullcontext
with amp_ctx():
output = model(inputs)
logits = _extract_logits(output)
probs = torch.softmax(logits, dim=1)
pred = probs.argmax(dim=1)
all_t.extend(labels.cpu().numpy().tolist())
all_p.extend(pred.cpu().numpy().tolist())
all_probs.extend(probs.cpu().numpy().tolist())
all_paths.extend(list(meta["path"]))
all_patients.extend(list(meta["patient"]))
all_image_names.extend(list(meta["image_name"]))
metrics, report = compute_metrics(all_t, all_p, num_classes, class_names)
monitor = metrics[primary_metric]
if (epoch + 1) % 5 == 0 or epoch == epochs - 1 or epoch == start_epoch:
print(
f"F{fold_id} E{epoch + 1}/{epochs} "
f"Loss={ep_loss:.4f} "
f"Macro-F1={metrics['macro_f1']:.2f}% "
f"BA={metrics['balanced_accuracy']:.2f}% "
f"Acc={metrics['accuracy']:.2f}%"
f"{' [TTA]' if use_tta else ''}"
)
improved = (monitor > best_monitor) or (
np.isclose(monitor, best_monitor)
and best_results is not None
and metrics["balanced_accuracy"] > best_results["metrics"]["balanced_accuracy"]
)
if improved:
best_monitor = monitor
best_results = {
"best_epoch": epoch + 1,
"metrics": metrics,
"classification_report": report,
"predictions": all_p,
"targets": all_t,
"image_path": all_paths,
"patients": all_patients,
"image_names": all_image_names,
"probabilities": all_probs,
"num_classes": num_classes,
"per_class": [
report.get(f"class{i}", {"precision": 0, "recall": 0, "f1-score": 0})
for i in range(num_classes)
],
}
save_fold_results(best_results, save_dir, tag=f"fold{fold_id}_best")
torch.save(model.state_dict(), save_dir / f"fold{fold_id}_best.pth")
torch.save({
"epoch": epoch,
"model_state_dict": model.state_dict(),
"optimizer_state_dict": optimizer.state_dict(),
"scheduler_state_dict": scheduler.state_dict(),
"scaler_state_dict": scaler.state_dict(),
"best_monitor": best_monitor,
"best_results": best_results,
}, ckpt_path)
if ckpt_path.is_file():
ckpt_path.unlink()
print(f"Fold {fold_id}: removed epoch-level checkpoint (training complete).")
if best_results is None:
raise RuntimeError(f"Fold {fold_id}: no valid result was produced.")
return best_results
def build_model_registry():
reg = {}
reg["DenseNet161"] = lambda: models.densenet161(weights=models.DenseNet161_Weights.DEFAULT)
reg["ConvNeXt_Tiny"] = lambda: models.convnext_tiny(weights=models.ConvNeXt_Tiny_Weights.DEFAULT)
reg["ViT_B_16"] = lambda: models.vit_b_16(weights=models.ViT_B_16_Weights.DEFAULT)
if HAS_TIMM:
reg["SwinV2_T"] = lambda: timm.create_model(
"swinv2_tiny_window8_256",
pretrained=True,
img_size=512,
)
reg["DeiT3_S"] = lambda: timm.create_model(
"deit3_small_patch16_224",
pretrained=True,
img_size=512,
)
else:
print("Skipping timm models because timm is not installed.")
return reg
def parse_args():
parser = argparse.ArgumentParser(description="ROP benchmark training with patient-grouped 5-fold CV (v7).")
boolean_action = getattr(argparse, "BooleanOptionalAction", None)
parser.add_argument(
"--excel_path",
type=str,
default="/media/fang/9fc99a7b-15d6-4e22-ab05-fe46e6058c39/felicia/Downloads/医生审核之后第一版12-17/部分公开数据集/公开数据集训练表_调整数据1.xlsx",
help="Path to Excel with at least columns: patient, path, label.",
)
parser.add_argument("--group_col", type=str, default="patient", help="Grouping column for leakage-free split.")
parser.add_argument("--num_classes", type=int, default=4)
# V7: epochs 90
parser.add_argument("--epochs", type=int, default=90)
parser.add_argument("--n_folds", type=int, default=5)
parser.add_argument("--batch_size", type=int, default=8)
# V7: backbone_lr 提升至 3e-5
parser.add_argument("--backbone_lr", type=float, default=3e-5)
# V7: head_lr 提升至 1e-3
parser.add_argument("--head_lr", type=float, default=1e-3)
parser.add_argument("--random_seed", type=int, default=42)
parser.add_argument("--num_workers", type=int, default=min(8, os.cpu_count() or 2))
parser.add_argument(
"--balance_mode",
type=str,
default="loss",
choices=["none", "loss", "sampler"],
help="Imbalance handling. 'loss' computes class weights; 'sampler' uses WeightedRandomSampler.",
)
parser.add_argument(
"--loss_type",
type=str,
default="weighted_ce",
choices=["weighted_ce", "focal", "ce"],
help="weighted_ce is the recommended default.",
)
parser.add_argument("--focal_gamma", type=float, default=2.0)
parser.add_argument("--label_smoothing", type=float, default=0.0)
# V7: freeze_backbone_epochs 提升至 8
parser.add_argument("--freeze_backbone_epochs", type=int, default=8)
parser.add_argument("--max_grad_norm", type=float, default=1.0)
parser.add_argument("--output_root", type=str, default="runs_rop_V7_old")
if boolean_action is not None:
parser.add_argument("--use_tta", action=boolean_action, default=True,
help="Enable 4-way TTA (flip) during validation.")
parser.add_argument("--deterministic", action=boolean_action, default=False)
else:
parser.add_argument("--use_tta", action="store_true", default=True,
help="Enable 4-way TTA (flip) during validation.")
parser.add_argument("--no_tta", dest="use_tta", action="store_false")
parser.add_argument("--deterministic", action="store_true", default=False)
parser.add_argument(
"--models",
nargs="*",
default=None,
help="Optional subset of model names to train (e.g. --models DenseNet161 ViT_B_16).",
)
return parser.parse_args()
def main():
args = parse_args()
seed_everything(args.random_seed, deterministic=args.deterministic)
print("\nLoading data...")
df = load_and_prepare_data(args.excel_path, group_col=args.group_col)
observed_num_classes = int(df["label"].nunique())
if observed_num_classes != args.num_classes:
raise ValueError(
f"num_classes mismatch: args.num_classes={args.num_classes}, "
f"but observed labels in Excel imply {observed_num_classes} classes after remapping."
)
fold_splits = build_fold_splits(
df=df,
n_folds=args.n_folds,
random_seed=args.random_seed,
group_col=args.group_col,
)
model_registry = build_model_registry()
if args.models:
selected = {k: v for k, v in model_registry.items() if k in set(args.models)}
missing = [m for m in args.models if m not in model_registry]
if missing:
print(f"Warning: these models were not found and will be ignored: {missing}")
model_registry = selected
print(f"\nTotal models to train: {len(model_registry)}")
for i, name in enumerate(model_registry, 1):
print(f"{i:2d}. {name}")
output_root = Path(args.output_root)
ensure_dir(output_root)
global_results = {}
for model_idx, (model_name, model_fn) in enumerate(model_registry.items(), 1):
print("\n" + "=" * 70)
print(f"[{model_idx}/{len(model_registry)}] Model: {model_name}")
print("=" * 70)
model_dir = output_root / model_name
ensure_dir(model_dir)
summary_path = model_dir / "kfold_summary.json"
if summary_path.is_file():
try:
with open(summary_path, "r", encoding="utf-8") as f:
old = json.load(f)
old_summary = old.get("summary", {})
if old_summary:
mean_primary = old_summary[PRIMARY_METRIC]["mean"]
std_primary = old_summary[PRIMARY_METRIC]["std"]
print(
f"[Skip] Found existing {args.n_folds}-fold summary: "
f"{PRIMARY_METRIC}={mean_primary:.2f}% +/- {std_primary:.2f}%"
)
global_results[model_name] = (mean_primary, std_primary)
continue
except Exception:
pass
input_size = get_model_input_size(model_name)
print(f"Input size: {input_size}x{input_size}")
fold_results = []
for fold_idx in range(args.n_folds):
fold_id = fold_idx + 1
print(f"\n-- Fold {fold_id}/{args.n_folds} --")
metrics_json = model_dir / f"fold{fold_id}_best_metrics.json"
weight_path = model_dir / f"fold{fold_id}_best.pth"
if metrics_json.is_file() and weight_path.is_file():
try:
with open(metrics_json, "r", encoding="utf-8") as f:
cached = json.load(f)
fold_results.append({
"best_epoch": cached["best_epoch"],
"metrics": cached["metrics"],
"per_class": cached["per_class"],
})
print(
f"Fold {fold_id}: cached result found "
f"(Macro-F1={cached['metrics']['macro_f1']:.2f}%, "
f"BA={cached['metrics']['balanced_accuracy']:.2f}%), skipped."
)
continue
except Exception:
pass
train_idx, val_idx = fold_splits[fold_idx]
train_df = df.iloc[train_idx].reset_index(drop=True)
val_df = df.iloc[val_idx].reset_index(drop=True)
train_patients = set(train_df[args.group_col].astype(str).tolist())
val_patients = set(val_df[args.group_col].astype(str).tolist())
overlap = train_patients & val_patients
if overlap:
raise RuntimeError(
f"Leakage detected in fold {fold_id}: {len(overlap)} overlapping patients/groups."
)
print(f"Train: {len(train_df)} | Validation: {len(val_df)}")
print(
f"Train patients: {train_df[args.group_col].nunique()} | "
f"Validation patients: {val_df[args.group_col].nunique()}"
)
print(
f"Train class dist: {dict(train_df['label'].value_counts().sort_index())} | "
f"Val class dist: {dict(val_df['label'].value_counts().sort_index())}"
)
train_loader, val_loader, class_weights = create_fold_loaders(
train_df=train_df,
val_df=val_df,
input_size=input_size,
batch_size=args.batch_size,
num_classes=args.num_classes,
balance_mode=args.balance_mode,
num_workers=args.num_workers,
)
try:
model = model_fn()
except Exception as exc:
print(f"Model creation failed for {model_name}: {exc}")
break
model = replace_classifier(model_name, model, args.num_classes)
model = patch_vit_for_large_input(model, model_name, input_size)
model = configure_small_batch_behavior(model_name, model, args.batch_size)
model = model.to(device)
dummy = torch.randn(1, 3, input_size, input_size, device=device)
model.eval()
with torch.no_grad():
out = model(dummy)
out = _extract_logits(out)
out_dim = out.shape[-1]
if out_dim != args.num_classes:
raise RuntimeError(
f"Fatal: classifier replacement failed for {model_name}. "
f"Output dim={out_dim}, expected={args.num_classes}."
)
print(f"Forward sanity check passed: output dim={out_dim}")
del dummy, out
if device.type == "cuda":
torch.cuda.empty_cache()
result = train_one_fold(
model_name=model_name,
model=model,
train_loader=train_loader,
val_loader=val_loader,
epochs=args.epochs,
num_classes=args.num_classes,
backbone_lr=args.backbone_lr,
head_lr=args.head_lr,
class_weights=class_weights,
fold_id=fold_id,
save_dir=model_dir,
freeze_backbone_epochs=args.freeze_backbone_epochs,
max_grad_norm=args.max_grad_norm,
primary_metric=PRIMARY_METRIC,
loss_type=args.loss_type,
focal_gamma=args.focal_gamma,
label_smoothing=args.label_smoothing,
use_tta=args.use_tta,
)
fold_results.append(result)
del model
if device.type == "cuda":
torch.cuda.empty_cache()
if len(fold_results) == args.n_folds:
mean_primary, std_primary = save_kfold_summary(
model_name,
fold_results,
args.num_classes,
model_dir,
)
global_results[model_name] = (mean_primary, std_primary)
else:
print(f"Warning: {model_name} completed only {len(fold_results)}/{args.n_folds} folds.")
print("\n" + "=" * 70)
print(f"Global leaderboard ({args.n_folds}-Fold CV)")
print(f"Sorted by: {PRIMARY_METRIC}")
print("=" * 70)
sorted_results = sorted(global_results.items(), key=lambda x: x[1][0], reverse=True)
print(f"{'Rank':<6} {'Model':<25} {PRIMARY_METRIC:>12} {'Std':>10}")
print("-" * 62)
for rank, (name, (mean_primary, std_primary)) in enumerate(sorted_results, 1):
print(f"{rank:<6} {name:<25} {mean_primary:>11.2f}% {std_primary:>9.2f}%")
leaderboard_path = output_root / f"global_leaderboard_{PRIMARY_METRIC}.csv"
pd.DataFrame([
{
"rank": idx + 1,
"model": name,
f"mean_{PRIMARY_METRIC}": mean_primary,
f"std_{PRIMARY_METRIC}": std_primary,
}
for idx, (name, (mean_primary, std_primary)) in enumerate(sorted_results)
]).to_csv(leaderboard_path, index=False)
print(f"\nLeaderboard saved to: {leaderboard_path}")
if __name__ == "__main__":
main()