| import argparse |
| import json |
| import math |
| import os |
| import random |
| from contextlib import nullcontext |
| from pathlib import Path |
| from typing import Dict, List, Optional, Tuple, Any |
|
|
| import cv2 |
| import numpy as np |
| import pandas as pd |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from PIL import Image |
| from sklearn.metrics import ( |
| accuracy_score, |
| balanced_accuracy_score, |
| classification_report, |
| f1_score, |
| precision_score, |
| recall_score, |
| ) |
| from sklearn.model_selection import GroupKFold, StratifiedKFold |
| from torch.utils.data import DataLoader, Dataset, WeightedRandomSampler |
| from torchvision import models, transforms |
| from torchvision.transforms import InterpolationMode |
| from tqdm import tqdm |
|
|
| try: |
| from sklearn.model_selection import StratifiedGroupKFold |
| HAS_STRATIFIED_GROUP_KFOLD = True |
| except Exception: |
| StratifiedGroupKFold = None |
| HAS_STRATIFIED_GROUP_KFOLD = False |
|
|
| try: |
| import timm |
| HAS_TIMM = True |
| except ImportError: |
| HAS_TIMM = False |
| print("Warning: timm is not installed. timm-based models will be skipped.") |
|
|
| PRIMARY_METRIC = "macro_f1" |
| DEFAULT_INPUT_SIZE = 512 |
|
|
| |
|
|
| def seed_everything(seed: int = 42, deterministic: bool = False) -> None: |
| random.seed(seed) |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
| torch.cuda.manual_seed_all(seed) |
| os.environ["PYTHONHASHSEED"] = str(seed) |
|
|
| if deterministic: |
| torch.backends.cudnn.benchmark = False |
| torch.backends.cudnn.deterministic = True |
| try: |
| torch.use_deterministic_algorithms(True, warn_only=True) |
| except Exception: |
| pass |
| else: |
| torch.backends.cudnn.benchmark = True |
| torch.backends.cudnn.deterministic = False |
|
|
| def ensure_dir(path: Path) -> None: |
| path.mkdir(parents=True, exist_ok=True) |
|
|
| def to_jsonable(obj: Any): |
| if isinstance(obj, dict): |
| return {k: to_jsonable(v) for k, v in obj.items()} |
| if isinstance(obj, list): |
| return [to_jsonable(v) for v in obj] |
| if isinstance(obj, tuple): |
| return [to_jsonable(v) for v in obj] |
| if isinstance(obj, (np.integer, np.floating)): |
| return obj.item() |
| return obj |
|
|
| def name_matches_keywords(name: str, keywords: List[str]) -> bool: |
| if not name: |
| return False |
| for kw in keywords: |
| plain_kw = kw.rstrip(".") |
| if kw in name or name == plain_kw or name.startswith(plain_kw + "."): |
| return True |
| return False |
|
|
|
|
| |
|
|
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| print(f"Using device: {device}") |
| if device.type == "cuda": |
| print(f"GPU: {torch.cuda.get_device_name(0)}") |
| print(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1024 ** 3:.1f} GB") |
| else: |
| print("Warning: CUDA is not available. Training will be much slower on CPU.") |
|
|
|
|
| |
|
|
| _VIT_KEYWORDS = [ |
| "ViT", "Swin", "Transformer", "DeiT", "MaxViT", "CoAtNet", |
| "EfficientFormer", "FastViT", "CaFormer", |
| ] |
|
|
| def _is_vit_family(model_name: str) -> bool: |
| return any(kw.lower() in model_name.lower() for kw in _VIT_KEYWORDS) |
|
|
| def _is_timm_model(model: nn.Module) -> bool: |
| return hasattr(model, "get_classifier") and hasattr(model, "num_features") |
|
|
| MODEL_INPUT_SIZES: Dict[str, int] = { |
| "inception_v3": 299, |
| } |
|
|
| def get_model_input_size(model_name: str) -> int: |
| return MODEL_INPUT_SIZES.get(model_name, DEFAULT_INPUT_SIZE) |
|
|
|
|
| |
| def compute_metrics( |
| y_true: List[int], |
| y_pred: List[int], |
| num_classes: int, |
| class_names: List[str], |
| ) -> Tuple[Dict, Dict]: |
| labels = list(range(num_classes)) |
| report = classification_report( |
| y_true, |
| y_pred, |
| labels=labels, |
| target_names=class_names, |
| output_dict=True, |
| zero_division=0, |
| ) |
| metrics = { |
| "accuracy": 100.0 * accuracy_score(y_true, y_pred), |
| "balanced_accuracy": 100.0 * balanced_accuracy_score(y_true, y_pred), |
| "macro_f1": 100.0 * f1_score(y_true, y_pred, labels=labels, average="macro", zero_division=0), |
| "macro_precision": 100.0 * precision_score(y_true, y_pred, labels=labels, average="macro", zero_division=0), |
| "macro_recall": 100.0 * recall_score(y_true, y_pred, labels=labels, average="macro", zero_division=0), |
| "weighted_f1": 100.0 * f1_score(y_true, y_pred, labels=labels, average="weighted", zero_division=0), |
| } |
| return metrics, report |
|
|
| def save_fold_results(results: Dict, save_dir: Path, tag: str = "best") -> None: |
| ensure_dir(save_dir) |
|
|
| report_df = pd.DataFrame(results["classification_report"]).transpose() |
| with open(save_dir / f"test_report_{tag}.txt", "w", encoding="utf-8") as f: |
| f.write(f"Primary Metric ({PRIMARY_METRIC}): {results['metrics'][PRIMARY_METRIC]:.4f}\n") |
| f.write(f"Accuracy: {results['metrics']['accuracy']:.4f}\n") |
| f.write(f"Balanced Accuracy: {results['metrics']['balanced_accuracy']:.4f}\n") |
| f.write(f"Macro F1: {results['metrics']['macro_f1']:.4f}\n") |
| f.write(f"Macro Recall: {results['metrics']['macro_recall']:.4f}\n") |
| f.write(f"Macro Precision: {results['metrics']['macro_precision']:.4f}\n\n") |
| f.write("Classification Report:\n") |
| f.write(report_df.to_string()) |
|
|
| pred_df = pd.DataFrame({ |
| "patient": results["patients"], |
| "image_name": results["image_names"], |
| "True": results["targets"], |
| "Predicted": results["predictions"], |
| "path": results["image_path"], |
| }) |
| for c in range(results["num_classes"]): |
| pred_df[f"prob_class{c}"] = [row[c] for row in results["probabilities"]] |
| pred_df.to_csv(save_dir / f"predictions_{tag}.csv", index=False) |
|
|
| payload = { |
| "best_epoch": results["best_epoch"], |
| "primary_metric": PRIMARY_METRIC, |
| "metrics": results["metrics"], |
| "per_class": [ |
| results["classification_report"].get( |
| f"class{i}", {"precision": 0, "recall": 0, "f1-score": 0} |
| ) |
| for i in range(results["num_classes"]) |
| ], |
| } |
| with open(save_dir / f"{tag}_metrics.json", "w", encoding="utf-8") as f: |
| json.dump(to_jsonable(payload), f, indent=2, ensure_ascii=False) |
|
|
| def save_kfold_summary( |
| model_name: str, |
| fold_results: List[Dict], |
| num_classes: int, |
| save_dir: Path, |
| ) -> Tuple[float, float]: |
| ensure_dir(save_dir) |
|
|
| metric_names = [ |
| "accuracy", |
| "balanced_accuracy", |
| "macro_f1", |
| "macro_recall", |
| "macro_precision", |
| "weighted_f1", |
| ] |
| summary = {} |
| for name in metric_names: |
| values = [r["metrics"][name] for r in fold_results] |
| summary[name] = { |
| "mean": float(np.mean(values)), |
| "std": float(np.std(values)), |
| } |
|
|
| lines = [ |
| "=" * 70, |
| f"Model: {model_name}", |
| "5-Fold Cross-Validation Summary", |
| f"Primary Metric: {PRIMARY_METRIC}", |
| "=" * 70, |
| "", |
| ] |
| for i, r in enumerate(fold_results, 1): |
| lines.append( |
| f"Fold {i}: Macro-F1={r['metrics']['macro_f1']:.2f}% | " |
| f"BA={r['metrics']['balanced_accuracy']:.2f}% | " |
| f"Acc={r['metrics']['accuracy']:.2f}% | " |
| f"BestEpoch={r['best_epoch']}" |
| ) |
| lines.append("") |
| for name in metric_names: |
| lines.append(f"{name}: {summary[name]['mean']:.2f}% +/- {summary[name]['std']:.2f}%") |
|
|
| lines.append("") |
| lines.append("Per-class metrics (mean +/- std)") |
| lines.append(f"{'class':<10} {'precision':>18} {'recall':>18} {'f1-score':>18}") |
|
|
| per_class_summary = {} |
| for c in range(num_classes): |
| ps = [r["per_class"][c]["precision"] for r in fold_results] |
| rs = [r["per_class"][c]["recall"] for r in fold_results] |
| fs = [r["per_class"][c]["f1-score"] for r in fold_results] |
| per_class_summary[c] = { |
| "precision_mean": float(np.mean(ps)), |
| "precision_std": float(np.std(ps)), |
| "recall_mean": float(np.mean(rs)), |
| "recall_std": float(np.std(rs)), |
| "f1_mean": float(np.mean(fs)), |
| "f1_std": float(np.std(fs)), |
| } |
| lines.append( |
| f"class{c:<5} " |
| f"{np.mean(ps):.4f}+/-{np.std(ps):.4f}" |
| f"{np.mean(rs):>18.4f}+/-{np.std(rs):.4f}" |
| f"{np.mean(fs):>18.4f}+/-{np.std(fs):.4f}" |
| ) |
|
|
| text = "\n".join(lines) |
| print(text) |
| with open(save_dir / "kfold_summary.txt", "w", encoding="utf-8") as f: |
| f.write(text) |
|
|
| with open(save_dir / "kfold_summary.json", "w", encoding="utf-8") as f: |
| json.dump( |
| to_jsonable({ |
| "model": model_name, |
| "primary_metric": PRIMARY_METRIC, |
| "summary": summary, |
| "per_class": per_class_summary, |
| }), |
| f, |
| indent=2, |
| ensure_ascii=False, |
| ) |
|
|
| all_targets, all_predictions, all_paths = [], [], [] |
| all_patients, all_image_names = [], [] |
| all_probabilities = [] |
|
|
| pooled_ready = all( |
| "targets" in r and "predictions" in r and "image_path" in r and "probabilities" in r |
| for r in fold_results |
| ) |
| if pooled_ready: |
| for r in fold_results: |
| all_targets.extend(r["targets"]) |
| all_predictions.extend(r["predictions"]) |
| all_paths.extend(r["image_path"]) |
| all_patients.extend(r["patients"]) |
| all_image_names.extend(r["image_names"]) |
| all_probabilities.extend(r["probabilities"]) |
|
|
| class_names = [f"class{i}" for i in range(num_classes)] |
| pooled_metrics, pooled_report = compute_metrics( |
| all_targets, |
| all_predictions, |
| num_classes, |
| class_names, |
| ) |
|
|
| with open(save_dir / "oof_report.txt", "w", encoding="utf-8") as f: |
| f.write("Pooled out-of-fold metrics\n") |
| f.write(f"Primary Metric ({PRIMARY_METRIC}): {pooled_metrics[PRIMARY_METRIC]:.4f}\n") |
| for k, v in pooled_metrics.items(): |
| f.write(f"{k}: {v:.4f}\n") |
| f.write("\nClassification Report:\n") |
| f.write(pd.DataFrame(pooled_report).transpose().to_string()) |
|
|
| oof_df = pd.DataFrame({ |
| "patient": all_patients, |
| "image_name": all_image_names, |
| "True": all_targets, |
| "Predicted": all_predictions, |
| "path": all_paths, |
| }) |
| for c in range(num_classes): |
| oof_df[f"prob_class{c}"] = [row[c] for row in all_probabilities] |
| oof_df.to_csv(save_dir / "oof_predictions.csv", index=False) |
|
|
| return summary[PRIMARY_METRIC]["mean"], summary[PRIMARY_METRIC]["std"] |
|
|
|
|
| |
|
|
| class BlackBorderCrop: |
| """Crop black borders and obvious invalid background around the fundus.""" |
| def __init__(self, threshold: int = 10, margin_ratio: float = 0.02): |
| self.threshold = threshold |
| self.margin_ratio = margin_ratio |
|
|
| def __call__(self, pil_img: Image.Image) -> Image.Image: |
| img = np.array(pil_img.convert("RGB")) |
| gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) |
| mask = gray > self.threshold |
|
|
| if mask.sum() < 64: |
| return pil_img.convert("RGB") |
|
|
| ys, xs = np.where(mask) |
| y1, y2 = ys.min(), ys.max() |
| x1, x2 = xs.min(), xs.max() |
|
|
| margin_y = int((y2 - y1 + 1) * self.margin_ratio) |
| margin_x = int((x2 - x1 + 1) * self.margin_ratio) |
|
|
| y1 = max(0, y1 - margin_y) |
| y2 = min(img.shape[0], y2 + margin_y + 1) |
| x1 = max(0, x1 - margin_x) |
| x2 = min(img.shape[1], x2 + margin_x + 1) |
|
|
| cropped = img[y1:y2, x1:x2] |
| return Image.fromarray(cropped) |
|
|
| class FundusCircularCrop: |
|
|
| def __init__(self, threshold: int = 8, radius_pad_ratio: float = 0.03): |
| self.threshold = threshold |
| self.radius_pad_ratio = radius_pad_ratio |
|
|
| def __call__(self, pil_img: Image.Image) -> Image.Image: |
| img = np.array(pil_img.convert("RGB")) |
| gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) |
|
|
| mask = (gray > self.threshold).astype(np.uint8) * 255 |
| kernel = np.ones((5, 5), np.uint8) |
| mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel) |
| mask = cv2.medianBlur(mask, 5) |
|
|
| contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
|
|
| if not contours: |
| return Image.fromarray(img) |
|
|
| largest = max(contours, key=cv2.contourArea) |
| (cx, cy), radius = cv2.minEnclosingCircle(largest) |
|
|
| if radius < 10: |
| return Image.fromarray(img) |
|
|
| radius = radius * (1.0 + self.radius_pad_ratio) |
| cx, cy, radius = float(cx), float(cy), float(radius) |
|
|
| x1 = max(0, int(cx - radius)) |
| y1 = max(0, int(cy - radius)) |
| x2 = min(img.shape[1], int(cx + radius)) |
| y2 = min(img.shape[0], int(cy + radius)) |
|
|
| cropped = img[y1:y2, x1:x2] |
| h, w = cropped.shape[:2] |
| if h < 2 or w < 2: |
| return Image.fromarray(img) |
|
|
| local_cx = cx - x1 |
| local_cy = cy - y1 |
| rr = max(1, min(int(radius), min(h, w) // 2)) |
|
|
| yy, xx = np.ogrid[:h, :w] |
| circle_mask = ((xx - local_cx) ** 2 + (yy - local_cy) ** 2) <= (rr ** 2) |
|
|
| out = np.zeros_like(cropped) |
| out[circle_mask] = cropped[circle_mask] |
| return Image.fromarray(out) |
|
|
| class ResizeToSquare: |
| def __init__(self, size: int): |
| self.size = size |
|
|
| def __call__(self, pil_img: Image.Image) -> Image.Image: |
| return pil_img.resize((self.size, self.size), resample=Image.BILINEAR) |
|
|
| class LightCLAHE: |
|
|
| def __init__(self, clip_limit: float = 2.0, grid: Tuple[int, int] = (8, 8)): |
| self.clahe = cv2.createCLAHE(clipLimit=clip_limit, tileGridSize=grid) |
|
|
| def __call__(self, pil_img: Image.Image) -> Image.Image: |
| img = np.array(pil_img.convert("RGB")) |
| lab = cv2.cvtColor(img, cv2.COLOR_RGB2LAB) |
| l, a, b = cv2.split(lab) |
| l = self.clahe.apply(l) |
| out = cv2.cvtColor(cv2.merge([l, a, b]), cv2.COLOR_LAB2RGB) |
| return Image.fromarray(out) |
|
|
| class GreenChannelEnhancement: |
|
|
| def __init__( |
| self, |
| clip_limit: float = 2.5, |
| grid: Tuple[int, int] = (8, 8), |
| blend_alpha: float = 0.30, |
| ): |
| self.clahe = cv2.createCLAHE(clipLimit=clip_limit, tileGridSize=grid) |
| self.blend_alpha = blend_alpha |
|
|
| def __call__(self, pil_img: Image.Image) -> Image.Image: |
| img = np.array(pil_img.convert("RGB")) |
| r, g, b = cv2.split(img) |
| g_eq = self.clahe.apply(g) |
| g_new = cv2.addWeighted(g, 1.0 - self.blend_alpha, g_eq, self.blend_alpha, 0.0) |
| out = cv2.merge([r, g_new, b]) |
| return Image.fromarray(out) |
|
|
|
|
| class FundusEyeMask: |
|
|
| def __init__( |
| self, |
| threshold: int = 8, |
| radius_pad_ratio: float = 0.03, |
| morph_kernel: int = 7, |
| blur_kernel: int = 5, |
| ): |
| self.threshold = threshold |
| self.radius_pad_ratio = radius_pad_ratio |
| self.morph_kernel = morph_kernel |
| self.blur_kernel = blur_kernel |
|
|
| def __call__(self, pil_img: Image.Image) -> Image.Image: |
| img = np.array(pil_img.convert("RGB")) |
| gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) |
|
|
| |
| _, mask = cv2.threshold(gray, self.threshold, 255, cv2.THRESH_BINARY) |
|
|
| kernel = np.ones((self.morph_kernel, self.morph_kernel), np.uint8) |
| mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel) |
| mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel) |
|
|
| contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) |
| if not contours: |
| return Image.fromarray(img) |
|
|
| largest = max(contours, key=cv2.contourArea) |
| (cx, cy), radius = cv2.minEnclosingCircle(largest) |
| if radius < 10: |
| return Image.fromarray(img) |
|
|
| radius = radius * (1.0 + self.radius_pad_ratio) |
| yy, xx = np.ogrid[:img.shape[0], :img.shape[1]] |
| circle_mask = (((xx - cx) ** 2 + (yy - cy) ** 2) <= (radius ** 2)).astype(np.uint8) * 255 |
|
|
| if self.blur_kernel and self.blur_kernel > 1: |
| k = self.blur_kernel if self.blur_kernel % 2 == 1 else self.blur_kernel + 1 |
| circle_mask = cv2.GaussianBlur(circle_mask, (k, k), 0) |
|
|
| circle_mask_f = (circle_mask.astype(np.float32) / 255.0)[..., None] |
| out = (img.astype(np.float32) * circle_mask_f).clip(0, 255).astype(np.uint8) |
| return Image.fromarray(out) |
|
|
| _light_clahe = LightCLAHE() |
| _green_enhance = GreenChannelEnhancement() |
| _eye_mask = FundusEyeMask() |
| _transform_cache: Dict[int, Tuple[transforms.Compose, transforms.Compose]] = {} |
|
|
| def build_transforms(input_size: int = DEFAULT_INPUT_SIZE): |
| """ |
| 预处理流程: |
| - 不再使用 BlackBorderCrop |
| - 缩放到 input_size → CLAHE → 绿色通道增强 → 眼底区域蒙版 |
| - 蒙版仅保留眼睛区域,屏蔽眼底边缘外的无关像素 |
| 训练增强: |
| - 水平翻转 + 垂直翻转 |
| - 小角度随机旋转 (±15°) + 轻微平移 + 尺度扰动 (0.85~1.15) |
| - 适度 ColorJitter |
| - 轻微高斯模糊 |
| """ |
| if input_size in _transform_cache: |
| return _transform_cache[input_size] |
|
|
| preprocess = [ |
| ResizeToSquare(input_size), |
| _light_clahe, |
| _green_enhance, |
| _eye_mask, |
| ] |
|
|
| train_tf = transforms.Compose( |
| preprocess |
| + [ |
| transforms.RandomHorizontalFlip(p=0.5), |
| transforms.RandomVerticalFlip(p=0.5), |
| transforms.RandomAffine( |
| degrees=15, |
| translate=(0.05, 0.05), |
| scale=(0.85, 1.15), |
| interpolation=InterpolationMode.BILINEAR, |
| fill=0, |
| ), |
| transforms.ColorJitter( |
| brightness=0.20, |
| contrast=0.20, |
| saturation=0.10, |
| hue=0.02, |
| ), |
| transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 0.8)), |
| transforms.ToTensor(), |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), |
| ] |
| ) |
|
|
| val_tf = transforms.Compose( |
| preprocess |
| + [ |
| transforms.ToTensor(), |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), |
| ] |
| ) |
|
|
| _transform_cache[input_size] = (train_tf, val_tf) |
| return train_tf, val_tf |
|
|
|
|
| |
| |
|
|
| def predict_with_tta( |
| model: nn.Module, |
| inputs: torch.Tensor, |
| amp_enabled: bool = False, |
| ) -> torch.Tensor: |
|
|
| amp_ctx = torch.cuda.amp.autocast if amp_enabled else nullcontext |
| aug_variants = [ |
| inputs, |
| inputs.flip(-1), |
| inputs.flip(-2), |
| inputs.flip(-1).flip(-2), |
| ] |
| probs_list = [] |
| for aug in aug_variants: |
| with amp_ctx(): |
| out = model(aug) |
| logits = _extract_logits(out) |
| probs_list.append(torch.softmax(logits, dim=1)) |
|
|
| return torch.stack(probs_list, dim=0).mean(dim=0) |
|
|
| |
|
|
| class ImageDataset(Dataset): |
| def __init__(self, df: pd.DataFrame, transform=None): |
| self.df = df.reset_index(drop=True).copy() |
| self.transform = transform |
|
|
| self.paths = self.df["path"].astype(str).tolist() |
| self.labels = self.df["label"].astype(int).tolist() |
| self.patients = self.df["patient"].astype(str).tolist() |
| if "image_name" in self.df.columns: |
| self.image_names = self.df["image_name"].astype(str).tolist() |
| else: |
| self.image_names = [Path(p).name for p in self.paths] |
|
|
| def __len__(self) -> int: |
| return len(self.paths) |
|
|
| def __getitem__(self, idx: int): |
| img_path = self.paths[idx] |
| label = self.labels[idx] |
|
|
| try: |
| image = Image.open(img_path).convert("RGB") |
| except Exception as exc: |
| raise RuntimeError(f"Failed to open image: {img_path}") from exc |
|
|
| if self.transform is not None: |
| image = self.transform(image) |
|
|
| meta = { |
| "path": img_path, |
| "patient": self.patients[idx], |
| "image_name": self.image_names[idx], |
| } |
| return image, torch.tensor(label, dtype=torch.long), meta |
|
|
|
|
| |
|
|
| def validate_image_paths(df: pd.DataFrame, path_col: str = "path") -> pd.DataFrame: |
| total = len(df) |
| mask = df[path_col].apply(os.path.isfile) |
| missing = total - int(mask.sum()) |
| if missing > 0: |
| print(f"Warning: {missing}/{total} paths do not exist and will be removed.") |
| df = df.loc[mask].reset_index(drop=True) |
| else: |
| print(f"All {total} image paths are valid.") |
| return df |
|
|
| def load_and_prepare_data(excel_path: str, group_col: str = "patient") -> pd.DataFrame: |
| df = pd.read_excel(excel_path, engine="openpyxl") |
|
|
| required_cols = {"path", "label", group_col} |
| missing_cols = required_cols - set(df.columns) |
| if missing_cols: |
| raise KeyError(f"Missing required columns in Excel: {sorted(missing_cols)}") |
|
|
| df = df.copy() |
| df[group_col] = df[group_col].astype(str).str.strip() |
| if df[group_col].isin(["", "nan", "None"]).any(): |
| bad_rows = int(df[group_col].isin(["", "nan", "None"]).sum()) |
| raise ValueError(f"Found {bad_rows} rows with invalid patient/group identifiers in column '{group_col}'.") |
|
|
| df["label"] = df["label"].replace({"AROP": 5}) |
| df["label"] = pd.to_numeric(df["label"], errors="raise").astype(int) |
|
|
| if df["label"].min() == 1: |
| df["label"] = df["label"] - 1 |
|
|
| |
| df["label"] = df["label"].replace({4: 3, 5: 3}) |
|
|
| df = validate_image_paths(df, path_col="path") |
|
|
| if "patient" != group_col: |
| df["patient"] = df[group_col].astype(str) |
|
|
| unique_labels = sorted(df["label"].unique().tolist()) |
| print(f"Dataset size: {len(df)} images") |
| print(f"Unique patients: {df[group_col].nunique()}") |
| print(f"Class distribution: {dict(df['label'].value_counts().sort_index())}") |
| print(f"Observed labels: {unique_labels}") |
| return df |
|
|
| def _approximate_group_stratified_splits( |
| df: pd.DataFrame, |
| n_folds: int, |
| random_seed: int, |
| group_col: str, |
| ): |
|
|
| group_df = ( |
| df.groupby(group_col)["label"] |
| .agg(lambda x: x.value_counts().index[0]) |
| .reset_index() |
| ) |
| if group_df[group_col].nunique() < n_folds: |
| raise ValueError( |
| f"Number of unique groups ({group_df[group_col].nunique()}) is smaller than n_folds={n_folds}." |
| ) |
|
|
| skf = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=random_seed) |
| splits = [] |
| group_ids = group_df[group_col].values |
| group_labels = group_df["label"].values |
|
|
| for group_train_idx, group_val_idx in skf.split(group_ids, group_labels): |
| train_groups = set(group_ids[group_train_idx]) |
| val_groups = set(group_ids[group_val_idx]) |
|
|
| train_idx = df.index[df[group_col].isin(train_groups)].to_numpy() |
| val_idx = df.index[df[group_col].isin(val_groups)].to_numpy() |
| splits.append((train_idx, val_idx)) |
|
|
| return splits |
|
|
| def build_fold_splits( |
| df: pd.DataFrame, |
| n_folds: int, |
| random_seed: int, |
| group_col: str = "patient", |
| ): |
| groups = df[group_col].astype(str).values |
| labels = df["label"].values |
|
|
| if len(np.unique(groups)) < n_folds: |
| raise ValueError( |
| f"Unique groups in '{group_col}' = {len(np.unique(groups))}, which is smaller than n_folds={n_folds}." |
| ) |
|
|
| if HAS_STRATIFIED_GROUP_KFOLD: |
| print( |
| f"Using StratifiedGroupKFold with group_col='{group_col}', n_folds={n_folds}, seed={random_seed}." |
| ) |
| try: |
| splitter = StratifiedGroupKFold( |
| n_splits=n_folds, |
| shuffle=True, |
| random_state=random_seed, |
| ) |
| splits = list(splitter.split(df, y=labels, groups=groups)) |
| except ValueError as exc: |
| print(f"StratifiedGroupKFold failed: {exc}") |
| print("Falling back to approximate grouped stratification using patient-majority labels.") |
| splits = _approximate_group_stratified_splits(df, n_folds, random_seed, group_col) |
| else: |
| print("StratifiedGroupKFold is unavailable. Falling back to approximate grouped stratification.") |
| splits = _approximate_group_stratified_splits(df, n_folds, random_seed, group_col) |
|
|
| for fold_id, (train_idx, val_idx) in enumerate(splits, 1): |
| train_groups = set(df.iloc[train_idx][group_col].astype(str).tolist()) |
| val_groups = set(df.iloc[val_idx][group_col].astype(str).tolist()) |
| overlap = train_groups & val_groups |
| if overlap: |
| raise RuntimeError( |
| f"Data leakage detected in fold {fold_id}: {len(overlap)} overlapping groups." |
| ) |
| return splits |
|
|
| def _compute_class_weights(train_df: pd.DataFrame, num_classes: int) -> torch.Tensor: |
| counts = train_df["label"].value_counts().sort_index() |
| total = len(train_df) |
| weights = [total / (num_classes * counts.get(c, 1)) for c in range(num_classes)] |
| return torch.tensor(weights, dtype=torch.float32, device=device) |
|
|
| def _make_weighted_sampler(train_df: pd.DataFrame) -> WeightedRandomSampler: |
| counts = train_df["label"].value_counts().to_dict() |
| sample_weights = train_df["label"].map(lambda x: 1.0 / counts[x]).astype(float).values |
| sample_weights = torch.as_tensor(sample_weights, dtype=torch.double) |
| return WeightedRandomSampler(sample_weights, num_samples=len(sample_weights), replacement=True) |
|
|
| def create_fold_loaders( |
| train_df: pd.DataFrame, |
| val_df: pd.DataFrame, |
| input_size: int = DEFAULT_INPUT_SIZE, |
| batch_size: int = 8, |
| num_classes: int = 4, |
| balance_mode: str = "loss", |
| num_workers: int = 4, |
| ): |
| train_tf, val_tf = build_transforms(input_size) |
|
|
| sampler = None |
| class_weights = None |
|
|
| if balance_mode == "sampler":+ |
| sampler = _make_weighted_sampler(train_df) |
| print("Training loader uses WeightedRandomSampler for class balancing.") |
| elif balance_mode == "loss": |
| class_weights = _compute_class_weights(train_df, num_classes) |
| print(f"Training loss uses class weights: {class_weights.detach().cpu().numpy().tolist()}") |
| else: |
| print("No imbalance correction is used.") |
|
|
| drop_last = (batch_size > 1) and (len(train_df) % batch_size == 1) |
| if drop_last: |
| print( |
| f"Training loader will drop the last singleton batch " |
| f"(train_size={len(train_df)}, batch_size={batch_size}) to avoid BatchNorm issues." |
| ) |
|
|
| pin_memory = device.type == "cuda" |
|
|
| train_loader = DataLoader( |
| ImageDataset(train_df, train_tf), |
| batch_size=batch_size, |
| shuffle=(sampler is None), |
| sampler=sampler, |
| num_workers=num_workers, |
| pin_memory=pin_memory, |
| drop_last=drop_last, |
| persistent_workers=(num_workers > 0), |
| ) |
| val_loader = DataLoader( |
| ImageDataset(val_df, val_tf), |
| batch_size=batch_size, |
| shuffle=False, |
| num_workers=num_workers, |
| pin_memory=pin_memory, |
| persistent_workers=(num_workers > 0), |
| ) |
| return train_loader, val_loader, class_weights |
|
|
| |
|
|
| def patch_vit_for_large_input( |
| model: nn.Module, |
| model_name: str, |
| input_size: int, |
| ) -> nn.Module: |
| if "ViT" not in model_name: |
| return model |
|
|
| if not (hasattr(model, "encoder") and hasattr(model.encoder, "pos_embedding")): |
| print(f"Warning: cannot find pos_embedding for {model_name}, skip interpolation.") |
| return model |
|
|
| patch_size = model.patch_size |
| expected_patches = (input_size // patch_size) ** 2 |
| pos_embed = model.encoder.pos_embedding |
| current_patches = pos_embed.shape[1] - 1 |
|
|
| if current_patches == expected_patches: |
| print(f"[ViT] pos_embedding already matches input_size={input_size}, no interpolation needed.") |
| return model |
|
|
| print( |
| f"[ViT] Interpolating pos_embedding: {current_patches} -> {expected_patches} patches " |
| f"for input_size={input_size}." |
| ) |
|
|
| cls_token = pos_embed[:, :1, :] |
| patch_tokens = pos_embed[:, 1:, :] |
| dim = patch_tokens.shape[-1] |
|
|
| h_old = w_old = int(math.sqrt(current_patches)) |
| h_new = w_new = int(math.sqrt(expected_patches)) |
|
|
| patch_tokens = ( |
| patch_tokens |
| .reshape(1, h_old, w_old, dim) |
| .permute(0, 3, 1, 2) |
| .float() |
| ) |
| patch_tokens = F.interpolate( |
| patch_tokens, |
| size=(h_new, w_new), |
| mode="bicubic", |
| align_corners=False, |
| ) |
| patch_tokens = patch_tokens.permute(0, 2, 3, 1).reshape(1, expected_patches, dim) |
|
|
| model.encoder.pos_embedding = nn.Parameter(torch.cat([cls_token, patch_tokens], dim=1)) |
|
|
| if hasattr(model, "image_size"): |
| model.image_size = input_size |
|
|
| return model |
|
|
| |
|
|
| def _find_last_linear(module: nn.Module): |
| if isinstance(module, nn.Linear): |
| return module |
| if isinstance(module, nn.Sequential): |
| for child in reversed(list(module.children())): |
| result = _find_last_linear(child) |
| if result is not None: |
| return result |
| if hasattr(module, "head") and isinstance(module.head, (nn.Linear, nn.Sequential)): |
| return _find_last_linear(module.head) |
| return None |
|
|
| def _verify_classifier(model: nn.Module, model_name: str, expected_classes: int) -> None: |
| for attr_name in ["fc", "head", "classifier", "heads"]: |
| if not hasattr(model, attr_name): |
| continue |
| layer = getattr(model, attr_name) |
| last_linear = _find_last_linear(layer) |
| if last_linear is not None: |
| if last_linear.out_features != expected_classes: |
| raise RuntimeError( |
| f"Classifier replacement failed for {model_name}: " |
| f"out_features={last_linear.out_features}, expected={expected_classes}" |
| ) |
| print(f"Verified {model_name}: classifier -> {expected_classes} classes (in={last_linear.in_features})") |
| return |
| print(f"Warning: failed to automatically verify classifier for {model_name}") |
|
|
| def replace_classifier( |
| model_name: str, |
| model: nn.Module, |
| num_classes: int, |
| dropout: float = 0.3, |
| ) -> nn.Module: |
| if _is_timm_model(model): |
| in_feat = model.num_features |
| orig_classifier = model.get_classifier() |
| print(f"[timm] {model_name}: original classifier={type(orig_classifier).__name__}, num_features={in_feat}") |
|
|
| model.reset_classifier(num_classes) |
| new_fc = model.get_classifier() |
|
|
| wrapped = False |
| if isinstance(new_fc, nn.Linear): |
| for parent_attr, child_attr in [ |
| ("head", "fc"), |
| ("head", "head"), |
| (None, "head"), |
| (None, "classifier"), |
| (None, "fc"), |
| ]: |
| try: |
| parent = getattr(model, parent_attr) if parent_attr else model |
| child = getattr(parent, child_attr) |
| if child is new_fc: |
| setattr( |
| parent, |
| child_attr, |
| nn.Sequential(nn.Dropout(dropout), nn.Linear(in_feat, num_classes)), |
| ) |
| wrapped = True |
| break |
| except AttributeError: |
| continue |
|
|
| if not wrapped: |
| print(f"[timm] {model_name}: reset_classifier({num_classes}) applied (no Dropout wrapper).") |
|
|
| _verify_classifier(model, model_name, num_classes) |
| return model |
|
|
| n = model_name |
|
|
| if "VGG" in n: |
| in_feat = model.classifier[6].in_features |
| model.classifier[6] = nn.Sequential(nn.Dropout(dropout), nn.Linear(in_feat, num_classes)) |
|
|
| elif n == "inception_v3": |
| aux_in = model.AuxLogits.fc.in_features |
| model.AuxLogits.fc = nn.Sequential(nn.Dropout(dropout), nn.Linear(aux_in, num_classes)) |
| fc_in = model.fc.in_features |
| model.fc = nn.Sequential(nn.Dropout(dropout), nn.Linear(fc_in, num_classes)) |
|
|
| elif "GoogLeNet" in n: |
| in_feat = model.fc.in_features |
| model.fc = nn.Sequential(nn.Dropout(dropout), nn.Linear(in_feat, num_classes)) |
| if hasattr(model, "aux1") and model.aux1 is not None and hasattr(model.aux1, "fc2"): |
| aux1_in = model.aux1.fc2.in_features |
| model.aux1.fc2 = nn.Sequential(nn.Dropout(dropout), nn.Linear(aux1_in, num_classes)) |
| if hasattr(model, "aux2") and model.aux2 is not None and hasattr(model.aux2, "fc2"): |
| aux2_in = model.aux2.fc2.in_features |
| model.aux2.fc2 = nn.Sequential(nn.Dropout(dropout), nn.Linear(aux2_in, num_classes)) |
|
|
| elif "ResNe" in n: |
| in_feat = model.fc.in_features |
| model.fc = nn.Sequential(nn.Dropout(dropout), nn.Linear(in_feat, num_classes)) |
|
|
| elif "DenseNet" in n: |
| in_feat = model.classifier.in_features |
| model.classifier = nn.Sequential(nn.Dropout(dropout), nn.Linear(in_feat, num_classes)) |
|
|
| elif "MobileNet" in n: |
| in_feat = model.classifier[-1].in_features |
| model.classifier[-1] = nn.Sequential(nn.Dropout(dropout), nn.Linear(in_feat, num_classes)) |
|
|
| elif "MnasNet" in n: |
| in_feat = model.classifier[-1].in_features |
| model.classifier[-1] = nn.Sequential(nn.Dropout(dropout), nn.Linear(in_feat, num_classes)) |
|
|
| elif "EfficientNet" in n: |
| in_feat = model.classifier[-1].in_features |
| model.classifier[-1] = nn.Sequential(nn.Dropout(dropout), nn.Linear(in_feat, num_classes)) |
|
|
| elif "ConvNeXt" in n: |
| in_feat = model.classifier[-1].in_features |
| model.classifier[-1] = nn.Sequential(nn.Dropout(dropout), nn.Linear(in_feat, num_classes)) |
|
|
| elif "RegNet" in n or "ShuffleNet" in n: |
| in_feat = model.fc.in_features |
| model.fc = nn.Sequential(nn.Dropout(dropout), nn.Linear(in_feat, num_classes)) |
|
|
| elif "ViT" in n: |
| if hasattr(model, "heads") and hasattr(model.heads, "head"): |
| in_feat = model.heads.head.in_features |
| model.heads = nn.Sequential(nn.Dropout(dropout), nn.Linear(in_feat, num_classes)) |
| elif hasattr(model, "head") and isinstance(model.head, nn.Linear): |
| in_feat = model.head.in_features |
| model.head = nn.Sequential(nn.Dropout(dropout), nn.Linear(in_feat, num_classes)) |
| else: |
| raise ValueError(f"Cannot find classifier head for {n}") |
|
|
| elif "Swin" in n: |
| if hasattr(model, "head") and isinstance(model.head, nn.Linear): |
| in_feat = model.head.in_features |
| model.head = nn.Sequential(nn.Dropout(dropout), nn.Linear(in_feat, num_classes)) |
| elif hasattr(model, "heads") and hasattr(model.heads, "head"): |
| in_feat = model.heads.head.in_features |
| model.heads = nn.Sequential(nn.Dropout(dropout), nn.Linear(in_feat, num_classes)) |
| else: |
| raise ValueError(f"Cannot find classifier head for {n}") |
|
|
| elif _is_vit_family(n): |
| replaced = False |
| for attr in ["heads.head", "head", "classifier"]: |
| parts = attr.split(".") |
| obj = model |
| try: |
| for p in parts: |
| obj = getattr(obj, p) |
| if isinstance(obj, nn.Linear): |
| in_feat = obj.in_features |
| parent = model |
| for p in parts[:-1]: |
| parent = getattr(parent, p) |
| setattr(parent, parts[-1], nn.Sequential(nn.Dropout(dropout), nn.Linear(in_feat, num_classes))) |
| replaced = True |
| break |
| except AttributeError: |
| continue |
| if not replaced: |
| raise ValueError(f"Cannot find classifier head for {n}") |
|
|
| else: |
| replaced = False |
| for attr_name in ["fc", "head", "classifier"]: |
| if not hasattr(model, attr_name): |
| continue |
| layer = getattr(model, attr_name) |
| if isinstance(layer, nn.Linear): |
| in_feat = layer.in_features |
| setattr(model, attr_name, nn.Sequential(nn.Dropout(dropout), nn.Linear(in_feat, num_classes))) |
| replaced = True |
| break |
| if isinstance(layer, nn.Sequential) and len(layer) > 0 and isinstance(layer[-1], nn.Linear): |
| in_feat = layer[-1].in_features |
| layer[-1] = nn.Sequential(nn.Dropout(dropout), nn.Linear(in_feat, num_classes)) |
| replaced = True |
| break |
| if not replaced: |
| raise ValueError(f"Cannot automatically replace classifier for {n}") |
|
|
| _verify_classifier(model, model_name, num_classes) |
| return model |
|
|
|
|
| |
|
|
| def _get_head_keywords(model_name: str) -> List[str]: |
| n = model_name |
| if "VGG" in n: |
| return ["classifier.6"] |
| if n == "inception_v3": |
| return ["fc.", "AuxLogits.fc"] |
| if "GoogLeNet" in n: |
| return ["fc.", "aux1.fc2", "aux2.fc2"] |
| if "ResNe" in n: |
| return ["fc."] |
| if "DenseNet" in n: |
| return ["classifier."] |
| if "MobileNet" in n: |
| return ["classifier.3", "classifier.2", "classifier."] |
| if "MnasNet" in n: |
| return ["classifier.1", "classifier."] |
| if "EfficientNet" in n or "ConvNeXt" in n: |
| return ["classifier.", "head.fc"] |
| if "RegNet" in n or "ShuffleNet" in n: |
| return ["fc."] |
| if "ViT" in n or _is_vit_family(n): |
| return ["heads.", "head.", "classifier."] |
| return ["fc.", "classifier.", "head.", "heads."] |
|
|
| def get_parameter_groups( |
| model_name: str, |
| model: nn.Module, |
| backbone_lr: float = 3e-5, |
| head_lr: float = 1e-3, |
| ): |
| head_kw = _get_head_keywords(model_name) |
| head_p, back_p = [], [] |
| for name, param in model.named_parameters(): |
| if name_matches_keywords(name, head_kw): |
| head_p.append(param) |
| else: |
| back_p.append(param) |
|
|
| if not head_p: |
| print(f"Warning: no head parameters matched for {model_name}; all params use head_lr.") |
| return [{"params": list(model.parameters()), "lr": head_lr}] |
|
|
| print( |
| f"Parameter groups | backbone: {sum(p.numel() for p in back_p):,} (lr={backbone_lr}) | " |
| f"head: {sum(p.numel() for p in head_p):,} (lr={head_lr})" |
| ) |
| return [{"params": back_p, "lr": backbone_lr}, {"params": head_p, "lr": head_lr}] |
|
|
| def set_backbone_trainable(model_name: str, model: nn.Module, train_backbone: bool) -> None: |
| head_kw = _get_head_keywords(model_name) |
| for name, param in model.named_parameters(): |
| is_head = name_matches_keywords(name, head_kw) |
| param.requires_grad = train_backbone or is_head |
|
|
| def set_frozen_backbone_bn_eval(model_name: str, model: nn.Module) -> None: |
| head_kw = _get_head_keywords(model_name) |
| bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm) |
|
|
| for name, module in model.named_modules(): |
| if isinstance(module, bn_types) and not name_matches_keywords(name, head_kw): |
| module.eval() |
| for param in module.parameters(): |
| param.requires_grad = False |
|
|
| def configure_small_batch_behavior(model_name: str, model: nn.Module, batch_size: int) -> nn.Module: |
| if batch_size >= 2: |
| return model |
|
|
| if model_name == "inception_v3": |
| print("batch_size=1 detected: disabling Inception auxiliary classifier.") |
| if hasattr(model, "aux_logits"): |
| model.aux_logits = False |
| if hasattr(model, "AuxLogits"): |
| model.AuxLogits = None |
|
|
| elif "GoogLeNet" in model_name: |
| print("batch_size=1 detected: disabling GoogLeNet auxiliary classifiers.") |
| if hasattr(model, "aux_logits"): |
| model.aux_logits = False |
| if hasattr(model, "aux1"): |
| model.aux1 = None |
| if hasattr(model, "aux2"): |
| model.aux2 = None |
|
|
| return model |
|
|
| |
|
|
| def _extract_logits(output): |
| if torch.is_tensor(output): |
| return output |
| if hasattr(output, "logits") and torch.is_tensor(output.logits): |
| return output.logits |
| if isinstance(output, (tuple, list)) and len(output) > 0 and torch.is_tensor(output[0]): |
| return output[0] |
| raise TypeError("Unable to extract logits from model output.") |
|
|
| def _extract_aux_outputs(output): |
| aux_outputs = [] |
| if isinstance(output, (tuple, list)): |
| aux_outputs.extend([o for o in output[1:] if torch.is_tensor(o)]) |
| else: |
| for attr in ["aux_logits", "aux_logits2", "aux_logits1"]: |
| if hasattr(output, attr): |
| aux = getattr(output, attr) |
| if torch.is_tensor(aux): |
| aux_outputs.append(aux) |
| return aux_outputs |
|
|
| def forward_with_loss( |
| model: nn.Module, |
| inputs: torch.Tensor, |
| labels: torch.Tensor, |
| criterion, |
| aux_weight: float = 0.3, |
| ): |
| output = model(inputs) |
| logits = _extract_logits(output) |
| aux_outputs = _extract_aux_outputs(output) |
|
|
| loss = criterion(logits, labels) |
| if model.training and aux_outputs: |
| for aux in aux_outputs: |
| loss = loss + aux_weight * criterion(aux, labels) |
| return logits, loss |
|
|
| |
|
|
| class FocalLoss(nn.Module): |
|
|
| def __init__(self, alpha: Optional[torch.Tensor] = None, gamma: float = 2.0, reduction: str = "mean"): |
| super().__init__() |
| self.alpha = alpha |
| self.gamma = gamma |
| self.reduction = reduction |
|
|
| def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: |
| log_probs = F.log_softmax(inputs, dim=1) |
| log_pt = log_probs.gather(1, targets.unsqueeze(1)).squeeze(1) |
| pt = log_pt.exp() |
|
|
| loss = -((1.0 - pt) ** self.gamma) * log_pt |
|
|
| if self.alpha is not None: |
| alpha_t = self.alpha.to(inputs.device)[targets] |
| loss = alpha_t * loss |
|
|
| if self.reduction == "mean": |
| return loss.mean() |
| if self.reduction == "sum": |
| return loss.sum() |
| return loss |
|
|
| def build_criterion( |
| loss_type: str, |
| class_weights: Optional[torch.Tensor] = None, |
| focal_gamma: float = 2.0, |
| label_smoothing: float = 0.0, |
| ): |
| loss_type = loss_type.lower() |
| if loss_type == "focal": |
| return FocalLoss(alpha=class_weights, gamma=focal_gamma, reduction="mean") |
| if loss_type == "weighted_ce": |
| return nn.CrossEntropyLoss(weight=class_weights, label_smoothing=label_smoothing) |
| if loss_type == "ce": |
| return nn.CrossEntropyLoss(label_smoothing=label_smoothing) |
| raise ValueError(f"Unsupported loss_type: {loss_type}") |
|
|
|
|
| |
| |
| |
| |
| |
| def train_one_fold( |
| model_name: str, |
| model: nn.Module, |
| train_loader, |
| val_loader, |
| epochs: int = 90, |
| num_classes: int = 4, |
| backbone_lr: float = 3e-5, |
| head_lr: float = 1e-3, |
| class_weights: Optional[torch.Tensor] = None, |
| fold_id: int = 1, |
| save_dir: Optional[Path] = None, |
| freeze_backbone_epochs: int = 8, |
| max_grad_norm: float = 1.0, |
| primary_metric: str = PRIMARY_METRIC, |
| loss_type: str = "weighted_ce", |
| focal_gamma: float = 2.0, |
| label_smoothing: float = 0.0, |
| use_tta: bool = True, |
| ): |
| if save_dir is None: |
| save_dir = Path(model_name) |
| else: |
| save_dir = Path(save_dir) |
| ensure_dir(save_dir) |
|
|
| criterion = build_criterion( |
| loss_type=loss_type, |
| class_weights=class_weights, |
| focal_gamma=focal_gamma, |
| label_smoothing=label_smoothing, |
| ) |
| print( |
| f"Fold {fold_id}: Using loss_type='{loss_type}'" |
| f"{' with class weights' if class_weights is not None else ''}." |
| ) |
| print( |
| f"Fold {fold_id}: backbone_lr={backbone_lr}, head_lr={head_lr}, " |
| f"freeze_backbone_epochs={freeze_backbone_epochs}, " |
| f"epochs={epochs}, use_tta={use_tta}." |
| ) |
|
|
| param_groups = get_parameter_groups(model_name, model, backbone_lr, head_lr) |
| optimizer = torch.optim.AdamW(param_groups, betas=(0.9, 0.999), weight_decay=5e-4) |
|
|
|
|
| warmup_ep = freeze_backbone_epochs |
| sched_main = torch.optim.lr_scheduler.CosineAnnealingLR( |
| optimizer, |
| T_max=max(1, epochs - warmup_ep), |
| eta_min=1e-7, |
| ) |
| sched_warm = torch.optim.lr_scheduler.LinearLR( |
| optimizer, |
| start_factor=0.1, |
| end_factor=1.0, |
| total_iters=warmup_ep, |
| ) |
| scheduler = torch.optim.lr_scheduler.SequentialLR( |
| optimizer, |
| schedulers=[sched_warm, sched_main], |
| milestones=[warmup_ep], |
| ) |
|
|
| amp_enabled = device.type == "cuda" |
| scaler = torch.cuda.amp.GradScaler(enabled=amp_enabled) |
| class_names = [f"class{i}" for i in range(num_classes)] |
|
|
| best_monitor = -float("inf") |
| best_results = None |
| was_backbone_trainable = None |
| start_epoch = 0 |
|
|
| ckpt_path = save_dir / f"fold{fold_id}_checkpoint.pth" |
| if ckpt_path.is_file(): |
| print(f"Fold {fold_id}: found epoch-level checkpoint, attempting to resume...") |
| try: |
| ckpt = torch.load(ckpt_path, map_location=device, weights_only=False) |
| model.load_state_dict(ckpt["model_state_dict"]) |
| optimizer.load_state_dict(ckpt["optimizer_state_dict"]) |
| scheduler.load_state_dict(ckpt["scheduler_state_dict"]) |
| scaler.load_state_dict(ckpt["scaler_state_dict"]) |
| start_epoch = ckpt["epoch"] + 1 |
| best_monitor = ckpt["best_monitor"] |
| best_results = ckpt.get("best_results", None) |
| print( |
| f"Fold {fold_id}: resumed from epoch {start_epoch}/{epochs} " |
| f"(best {primary_metric}={best_monitor:.2f})." |
| ) |
| except Exception as exc: |
| print(f"Fold {fold_id}: failed to load checkpoint ({exc}), training from scratch.") |
| start_epoch = 0 |
| best_monitor = -float("inf") |
| best_results = None |
| |
| for epoch in tqdm( |
| range(start_epoch, epochs), |
| desc=f"Fold {fold_id}", |
| leave=False, |
| initial=start_epoch, |
| total=epochs, |
| ): |
| train_backbone = epoch >= freeze_backbone_epochs |
| if was_backbone_trainable is None or was_backbone_trainable != train_backbone: |
| set_backbone_trainable(model_name, model, train_backbone=train_backbone) |
| stage = "unfrozen" if train_backbone else "frozen" |
| print(f"Fold {fold_id}: backbone is now {stage} (epoch {epoch + 1}).") |
| was_backbone_trainable = train_backbone |
|
|
| model.train() |
| if not train_backbone: |
| set_frozen_backbone_bn_eval(model_name, model) |
|
|
| run_loss = 0.0 |
|
|
| for inputs, labels, _meta in train_loader: |
| inputs = inputs.to(device, non_blocking=(device.type == "cuda")) |
| labels = labels.to(device, non_blocking=(device.type == "cuda")) |
|
|
| optimizer.zero_grad(set_to_none=True) |
| amp_ctx = torch.cuda.amp.autocast if amp_enabled else nullcontext |
| with amp_ctx(): |
| logits, loss = forward_with_loss(model, inputs, labels, criterion, aux_weight=0.3) |
|
|
| scaler.scale(loss).backward() |
| if max_grad_norm is not None and max_grad_norm > 0: |
| scaler.unscale_(optimizer) |
| torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) |
| scaler.step(optimizer) |
| scaler.update() |
|
|
| run_loss += loss.item() * inputs.size(0) |
|
|
| scheduler.step() |
| ep_loss = run_loss / len(train_loader.dataset) |
|
|
| |
| model.eval() |
| all_t, all_p, all_paths, all_probs = [], [], [], [] |
| all_patients, all_image_names = [], [] |
|
|
| with torch.no_grad(): |
| for inputs, labels, meta in val_loader: |
| inputs = inputs.to(device, non_blocking=(device.type == "cuda")) |
| labels = labels.to(device, non_blocking=(device.type == "cuda")) |
|
|
| if use_tta: |
| |
| probs = predict_with_tta(model, inputs, amp_enabled=amp_enabled) |
| else: |
| amp_ctx = torch.cuda.amp.autocast if amp_enabled else nullcontext |
| with amp_ctx(): |
| output = model(inputs) |
| logits = _extract_logits(output) |
| probs = torch.softmax(logits, dim=1) |
|
|
| pred = probs.argmax(dim=1) |
|
|
| all_t.extend(labels.cpu().numpy().tolist()) |
| all_p.extend(pred.cpu().numpy().tolist()) |
| all_probs.extend(probs.cpu().numpy().tolist()) |
| all_paths.extend(list(meta["path"])) |
| all_patients.extend(list(meta["patient"])) |
| all_image_names.extend(list(meta["image_name"])) |
|
|
| metrics, report = compute_metrics(all_t, all_p, num_classes, class_names) |
| monitor = metrics[primary_metric] |
|
|
| if (epoch + 1) % 5 == 0 or epoch == epochs - 1 or epoch == start_epoch: |
| print( |
| f"F{fold_id} E{epoch + 1}/{epochs} " |
| f"Loss={ep_loss:.4f} " |
| f"Macro-F1={metrics['macro_f1']:.2f}% " |
| f"BA={metrics['balanced_accuracy']:.2f}% " |
| f"Acc={metrics['accuracy']:.2f}%" |
| f"{' [TTA]' if use_tta else ''}" |
| ) |
|
|
| improved = (monitor > best_monitor) or ( |
| np.isclose(monitor, best_monitor) |
| and best_results is not None |
| and metrics["balanced_accuracy"] > best_results["metrics"]["balanced_accuracy"] |
| ) |
|
|
| if improved: |
| best_monitor = monitor |
| best_results = { |
| "best_epoch": epoch + 1, |
| "metrics": metrics, |
| "classification_report": report, |
| "predictions": all_p, |
| "targets": all_t, |
| "image_path": all_paths, |
| "patients": all_patients, |
| "image_names": all_image_names, |
| "probabilities": all_probs, |
| "num_classes": num_classes, |
| "per_class": [ |
| report.get(f"class{i}", {"precision": 0, "recall": 0, "f1-score": 0}) |
| for i in range(num_classes) |
| ], |
| } |
| save_fold_results(best_results, save_dir, tag=f"fold{fold_id}_best") |
| torch.save(model.state_dict(), save_dir / f"fold{fold_id}_best.pth") |
| torch.save({ |
| "epoch": epoch, |
| "model_state_dict": model.state_dict(), |
| "optimizer_state_dict": optimizer.state_dict(), |
| "scheduler_state_dict": scheduler.state_dict(), |
| "scaler_state_dict": scaler.state_dict(), |
| "best_monitor": best_monitor, |
| "best_results": best_results, |
| }, ckpt_path) |
|
|
| if ckpt_path.is_file(): |
| ckpt_path.unlink() |
| print(f"Fold {fold_id}: removed epoch-level checkpoint (training complete).") |
|
|
| if best_results is None: |
| raise RuntimeError(f"Fold {fold_id}: no valid result was produced.") |
|
|
| return best_results |
|
|
|
|
|
|
| def build_model_registry(): |
| reg = {} |
|
|
| reg["DenseNet161"] = lambda: models.densenet161(weights=models.DenseNet161_Weights.DEFAULT) |
| reg["ConvNeXt_Tiny"] = lambda: models.convnext_tiny(weights=models.ConvNeXt_Tiny_Weights.DEFAULT) |
| reg["ViT_B_16"] = lambda: models.vit_b_16(weights=models.ViT_B_16_Weights.DEFAULT) |
|
|
| if HAS_TIMM: |
| reg["SwinV2_T"] = lambda: timm.create_model( |
| "swinv2_tiny_window8_256", |
| pretrained=True, |
| img_size=512, |
| ) |
| reg["DeiT3_S"] = lambda: timm.create_model( |
| "deit3_small_patch16_224", |
| pretrained=True, |
| img_size=512, |
| ) |
| else: |
| print("Skipping timm models because timm is not installed.") |
|
|
| return reg |
|
|
|
|
| def parse_args(): |
| parser = argparse.ArgumentParser(description="ROP benchmark training with patient-grouped 5-fold CV (v7).") |
|
|
| boolean_action = getattr(argparse, "BooleanOptionalAction", None) |
|
|
| parser.add_argument( |
| "--excel_path", |
| type=str, |
| default="/media/fang/9fc99a7b-15d6-4e22-ab05-fe46e6058c39/felicia/Downloads/医生审核之后第一版12-17/部分公开数据集/公开数据集训练表_调整数据1.xlsx", |
| help="Path to Excel with at least columns: patient, path, label.", |
| ) |
| parser.add_argument("--group_col", type=str, default="patient", help="Grouping column for leakage-free split.") |
| parser.add_argument("--num_classes", type=int, default=4) |
| |
| parser.add_argument("--epochs", type=int, default=90) |
| parser.add_argument("--n_folds", type=int, default=5) |
| parser.add_argument("--batch_size", type=int, default=8) |
| |
| parser.add_argument("--backbone_lr", type=float, default=3e-5) |
| |
| parser.add_argument("--head_lr", type=float, default=1e-3) |
| parser.add_argument("--random_seed", type=int, default=42) |
| parser.add_argument("--num_workers", type=int, default=min(8, os.cpu_count() or 2)) |
|
|
| parser.add_argument( |
| "--balance_mode", |
| type=str, |
| default="loss", |
| choices=["none", "loss", "sampler"], |
| help="Imbalance handling. 'loss' computes class weights; 'sampler' uses WeightedRandomSampler.", |
| ) |
| parser.add_argument( |
| "--loss_type", |
| type=str, |
| default="weighted_ce", |
| choices=["weighted_ce", "focal", "ce"], |
| help="weighted_ce is the recommended default.", |
| ) |
| parser.add_argument("--focal_gamma", type=float, default=2.0) |
| parser.add_argument("--label_smoothing", type=float, default=0.0) |
| |
| parser.add_argument("--freeze_backbone_epochs", type=int, default=8) |
| parser.add_argument("--max_grad_norm", type=float, default=1.0) |
| parser.add_argument("--output_root", type=str, default="runs_rop_V7_old") |
|
|
| if boolean_action is not None: |
| parser.add_argument("--use_tta", action=boolean_action, default=True, |
| help="Enable 4-way TTA (flip) during validation.") |
| parser.add_argument("--deterministic", action=boolean_action, default=False) |
| else: |
| parser.add_argument("--use_tta", action="store_true", default=True, |
| help="Enable 4-way TTA (flip) during validation.") |
| parser.add_argument("--no_tta", dest="use_tta", action="store_false") |
| parser.add_argument("--deterministic", action="store_true", default=False) |
|
|
| parser.add_argument( |
| "--models", |
| nargs="*", |
| default=None, |
| help="Optional subset of model names to train (e.g. --models DenseNet161 ViT_B_16).", |
| ) |
|
|
| return parser.parse_args() |
|
|
|
|
| def main(): |
| args = parse_args() |
| seed_everything(args.random_seed, deterministic=args.deterministic) |
|
|
| print("\nLoading data...") |
| df = load_and_prepare_data(args.excel_path, group_col=args.group_col) |
|
|
| observed_num_classes = int(df["label"].nunique()) |
| if observed_num_classes != args.num_classes: |
| raise ValueError( |
| f"num_classes mismatch: args.num_classes={args.num_classes}, " |
| f"but observed labels in Excel imply {observed_num_classes} classes after remapping." |
| ) |
|
|
| fold_splits = build_fold_splits( |
| df=df, |
| n_folds=args.n_folds, |
| random_seed=args.random_seed, |
| group_col=args.group_col, |
| ) |
|
|
| model_registry = build_model_registry() |
| if args.models: |
| selected = {k: v for k, v in model_registry.items() if k in set(args.models)} |
| missing = [m for m in args.models if m not in model_registry] |
| if missing: |
| print(f"Warning: these models were not found and will be ignored: {missing}") |
| model_registry = selected |
|
|
| print(f"\nTotal models to train: {len(model_registry)}") |
| for i, name in enumerate(model_registry, 1): |
| print(f"{i:2d}. {name}") |
|
|
| output_root = Path(args.output_root) |
| ensure_dir(output_root) |
|
|
| global_results = {} |
|
|
| for model_idx, (model_name, model_fn) in enumerate(model_registry.items(), 1): |
| print("\n" + "=" * 70) |
| print(f"[{model_idx}/{len(model_registry)}] Model: {model_name}") |
| print("=" * 70) |
|
|
| model_dir = output_root / model_name |
| ensure_dir(model_dir) |
|
|
| summary_path = model_dir / "kfold_summary.json" |
| if summary_path.is_file(): |
| try: |
| with open(summary_path, "r", encoding="utf-8") as f: |
| old = json.load(f) |
| old_summary = old.get("summary", {}) |
| if old_summary: |
| mean_primary = old_summary[PRIMARY_METRIC]["mean"] |
| std_primary = old_summary[PRIMARY_METRIC]["std"] |
| print( |
| f"[Skip] Found existing {args.n_folds}-fold summary: " |
| f"{PRIMARY_METRIC}={mean_primary:.2f}% +/- {std_primary:.2f}%" |
| ) |
| global_results[model_name] = (mean_primary, std_primary) |
| continue |
| except Exception: |
| pass |
|
|
| input_size = get_model_input_size(model_name) |
| print(f"Input size: {input_size}x{input_size}") |
|
|
| fold_results = [] |
|
|
| for fold_idx in range(args.n_folds): |
| fold_id = fold_idx + 1 |
| print(f"\n-- Fold {fold_id}/{args.n_folds} --") |
|
|
| metrics_json = model_dir / f"fold{fold_id}_best_metrics.json" |
| weight_path = model_dir / f"fold{fold_id}_best.pth" |
| if metrics_json.is_file() and weight_path.is_file(): |
| try: |
| with open(metrics_json, "r", encoding="utf-8") as f: |
| cached = json.load(f) |
| fold_results.append({ |
| "best_epoch": cached["best_epoch"], |
| "metrics": cached["metrics"], |
| "per_class": cached["per_class"], |
| }) |
| print( |
| f"Fold {fold_id}: cached result found " |
| f"(Macro-F1={cached['metrics']['macro_f1']:.2f}%, " |
| f"BA={cached['metrics']['balanced_accuracy']:.2f}%), skipped." |
| ) |
| continue |
| except Exception: |
| pass |
|
|
| train_idx, val_idx = fold_splits[fold_idx] |
| train_df = df.iloc[train_idx].reset_index(drop=True) |
| val_df = df.iloc[val_idx].reset_index(drop=True) |
|
|
| train_patients = set(train_df[args.group_col].astype(str).tolist()) |
| val_patients = set(val_df[args.group_col].astype(str).tolist()) |
| overlap = train_patients & val_patients |
| if overlap: |
| raise RuntimeError( |
| f"Leakage detected in fold {fold_id}: {len(overlap)} overlapping patients/groups." |
| ) |
|
|
| print(f"Train: {len(train_df)} | Validation: {len(val_df)}") |
| print( |
| f"Train patients: {train_df[args.group_col].nunique()} | " |
| f"Validation patients: {val_df[args.group_col].nunique()}" |
| ) |
| print( |
| f"Train class dist: {dict(train_df['label'].value_counts().sort_index())} | " |
| f"Val class dist: {dict(val_df['label'].value_counts().sort_index())}" |
| ) |
|
|
| train_loader, val_loader, class_weights = create_fold_loaders( |
| train_df=train_df, |
| val_df=val_df, |
| input_size=input_size, |
| batch_size=args.batch_size, |
| num_classes=args.num_classes, |
| balance_mode=args.balance_mode, |
| num_workers=args.num_workers, |
| ) |
|
|
| try: |
| model = model_fn() |
| except Exception as exc: |
| print(f"Model creation failed for {model_name}: {exc}") |
| break |
|
|
| model = replace_classifier(model_name, model, args.num_classes) |
| model = patch_vit_for_large_input(model, model_name, input_size) |
| model = configure_small_batch_behavior(model_name, model, args.batch_size) |
| model = model.to(device) |
|
|
| dummy = torch.randn(1, 3, input_size, input_size, device=device) |
| model.eval() |
| with torch.no_grad(): |
| out = model(dummy) |
| out = _extract_logits(out) |
| out_dim = out.shape[-1] |
| if out_dim != args.num_classes: |
| raise RuntimeError( |
| f"Fatal: classifier replacement failed for {model_name}. " |
| f"Output dim={out_dim}, expected={args.num_classes}." |
| ) |
| print(f"Forward sanity check passed: output dim={out_dim}") |
| del dummy, out |
| if device.type == "cuda": |
| torch.cuda.empty_cache() |
|
|
| result = train_one_fold( |
| model_name=model_name, |
| model=model, |
| train_loader=train_loader, |
| val_loader=val_loader, |
| epochs=args.epochs, |
| num_classes=args.num_classes, |
| backbone_lr=args.backbone_lr, |
| head_lr=args.head_lr, |
| class_weights=class_weights, |
| fold_id=fold_id, |
| save_dir=model_dir, |
| freeze_backbone_epochs=args.freeze_backbone_epochs, |
| max_grad_norm=args.max_grad_norm, |
| primary_metric=PRIMARY_METRIC, |
| loss_type=args.loss_type, |
| focal_gamma=args.focal_gamma, |
| label_smoothing=args.label_smoothing, |
| use_tta=args.use_tta, |
| ) |
| fold_results.append(result) |
|
|
| del model |
| if device.type == "cuda": |
| torch.cuda.empty_cache() |
|
|
| if len(fold_results) == args.n_folds: |
| mean_primary, std_primary = save_kfold_summary( |
| model_name, |
| fold_results, |
| args.num_classes, |
| model_dir, |
| ) |
| global_results[model_name] = (mean_primary, std_primary) |
| else: |
| print(f"Warning: {model_name} completed only {len(fold_results)}/{args.n_folds} folds.") |
|
|
| print("\n" + "=" * 70) |
| print(f"Global leaderboard ({args.n_folds}-Fold CV)") |
| print(f"Sorted by: {PRIMARY_METRIC}") |
| print("=" * 70) |
|
|
| sorted_results = sorted(global_results.items(), key=lambda x: x[1][0], reverse=True) |
| print(f"{'Rank':<6} {'Model':<25} {PRIMARY_METRIC:>12} {'Std':>10}") |
| print("-" * 62) |
| for rank, (name, (mean_primary, std_primary)) in enumerate(sorted_results, 1): |
| print(f"{rank:<6} {name:<25} {mean_primary:>11.2f}% {std_primary:>9.2f}%") |
|
|
| leaderboard_path = output_root / f"global_leaderboard_{PRIMARY_METRIC}.csv" |
| pd.DataFrame([ |
| { |
| "rank": idx + 1, |
| "model": name, |
| f"mean_{PRIMARY_METRIC}": mean_primary, |
| f"std_{PRIMARY_METRIC}": std_primary, |
| } |
| for idx, (name, (mean_primary, std_primary)) in enumerate(sorted_results) |
| ]).to_csv(leaderboard_path, index=False) |
| print(f"\nLeaderboard saved to: {leaderboard_path}") |
|
|
| if __name__ == "__main__": |
| main() |
|
|