Spaces:
Sleeping
Sleeping
| import argparse | |
| import json | |
| from pathlib import Path | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| from PIL import Image | |
| from augmentations import get_val_transforms | |
| from model import DeepSeeNet | |
| N_CLASSES = { | |
| "DRUS": 3, | |
| "PIG": 2, | |
| } | |
| class AlbumentationsTransform: | |
| def __init__(self, transform): | |
| self.transform = transform | |
| def __call__(self, image): | |
| return self.transform(image=np.asarray(image))["image"] | |
| class FeatureExtractor(nn.Module): | |
| """ | |
| Wraps a classifier and captures the input to the final Linear layer. | |
| This is intended to recover the penultimate feature vector used before | |
| the classification head. For the paper-faithful setup, we extract: | |
| DRUS left/right features | |
| PIG left/right features | |
| """ | |
| def __init__(self, model: nn.Module): | |
| super().__init__() | |
| self.model = model | |
| self.features = None | |
| final_linear = self._find_last_linear(model) | |
| final_linear.register_forward_pre_hook(self._hook) | |
| def _find_last_linear(model: nn.Module) -> nn.Linear: | |
| last_linear = None | |
| for module in model.modules(): | |
| if isinstance(module, nn.Linear): | |
| last_linear = module | |
| if last_linear is None: | |
| raise RuntimeError("Could not find a final nn.Linear layer in the model.") | |
| return last_linear | |
| def _hook(self, module, inputs): | |
| x = inputs[0] | |
| if isinstance(x, (tuple, list)): | |
| x = x[0] | |
| if x.ndim > 2: | |
| x = torch.flatten(x, start_dim=1) | |
| self.features = x.detach() | |
| def forward(self, x): | |
| self.features = None | |
| _ = self.model(x) | |
| if self.features is None: | |
| raise RuntimeError("Feature hook did not capture any features.") | |
| return self.features | |
| def parse_args(): | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument( | |
| "--json", | |
| default="data/AREDS1_all_survival_small_Status_late_amd_20190601.json", | |
| help="Reference JSON containing PATID, LE_PATHNAME, and RE_PATHNAME.", | |
| ) | |
| parser.add_argument( | |
| "--image-root", | |
| required=True, | |
| help="Root directory containing AREDS images.", | |
| ) | |
| parser.add_argument( | |
| "--drusen-weights", | |
| default="deepseenet/weights/drus.pt", | |
| ) | |
| parser.add_argument( | |
| "--pigment-weights", | |
| default="deepseenet/weights/pig.pt", | |
| ) | |
| parser.add_argument( | |
| "--output", | |
| default="data/areds1_deepseenet_features.npz", | |
| ) | |
| parser.add_argument( | |
| "--backbone", | |
| default="inception_v3", | |
| ) | |
| parser.add_argument( | |
| "--image-size", | |
| type=int, | |
| default=1024, | |
| ) | |
| parser.add_argument( | |
| "--batch-size", | |
| type=int, | |
| default=16, | |
| ) | |
| parser.add_argument( | |
| "--num-workers", | |
| type=int, | |
| default=4, | |
| ) | |
| parser.add_argument( | |
| "--on-missing", | |
| choices=["error", "skip"], | |
| default="error", | |
| help="Whether to error or skip patients with missing LE/RE images.", | |
| ) | |
| return parser.parse_args() | |
| def load_json(path): | |
| with open(path, "r") as f: | |
| data = json.load(f) | |
| if not isinstance(data, list): | |
| raise ValueError(f"Expected JSON list, got {type(data)}") | |
| return data | |
| def load_model(path, task, backbone, device): | |
| checkpoint = torch.load(path, map_location=device) | |
| checkpoint_args = checkpoint.get("args", {}) | |
| model = DeepSeeNet( | |
| n_classes=N_CLASSES[task], | |
| backbone=checkpoint_args.get("backbone", backbone), | |
| pretrained=False, | |
| ).to(device) | |
| state_dict = checkpoint["model"] if isinstance(checkpoint, dict) and "model" in checkpoint else checkpoint | |
| model.load_state_dict(state_dict) | |
| model.eval() | |
| return FeatureExtractor(model).to(device).eval() | |
| def resolve_image_path(image_root, rel_path): | |
| rel_path = str(rel_path) | |
| return Path(image_root) / rel_path | |
| def load_image(path, transform): | |
| image = Image.open(path).convert("RGB") | |
| return transform(image) | |
| class AREDSPatientImageDataset(torch.utils.data.Dataset): | |
| def __init__(self, rows, image_root, transform, on_missing="error"): | |
| self.rows = [] | |
| self.image_root = Path(image_root) | |
| self.transform = transform | |
| self.on_missing = on_missing | |
| for row in rows: | |
| patid = row["PATID"] | |
| le_path = resolve_image_path(self.image_root, row["LE_PATHNAME"]) | |
| re_path = resolve_image_path(self.image_root, row["RE_PATHNAME"]) | |
| le_exists = le_path.exists() | |
| re_exists = re_path.exists() | |
| if not le_exists or not re_exists: | |
| msg = ( | |
| f"Missing image for PATID={patid}: " | |
| f"LE exists={le_exists} ({le_path}), " | |
| f"RE exists={re_exists} ({re_path})" | |
| ) | |
| if on_missing == "error": | |
| raise FileNotFoundError(msg) | |
| print(f"[skip] {msg}") | |
| continue | |
| self.rows.append( | |
| { | |
| "PATID": patid, | |
| "LE_PATHNAME": str(row["LE_PATHNAME"]), | |
| "RE_PATHNAME": str(row["RE_PATHNAME"]), | |
| "le_path": le_path, | |
| "re_path": re_path, | |
| } | |
| ) | |
| def __len__(self): | |
| return len(self.rows) | |
| def __getitem__(self, idx): | |
| row = self.rows[idx] | |
| le_img = load_image(row["le_path"], self.transform) | |
| re_img = load_image(row["re_path"], self.transform) | |
| return { | |
| "patid": int(row["PATID"]), | |
| "le_image": le_img, | |
| "re_image": re_img, | |
| "le_pathname": row["LE_PATHNAME"], | |
| "re_pathname": row["RE_PATHNAME"], | |
| } | |
| def collate_fn(batch): | |
| return { | |
| "patids": np.array([x["patid"] for x in batch]), | |
| "le_images": torch.stack([x["le_image"] for x in batch], dim=0), | |
| "re_images": torch.stack([x["re_image"] for x in batch], dim=0), | |
| "le_pathnames": np.array([x["le_pathname"] for x in batch]), | |
| "re_pathnames": np.array([x["re_pathname"] for x in batch]), | |
| } | |
| def make_feature_names(feature_dim): | |
| names = [] | |
| for prefix in ["LE_DRUS", "RE_DRUS", "LE_PIG", "RE_PIG"]: | |
| for i in range(feature_dim): | |
| names.append(f"{prefix}_{i:03d}") | |
| return np.array(names) | |
| def extract_features(loader, drus_model, pig_model, device): | |
| all_features = [] | |
| all_patids = [] | |
| all_le_pathnames = [] | |
| all_re_pathnames = [] | |
| try: | |
| from tqdm import tqdm | |
| iterator = tqdm(loader, desc="Extracting DeepSeeNet features") | |
| except ImportError: | |
| iterator = loader | |
| feature_dim = None | |
| for batch in iterator: | |
| le = batch["le_images"].to(device, non_blocking=True) | |
| re = batch["re_images"].to(device, non_blocking=True) | |
| le_drus = drus_model(le).detach().cpu().numpy() | |
| re_drus = drus_model(re).detach().cpu().numpy() | |
| le_pig = pig_model(le).detach().cpu().numpy() | |
| re_pig = pig_model(re).detach().cpu().numpy() | |
| if feature_dim is None: | |
| feature_dim = le_drus.shape[1] | |
| print(f"Detected feature dimension per model/eye: {feature_dim}") | |
| patient_features = np.concatenate( | |
| [ | |
| le_drus, | |
| re_drus, | |
| le_pig, | |
| re_pig, | |
| ], | |
| axis=1, | |
| ) | |
| all_features.append(patient_features) | |
| all_patids.append(batch["patids"]) | |
| all_le_pathnames.append(batch["le_pathnames"]) | |
| all_re_pathnames.append(batch["re_pathnames"]) | |
| features = np.concatenate(all_features, axis=0) | |
| patids = np.concatenate(all_patids, axis=0) | |
| le_pathnames = np.concatenate(all_le_pathnames, axis=0) | |
| re_pathnames = np.concatenate(all_re_pathnames, axis=0) | |
| feature_names = make_feature_names(feature_dim) | |
| return features, patids, le_pathnames, re_pathnames, feature_names | |
| def main(): | |
| args = parse_args() | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| print(f"Device: {device}") | |
| rows = load_json(args.json) | |
| print(f"Loaded JSON rows: {len(rows)}") | |
| transform = AlbumentationsTransform(get_val_transforms(args.image_size)) | |
| dataset = AREDSPatientImageDataset( | |
| rows=rows, | |
| image_root=args.image_root, | |
| transform=transform, | |
| on_missing=args.on_missing, | |
| ) | |
| print(f"Usable patients: {len(dataset)}") | |
| loader = torch.utils.data.DataLoader( | |
| dataset, | |
| batch_size=args.batch_size, | |
| shuffle=False, | |
| num_workers=args.num_workers, | |
| pin_memory=(device.type == "cuda"), | |
| collate_fn=collate_fn, | |
| ) | |
| drus_model = load_model( | |
| args.drusen_weights, | |
| task="DRUS", | |
| backbone=args.backbone, | |
| device=device, | |
| ) | |
| pig_model = load_model( | |
| args.pigment_weights, | |
| task="PIG", | |
| backbone=args.backbone, | |
| device=device, | |
| ) | |
| features, patids, le_pathnames, re_pathnames, feature_names = extract_features( | |
| loader=loader, | |
| drus_model=drus_model, | |
| pig_model=pig_model, | |
| device=device, | |
| ) | |
| print(f"Final feature matrix: {features.shape}") | |
| if features.shape[1] != 512: | |
| print( | |
| "[warning] Expected paper-faithful feature dimension of 512 " | |
| f"but got {features.shape[1]}. This likely means each model's " | |
| f"penultimate feature dimension is {features.shape[1] // 4}, not 128." | |
| ) | |
| output = Path(args.output) | |
| output.parent.mkdir(parents=True, exist_ok=True) | |
| np.savez_compressed( | |
| output, | |
| features=features.astype(np.float32), | |
| patids=patids, | |
| le_pathnames=le_pathnames, | |
| re_pathnames=re_pathnames, | |
| feature_names=feature_names, | |
| ) | |
| print(f"Saved: {output}") | |
| if __name__ == "__main__": | |
| main() |