AMDRisk / deepseenet /extract_features.py
Hou
add src
a7c73c5
Raw
History Blame Contribute Delete
10.1 kB
import argparse
import json
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
from PIL import Image
from augmentations import get_val_transforms
from model import DeepSeeNet
N_CLASSES = {
"DRUS": 3,
"PIG": 2,
}
class AlbumentationsTransform:
def __init__(self, transform):
self.transform = transform
def __call__(self, image):
return self.transform(image=np.asarray(image))["image"]
class FeatureExtractor(nn.Module):
"""
Wraps a classifier and captures the input to the final Linear layer.
This is intended to recover the penultimate feature vector used before
the classification head. For the paper-faithful setup, we extract:
DRUS left/right features
PIG left/right features
"""
def __init__(self, model: nn.Module):
super().__init__()
self.model = model
self.features = None
final_linear = self._find_last_linear(model)
final_linear.register_forward_pre_hook(self._hook)
@staticmethod
def _find_last_linear(model: nn.Module) -> nn.Linear:
last_linear = None
for module in model.modules():
if isinstance(module, nn.Linear):
last_linear = module
if last_linear is None:
raise RuntimeError("Could not find a final nn.Linear layer in the model.")
return last_linear
def _hook(self, module, inputs):
x = inputs[0]
if isinstance(x, (tuple, list)):
x = x[0]
if x.ndim > 2:
x = torch.flatten(x, start_dim=1)
self.features = x.detach()
def forward(self, x):
self.features = None
_ = self.model(x)
if self.features is None:
raise RuntimeError("Feature hook did not capture any features.")
return self.features
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--json",
default="data/AREDS1_all_survival_small_Status_late_amd_20190601.json",
help="Reference JSON containing PATID, LE_PATHNAME, and RE_PATHNAME.",
)
parser.add_argument(
"--image-root",
required=True,
help="Root directory containing AREDS images.",
)
parser.add_argument(
"--drusen-weights",
default="deepseenet/weights/drus.pt",
)
parser.add_argument(
"--pigment-weights",
default="deepseenet/weights/pig.pt",
)
parser.add_argument(
"--output",
default="data/areds1_deepseenet_features.npz",
)
parser.add_argument(
"--backbone",
default="inception_v3",
)
parser.add_argument(
"--image-size",
type=int,
default=1024,
)
parser.add_argument(
"--batch-size",
type=int,
default=16,
)
parser.add_argument(
"--num-workers",
type=int,
default=4,
)
parser.add_argument(
"--on-missing",
choices=["error", "skip"],
default="error",
help="Whether to error or skip patients with missing LE/RE images.",
)
return parser.parse_args()
def load_json(path):
with open(path, "r") as f:
data = json.load(f)
if not isinstance(data, list):
raise ValueError(f"Expected JSON list, got {type(data)}")
return data
def load_model(path, task, backbone, device):
checkpoint = torch.load(path, map_location=device)
checkpoint_args = checkpoint.get("args", {})
model = DeepSeeNet(
n_classes=N_CLASSES[task],
backbone=checkpoint_args.get("backbone", backbone),
pretrained=False,
).to(device)
state_dict = checkpoint["model"] if isinstance(checkpoint, dict) and "model" in checkpoint else checkpoint
model.load_state_dict(state_dict)
model.eval()
return FeatureExtractor(model).to(device).eval()
def resolve_image_path(image_root, rel_path):
rel_path = str(rel_path)
return Path(image_root) / rel_path
def load_image(path, transform):
image = Image.open(path).convert("RGB")
return transform(image)
class AREDSPatientImageDataset(torch.utils.data.Dataset):
def __init__(self, rows, image_root, transform, on_missing="error"):
self.rows = []
self.image_root = Path(image_root)
self.transform = transform
self.on_missing = on_missing
for row in rows:
patid = row["PATID"]
le_path = resolve_image_path(self.image_root, row["LE_PATHNAME"])
re_path = resolve_image_path(self.image_root, row["RE_PATHNAME"])
le_exists = le_path.exists()
re_exists = re_path.exists()
if not le_exists or not re_exists:
msg = (
f"Missing image for PATID={patid}: "
f"LE exists={le_exists} ({le_path}), "
f"RE exists={re_exists} ({re_path})"
)
if on_missing == "error":
raise FileNotFoundError(msg)
print(f"[skip] {msg}")
continue
self.rows.append(
{
"PATID": patid,
"LE_PATHNAME": str(row["LE_PATHNAME"]),
"RE_PATHNAME": str(row["RE_PATHNAME"]),
"le_path": le_path,
"re_path": re_path,
}
)
def __len__(self):
return len(self.rows)
def __getitem__(self, idx):
row = self.rows[idx]
le_img = load_image(row["le_path"], self.transform)
re_img = load_image(row["re_path"], self.transform)
return {
"patid": int(row["PATID"]),
"le_image": le_img,
"re_image": re_img,
"le_pathname": row["LE_PATHNAME"],
"re_pathname": row["RE_PATHNAME"],
}
def collate_fn(batch):
return {
"patids": np.array([x["patid"] for x in batch]),
"le_images": torch.stack([x["le_image"] for x in batch], dim=0),
"re_images": torch.stack([x["re_image"] for x in batch], dim=0),
"le_pathnames": np.array([x["le_pathname"] for x in batch]),
"re_pathnames": np.array([x["re_pathname"] for x in batch]),
}
def make_feature_names(feature_dim):
names = []
for prefix in ["LE_DRUS", "RE_DRUS", "LE_PIG", "RE_PIG"]:
for i in range(feature_dim):
names.append(f"{prefix}_{i:03d}")
return np.array(names)
@torch.no_grad()
def extract_features(loader, drus_model, pig_model, device):
all_features = []
all_patids = []
all_le_pathnames = []
all_re_pathnames = []
try:
from tqdm import tqdm
iterator = tqdm(loader, desc="Extracting DeepSeeNet features")
except ImportError:
iterator = loader
feature_dim = None
for batch in iterator:
le = batch["le_images"].to(device, non_blocking=True)
re = batch["re_images"].to(device, non_blocking=True)
le_drus = drus_model(le).detach().cpu().numpy()
re_drus = drus_model(re).detach().cpu().numpy()
le_pig = pig_model(le).detach().cpu().numpy()
re_pig = pig_model(re).detach().cpu().numpy()
if feature_dim is None:
feature_dim = le_drus.shape[1]
print(f"Detected feature dimension per model/eye: {feature_dim}")
patient_features = np.concatenate(
[
le_drus,
re_drus,
le_pig,
re_pig,
],
axis=1,
)
all_features.append(patient_features)
all_patids.append(batch["patids"])
all_le_pathnames.append(batch["le_pathnames"])
all_re_pathnames.append(batch["re_pathnames"])
features = np.concatenate(all_features, axis=0)
patids = np.concatenate(all_patids, axis=0)
le_pathnames = np.concatenate(all_le_pathnames, axis=0)
re_pathnames = np.concatenate(all_re_pathnames, axis=0)
feature_names = make_feature_names(feature_dim)
return features, patids, le_pathnames, re_pathnames, feature_names
def main():
args = parse_args()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")
rows = load_json(args.json)
print(f"Loaded JSON rows: {len(rows)}")
transform = AlbumentationsTransform(get_val_transforms(args.image_size))
dataset = AREDSPatientImageDataset(
rows=rows,
image_root=args.image_root,
transform=transform,
on_missing=args.on_missing,
)
print(f"Usable patients: {len(dataset)}")
loader = torch.utils.data.DataLoader(
dataset,
batch_size=args.batch_size,
shuffle=False,
num_workers=args.num_workers,
pin_memory=(device.type == "cuda"),
collate_fn=collate_fn,
)
drus_model = load_model(
args.drusen_weights,
task="DRUS",
backbone=args.backbone,
device=device,
)
pig_model = load_model(
args.pigment_weights,
task="PIG",
backbone=args.backbone,
device=device,
)
features, patids, le_pathnames, re_pathnames, feature_names = extract_features(
loader=loader,
drus_model=drus_model,
pig_model=pig_model,
device=device,
)
print(f"Final feature matrix: {features.shape}")
if features.shape[1] != 512:
print(
"[warning] Expected paper-faithful feature dimension of 512 "
f"but got {features.shape[1]}. This likely means each model's "
f"penultimate feature dimension is {features.shape[1] // 4}, not 128."
)
output = Path(args.output)
output.parent.mkdir(parents=True, exist_ok=True)
np.savez_compressed(
output,
features=features.astype(np.float32),
patids=patids,
le_pathnames=le_pathnames,
re_pathnames=re_pathnames,
feature_names=feature_names,
)
print(f"Saved: {output}")
if __name__ == "__main__":
main()