File size: 7,629 Bytes
e101805 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 |
"""Extract patch-level features from a single WSI.
This script reads patch coordinates from an HDF5 file (typically produced by
`exaonepath.patches.patchfy`) and runs the EXAONE Path patch encoder on the
corresponding image regions to produce patch embeddings.
Input H5 (`--coords_h5_path`) requirements
-----------------------------------------
- Dataset: `coords` (shape: [N, 2]) containing (x, y) coordinates in **level-0** pixel space.
- Attribute on `coords`: `patch_size` (int). If missing, defaults to 256.
- Optional dataset: `contour_index` (shape: [N]). If missing, all patches are treated as contour 0.
Output H5 (`--out_h5_path`) keys
--------------------------------
- `features`: [N, C] patch embeddings
- `coords`: [N, 2] coordinates (copied from input)
- `contour_index`: [N] contour indices (copied from input or synthesized)
Notes
-----
- CUDA is required.
- This implementation loads all extracted features into memory before writing.
"""
import argparse
import os
import h5py
import numpy as np
import openslide
import torch
import torch.backends.cudnn as cudnn
from torch.amp import autocast
from torchvision import transforms
from exaonepath.models.patch_encoder_hf import PatchEncoder
def _open_coords_h5(coords_h5_path):
"""Open coordinates H5 file and read coordinates."""
if not os.path.exists(coords_h5_path):
print(f"Coords H5 file not found: {coords_h5_path}")
return None, None, None
with h5py.File(coords_h5_path, "r") as f:
if "coords" not in f:
print(f"Invalid coords H5 (missing 'coords'): {coords_h5_path}")
return None, None, None
coords = np.array(f["coords"])
patch_size = int(f["coords"].attrs.get("patch_size", 256))
if "contour_index" in f:
contour_index = np.array(f["contour_index"])
else:
# Backward-compat: if contour_index isn't present, treat all patches as one contour.
contour_index = np.zeros((len(coords),), dtype=np.int32)
return coords, patch_size, contour_index
def _save_slide_features(out_h5_path, feat_list, coord_list, contour_list):
os.makedirs(os.path.dirname(os.path.abspath(out_h5_path)), exist_ok=True)
feat_tensor = torch.cat(feat_list, dim=0).numpy()
coords_tensor = torch.cat(coord_list, dim=0).numpy()
contour_tensor = torch.cat(contour_list, dim=0).numpy()
with h5py.File(out_h5_path, "w") as f:
f.create_dataset("features", data=feat_tensor)
f.create_dataset("coords", data=coords_tensor)
f.create_dataset("contour_index", data=contour_tensor)
def process_single_slide(
slide_path,
coords_h5_path,
out_h5_path,
model,
transform,
batch_size_per_gpu=32,
):
device = "cuda"
print(f"Processing slide: {slide_path}")
print(f"Coords H5: {coords_h5_path}")
print(f"Output H5: {out_h5_path}")
# Load coordinates
coords, patch_size, contour_index = _open_coords_h5(coords_h5_path)
if coords is None or contour_index is None:
print("No coordinates loaded; aborting.")
return
if len(coords) == 0:
print("No coordinates found (N=0); nothing to do.")
return
# Open WSI
wsi = openslide.OpenSlide(slide_path)
batch_images = []
batch_coords = []
batch_contours = []
feat_list = []
coord_list = []
contour_list = []
total_patches = len(coords)
# Iterate through all patches
for j, (coord, contour_val) in enumerate(zip(coords, contour_index)):
x, y = int(coord[0]), int(coord[1])
try:
patch = wsi.read_region((x, y), 0, (patch_size, patch_size)).convert("RGB")
patch = transform(patch)
except Exception as e:
print(
f"Error extracting patch at index {j} (coord: {coord}) from slide {slide_path}: {e}"
)
continue
batch_images.append(patch)
batch_coords.append(coord)
batch_contours.append(contour_val)
# Process batch if full or last item
if len(batch_images) >= batch_size_per_gpu or j == total_patches - 1:
imgs_tensor = torch.stack(batch_images).to(device, non_blocking=True)
# Inference & store results (GPU-only)
with torch.inference_mode(), autocast("cuda", torch.bfloat16):
features = model(imgs_tensor)
features = features.detach().float().cpu()
feat_list.append(features)
coord_list.append(torch.tensor(np.array(batch_coords)))
contour_list.append(torch.tensor(np.array(batch_contours)))
batch_images = []
batch_coords = []
batch_contours = []
# Log every ~1000 patches (use 1-based count for readability).
if (j + 1) % 1000 == 0 or (j + 1) == total_patches:
print(f"Processed {j + 1}/{total_patches} patches...")
wsi.close()
# Save results
if feat_list:
_save_slide_features(out_h5_path, feat_list, coord_list, contour_list)
else:
print(f"No features extracted for {slide_path}")
def main(args):
# GPU-only execution (as required by the model release).
if not torch.cuda.is_available():
raise RuntimeError(
"CUDA is required for this script, but torch.cuda.is_available() is False"
)
device = "cuda"
cudnn.benchmark = True
# Load model
print("Loading model...")
model = PatchEncoder.from_pretrained(args.repo_id)
model.to(device).eval()
# Load transform
transform = transforms.Compose(
[
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
),
]
)
# Process input slide
if args.out_h5_path:
out_h5_path = args.out_h5_path
else:
# Default: place next to coords file with a clear suffix.
base, _ = os.path.splitext(args.coords_h5_path)
out_h5_path = base + "_features.h5"
print("Extracting patch features ...")
process_single_slide(
args.slide_path,
args.coords_h5_path,
out_h5_path,
model,
transform,
args.batch_size_per_gpu,
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
"Single WSI patch-feature extraction",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"--slide_path",
type=str,
required=True,
help="Path to a whole-slide image file (.svs/.tif/.tiff/.ndpi/.mrxs/...).",
)
parser.add_argument(
"--coords_h5_path",
type=str,
required=True,
help=(
"Path to the coordinates HDF5 produced by `patchfy` (must contain dataset 'coords' "
"and ideally 'contour_index')."
),
)
parser.add_argument(
"--out_h5_path",
type=str,
default="",
help=(
"Output HDF5 path for patch features. If empty, defaults to '<coords_h5_path>_features.h5'."
),
)
parser.add_argument(
"--batch_size_per_gpu",
type=int,
default=32,
help="Batch size for patch encoder inference on a single GPU.",
)
parser.add_argument(
"--repo_id",
type=str,
default="anonymous-bio/EXAONE-Path-2.5",
help="Hugging Face repo id containing the EXAONE Path 2.5 patch encoder.",
)
args = parser.parse_args()
main(args)
|