import argparse import json import math import os import random from contextlib import nullcontext from pathlib import Path from typing import Dict, List, Optional, Tuple, Any import cv2 import numpy as np import pandas as pd import torch import torch.nn as nn import torch.nn.functional as F from PIL import Image from sklearn.metrics import ( accuracy_score, balanced_accuracy_score, classification_report, f1_score, precision_score, recall_score, ) from sklearn.model_selection import GroupKFold, StratifiedKFold from torch.utils.data import DataLoader, Dataset, WeightedRandomSampler from torchvision import models, transforms from torchvision.transforms import InterpolationMode from tqdm import tqdm try: from sklearn.model_selection import StratifiedGroupKFold HAS_STRATIFIED_GROUP_KFOLD = True except Exception: StratifiedGroupKFold = None HAS_STRATIFIED_GROUP_KFOLD = False try: import timm HAS_TIMM = True except ImportError: HAS_TIMM = False print("Warning: timm is not installed. timm-based models will be skipped.") PRIMARY_METRIC = "macro_f1" DEFAULT_INPUT_SIZE = 512 # Utilities def seed_everything(seed: int = 42, deterministic: bool = False) -> None: random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) os.environ["PYTHONHASHSEED"] = str(seed) if deterministic: torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True try: torch.use_deterministic_algorithms(True, warn_only=True) except Exception: pass else: torch.backends.cudnn.benchmark = True torch.backends.cudnn.deterministic = False def ensure_dir(path: Path) -> None: path.mkdir(parents=True, exist_ok=True) def to_jsonable(obj: Any): if isinstance(obj, dict): return {k: to_jsonable(v) for k, v in obj.items()} if isinstance(obj, list): return [to_jsonable(v) for v in obj] if isinstance(obj, tuple): return [to_jsonable(v) for v in obj] if isinstance(obj, (np.integer, np.floating)): return obj.item() return obj def name_matches_keywords(name: str, keywords: List[str]) -> bool: if not name: return False for kw in keywords: plain_kw = kw.rstrip(".") if kw in name or name == plain_kw or name.startswith(plain_kw + "."): return True return False # Device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") if device.type == "cuda": print(f"GPU: {torch.cuda.get_device_name(0)}") print(f"GPU memory: {torch.cuda.get_device_properties(0).total_memory / 1024 ** 3:.1f} GB") else: print("Warning: CUDA is not available. Training will be much slower on CPU.") # Model _VIT_KEYWORDS = [ "ViT", "Swin", "Transformer", "DeiT", "MaxViT", "CoAtNet", "EfficientFormer", "FastViT", "CaFormer", ] def _is_vit_family(model_name: str) -> bool: return any(kw.lower() in model_name.lower() for kw in _VIT_KEYWORDS) def _is_timm_model(model: nn.Module) -> bool: return hasattr(model, "get_classifier") and hasattr(model, "num_features") MODEL_INPUT_SIZES: Dict[str, int] = { "inception_v3": 299, } def get_model_input_size(model_name: str) -> int: return MODEL_INPUT_SIZES.get(model_name, DEFAULT_INPUT_SIZE) # Metrics / IO def compute_metrics( y_true: List[int], y_pred: List[int], num_classes: int, class_names: List[str], ) -> Tuple[Dict, Dict]: labels = list(range(num_classes)) report = classification_report( y_true, y_pred, labels=labels, target_names=class_names, output_dict=True, zero_division=0, ) metrics = { "accuracy": 100.0 * accuracy_score(y_true, y_pred), "balanced_accuracy": 100.0 * balanced_accuracy_score(y_true, y_pred), "macro_f1": 100.0 * f1_score(y_true, y_pred, labels=labels, average="macro", zero_division=0), "macro_precision": 100.0 * precision_score(y_true, y_pred, labels=labels, average="macro", zero_division=0), "macro_recall": 100.0 * recall_score(y_true, y_pred, labels=labels, average="macro", zero_division=0), "weighted_f1": 100.0 * f1_score(y_true, y_pred, labels=labels, average="weighted", zero_division=0), } return metrics, report def save_fold_results(results: Dict, save_dir: Path, tag: str = "best") -> None: ensure_dir(save_dir) report_df = pd.DataFrame(results["classification_report"]).transpose() with open(save_dir / f"test_report_{tag}.txt", "w", encoding="utf-8") as f: f.write(f"Primary Metric ({PRIMARY_METRIC}): {results['metrics'][PRIMARY_METRIC]:.4f}\n") f.write(f"Accuracy: {results['metrics']['accuracy']:.4f}\n") f.write(f"Balanced Accuracy: {results['metrics']['balanced_accuracy']:.4f}\n") f.write(f"Macro F1: {results['metrics']['macro_f1']:.4f}\n") f.write(f"Macro Recall: {results['metrics']['macro_recall']:.4f}\n") f.write(f"Macro Precision: {results['metrics']['macro_precision']:.4f}\n\n") f.write("Classification Report:\n") f.write(report_df.to_string()) pred_df = pd.DataFrame({ "patient": results["patients"], "image_name": results["image_names"], "True": results["targets"], "Predicted": results["predictions"], "path": results["image_path"], }) for c in range(results["num_classes"]): pred_df[f"prob_class{c}"] = [row[c] for row in results["probabilities"]] pred_df.to_csv(save_dir / f"predictions_{tag}.csv", index=False) payload = { "best_epoch": results["best_epoch"], "primary_metric": PRIMARY_METRIC, "metrics": results["metrics"], "per_class": [ results["classification_report"].get( f"class{i}", {"precision": 0, "recall": 0, "f1-score": 0} ) for i in range(results["num_classes"]) ], } with open(save_dir / f"{tag}_metrics.json", "w", encoding="utf-8") as f: json.dump(to_jsonable(payload), f, indent=2, ensure_ascii=False) def save_kfold_summary( model_name: str, fold_results: List[Dict], num_classes: int, save_dir: Path, ) -> Tuple[float, float]: ensure_dir(save_dir) metric_names = [ "accuracy", "balanced_accuracy", "macro_f1", "macro_recall", "macro_precision", "weighted_f1", ] summary = {} for name in metric_names: values = [r["metrics"][name] for r in fold_results] summary[name] = { "mean": float(np.mean(values)), "std": float(np.std(values)), } lines = [ "=" * 70, f"Model: {model_name}", "5-Fold Cross-Validation Summary", f"Primary Metric: {PRIMARY_METRIC}", "=" * 70, "", ] for i, r in enumerate(fold_results, 1): lines.append( f"Fold {i}: Macro-F1={r['metrics']['macro_f1']:.2f}% | " f"BA={r['metrics']['balanced_accuracy']:.2f}% | " f"Acc={r['metrics']['accuracy']:.2f}% | " f"BestEpoch={r['best_epoch']}" ) lines.append("") for name in metric_names: lines.append(f"{name}: {summary[name]['mean']:.2f}% +/- {summary[name]['std']:.2f}%") lines.append("") lines.append("Per-class metrics (mean +/- std)") lines.append(f"{'class':<10} {'precision':>18} {'recall':>18} {'f1-score':>18}") per_class_summary = {} for c in range(num_classes): ps = [r["per_class"][c]["precision"] for r in fold_results] rs = [r["per_class"][c]["recall"] for r in fold_results] fs = [r["per_class"][c]["f1-score"] for r in fold_results] per_class_summary[c] = { "precision_mean": float(np.mean(ps)), "precision_std": float(np.std(ps)), "recall_mean": float(np.mean(rs)), "recall_std": float(np.std(rs)), "f1_mean": float(np.mean(fs)), "f1_std": float(np.std(fs)), } lines.append( f"class{c:<5} " f"{np.mean(ps):.4f}+/-{np.std(ps):.4f}" f"{np.mean(rs):>18.4f}+/-{np.std(rs):.4f}" f"{np.mean(fs):>18.4f}+/-{np.std(fs):.4f}" ) text = "\n".join(lines) print(text) with open(save_dir / "kfold_summary.txt", "w", encoding="utf-8") as f: f.write(text) with open(save_dir / "kfold_summary.json", "w", encoding="utf-8") as f: json.dump( to_jsonable({ "model": model_name, "primary_metric": PRIMARY_METRIC, "summary": summary, "per_class": per_class_summary, }), f, indent=2, ensure_ascii=False, ) all_targets, all_predictions, all_paths = [], [], [] all_patients, all_image_names = [], [] all_probabilities = [] pooled_ready = all( "targets" in r and "predictions" in r and "image_path" in r and "probabilities" in r for r in fold_results ) if pooled_ready: for r in fold_results: all_targets.extend(r["targets"]) all_predictions.extend(r["predictions"]) all_paths.extend(r["image_path"]) all_patients.extend(r["patients"]) all_image_names.extend(r["image_names"]) all_probabilities.extend(r["probabilities"]) class_names = [f"class{i}" for i in range(num_classes)] pooled_metrics, pooled_report = compute_metrics( all_targets, all_predictions, num_classes, class_names, ) with open(save_dir / "oof_report.txt", "w", encoding="utf-8") as f: f.write("Pooled out-of-fold metrics\n") f.write(f"Primary Metric ({PRIMARY_METRIC}): {pooled_metrics[PRIMARY_METRIC]:.4f}\n") for k, v in pooled_metrics.items(): f.write(f"{k}: {v:.4f}\n") f.write("\nClassification Report:\n") f.write(pd.DataFrame(pooled_report).transpose().to_string()) oof_df = pd.DataFrame({ "patient": all_patients, "image_name": all_image_names, "True": all_targets, "Predicted": all_predictions, "path": all_paths, }) for c in range(num_classes): oof_df[f"prob_class{c}"] = [row[c] for row in all_probabilities] oof_df.to_csv(save_dir / "oof_predictions.csv", index=False) return summary[PRIMARY_METRIC]["mean"], summary[PRIMARY_METRIC]["std"] # 去掉黑边裁切,在 CLAHE + 绿色增强后增加眼底区域蒙版 class BlackBorderCrop: """Crop black borders and obvious invalid background around the fundus.""" def __init__(self, threshold: int = 10, margin_ratio: float = 0.02): self.threshold = threshold self.margin_ratio = margin_ratio def __call__(self, pil_img: Image.Image) -> Image.Image: img = np.array(pil_img.convert("RGB")) gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) mask = gray > self.threshold if mask.sum() < 64: return pil_img.convert("RGB") ys, xs = np.where(mask) y1, y2 = ys.min(), ys.max() x1, x2 = xs.min(), xs.max() margin_y = int((y2 - y1 + 1) * self.margin_ratio) margin_x = int((x2 - x1 + 1) * self.margin_ratio) y1 = max(0, y1 - margin_y) y2 = min(img.shape[0], y2 + margin_y + 1) x1 = max(0, x1 - margin_x) x2 = min(img.shape[1], x2 + margin_x + 1) cropped = img[y1:y2, x1:x2] return Image.fromarray(cropped) class FundusCircularCrop: def __init__(self, threshold: int = 8, radius_pad_ratio: float = 0.03): self.threshold = threshold self.radius_pad_ratio = radius_pad_ratio def __call__(self, pil_img: Image.Image) -> Image.Image: img = np.array(pil_img.convert("RGB")) gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) mask = (gray > self.threshold).astype(np.uint8) * 255 kernel = np.ones((5, 5), np.uint8) mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel) mask = cv2.medianBlur(mask, 5) contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) if not contours: return Image.fromarray(img) largest = max(contours, key=cv2.contourArea) (cx, cy), radius = cv2.minEnclosingCircle(largest) if radius < 10: return Image.fromarray(img) radius = radius * (1.0 + self.radius_pad_ratio) cx, cy, radius = float(cx), float(cy), float(radius) x1 = max(0, int(cx - radius)) y1 = max(0, int(cy - radius)) x2 = min(img.shape[1], int(cx + radius)) y2 = min(img.shape[0], int(cy + radius)) cropped = img[y1:y2, x1:x2] h, w = cropped.shape[:2] if h < 2 or w < 2: return Image.fromarray(img) local_cx = cx - x1 local_cy = cy - y1 rr = max(1, min(int(radius), min(h, w) // 2)) yy, xx = np.ogrid[:h, :w] circle_mask = ((xx - local_cx) ** 2 + (yy - local_cy) ** 2) <= (rr ** 2) out = np.zeros_like(cropped) out[circle_mask] = cropped[circle_mask] return Image.fromarray(out) class ResizeToSquare: def __init__(self, size: int): self.size = size def __call__(self, pil_img: Image.Image) -> Image.Image: return pil_img.resize((self.size, self.size), resample=Image.BILINEAR) class LightCLAHE: def __init__(self, clip_limit: float = 2.0, grid: Tuple[int, int] = (8, 8)): self.clahe = cv2.createCLAHE(clipLimit=clip_limit, tileGridSize=grid) def __call__(self, pil_img: Image.Image) -> Image.Image: img = np.array(pil_img.convert("RGB")) lab = cv2.cvtColor(img, cv2.COLOR_RGB2LAB) l, a, b = cv2.split(lab) l = self.clahe.apply(l) out = cv2.cvtColor(cv2.merge([l, a, b]), cv2.COLOR_LAB2RGB) return Image.fromarray(out) class GreenChannelEnhancement: def __init__( self, clip_limit: float = 2.5, grid: Tuple[int, int] = (8, 8), blend_alpha: float = 0.30, ): self.clahe = cv2.createCLAHE(clipLimit=clip_limit, tileGridSize=grid) self.blend_alpha = blend_alpha def __call__(self, pil_img: Image.Image) -> Image.Image: img = np.array(pil_img.convert("RGB")) r, g, b = cv2.split(img) g_eq = self.clahe.apply(g) g_new = cv2.addWeighted(g, 1.0 - self.blend_alpha, g_eq, self.blend_alpha, 0.0) out = cv2.merge([r, g_new, b]) return Image.fromarray(out) class FundusEyeMask: def __init__( self, threshold: int = 8, radius_pad_ratio: float = 0.03, morph_kernel: int = 7, blur_kernel: int = 5, ): self.threshold = threshold self.radius_pad_ratio = radius_pad_ratio self.morph_kernel = morph_kernel self.blur_kernel = blur_kernel def __call__(self, pil_img: Image.Image) -> Image.Image: img = np.array(pil_img.convert("RGB")) gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY) # Robust threshold against dark background after CLAHE + green enhancement _, mask = cv2.threshold(gray, self.threshold, 255, cv2.THRESH_BINARY) kernel = np.ones((self.morph_kernel, self.morph_kernel), np.uint8) mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel) mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel) contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) if not contours: return Image.fromarray(img) largest = max(contours, key=cv2.contourArea) (cx, cy), radius = cv2.minEnclosingCircle(largest) if radius < 10: return Image.fromarray(img) radius = radius * (1.0 + self.radius_pad_ratio) yy, xx = np.ogrid[:img.shape[0], :img.shape[1]] circle_mask = (((xx - cx) ** 2 + (yy - cy) ** 2) <= (radius ** 2)).astype(np.uint8) * 255 if self.blur_kernel and self.blur_kernel > 1: k = self.blur_kernel if self.blur_kernel % 2 == 1 else self.blur_kernel + 1 circle_mask = cv2.GaussianBlur(circle_mask, (k, k), 0) circle_mask_f = (circle_mask.astype(np.float32) / 255.0)[..., None] out = (img.astype(np.float32) * circle_mask_f).clip(0, 255).astype(np.uint8) return Image.fromarray(out) _light_clahe = LightCLAHE() _green_enhance = GreenChannelEnhancement() _eye_mask = FundusEyeMask() _transform_cache: Dict[int, Tuple[transforms.Compose, transforms.Compose]] = {} def build_transforms(input_size: int = DEFAULT_INPUT_SIZE): """ 预处理流程: - 不再使用 BlackBorderCrop - 缩放到 input_size → CLAHE → 绿色通道增强 → 眼底区域蒙版 - 蒙版仅保留眼睛区域,屏蔽眼底边缘外的无关像素 训练增强: - 水平翻转 + 垂直翻转 - 小角度随机旋转 (±15°) + 轻微平移 + 尺度扰动 (0.85~1.15) - 适度 ColorJitter - 轻微高斯模糊 """ if input_size in _transform_cache: return _transform_cache[input_size] preprocess = [ ResizeToSquare(input_size), _light_clahe, _green_enhance, _eye_mask, ] train_tf = transforms.Compose( preprocess + [ transforms.RandomHorizontalFlip(p=0.5), transforms.RandomVerticalFlip(p=0.5), transforms.RandomAffine( degrees=15, translate=(0.05, 0.05), scale=(0.85, 1.15), interpolation=InterpolationMode.BILINEAR, fill=0, ), transforms.ColorJitter( brightness=0.20, contrast=0.20, saturation=0.10, hue=0.02, ), transforms.GaussianBlur(kernel_size=3, sigma=(0.1, 0.8)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ] ) val_tf = transforms.Compose( preprocess + [ transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), ] ) _transform_cache[input_size] = (train_tf, val_tf) return train_tf, val_tf # TTA (Test-Time Augmentation) # 增:4 路 TTA — 原图 / 水平翻转 / 垂直翻转 / 双向翻转 def predict_with_tta( model: nn.Module, inputs: torch.Tensor, amp_enabled: bool = False, ) -> torch.Tensor: amp_ctx = torch.cuda.amp.autocast if amp_enabled else nullcontext aug_variants = [ inputs, # 原图 inputs.flip(-1), # 水平翻转 inputs.flip(-2), # 垂直翻转 inputs.flip(-1).flip(-2), # 双向翻转 ] probs_list = [] for aug in aug_variants: with amp_ctx(): out = model(aug) logits = _extract_logits(out) probs_list.append(torch.softmax(logits, dim=1)) return torch.stack(probs_list, dim=0).mean(dim=0) # Dataset class ImageDataset(Dataset): def __init__(self, df: pd.DataFrame, transform=None): self.df = df.reset_index(drop=True).copy() self.transform = transform self.paths = self.df["path"].astype(str).tolist() self.labels = self.df["label"].astype(int).tolist() self.patients = self.df["patient"].astype(str).tolist() if "image_name" in self.df.columns: self.image_names = self.df["image_name"].astype(str).tolist() else: self.image_names = [Path(p).name for p in self.paths] def __len__(self) -> int: return len(self.paths) def __getitem__(self, idx: int): img_path = self.paths[idx] label = self.labels[idx] try: image = Image.open(img_path).convert("RGB") except Exception as exc: raise RuntimeError(f"Failed to open image: {img_path}") from exc if self.transform is not None: image = self.transform(image) meta = { "path": img_path, "patient": self.patients[idx], "image_name": self.image_names[idx], } return image, torch.tensor(label, dtype=torch.long), meta # Data loading / grouped splitting def validate_image_paths(df: pd.DataFrame, path_col: str = "path") -> pd.DataFrame: total = len(df) mask = df[path_col].apply(os.path.isfile) missing = total - int(mask.sum()) if missing > 0: print(f"Warning: {missing}/{total} paths do not exist and will be removed.") df = df.loc[mask].reset_index(drop=True) else: print(f"All {total} image paths are valid.") return df def load_and_prepare_data(excel_path: str, group_col: str = "patient") -> pd.DataFrame: df = pd.read_excel(excel_path, engine="openpyxl") required_cols = {"path", "label", group_col} missing_cols = required_cols - set(df.columns) if missing_cols: raise KeyError(f"Missing required columns in Excel: {sorted(missing_cols)}") df = df.copy() df[group_col] = df[group_col].astype(str).str.strip() if df[group_col].isin(["", "nan", "None"]).any(): bad_rows = int(df[group_col].isin(["", "nan", "None"]).sum()) raise ValueError(f"Found {bad_rows} rows with invalid patient/group identifiers in column '{group_col}'.") df["label"] = df["label"].replace({"AROP": 5}) df["label"] = pd.to_numeric(df["label"], errors="raise").astype(int) if df["label"].min() == 1: df["label"] = df["label"] - 1 # Merge old labels 4 and 5 into class 3 -> final 4-class setup df["label"] = df["label"].replace({4: 3, 5: 3}) df = validate_image_paths(df, path_col="path") if "patient" != group_col: df["patient"] = df[group_col].astype(str) unique_labels = sorted(df["label"].unique().tolist()) print(f"Dataset size: {len(df)} images") print(f"Unique patients: {df[group_col].nunique()}") print(f"Class distribution: {dict(df['label'].value_counts().sort_index())}") print(f"Observed labels: {unique_labels}") return df def _approximate_group_stratified_splits( df: pd.DataFrame, n_folds: int, random_seed: int, group_col: str, ): group_df = ( df.groupby(group_col)["label"] .agg(lambda x: x.value_counts().index[0]) .reset_index() ) if group_df[group_col].nunique() < n_folds: raise ValueError( f"Number of unique groups ({group_df[group_col].nunique()}) is smaller than n_folds={n_folds}." ) skf = StratifiedKFold(n_splits=n_folds, shuffle=True, random_state=random_seed) splits = [] group_ids = group_df[group_col].values group_labels = group_df["label"].values for group_train_idx, group_val_idx in skf.split(group_ids, group_labels): train_groups = set(group_ids[group_train_idx]) val_groups = set(group_ids[group_val_idx]) train_idx = df.index[df[group_col].isin(train_groups)].to_numpy() val_idx = df.index[df[group_col].isin(val_groups)].to_numpy() splits.append((train_idx, val_idx)) return splits def build_fold_splits( df: pd.DataFrame, n_folds: int, random_seed: int, group_col: str = "patient", ): groups = df[group_col].astype(str).values labels = df["label"].values if len(np.unique(groups)) < n_folds: raise ValueError( f"Unique groups in '{group_col}' = {len(np.unique(groups))}, which is smaller than n_folds={n_folds}." ) if HAS_STRATIFIED_GROUP_KFOLD: print( f"Using StratifiedGroupKFold with group_col='{group_col}', n_folds={n_folds}, seed={random_seed}." ) try: splitter = StratifiedGroupKFold( n_splits=n_folds, shuffle=True, random_state=random_seed, ) splits = list(splitter.split(df, y=labels, groups=groups)) except ValueError as exc: print(f"StratifiedGroupKFold failed: {exc}") print("Falling back to approximate grouped stratification using patient-majority labels.") splits = _approximate_group_stratified_splits(df, n_folds, random_seed, group_col) else: print("StratifiedGroupKFold is unavailable. Falling back to approximate grouped stratification.") splits = _approximate_group_stratified_splits(df, n_folds, random_seed, group_col) for fold_id, (train_idx, val_idx) in enumerate(splits, 1): train_groups = set(df.iloc[train_idx][group_col].astype(str).tolist()) val_groups = set(df.iloc[val_idx][group_col].astype(str).tolist()) overlap = train_groups & val_groups if overlap: raise RuntimeError( f"Data leakage detected in fold {fold_id}: {len(overlap)} overlapping groups." ) return splits def _compute_class_weights(train_df: pd.DataFrame, num_classes: int) -> torch.Tensor: counts = train_df["label"].value_counts().sort_index() total = len(train_df) weights = [total / (num_classes * counts.get(c, 1)) for c in range(num_classes)] return torch.tensor(weights, dtype=torch.float32, device=device) def _make_weighted_sampler(train_df: pd.DataFrame) -> WeightedRandomSampler: counts = train_df["label"].value_counts().to_dict() sample_weights = train_df["label"].map(lambda x: 1.0 / counts[x]).astype(float).values sample_weights = torch.as_tensor(sample_weights, dtype=torch.double) return WeightedRandomSampler(sample_weights, num_samples=len(sample_weights), replacement=True) def create_fold_loaders( train_df: pd.DataFrame, val_df: pd.DataFrame, input_size: int = DEFAULT_INPUT_SIZE, batch_size: int = 8, num_classes: int = 4, balance_mode: str = "loss", num_workers: int = 4, ): train_tf, val_tf = build_transforms(input_size) sampler = None class_weights = None if balance_mode == "sampler":+ sampler = _make_weighted_sampler(train_df) print("Training loader uses WeightedRandomSampler for class balancing.") elif balance_mode == "loss": class_weights = _compute_class_weights(train_df, num_classes) print(f"Training loss uses class weights: {class_weights.detach().cpu().numpy().tolist()}") else: print("No imbalance correction is used.") drop_last = (batch_size > 1) and (len(train_df) % batch_size == 1) if drop_last: print( f"Training loader will drop the last singleton batch " f"(train_size={len(train_df)}, batch_size={batch_size}) to avoid BatchNorm issues." ) pin_memory = device.type == "cuda" train_loader = DataLoader( ImageDataset(train_df, train_tf), batch_size=batch_size, shuffle=(sampler is None), sampler=sampler, num_workers=num_workers, pin_memory=pin_memory, drop_last=drop_last, persistent_workers=(num_workers > 0), ) val_loader = DataLoader( ImageDataset(val_df, val_tf), batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=pin_memory, persistent_workers=(num_workers > 0), ) return train_loader, val_loader, class_weights # ViT positional embedding interpolation def patch_vit_for_large_input( model: nn.Module, model_name: str, input_size: int, ) -> nn.Module: if "ViT" not in model_name: return model if not (hasattr(model, "encoder") and hasattr(model.encoder, "pos_embedding")): print(f"Warning: cannot find pos_embedding for {model_name}, skip interpolation.") return model patch_size = model.patch_size expected_patches = (input_size // patch_size) ** 2 pos_embed = model.encoder.pos_embedding current_patches = pos_embed.shape[1] - 1 if current_patches == expected_patches: print(f"[ViT] pos_embedding already matches input_size={input_size}, no interpolation needed.") return model print( f"[ViT] Interpolating pos_embedding: {current_patches} -> {expected_patches} patches " f"for input_size={input_size}." ) cls_token = pos_embed[:, :1, :] patch_tokens = pos_embed[:, 1:, :] dim = patch_tokens.shape[-1] h_old = w_old = int(math.sqrt(current_patches)) h_new = w_new = int(math.sqrt(expected_patches)) patch_tokens = ( patch_tokens .reshape(1, h_old, w_old, dim) .permute(0, 3, 1, 2) .float() ) patch_tokens = F.interpolate( patch_tokens, size=(h_new, w_new), mode="bicubic", align_corners=False, ) patch_tokens = patch_tokens.permute(0, 2, 3, 1).reshape(1, expected_patches, dim) model.encoder.pos_embedding = nn.Parameter(torch.cat([cls_token, patch_tokens], dim=1)) if hasattr(model, "image_size"): model.image_size = input_size return model # Classifier replacement def _find_last_linear(module: nn.Module): if isinstance(module, nn.Linear): return module if isinstance(module, nn.Sequential): for child in reversed(list(module.children())): result = _find_last_linear(child) if result is not None: return result if hasattr(module, "head") and isinstance(module.head, (nn.Linear, nn.Sequential)): return _find_last_linear(module.head) return None def _verify_classifier(model: nn.Module, model_name: str, expected_classes: int) -> None: for attr_name in ["fc", "head", "classifier", "heads"]: if not hasattr(model, attr_name): continue layer = getattr(model, attr_name) last_linear = _find_last_linear(layer) if last_linear is not None: if last_linear.out_features != expected_classes: raise RuntimeError( f"Classifier replacement failed for {model_name}: " f"out_features={last_linear.out_features}, expected={expected_classes}" ) print(f"Verified {model_name}: classifier -> {expected_classes} classes (in={last_linear.in_features})") return print(f"Warning: failed to automatically verify classifier for {model_name}") def replace_classifier( model_name: str, model: nn.Module, num_classes: int, dropout: float = 0.3, ) -> nn.Module: if _is_timm_model(model): in_feat = model.num_features orig_classifier = model.get_classifier() print(f"[timm] {model_name}: original classifier={type(orig_classifier).__name__}, num_features={in_feat}") model.reset_classifier(num_classes) new_fc = model.get_classifier() wrapped = False if isinstance(new_fc, nn.Linear): for parent_attr, child_attr in [ ("head", "fc"), ("head", "head"), (None, "head"), (None, "classifier"), (None, "fc"), ]: try: parent = getattr(model, parent_attr) if parent_attr else model child = getattr(parent, child_attr) if child is new_fc: setattr( parent, child_attr, nn.Sequential(nn.Dropout(dropout), nn.Linear(in_feat, num_classes)), ) wrapped = True break except AttributeError: continue if not wrapped: print(f"[timm] {model_name}: reset_classifier({num_classes}) applied (no Dropout wrapper).") _verify_classifier(model, model_name, num_classes) return model n = model_name if "VGG" in n: in_feat = model.classifier[6].in_features model.classifier[6] = nn.Sequential(nn.Dropout(dropout), nn.Linear(in_feat, num_classes)) elif n == "inception_v3": aux_in = model.AuxLogits.fc.in_features model.AuxLogits.fc = nn.Sequential(nn.Dropout(dropout), nn.Linear(aux_in, num_classes)) fc_in = model.fc.in_features model.fc = nn.Sequential(nn.Dropout(dropout), nn.Linear(fc_in, num_classes)) elif "GoogLeNet" in n: in_feat = model.fc.in_features model.fc = nn.Sequential(nn.Dropout(dropout), nn.Linear(in_feat, num_classes)) if hasattr(model, "aux1") and model.aux1 is not None and hasattr(model.aux1, "fc2"): aux1_in = model.aux1.fc2.in_features model.aux1.fc2 = nn.Sequential(nn.Dropout(dropout), nn.Linear(aux1_in, num_classes)) if hasattr(model, "aux2") and model.aux2 is not None and hasattr(model.aux2, "fc2"): aux2_in = model.aux2.fc2.in_features model.aux2.fc2 = nn.Sequential(nn.Dropout(dropout), nn.Linear(aux2_in, num_classes)) elif "ResNe" in n: in_feat = model.fc.in_features model.fc = nn.Sequential(nn.Dropout(dropout), nn.Linear(in_feat, num_classes)) elif "DenseNet" in n: in_feat = model.classifier.in_features model.classifier = nn.Sequential(nn.Dropout(dropout), nn.Linear(in_feat, num_classes)) elif "MobileNet" in n: in_feat = model.classifier[-1].in_features model.classifier[-1] = nn.Sequential(nn.Dropout(dropout), nn.Linear(in_feat, num_classes)) elif "MnasNet" in n: in_feat = model.classifier[-1].in_features model.classifier[-1] = nn.Sequential(nn.Dropout(dropout), nn.Linear(in_feat, num_classes)) elif "EfficientNet" in n: in_feat = model.classifier[-1].in_features model.classifier[-1] = nn.Sequential(nn.Dropout(dropout), nn.Linear(in_feat, num_classes)) elif "ConvNeXt" in n: in_feat = model.classifier[-1].in_features model.classifier[-1] = nn.Sequential(nn.Dropout(dropout), nn.Linear(in_feat, num_classes)) elif "RegNet" in n or "ShuffleNet" in n: in_feat = model.fc.in_features model.fc = nn.Sequential(nn.Dropout(dropout), nn.Linear(in_feat, num_classes)) elif "ViT" in n: if hasattr(model, "heads") and hasattr(model.heads, "head"): in_feat = model.heads.head.in_features model.heads = nn.Sequential(nn.Dropout(dropout), nn.Linear(in_feat, num_classes)) elif hasattr(model, "head") and isinstance(model.head, nn.Linear): in_feat = model.head.in_features model.head = nn.Sequential(nn.Dropout(dropout), nn.Linear(in_feat, num_classes)) else: raise ValueError(f"Cannot find classifier head for {n}") elif "Swin" in n: if hasattr(model, "head") and isinstance(model.head, nn.Linear): in_feat = model.head.in_features model.head = nn.Sequential(nn.Dropout(dropout), nn.Linear(in_feat, num_classes)) elif hasattr(model, "heads") and hasattr(model.heads, "head"): in_feat = model.heads.head.in_features model.heads = nn.Sequential(nn.Dropout(dropout), nn.Linear(in_feat, num_classes)) else: raise ValueError(f"Cannot find classifier head for {n}") elif _is_vit_family(n): replaced = False for attr in ["heads.head", "head", "classifier"]: parts = attr.split(".") obj = model try: for p in parts: obj = getattr(obj, p) if isinstance(obj, nn.Linear): in_feat = obj.in_features parent = model for p in parts[:-1]: parent = getattr(parent, p) setattr(parent, parts[-1], nn.Sequential(nn.Dropout(dropout), nn.Linear(in_feat, num_classes))) replaced = True break except AttributeError: continue if not replaced: raise ValueError(f"Cannot find classifier head for {n}") else: replaced = False for attr_name in ["fc", "head", "classifier"]: if not hasattr(model, attr_name): continue layer = getattr(model, attr_name) if isinstance(layer, nn.Linear): in_feat = layer.in_features setattr(model, attr_name, nn.Sequential(nn.Dropout(dropout), nn.Linear(in_feat, num_classes))) replaced = True break if isinstance(layer, nn.Sequential) and len(layer) > 0 and isinstance(layer[-1], nn.Linear): in_feat = layer[-1].in_features layer[-1] = nn.Sequential(nn.Dropout(dropout), nn.Linear(in_feat, num_classes)) replaced = True break if not replaced: raise ValueError(f"Cannot automatically replace classifier for {n}") _verify_classifier(model, model_name, num_classes) return model # Optimizer groups / freezing def _get_head_keywords(model_name: str) -> List[str]: n = model_name if "VGG" in n: return ["classifier.6"] if n == "inception_v3": return ["fc.", "AuxLogits.fc"] if "GoogLeNet" in n: return ["fc.", "aux1.fc2", "aux2.fc2"] if "ResNe" in n: return ["fc."] if "DenseNet" in n: return ["classifier."] if "MobileNet" in n: return ["classifier.3", "classifier.2", "classifier."] if "MnasNet" in n: return ["classifier.1", "classifier."] if "EfficientNet" in n or "ConvNeXt" in n: return ["classifier.", "head.fc"] if "RegNet" in n or "ShuffleNet" in n: return ["fc."] if "ViT" in n or _is_vit_family(n): return ["heads.", "head.", "classifier."] return ["fc.", "classifier.", "head.", "heads."] def get_parameter_groups( model_name: str, model: nn.Module, backbone_lr: float = 3e-5, head_lr: float = 1e-3, ): head_kw = _get_head_keywords(model_name) head_p, back_p = [], [] for name, param in model.named_parameters(): if name_matches_keywords(name, head_kw): head_p.append(param) else: back_p.append(param) if not head_p: print(f"Warning: no head parameters matched for {model_name}; all params use head_lr.") return [{"params": list(model.parameters()), "lr": head_lr}] print( f"Parameter groups | backbone: {sum(p.numel() for p in back_p):,} (lr={backbone_lr}) | " f"head: {sum(p.numel() for p in head_p):,} (lr={head_lr})" ) return [{"params": back_p, "lr": backbone_lr}, {"params": head_p, "lr": head_lr}] def set_backbone_trainable(model_name: str, model: nn.Module, train_backbone: bool) -> None: head_kw = _get_head_keywords(model_name) for name, param in model.named_parameters(): is_head = name_matches_keywords(name, head_kw) param.requires_grad = train_backbone or is_head def set_frozen_backbone_bn_eval(model_name: str, model: nn.Module) -> None: head_kw = _get_head_keywords(model_name) bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm) for name, module in model.named_modules(): if isinstance(module, bn_types) and not name_matches_keywords(name, head_kw): module.eval() for param in module.parameters(): param.requires_grad = False def configure_small_batch_behavior(model_name: str, model: nn.Module, batch_size: int) -> nn.Module: if batch_size >= 2: return model if model_name == "inception_v3": print("batch_size=1 detected: disabling Inception auxiliary classifier.") if hasattr(model, "aux_logits"): model.aux_logits = False if hasattr(model, "AuxLogits"): model.AuxLogits = None elif "GoogLeNet" in model_name: print("batch_size=1 detected: disabling GoogLeNet auxiliary classifiers.") if hasattr(model, "aux_logits"): model.aux_logits = False if hasattr(model, "aux1"): model.aux1 = None if hasattr(model, "aux2"): model.aux2 = None return model # Forward helpers def _extract_logits(output): if torch.is_tensor(output): return output if hasattr(output, "logits") and torch.is_tensor(output.logits): return output.logits if isinstance(output, (tuple, list)) and len(output) > 0 and torch.is_tensor(output[0]): return output[0] raise TypeError("Unable to extract logits from model output.") def _extract_aux_outputs(output): aux_outputs = [] if isinstance(output, (tuple, list)): aux_outputs.extend([o for o in output[1:] if torch.is_tensor(o)]) else: for attr in ["aux_logits", "aux_logits2", "aux_logits1"]: if hasattr(output, attr): aux = getattr(output, attr) if torch.is_tensor(aux): aux_outputs.append(aux) return aux_outputs def forward_with_loss( model: nn.Module, inputs: torch.Tensor, labels: torch.Tensor, criterion, aux_weight: float = 0.3, ): output = model(inputs) logits = _extract_logits(output) aux_outputs = _extract_aux_outputs(output) loss = criterion(logits, labels) if model.training and aux_outputs: for aux in aux_outputs: loss = loss + aux_weight * criterion(aux, labels) return logits, loss # Losses class FocalLoss(nn.Module): def __init__(self, alpha: Optional[torch.Tensor] = None, gamma: float = 2.0, reduction: str = "mean"): super().__init__() self.alpha = alpha self.gamma = gamma self.reduction = reduction def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: log_probs = F.log_softmax(inputs, dim=1) log_pt = log_probs.gather(1, targets.unsqueeze(1)).squeeze(1) pt = log_pt.exp() loss = -((1.0 - pt) ** self.gamma) * log_pt if self.alpha is not None: alpha_t = self.alpha.to(inputs.device)[targets] loss = alpha_t * loss if self.reduction == "mean": return loss.mean() if self.reduction == "sum": return loss.sum() return loss def build_criterion( loss_type: str, class_weights: Optional[torch.Tensor] = None, focal_gamma: float = 2.0, label_smoothing: float = 0.0, ): loss_type = loss_type.lower() if loss_type == "focal": return FocalLoss(alpha=class_weights, gamma=focal_gamma, reduction="mean") if loss_type == "weighted_ce": return nn.CrossEntropyLoss(weight=class_weights, label_smoothing=label_smoothing) if loss_type == "ce": return nn.CrossEntropyLoss(label_smoothing=label_smoothing) raise ValueError(f"Unsupported loss_type: {loss_type}") # Training # - epochs 90 # - freeze_backbone_epochs # - warmup_ep = freeze_backbone_epochs # - 验证循环使用 TTA def train_one_fold( model_name: str, model: nn.Module, train_loader, val_loader, epochs: int = 90, num_classes: int = 4, backbone_lr: float = 3e-5, head_lr: float = 1e-3, class_weights: Optional[torch.Tensor] = None, fold_id: int = 1, save_dir: Optional[Path] = None, freeze_backbone_epochs: int = 8, max_grad_norm: float = 1.0, primary_metric: str = PRIMARY_METRIC, loss_type: str = "weighted_ce", focal_gamma: float = 2.0, label_smoothing: float = 0.0, use_tta: bool = True, ): if save_dir is None: save_dir = Path(model_name) else: save_dir = Path(save_dir) ensure_dir(save_dir) criterion = build_criterion( loss_type=loss_type, class_weights=class_weights, focal_gamma=focal_gamma, label_smoothing=label_smoothing, ) print( f"Fold {fold_id}: Using loss_type='{loss_type}'" f"{' with class weights' if class_weights is not None else ''}." ) print( f"Fold {fold_id}: backbone_lr={backbone_lr}, head_lr={head_lr}, " f"freeze_backbone_epochs={freeze_backbone_epochs}, " f"epochs={epochs}, use_tta={use_tta}." ) param_groups = get_parameter_groups(model_name, model, backbone_lr, head_lr) optimizer = torch.optim.AdamW(param_groups, betas=(0.9, 0.999), weight_decay=5e-4) warmup_ep = freeze_backbone_epochs sched_main = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=max(1, epochs - warmup_ep), eta_min=1e-7, ) sched_warm = torch.optim.lr_scheduler.LinearLR( optimizer, start_factor=0.1, end_factor=1.0, total_iters=warmup_ep, ) scheduler = torch.optim.lr_scheduler.SequentialLR( optimizer, schedulers=[sched_warm, sched_main], milestones=[warmup_ep], ) amp_enabled = device.type == "cuda" scaler = torch.cuda.amp.GradScaler(enabled=amp_enabled) class_names = [f"class{i}" for i in range(num_classes)] best_monitor = -float("inf") best_results = None was_backbone_trainable = None start_epoch = 0 ckpt_path = save_dir / f"fold{fold_id}_checkpoint.pth" if ckpt_path.is_file(): print(f"Fold {fold_id}: found epoch-level checkpoint, attempting to resume...") try: ckpt = torch.load(ckpt_path, map_location=device, weights_only=False) model.load_state_dict(ckpt["model_state_dict"]) optimizer.load_state_dict(ckpt["optimizer_state_dict"]) scheduler.load_state_dict(ckpt["scheduler_state_dict"]) scaler.load_state_dict(ckpt["scaler_state_dict"]) start_epoch = ckpt["epoch"] + 1 best_monitor = ckpt["best_monitor"] best_results = ckpt.get("best_results", None) print( f"Fold {fold_id}: resumed from epoch {start_epoch}/{epochs} " f"(best {primary_metric}={best_monitor:.2f})." ) except Exception as exc: print(f"Fold {fold_id}: failed to load checkpoint ({exc}), training from scratch.") start_epoch = 0 best_monitor = -float("inf") best_results = None for epoch in tqdm( range(start_epoch, epochs), desc=f"Fold {fold_id}", leave=False, initial=start_epoch, total=epochs, ): train_backbone = epoch >= freeze_backbone_epochs if was_backbone_trainable is None or was_backbone_trainable != train_backbone: set_backbone_trainable(model_name, model, train_backbone=train_backbone) stage = "unfrozen" if train_backbone else "frozen" print(f"Fold {fold_id}: backbone is now {stage} (epoch {epoch + 1}).") was_backbone_trainable = train_backbone model.train() if not train_backbone: set_frozen_backbone_bn_eval(model_name, model) run_loss = 0.0 for inputs, labels, _meta in train_loader: inputs = inputs.to(device, non_blocking=(device.type == "cuda")) labels = labels.to(device, non_blocking=(device.type == "cuda")) optimizer.zero_grad(set_to_none=True) amp_ctx = torch.cuda.amp.autocast if amp_enabled else nullcontext with amp_ctx(): logits, loss = forward_with_loss(model, inputs, labels, criterion, aux_weight=0.3) scaler.scale(loss).backward() if max_grad_norm is not None and max_grad_norm > 0: scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), max_grad_norm) scaler.step(optimizer) scaler.update() run_loss += loss.item() * inputs.size(0) scheduler.step() ep_loss = run_loss / len(train_loader.dataset) # ---- 验证阶段:可选 TTA ---- model.eval() all_t, all_p, all_paths, all_probs = [], [], [], [] all_patients, all_image_names = [], [] with torch.no_grad(): for inputs, labels, meta in val_loader: inputs = inputs.to(device, non_blocking=(device.type == "cuda")) labels = labels.to(device, non_blocking=(device.type == "cuda")) if use_tta: # V7: 使用 TTA 推断 probs = predict_with_tta(model, inputs, amp_enabled=amp_enabled) else: amp_ctx = torch.cuda.amp.autocast if amp_enabled else nullcontext with amp_ctx(): output = model(inputs) logits = _extract_logits(output) probs = torch.softmax(logits, dim=1) pred = probs.argmax(dim=1) all_t.extend(labels.cpu().numpy().tolist()) all_p.extend(pred.cpu().numpy().tolist()) all_probs.extend(probs.cpu().numpy().tolist()) all_paths.extend(list(meta["path"])) all_patients.extend(list(meta["patient"])) all_image_names.extend(list(meta["image_name"])) metrics, report = compute_metrics(all_t, all_p, num_classes, class_names) monitor = metrics[primary_metric] if (epoch + 1) % 5 == 0 or epoch == epochs - 1 or epoch == start_epoch: print( f"F{fold_id} E{epoch + 1}/{epochs} " f"Loss={ep_loss:.4f} " f"Macro-F1={metrics['macro_f1']:.2f}% " f"BA={metrics['balanced_accuracy']:.2f}% " f"Acc={metrics['accuracy']:.2f}%" f"{' [TTA]' if use_tta else ''}" ) improved = (monitor > best_monitor) or ( np.isclose(monitor, best_monitor) and best_results is not None and metrics["balanced_accuracy"] > best_results["metrics"]["balanced_accuracy"] ) if improved: best_monitor = monitor best_results = { "best_epoch": epoch + 1, "metrics": metrics, "classification_report": report, "predictions": all_p, "targets": all_t, "image_path": all_paths, "patients": all_patients, "image_names": all_image_names, "probabilities": all_probs, "num_classes": num_classes, "per_class": [ report.get(f"class{i}", {"precision": 0, "recall": 0, "f1-score": 0}) for i in range(num_classes) ], } save_fold_results(best_results, save_dir, tag=f"fold{fold_id}_best") torch.save(model.state_dict(), save_dir / f"fold{fold_id}_best.pth") torch.save({ "epoch": epoch, "model_state_dict": model.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "scheduler_state_dict": scheduler.state_dict(), "scaler_state_dict": scaler.state_dict(), "best_monitor": best_monitor, "best_results": best_results, }, ckpt_path) if ckpt_path.is_file(): ckpt_path.unlink() print(f"Fold {fold_id}: removed epoch-level checkpoint (training complete).") if best_results is None: raise RuntimeError(f"Fold {fold_id}: no valid result was produced.") return best_results def build_model_registry(): reg = {} reg["DenseNet161"] = lambda: models.densenet161(weights=models.DenseNet161_Weights.DEFAULT) reg["ConvNeXt_Tiny"] = lambda: models.convnext_tiny(weights=models.ConvNeXt_Tiny_Weights.DEFAULT) reg["ViT_B_16"] = lambda: models.vit_b_16(weights=models.ViT_B_16_Weights.DEFAULT) if HAS_TIMM: reg["SwinV2_T"] = lambda: timm.create_model( "swinv2_tiny_window8_256", pretrained=True, img_size=512, ) reg["DeiT3_S"] = lambda: timm.create_model( "deit3_small_patch16_224", pretrained=True, img_size=512, ) else: print("Skipping timm models because timm is not installed.") return reg def parse_args(): parser = argparse.ArgumentParser(description="ROP benchmark training with patient-grouped 5-fold CV (v7).") boolean_action = getattr(argparse, "BooleanOptionalAction", None) parser.add_argument( "--excel_path", type=str, default="/media/fang/9fc99a7b-15d6-4e22-ab05-fe46e6058c39/felicia/Downloads/医生审核之后第一版12-17/部分公开数据集/公开数据集训练表_调整数据1.xlsx", help="Path to Excel with at least columns: patient, path, label.", ) parser.add_argument("--group_col", type=str, default="patient", help="Grouping column for leakage-free split.") parser.add_argument("--num_classes", type=int, default=4) # V7: epochs 90 parser.add_argument("--epochs", type=int, default=90) parser.add_argument("--n_folds", type=int, default=5) parser.add_argument("--batch_size", type=int, default=8) # V7: backbone_lr 提升至 3e-5 parser.add_argument("--backbone_lr", type=float, default=3e-5) # V7: head_lr 提升至 1e-3 parser.add_argument("--head_lr", type=float, default=1e-3) parser.add_argument("--random_seed", type=int, default=42) parser.add_argument("--num_workers", type=int, default=min(8, os.cpu_count() or 2)) parser.add_argument( "--balance_mode", type=str, default="loss", choices=["none", "loss", "sampler"], help="Imbalance handling. 'loss' computes class weights; 'sampler' uses WeightedRandomSampler.", ) parser.add_argument( "--loss_type", type=str, default="weighted_ce", choices=["weighted_ce", "focal", "ce"], help="weighted_ce is the recommended default.", ) parser.add_argument("--focal_gamma", type=float, default=2.0) parser.add_argument("--label_smoothing", type=float, default=0.0) # V7: freeze_backbone_epochs 提升至 8 parser.add_argument("--freeze_backbone_epochs", type=int, default=8) parser.add_argument("--max_grad_norm", type=float, default=1.0) parser.add_argument("--output_root", type=str, default="runs_rop_V7_old") if boolean_action is not None: parser.add_argument("--use_tta", action=boolean_action, default=True, help="Enable 4-way TTA (flip) during validation.") parser.add_argument("--deterministic", action=boolean_action, default=False) else: parser.add_argument("--use_tta", action="store_true", default=True, help="Enable 4-way TTA (flip) during validation.") parser.add_argument("--no_tta", dest="use_tta", action="store_false") parser.add_argument("--deterministic", action="store_true", default=False) parser.add_argument( "--models", nargs="*", default=None, help="Optional subset of model names to train (e.g. --models DenseNet161 ViT_B_16).", ) return parser.parse_args() def main(): args = parse_args() seed_everything(args.random_seed, deterministic=args.deterministic) print("\nLoading data...") df = load_and_prepare_data(args.excel_path, group_col=args.group_col) observed_num_classes = int(df["label"].nunique()) if observed_num_classes != args.num_classes: raise ValueError( f"num_classes mismatch: args.num_classes={args.num_classes}, " f"but observed labels in Excel imply {observed_num_classes} classes after remapping." ) fold_splits = build_fold_splits( df=df, n_folds=args.n_folds, random_seed=args.random_seed, group_col=args.group_col, ) model_registry = build_model_registry() if args.models: selected = {k: v for k, v in model_registry.items() if k in set(args.models)} missing = [m for m in args.models if m not in model_registry] if missing: print(f"Warning: these models were not found and will be ignored: {missing}") model_registry = selected print(f"\nTotal models to train: {len(model_registry)}") for i, name in enumerate(model_registry, 1): print(f"{i:2d}. {name}") output_root = Path(args.output_root) ensure_dir(output_root) global_results = {} for model_idx, (model_name, model_fn) in enumerate(model_registry.items(), 1): print("\n" + "=" * 70) print(f"[{model_idx}/{len(model_registry)}] Model: {model_name}") print("=" * 70) model_dir = output_root / model_name ensure_dir(model_dir) summary_path = model_dir / "kfold_summary.json" if summary_path.is_file(): try: with open(summary_path, "r", encoding="utf-8") as f: old = json.load(f) old_summary = old.get("summary", {}) if old_summary: mean_primary = old_summary[PRIMARY_METRIC]["mean"] std_primary = old_summary[PRIMARY_METRIC]["std"] print( f"[Skip] Found existing {args.n_folds}-fold summary: " f"{PRIMARY_METRIC}={mean_primary:.2f}% +/- {std_primary:.2f}%" ) global_results[model_name] = (mean_primary, std_primary) continue except Exception: pass input_size = get_model_input_size(model_name) print(f"Input size: {input_size}x{input_size}") fold_results = [] for fold_idx in range(args.n_folds): fold_id = fold_idx + 1 print(f"\n-- Fold {fold_id}/{args.n_folds} --") metrics_json = model_dir / f"fold{fold_id}_best_metrics.json" weight_path = model_dir / f"fold{fold_id}_best.pth" if metrics_json.is_file() and weight_path.is_file(): try: with open(metrics_json, "r", encoding="utf-8") as f: cached = json.load(f) fold_results.append({ "best_epoch": cached["best_epoch"], "metrics": cached["metrics"], "per_class": cached["per_class"], }) print( f"Fold {fold_id}: cached result found " f"(Macro-F1={cached['metrics']['macro_f1']:.2f}%, " f"BA={cached['metrics']['balanced_accuracy']:.2f}%), skipped." ) continue except Exception: pass train_idx, val_idx = fold_splits[fold_idx] train_df = df.iloc[train_idx].reset_index(drop=True) val_df = df.iloc[val_idx].reset_index(drop=True) train_patients = set(train_df[args.group_col].astype(str).tolist()) val_patients = set(val_df[args.group_col].astype(str).tolist()) overlap = train_patients & val_patients if overlap: raise RuntimeError( f"Leakage detected in fold {fold_id}: {len(overlap)} overlapping patients/groups." ) print(f"Train: {len(train_df)} | Validation: {len(val_df)}") print( f"Train patients: {train_df[args.group_col].nunique()} | " f"Validation patients: {val_df[args.group_col].nunique()}" ) print( f"Train class dist: {dict(train_df['label'].value_counts().sort_index())} | " f"Val class dist: {dict(val_df['label'].value_counts().sort_index())}" ) train_loader, val_loader, class_weights = create_fold_loaders( train_df=train_df, val_df=val_df, input_size=input_size, batch_size=args.batch_size, num_classes=args.num_classes, balance_mode=args.balance_mode, num_workers=args.num_workers, ) try: model = model_fn() except Exception as exc: print(f"Model creation failed for {model_name}: {exc}") break model = replace_classifier(model_name, model, args.num_classes) model = patch_vit_for_large_input(model, model_name, input_size) model = configure_small_batch_behavior(model_name, model, args.batch_size) model = model.to(device) dummy = torch.randn(1, 3, input_size, input_size, device=device) model.eval() with torch.no_grad(): out = model(dummy) out = _extract_logits(out) out_dim = out.shape[-1] if out_dim != args.num_classes: raise RuntimeError( f"Fatal: classifier replacement failed for {model_name}. " f"Output dim={out_dim}, expected={args.num_classes}." ) print(f"Forward sanity check passed: output dim={out_dim}") del dummy, out if device.type == "cuda": torch.cuda.empty_cache() result = train_one_fold( model_name=model_name, model=model, train_loader=train_loader, val_loader=val_loader, epochs=args.epochs, num_classes=args.num_classes, backbone_lr=args.backbone_lr, head_lr=args.head_lr, class_weights=class_weights, fold_id=fold_id, save_dir=model_dir, freeze_backbone_epochs=args.freeze_backbone_epochs, max_grad_norm=args.max_grad_norm, primary_metric=PRIMARY_METRIC, loss_type=args.loss_type, focal_gamma=args.focal_gamma, label_smoothing=args.label_smoothing, use_tta=args.use_tta, ) fold_results.append(result) del model if device.type == "cuda": torch.cuda.empty_cache() if len(fold_results) == args.n_folds: mean_primary, std_primary = save_kfold_summary( model_name, fold_results, args.num_classes, model_dir, ) global_results[model_name] = (mean_primary, std_primary) else: print(f"Warning: {model_name} completed only {len(fold_results)}/{args.n_folds} folds.") print("\n" + "=" * 70) print(f"Global leaderboard ({args.n_folds}-Fold CV)") print(f"Sorted by: {PRIMARY_METRIC}") print("=" * 70) sorted_results = sorted(global_results.items(), key=lambda x: x[1][0], reverse=True) print(f"{'Rank':<6} {'Model':<25} {PRIMARY_METRIC:>12} {'Std':>10}") print("-" * 62) for rank, (name, (mean_primary, std_primary)) in enumerate(sorted_results, 1): print(f"{rank:<6} {name:<25} {mean_primary:>11.2f}% {std_primary:>9.2f}%") leaderboard_path = output_root / f"global_leaderboard_{PRIMARY_METRIC}.csv" pd.DataFrame([ { "rank": idx + 1, "model": name, f"mean_{PRIMARY_METRIC}": mean_primary, f"std_{PRIMARY_METRIC}": std_primary, } for idx, (name, (mean_primary, std_primary)) in enumerate(sorted_results) ]).to_csv(leaderboard_path, index=False) print(f"\nLeaderboard saved to: {leaderboard_path}") if __name__ == "__main__": main()