File size: 7,629 Bytes
e101805
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
"""Extract patch-level features from a single WSI.

This script reads patch coordinates from an HDF5 file (typically produced by
`exaonepath.patches.patchfy`) and runs the EXAONE Path patch encoder on the
corresponding image regions to produce patch embeddings.

Input H5 (`--coords_h5_path`) requirements
-----------------------------------------
- Dataset: `coords` (shape: [N, 2]) containing (x, y) coordinates in **level-0** pixel space.
- Attribute on `coords`: `patch_size` (int). If missing, defaults to 256.
- Optional dataset: `contour_index` (shape: [N]). If missing, all patches are treated as contour 0.

Output H5 (`--out_h5_path`) keys
--------------------------------
- `features`: [N, C] patch embeddings
- `coords`: [N, 2] coordinates (copied from input)
- `contour_index`: [N] contour indices (copied from input or synthesized)

Notes
-----
- CUDA is required.
- This implementation loads all extracted features into memory before writing.
"""

import argparse
import os

import h5py
import numpy as np
import openslide
import torch
import torch.backends.cudnn as cudnn
from torch.amp import autocast
from torchvision import transforms

from exaonepath.models.patch_encoder_hf import PatchEncoder


def _open_coords_h5(coords_h5_path):
    """Open coordinates H5 file and read coordinates."""
    if not os.path.exists(coords_h5_path):
        print(f"Coords H5 file not found: {coords_h5_path}")
        return None, None, None

    with h5py.File(coords_h5_path, "r") as f:
        if "coords" not in f:
            print(f"Invalid coords H5 (missing 'coords'): {coords_h5_path}")
            return None, None, None

        coords = np.array(f["coords"])
        patch_size = int(f["coords"].attrs.get("patch_size", 256))
        if "contour_index" in f:
            contour_index = np.array(f["contour_index"])
        else:
            # Backward-compat: if contour_index isn't present, treat all patches as one contour.
            contour_index = np.zeros((len(coords),), dtype=np.int32)

    return coords, patch_size, contour_index


def _save_slide_features(out_h5_path, feat_list, coord_list, contour_list):
    os.makedirs(os.path.dirname(os.path.abspath(out_h5_path)), exist_ok=True)
    feat_tensor = torch.cat(feat_list, dim=0).numpy()
    coords_tensor = torch.cat(coord_list, dim=0).numpy()
    contour_tensor = torch.cat(contour_list, dim=0).numpy()

    with h5py.File(out_h5_path, "w") as f:
        f.create_dataset("features", data=feat_tensor)
        f.create_dataset("coords", data=coords_tensor)
        f.create_dataset("contour_index", data=contour_tensor)


def process_single_slide(
    slide_path,
    coords_h5_path,
    out_h5_path,
    model,
    transform,
    batch_size_per_gpu=32,
):
    device = "cuda"
    print(f"Processing slide: {slide_path}")
    print(f"Coords H5: {coords_h5_path}")
    print(f"Output H5: {out_h5_path}")

    # Load coordinates
    coords, patch_size, contour_index = _open_coords_h5(coords_h5_path)
    if coords is None or contour_index is None:
        print("No coordinates loaded; aborting.")
        return

    if len(coords) == 0:
        print("No coordinates found (N=0); nothing to do.")
        return

    # Open WSI
    wsi = openslide.OpenSlide(slide_path)

    batch_images = []
    batch_coords = []
    batch_contours = []

    feat_list = []
    coord_list = []
    contour_list = []

    total_patches = len(coords)

    # Iterate through all patches
    for j, (coord, contour_val) in enumerate(zip(coords, contour_index)):
        x, y = int(coord[0]), int(coord[1])

        try:
            patch = wsi.read_region((x, y), 0, (patch_size, patch_size)).convert("RGB")
            patch = transform(patch)
        except Exception as e:
            print(
                f"Error extracting patch at index {j} (coord: {coord}) from slide {slide_path}: {e}"
            )
            continue

        batch_images.append(patch)
        batch_coords.append(coord)
        batch_contours.append(contour_val)

        # Process batch if full or last item
        if len(batch_images) >= batch_size_per_gpu or j == total_patches - 1:
            imgs_tensor = torch.stack(batch_images).to(device, non_blocking=True)

            # Inference & store results (GPU-only)
            with torch.inference_mode(), autocast("cuda", torch.bfloat16):
                features = model(imgs_tensor)

            features = features.detach().float().cpu()
            feat_list.append(features)
            coord_list.append(torch.tensor(np.array(batch_coords)))
            contour_list.append(torch.tensor(np.array(batch_contours)))

            batch_images = []
            batch_coords = []
            batch_contours = []

            # Log every ~1000 patches (use 1-based count for readability).
            if (j + 1) % 1000 == 0 or (j + 1) == total_patches:
                print(f"Processed {j + 1}/{total_patches} patches...")

    wsi.close()

    # Save results
    if feat_list:
        _save_slide_features(out_h5_path, feat_list, coord_list, contour_list)
    else:
        print(f"No features extracted for {slide_path}")


def main(args):
    # GPU-only execution (as required by the model release).
    if not torch.cuda.is_available():
        raise RuntimeError(
            "CUDA is required for this script, but torch.cuda.is_available() is False"
        )

    device = "cuda"
    cudnn.benchmark = True

    # Load model
    print("Loading model...")
    model = PatchEncoder.from_pretrained(args.repo_id)
    model.to(device).eval()

    # Load transform
    transform = transforms.Compose(
        [
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
            ),
        ]
    )

    # Process input slide
    if args.out_h5_path:
        out_h5_path = args.out_h5_path
    else:
        # Default: place next to coords file with a clear suffix.
        base, _ = os.path.splitext(args.coords_h5_path)
        out_h5_path = base + "_features.h5"

    print("Extracting patch features ...")
    process_single_slide(
        args.slide_path,
        args.coords_h5_path,
        out_h5_path,
        model,
        transform,
        args.batch_size_per_gpu,
    )


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        "Single WSI patch-feature extraction",
        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
    )

    parser.add_argument(
        "--slide_path",
        type=str,
        required=True,
        help="Path to a whole-slide image file (.svs/.tif/.tiff/.ndpi/.mrxs/...).",
    )
    parser.add_argument(
        "--coords_h5_path",
        type=str,
        required=True,
        help=(
            "Path to the coordinates HDF5 produced by `patchfy` (must contain dataset 'coords' "
            "and ideally 'contour_index')."
        ),
    )
    parser.add_argument(
        "--out_h5_path",
        type=str,
        default="",
        help=(
            "Output HDF5 path for patch features. If empty, defaults to '<coords_h5_path>_features.h5'."
        ),
    )
    parser.add_argument(
        "--batch_size_per_gpu",
        type=int,
        default=32,
        help="Batch size for patch encoder inference on a single GPU.",
    )
    parser.add_argument(
        "--repo_id",
        type=str,
        default="anonymous-bio/EXAONE-Path-2.5",
        help="Hugging Face repo id containing the EXAONE Path 2.5 patch encoder.",
    )

    args = parser.parse_args()
    main(args)