import argparse import json from pathlib import Path import numpy as np import torch import torch.nn as nn from PIL import Image from augmentations import get_val_transforms from model import DeepSeeNet N_CLASSES = { "DRUS": 3, "PIG": 2, } class AlbumentationsTransform: def __init__(self, transform): self.transform = transform def __call__(self, image): return self.transform(image=np.asarray(image))["image"] class FeatureExtractor(nn.Module): """ Wraps a classifier and captures the input to the final Linear layer. This is intended to recover the penultimate feature vector used before the classification head. For the paper-faithful setup, we extract: DRUS left/right features PIG left/right features """ def __init__(self, model: nn.Module): super().__init__() self.model = model self.features = None final_linear = self._find_last_linear(model) final_linear.register_forward_pre_hook(self._hook) @staticmethod def _find_last_linear(model: nn.Module) -> nn.Linear: last_linear = None for module in model.modules(): if isinstance(module, nn.Linear): last_linear = module if last_linear is None: raise RuntimeError("Could not find a final nn.Linear layer in the model.") return last_linear def _hook(self, module, inputs): x = inputs[0] if isinstance(x, (tuple, list)): x = x[0] if x.ndim > 2: x = torch.flatten(x, start_dim=1) self.features = x.detach() def forward(self, x): self.features = None _ = self.model(x) if self.features is None: raise RuntimeError("Feature hook did not capture any features.") return self.features def parse_args(): parser = argparse.ArgumentParser() parser.add_argument( "--json", default="data/AREDS1_all_survival_small_Status_late_amd_20190601.json", help="Reference JSON containing PATID, LE_PATHNAME, and RE_PATHNAME.", ) parser.add_argument( "--image-root", required=True, help="Root directory containing AREDS images.", ) parser.add_argument( "--drusen-weights", default="deepseenet/weights/drus.pt", ) parser.add_argument( "--pigment-weights", default="deepseenet/weights/pig.pt", ) parser.add_argument( "--output", default="data/areds1_deepseenet_features.npz", ) parser.add_argument( "--backbone", default="inception_v3", ) parser.add_argument( "--image-size", type=int, default=1024, ) parser.add_argument( "--batch-size", type=int, default=16, ) parser.add_argument( "--num-workers", type=int, default=4, ) parser.add_argument( "--on-missing", choices=["error", "skip"], default="error", help="Whether to error or skip patients with missing LE/RE images.", ) return parser.parse_args() def load_json(path): with open(path, "r") as f: data = json.load(f) if not isinstance(data, list): raise ValueError(f"Expected JSON list, got {type(data)}") return data def load_model(path, task, backbone, device): checkpoint = torch.load(path, map_location=device) checkpoint_args = checkpoint.get("args", {}) model = DeepSeeNet( n_classes=N_CLASSES[task], backbone=checkpoint_args.get("backbone", backbone), pretrained=False, ).to(device) state_dict = checkpoint["model"] if isinstance(checkpoint, dict) and "model" in checkpoint else checkpoint model.load_state_dict(state_dict) model.eval() return FeatureExtractor(model).to(device).eval() def resolve_image_path(image_root, rel_path): rel_path = str(rel_path) return Path(image_root) / rel_path def load_image(path, transform): image = Image.open(path).convert("RGB") return transform(image) class AREDSPatientImageDataset(torch.utils.data.Dataset): def __init__(self, rows, image_root, transform, on_missing="error"): self.rows = [] self.image_root = Path(image_root) self.transform = transform self.on_missing = on_missing for row in rows: patid = row["PATID"] le_path = resolve_image_path(self.image_root, row["LE_PATHNAME"]) re_path = resolve_image_path(self.image_root, row["RE_PATHNAME"]) le_exists = le_path.exists() re_exists = re_path.exists() if not le_exists or not re_exists: msg = ( f"Missing image for PATID={patid}: " f"LE exists={le_exists} ({le_path}), " f"RE exists={re_exists} ({re_path})" ) if on_missing == "error": raise FileNotFoundError(msg) print(f"[skip] {msg}") continue self.rows.append( { "PATID": patid, "LE_PATHNAME": str(row["LE_PATHNAME"]), "RE_PATHNAME": str(row["RE_PATHNAME"]), "le_path": le_path, "re_path": re_path, } ) def __len__(self): return len(self.rows) def __getitem__(self, idx): row = self.rows[idx] le_img = load_image(row["le_path"], self.transform) re_img = load_image(row["re_path"], self.transform) return { "patid": int(row["PATID"]), "le_image": le_img, "re_image": re_img, "le_pathname": row["LE_PATHNAME"], "re_pathname": row["RE_PATHNAME"], } def collate_fn(batch): return { "patids": np.array([x["patid"] for x in batch]), "le_images": torch.stack([x["le_image"] for x in batch], dim=0), "re_images": torch.stack([x["re_image"] for x in batch], dim=0), "le_pathnames": np.array([x["le_pathname"] for x in batch]), "re_pathnames": np.array([x["re_pathname"] for x in batch]), } def make_feature_names(feature_dim): names = [] for prefix in ["LE_DRUS", "RE_DRUS", "LE_PIG", "RE_PIG"]: for i in range(feature_dim): names.append(f"{prefix}_{i:03d}") return np.array(names) @torch.no_grad() def extract_features(loader, drus_model, pig_model, device): all_features = [] all_patids = [] all_le_pathnames = [] all_re_pathnames = [] try: from tqdm import tqdm iterator = tqdm(loader, desc="Extracting DeepSeeNet features") except ImportError: iterator = loader feature_dim = None for batch in iterator: le = batch["le_images"].to(device, non_blocking=True) re = batch["re_images"].to(device, non_blocking=True) le_drus = drus_model(le).detach().cpu().numpy() re_drus = drus_model(re).detach().cpu().numpy() le_pig = pig_model(le).detach().cpu().numpy() re_pig = pig_model(re).detach().cpu().numpy() if feature_dim is None: feature_dim = le_drus.shape[1] print(f"Detected feature dimension per model/eye: {feature_dim}") patient_features = np.concatenate( [ le_drus, re_drus, le_pig, re_pig, ], axis=1, ) all_features.append(patient_features) all_patids.append(batch["patids"]) all_le_pathnames.append(batch["le_pathnames"]) all_re_pathnames.append(batch["re_pathnames"]) features = np.concatenate(all_features, axis=0) patids = np.concatenate(all_patids, axis=0) le_pathnames = np.concatenate(all_le_pathnames, axis=0) re_pathnames = np.concatenate(all_re_pathnames, axis=0) feature_names = make_feature_names(feature_dim) return features, patids, le_pathnames, re_pathnames, feature_names def main(): args = parse_args() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Device: {device}") rows = load_json(args.json) print(f"Loaded JSON rows: {len(rows)}") transform = AlbumentationsTransform(get_val_transforms(args.image_size)) dataset = AREDSPatientImageDataset( rows=rows, image_root=args.image_root, transform=transform, on_missing=args.on_missing, ) print(f"Usable patients: {len(dataset)}") loader = torch.utils.data.DataLoader( dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=(device.type == "cuda"), collate_fn=collate_fn, ) drus_model = load_model( args.drusen_weights, task="DRUS", backbone=args.backbone, device=device, ) pig_model = load_model( args.pigment_weights, task="PIG", backbone=args.backbone, device=device, ) features, patids, le_pathnames, re_pathnames, feature_names = extract_features( loader=loader, drus_model=drus_model, pig_model=pig_model, device=device, ) print(f"Final feature matrix: {features.shape}") if features.shape[1] != 512: print( "[warning] Expected paper-faithful feature dimension of 512 " f"but got {features.shape[1]}. This likely means each model's " f"penultimate feature dimension is {features.shape[1] // 4}, not 128." ) output = Path(args.output) output.parent.mkdir(parents=True, exist_ok=True) np.savez_compressed( output, features=features.astype(np.float32), patids=patids, le_pathnames=le_pathnames, re_pathnames=re_pathnames, feature_names=feature_names, ) print(f"Saved: {output}") if __name__ == "__main__": main()