roobee79 commited on
Commit
7747544
·
verified ·
1 Parent(s): 5f524ef

Upload 7 files

Browse files
Files changed (8) hide show
  1. .gitattributes +1 -0
  2. HNE2cell_all_patch73_jit.pt +3 -0
  3. inference.py +316 -0
  4. normalize.py +277 -0
  5. patchify.py +254 -0
  6. post_processing.py +348 -0
  7. standard-ilc.tif +3 -0
  8. tools.py +400 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ standard-ilc.tif filter=lfs diff=lfs merge=lfs -text
HNE2cell_all_patch73_jit.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cd428608461bf295e4abfa646fb834c2d24acedca0155d43fdb817da42936593
3
+ size 5138307858
inference.py ADDED
@@ -0,0 +1,316 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ HNE2Cell — Step 3: Cell Detection & Classification Inference
3
+
4
+ Run the HNE2Cell model on extracted patches to detect and classify cells.
5
+ Outputs: per-patch cell masks (PNG) and centroid CSVs with cell type annotations.
6
+
7
+ Usage:
8
+ python inference.py \
9
+ --input_dir /path/to/patch_folders \
10
+ --output_dir /path/to/results \
11
+ --model_path ./HNE2cell_all_patch73_jit.pt \
12
+ --magnification 40 \
13
+ --batch_size 32
14
+
15
+ Cell Types (16 classes):
16
+ 0: Background 4: B 8: DC 12: Epithelial
17
+ 1: Malignant 5: Plasma 9: Fibroblast 13: Immune_Other
18
+ 2: CD4T 6: Macrophage 10: Endothelial 14: Stromal_Other
19
+ 3: CD8T 7: Myeloid 11: Pericyte 15: Dead
20
+ """
21
+
22
+ import os
23
+ import argparse
24
+ import glob
25
+
26
+ import cv2
27
+ import numpy as np
28
+ import pandas as pd
29
+ import torch
30
+ from PIL import Image
31
+ from torch.cuda.amp import autocast
32
+ from torch.utils.data import DataLoader, Dataset
33
+ from torchvision import transforms
34
+ from tqdm import tqdm
35
+
36
+ from post_processing import DetectionCellPostProcessor
37
+
38
+ # ========================== Constants ======================================
39
+
40
+ CELL_TYPES = {
41
+ 0: "Background",
42
+ 1: "Malignant",
43
+ 2: "CD4T",
44
+ 3: "CD8T",
45
+ 4: "B",
46
+ 5: "Plasma",
47
+ 6: "Macrophage",
48
+ 7: "Myeloid",
49
+ 8: "DC",
50
+ 9: "Fibroblast",
51
+ 10: "Endothelial",
52
+ 11: "Pericyte",
53
+ 12: "Epithelial",
54
+ 13: "Immune_Other",
55
+ 14: "Stromal_Other",
56
+ 15: "Dead",
57
+ }
58
+
59
+ # RGBA colors for mask visualization
60
+ CELL_COLORS = {
61
+ 0: [0, 0, 0, 0],
62
+ 1: [255, 0, 0, 255],
63
+ 2: [30, 144, 255, 255],
64
+ 3: [65, 105, 225, 255],
65
+ 4: [0, 0, 255, 255],
66
+ 5: [100, 149, 237, 255],
67
+ 6: [176, 224, 230, 255],
68
+ 7: [70, 130, 180, 255],
69
+ 8: [0, 191, 255, 255],
70
+ 9: [34, 139, 34, 255],
71
+ 10: [60, 179, 113, 255],
72
+ 11: [50, 205, 50, 255],
73
+ 12: [255, 140, 0, 255],
74
+ 13: [176, 224, 230, 255],
75
+ 14: [107, 142, 35, 255],
76
+ 15: [128, 128, 128, 255],
77
+ }
78
+
79
+ # ImageNet-style normalization fitted to H&E data
80
+ TRANSFORM = transforms.Compose(
81
+ [
82
+ transforms.Resize((224, 224)),
83
+ transforms.ToTensor(),
84
+ transforms.Normalize(
85
+ mean=[0.707223, 0.578729, 0.703617],
86
+ std=[0.211883, 0.230117, 0.177517],
87
+ ),
88
+ ]
89
+ )
90
+
91
+
92
+ # ========================== Dataset ========================================
93
+
94
+
95
+ class PatchDataset(Dataset):
96
+ def __init__(self, file_paths, transform=None):
97
+ self.file_paths = file_paths
98
+ self.transform = transform
99
+
100
+ def __len__(self):
101
+ return len(self.file_paths)
102
+
103
+ def __getitem__(self, idx):
104
+ fpath = self.file_paths[idx]
105
+ img = cv2.imread(fpath)
106
+ if img is None:
107
+ print(f"[WARN] Failed to load: {fpath}")
108
+ img = np.zeros((256, 256, 3), dtype=np.uint8)
109
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
110
+ if self.transform:
111
+ img = self.transform(Image.fromarray(img))
112
+ return fpath, img
113
+
114
+
115
+ # ========================== Inference ======================================
116
+
117
+
118
+ def process_batch(
119
+ batch,
120
+ model,
121
+ device,
122
+ mask_output_dir,
123
+ centroid_records,
124
+ magnification=40,
125
+ ):
126
+ """Run inference on one batch and save masks + centroid info."""
127
+ file_paths, images = batch
128
+
129
+ with torch.no_grad():
130
+ with autocast():
131
+ outputs = model(images.to(device, non_blocking=True))
132
+
133
+ for i, fpath in enumerate(file_paths):
134
+ slide_id = os.path.splitext(os.path.basename(fpath))[0]
135
+
136
+ # Extract per-sample predictions
137
+ cell_type_map = outputs["cell_type_map"][i].float().detach().cpu()
138
+ nuclei_binary_map = outputs["nuclei_binary_map"][i].float().detach().cpu()
139
+ hv_map = outputs["hv_map"][i].float().detach().cpu()
140
+ tissue_type_map = outputs["tissue_type_map"][i].float().detach().cpu()
141
+
142
+ # Build prediction map [H, W, 5]
143
+ pred_map = np.concatenate(
144
+ [
145
+ torch.argmax(tissue_type_map, dim=0)[..., None].numpy(),
146
+ torch.argmax(cell_type_map, dim=0)[..., None].numpy(),
147
+ torch.argmax(nuclei_binary_map, dim=0)[..., None].numpy(),
148
+ hv_map.permute(1, 2, 0).numpy(),
149
+ ],
150
+ axis=-1,
151
+ )
152
+
153
+ # Post-processing
154
+ post_processor = DetectionCellPostProcessor(
155
+ nr_types=cell_type_map.shape[0],
156
+ magnification=magnification,
157
+ gt=False,
158
+ )
159
+ _, type_pred = post_processor.post_process_cell_segmentation(pred_map)
160
+
161
+ # Create mask image
162
+ mask = np.ones((256, 256, 3), dtype=np.uint8) * 255
163
+ for cell in type_pred.values():
164
+ ctype = cell["type"]
165
+ rgba = CELL_COLORS.get(ctype, [255, 255, 255, 255])
166
+ bgr = [rgba[2], rgba[1], rgba[0]]
167
+ cv2.fillPoly(mask, [cell["contour"]], bgr)
168
+
169
+ centroid_records.append(
170
+ {
171
+ "slide_id": slide_id,
172
+ "x": cell["centroid"][0],
173
+ "y": cell["centroid"][1],
174
+ "celltype": ctype,
175
+ "celltype_name": CELL_TYPES.get(ctype, "Unknown"),
176
+ }
177
+ )
178
+
179
+ # Save mask only if non-trivial
180
+ if not np.all(mask == 255):
181
+ cv2.imwrite(
182
+ os.path.join(mask_output_dir, f"{slide_id}_mask.png"), mask
183
+ )
184
+
185
+
186
+ def run_inference(
187
+ patch_folders: list[str],
188
+ model,
189
+ device,
190
+ output_dir: str,
191
+ magnification: int = 40,
192
+ batch_size: int = 32,
193
+ num_workers: int = 4,
194
+ ):
195
+ """Run inference over a list of patch folders."""
196
+ model.to(device).eval()
197
+
198
+ for folder in patch_folders:
199
+ folder_name = os.path.basename(folder)
200
+ png_files = sorted(glob.glob(os.path.join(folder, "*.png")))
201
+
202
+ if not png_files:
203
+ print(f"[SKIP] {folder}: no PNG patches found")
204
+ continue
205
+
206
+ mask_dir = os.path.join(output_dir, "mask_patches", f"{folder_name}")
207
+ centroid_path = os.path.join(output_dir, "centroid", f"{folder_name}_centroid.csv")
208
+ os.makedirs(mask_dir, exist_ok=True)
209
+ os.makedirs(os.path.dirname(centroid_path), exist_ok=True)
210
+
211
+ dataset = PatchDataset(png_files, transform=TRANSFORM)
212
+ loader = DataLoader(
213
+ dataset,
214
+ batch_size=batch_size,
215
+ num_workers=num_workers,
216
+ pin_memory=True,
217
+ shuffle=False,
218
+ prefetch_factor=2,
219
+ persistent_workers=True,
220
+ )
221
+
222
+ centroids = []
223
+ for batch in tqdm(loader, desc=f"Inference: {folder_name}"):
224
+ process_batch(
225
+ batch, model, device, mask_dir, centroids, magnification
226
+ )
227
+
228
+ df = pd.DataFrame(centroids)
229
+ df.to_csv(centroid_path, index=False)
230
+ print(f"[DONE] {folder_name} → {centroid_path} ({len(df)} cells)")
231
+
232
+
233
+ # =============================== CLI =======================================
234
+
235
+
236
+ def main():
237
+ parser = argparse.ArgumentParser(description="HNE2Cell inference")
238
+ parser.add_argument(
239
+ "--input_dir",
240
+ type=str,
241
+ required=True,
242
+ help="Directory containing patch folders (each with *.png)",
243
+ )
244
+ parser.add_argument(
245
+ "--output_dir",
246
+ type=str,
247
+ required=True,
248
+ help="Output directory for masks and centroid CSVs",
249
+ )
250
+ parser.add_argument(
251
+ "--model_path",
252
+ type=str,
253
+ required=True,
254
+ help="Path to the TorchScript JIT model (.pt)",
255
+ )
256
+ parser.add_argument(
257
+ "--magnification",
258
+ type=int,
259
+ default=40,
260
+ choices=[20, 40],
261
+ help="Magnification of input patches. 40x recommended. (default: 40)",
262
+ )
263
+ parser.add_argument("--batch_size", type=int, default=32)
264
+ parser.add_argument("--num_workers", type=int, default=4)
265
+ parser.add_argument(
266
+ "--device",
267
+ type=str,
268
+ default="auto",
269
+ help="Device: 'cuda', 'cpu', or 'auto' (default: auto)",
270
+ )
271
+
272
+ args = parser.parse_args()
273
+
274
+ # Device
275
+ if args.device == "auto":
276
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
277
+ else:
278
+ device = torch.device(args.device)
279
+ print(f"Using device: {device}")
280
+
281
+ # Load model
282
+ print(f"Loading model: {args.model_path}")
283
+ model = torch.jit.load(args.model_path, map_location=device)
284
+ model.eval()
285
+
286
+ if args.magnification == 20:
287
+ print(
288
+ "⚠️ Running at 20x. Results are usable but 40x is recommended "
289
+ "for best accuracy, especially for small immune cells."
290
+ )
291
+
292
+ # Collect patch folders
293
+ patch_folders = sorted(
294
+ p
295
+ for p in glob.glob(os.path.join(args.input_dir, "*"))
296
+ if os.path.isdir(p)
297
+ )
298
+ # Also check if input_dir itself contains patches
299
+ if not patch_folders and glob.glob(os.path.join(args.input_dir, "*.png")):
300
+ patch_folders = [args.input_dir]
301
+
302
+ print(f"Found {len(patch_folders)} patch folder(s)")
303
+
304
+ run_inference(
305
+ patch_folders,
306
+ model,
307
+ device,
308
+ args.output_dir,
309
+ magnification=args.magnification,
310
+ batch_size=args.batch_size,
311
+ num_workers=args.num_workers,
312
+ )
313
+
314
+
315
+ if __name__ == "__main__":
316
+ main()
normalize.py ADDED
@@ -0,0 +1,277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ HNE2Cell — Step 1: Reinhard Color Normalization
3
+
4
+ Normalize H&E stained whole-slide images (WSI) to a reference color distribution
5
+ using the Reinhard method in LAB color space.
6
+
7
+ Supported input formats: .svs, .tif, .tiff, .ndpi
8
+ Output: Aligned-hne.tif (full-resolution normalized), Aligned-hne.jpg (4x downsampled preview)
9
+
10
+ Usage:
11
+ python normalize.py \
12
+ --input_dir /path/to/slides \
13
+ --target /path/to/standard-ilc.tif \
14
+ --patch_size 128 \
15
+ --saturation_threshold 0.1
16
+ """
17
+
18
+ import os
19
+ import argparse
20
+ import glob
21
+
22
+ import numpy as np
23
+ import tifffile as tiff
24
+ from PIL import Image
25
+ from skimage import color
26
+
27
+ Image.MAX_IMAGE_PIXELS = None
28
+ os.environ["OPENCV_IO_MAX_IMAGE_PIXELS"] = str(pow(2, 40))
29
+
30
+ # ---------------------------------------------------------------------------
31
+ # Optional: openslide (only needed for .svs / .ndpi)
32
+ # ---------------------------------------------------------------------------
33
+ try:
34
+ import openslide
35
+
36
+ OPENSLIDE_AVAILABLE = True
37
+ except ImportError:
38
+ OPENSLIDE_AVAILABLE = False
39
+
40
+
41
+ # ============================= I/O helpers =================================
42
+
43
+
44
+ def load_image(image_path: str, level: int = 0) -> np.ndarray:
45
+ """Load a whole-slide image as an RGB numpy array.
46
+
47
+ Supports .svs/.ndpi (via OpenSlide) and .tif/.tiff (via tifffile).
48
+ """
49
+ ext = os.path.splitext(image_path)[1].lower()
50
+
51
+ if ext in (".svs", ".ndpi"):
52
+ if not OPENSLIDE_AVAILABLE:
53
+ raise ImportError(
54
+ "openslide-python is required to read .svs/.ndpi files. "
55
+ "Install it with: pip install openslide-python"
56
+ )
57
+ slide = openslide.OpenSlide(image_path)
58
+ image = slide.read_region((0, 0), level, slide.level_dimensions[level])
59
+ image = image.convert("RGB")
60
+ slide.close()
61
+ return np.array(image)
62
+
63
+ if ext in (".tif", ".tiff"):
64
+ image = tiff.imread(image_path)
65
+ if image.ndim == 2:
66
+ image = np.stack((image,) * 3, axis=-1)
67
+ elif image.ndim == 4 and image.shape[0] == 1:
68
+ image = image[0]
69
+ # Ensure RGB uint8
70
+ if image.dtype != np.uint8:
71
+ image = np.clip(image, 0, 255).astype(np.uint8)
72
+ return image
73
+
74
+ raise ValueError(f"Unsupported file format: {ext}")
75
+
76
+
77
+ # ======================== Saturation filtering =============================
78
+
79
+
80
+ def calculate_saturation(patch: Image.Image) -> float:
81
+ hsv = patch.convert("HSV")
82
+ return np.mean(np.array(hsv)[:, :, 1] / 255.0)
83
+
84
+
85
+ def extract_high_saturation_patches(
86
+ image: np.ndarray, patch_size: int, saturation_threshold: float
87
+ ) -> list:
88
+ """Return list of ((x0, y0), patch_array) for patches above the saturation threshold."""
89
+ pil_img = Image.fromarray(image)
90
+ width, height = pil_img.size
91
+
92
+ patches = []
93
+ for i in range(width // patch_size):
94
+ for j in range(height // patch_size):
95
+ x0, y0 = i * patch_size, j * patch_size
96
+ patch = pil_img.crop((x0, y0, x0 + patch_size, y0 + patch_size))
97
+ if calculate_saturation(patch) >= saturation_threshold:
98
+ patches.append(((x0, y0), np.array(patch)))
99
+ return patches
100
+
101
+
102
+ def reconstruct_from_patches(
103
+ width: int, height: int, patch_size: int, patches: list
104
+ ) -> np.ndarray:
105
+ """Place high-saturation patches back into a blank canvas (background = black)."""
106
+ canvas = np.zeros((height, width, 3), dtype=np.uint8)
107
+ for (x0, y0), arr in patches:
108
+ if arr.shape == (patch_size, patch_size, 3):
109
+ canvas[y0 : y0 + patch_size, x0 : x0 + patch_size, :] = arr
110
+ return canvas
111
+
112
+
113
+ # =================== Reinhard color normalization ==========================
114
+
115
+
116
+ def _color_convert_chunked(image, func, chunk_size=16384):
117
+ """Apply color conversion function in spatial chunks to limit memory."""
118
+ h, w, _ = image.shape
119
+ out = np.zeros_like(image, dtype=np.float32)
120
+ for i in range(0, h, chunk_size):
121
+ for j in range(0, w, chunk_size):
122
+ out[i : min(i + chunk_size, h), j : min(j + chunk_size, w), :] = func(
123
+ image[i : min(i + chunk_size, h), j : min(j + chunk_size, w), :]
124
+ )
125
+ return out
126
+
127
+
128
+ def reinhard_normalize(source: np.ndarray, target: np.ndarray) -> np.ndarray:
129
+ """Reinhard color normalization in LAB space.
130
+
131
+ Only non-zero (tissue) pixels are used for statistics.
132
+ Returns float64 image in [0, 1] range.
133
+ """
134
+ src_lab = _color_convert_chunked(source, color.rgb2lab)
135
+ tgt_lab = color.rgb2lab(target)
136
+
137
+ for ch in range(3):
138
+ src_ch = src_lab[:, :, ch]
139
+ tgt_ch = tgt_lab[:, :, ch]
140
+
141
+ src_vals = src_ch[src_ch != 0]
142
+ tgt_vals = tgt_ch[tgt_ch != 0]
143
+
144
+ if len(src_vals) == 0 or len(tgt_vals) == 0:
145
+ continue
146
+
147
+ src_mean, src_std = src_vals.mean(), src_vals.std()
148
+ tgt_mean, tgt_std = tgt_vals.mean(), tgt_vals.std()
149
+
150
+ if src_std < 1e-6:
151
+ continue
152
+
153
+ src_lab[:, :, ch] = np.where(
154
+ src_ch != 0,
155
+ (src_ch - src_mean) * (tgt_std / src_std) + tgt_mean,
156
+ 0,
157
+ )
158
+
159
+ return _color_convert_chunked(src_lab, color.lab2rgb)
160
+
161
+
162
+ # ============================= Main pipeline ===============================
163
+
164
+
165
+ def normalize_slide(
166
+ slide_path: str,
167
+ target_image: np.ndarray,
168
+ patch_size: int = 128,
169
+ saturation_threshold: float = 0.1,
170
+ output_dir: str | None = None,
171
+ skip_existing: bool = True,
172
+ ):
173
+ """Full normalization pipeline for a single slide."""
174
+
175
+ if output_dir is None:
176
+ output_dir = os.path.dirname(slide_path)
177
+
178
+ output_tif = os.path.join(output_dir, "Aligned-hne.tif")
179
+ output_jpg = os.path.join(output_dir, "Aligned-hne.jpg")
180
+
181
+ if skip_existing and os.path.exists(output_tif):
182
+ print(f"[SKIP] {slide_path} — Aligned-hne.tif already exists.")
183
+ return
184
+
185
+ print(f"[LOAD] {slide_path}")
186
+ raw = load_image(slide_path)
187
+ h, w = raw.shape[:2]
188
+
189
+ # 1. Saturation-based tissue detection
190
+ patches = extract_high_saturation_patches(
191
+ raw, patch_size, saturation_threshold
192
+ )
193
+ reconstructed = reconstruct_from_patches(w, h, patch_size, patches)
194
+
195
+ # 2. (Optional) save intermediate reconstruction
196
+ recon_path = os.path.join(output_dir, "recon.tif")
197
+ bigtiff = reconstructed.nbytes > 4 * 1024**3
198
+ tiff.imwrite(recon_path, reconstructed, bigtiff=bigtiff)
199
+
200
+ # 3. Reinhard normalization
201
+ normalized = reinhard_normalize(reconstructed, target_image)
202
+ normalized_u8 = (normalized * 255).astype(np.uint8)
203
+
204
+ # 4. Save outputs
205
+ tiff.imwrite(output_tif, normalized_u8, bigtiff=bigtiff)
206
+
207
+ resized = Image.fromarray(normalized_u8).resize(
208
+ (w // 4, h // 4), Image.LANCZOS
209
+ )
210
+ resized.save(output_jpg, quality=90)
211
+
212
+ print(f"[DONE] {slide_path} → {output_tif}")
213
+
214
+
215
+ # =============================== CLI =======================================
216
+
217
+
218
+ def main():
219
+ parser = argparse.ArgumentParser(
220
+ description="Reinhard color normalization for H&E WSIs"
221
+ )
222
+ parser.add_argument(
223
+ "--input_dir",
224
+ type=str,
225
+ required=True,
226
+ help="Root directory to search for slide files (.svs, .tif, .tiff, .ndpi)",
227
+ )
228
+ parser.add_argument(
229
+ "--target",
230
+ type=str,
231
+ required=True,
232
+ help="Path to the reference/target image (.tif)",
233
+ )
234
+ parser.add_argument("--patch_size", type=int, default=128)
235
+ parser.add_argument("--saturation_threshold", type=float, default=0.1)
236
+ parser.add_argument(
237
+ "--output_dir",
238
+ type=str,
239
+ default=None,
240
+ help="If set, all outputs go here. Otherwise, outputs are saved next to each slide.",
241
+ )
242
+
243
+ args = parser.parse_args()
244
+
245
+ # Load target image once
246
+ target_image = load_image(args.target)
247
+
248
+ # Collect slides
249
+ extensions = ("*.svs", "*.tif", "*.tiff", "*.ndpi")
250
+ slides = []
251
+ for ext in extensions:
252
+ slides.extend(glob.glob(os.path.join(args.input_dir, "**", ext), recursive=True))
253
+
254
+ # Exclude files that are already outputs
255
+ slides = [
256
+ s
257
+ for s in slides
258
+ if os.path.basename(s) not in ("Aligned-hne.tif", "Aligned-hne.tiff", "recon.tif")
259
+ ]
260
+
261
+ print(f"Found {len(slides)} slide(s) in {args.input_dir}")
262
+
263
+ for slide_path in slides:
264
+ try:
265
+ normalize_slide(
266
+ slide_path,
267
+ target_image,
268
+ patch_size=args.patch_size,
269
+ saturation_threshold=args.saturation_threshold,
270
+ output_dir=args.output_dir,
271
+ )
272
+ except Exception as e:
273
+ print(f"[ERROR] {slide_path}: {e}")
274
+
275
+
276
+ if __name__ == "__main__":
277
+ main()
patchify.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ HNE2Cell — Step 2: Patch Extraction
3
+
4
+ Extract overlapping patches from color-normalized H&E images for cell detection.
5
+ Supports both 20x and 40x magnification (40x recommended for best results).
6
+
7
+ Usage:
8
+ # 40x (recommended)
9
+ python patchify.py \
10
+ --input_dir /path/to/slides \
11
+ --patch_size 256 \
12
+ --overlap 64 \
13
+ --magnification 40 \
14
+ --workers 8
15
+
16
+ # 20x (supported but 40x preferred)
17
+ python patchify.py \
18
+ --input_dir /path/to/slides \
19
+ --patch_size 256 \
20
+ --overlap 64 \
21
+ --magnification 20 \
22
+ --workers 8
23
+
24
+ Notes:
25
+ - 40x magnification is recommended for optimal cell detection accuracy.
26
+ - 20x is supported and functional, but fine-grained cell boundaries
27
+ (especially small immune cells) may be less precise.
28
+ - Input: Aligned-hne.tif (output of normalize.py)
29
+ - Output: <section>/patches_<mag>x_p<patch>_o<overlap>/<name>_<x>_<y>.png
30
+ """
31
+
32
+ import os
33
+ import argparse
34
+ import glob
35
+ import time
36
+ from multiprocessing import Pool
37
+
38
+ import numpy as np
39
+ from PIL import Image
40
+ from tqdm import tqdm
41
+
42
+ Image.MAX_IMAGE_PIXELS = None
43
+
44
+
45
+ # ========================== Utility functions ==============================
46
+
47
+
48
+ def black_to_white(pil_img: Image.Image) -> Image.Image:
49
+ """Replace pure-black (0,0,0) pixels with white — avoids dark-border artifacts."""
50
+ arr = np.array(pil_img)
51
+ if arr.ndim == 3 and arr.shape[2] >= 3:
52
+ mask = (arr[..., :3] == 0).all(axis=-1)
53
+ arr[mask] = 255
54
+ return Image.fromarray(arr)
55
+
56
+
57
+ def make_start_positions(length: int, patch_size: int, stride: int) -> list[int]:
58
+ """Generate start positions so the last patch always reaches the edge."""
59
+ if length < patch_size:
60
+ return [0]
61
+ starts = list(range(0, length - patch_size + 1, stride))
62
+ last = length - patch_size
63
+ if starts[-1] != last:
64
+ starts.append(last)
65
+ return starts
66
+
67
+
68
+ # ========================== Core patching ==================================
69
+
70
+
71
+ def extract_patches(
72
+ image_path: str,
73
+ output_dir: str,
74
+ patch_size: int = 256,
75
+ overlap: int = 64,
76
+ prefix: str = "patch",
77
+ ) -> int:
78
+ """Crop overlapping patches from a single image and save as PNG.
79
+
80
+ Returns the number of patches saved.
81
+ """
82
+ os.makedirs(output_dir, exist_ok=True)
83
+
84
+ stride = patch_size - overlap
85
+ assert stride > 0, f"overlap ({overlap}) must be < patch_size ({patch_size})"
86
+
87
+ img = Image.open(image_path).convert("RGB")
88
+ width, height = img.size
89
+
90
+ xs = make_start_positions(width, patch_size, stride)
91
+ ys = make_start_positions(height, patch_size, stride)
92
+
93
+ count = 0
94
+ with tqdm(total=len(xs) * len(ys), desc=prefix, unit="patch", leave=False) as pbar:
95
+ for x0 in xs:
96
+ for y0 in ys:
97
+ patch = img.crop((x0, y0, x0 + patch_size, y0 + patch_size))
98
+ patch = black_to_white(patch)
99
+ patch.save(
100
+ os.path.join(output_dir, f"{prefix}_{x0}_{y0}.png"),
101
+ format="PNG",
102
+ )
103
+ count += 1
104
+ pbar.update(1)
105
+ return count
106
+
107
+
108
+ # =================== Per-section processing (for Pool) =====================
109
+
110
+ # These will be set once in main() before the pool is created
111
+ _ARGS = {}
112
+
113
+
114
+ def _process_section(section_dir: str) -> str:
115
+ """Process a single section directory. Designed for multiprocessing.Pool."""
116
+
117
+ patch_size = _ARGS["patch_size"]
118
+ overlap = _ARGS["overlap"]
119
+ magnification = _ARGS["magnification"]
120
+ input_filename = _ARGS["input_filename"]
121
+
122
+ # Locate input file
123
+ candidates = [
124
+ os.path.join(section_dir, f"{input_filename}.tif"),
125
+ os.path.join(section_dir, f"{input_filename}.tiff"),
126
+ ]
127
+ image_path = next((p for p in candidates if os.path.exists(p)), None)
128
+
129
+ if image_path is None:
130
+ return f"[SKIP] {section_dir}: {input_filename}.tif not found"
131
+
132
+ stride = patch_size - overlap
133
+ out_dir = os.path.join(
134
+ section_dir,
135
+ f"patches_{magnification}x_p{patch_size}_o{overlap}",
136
+ )
137
+
138
+ section_name = os.path.basename(section_dir)
139
+
140
+ t0 = time.time()
141
+ n = extract_patches(
142
+ image_path=image_path,
143
+ output_dir=out_dir,
144
+ patch_size=patch_size,
145
+ overlap=overlap,
146
+ prefix=section_name,
147
+ )
148
+ dt = time.time() - t0
149
+
150
+ return (
151
+ f"[OK] {section_name} | {magnification}x | "
152
+ f"stride={stride} | {n} patches | {dt:.1f}s → {out_dir}"
153
+ )
154
+
155
+
156
+ # =============================== CLI =======================================
157
+
158
+
159
+ def main():
160
+ parser = argparse.ArgumentParser(
161
+ description="Extract overlapping patches from normalized H&E images"
162
+ )
163
+ parser.add_argument(
164
+ "--input_dir",
165
+ type=str,
166
+ required=True,
167
+ help="Root directory containing section folders with Aligned-hne.tif files",
168
+ )
169
+ parser.add_argument(
170
+ "--input_filename",
171
+ type=str,
172
+ default="Aligned-hne",
173
+ help="Base filename of the normalized image (default: Aligned-hne)",
174
+ )
175
+ parser.add_argument(
176
+ "--patch_size", type=int, default=256, help="Patch size in pixels (default: 256)"
177
+ )
178
+ parser.add_argument(
179
+ "--overlap", type=int, default=64, help="Overlap in pixels (default: 64)"
180
+ )
181
+ parser.add_argument(
182
+ "--magnification",
183
+ type=int,
184
+ default=40,
185
+ choices=[20, 40],
186
+ help="Slide magnification. 40x recommended; 20x supported. (default: 40)",
187
+ )
188
+ parser.add_argument(
189
+ "--pattern",
190
+ type=str,
191
+ default="*",
192
+ help="Glob pattern to match section folders (default: '*')",
193
+ )
194
+ parser.add_argument(
195
+ "--workers", type=int, default=8, help="Number of parallel workers (default: 8)"
196
+ )
197
+
198
+ args = parser.parse_args()
199
+
200
+ if args.magnification == 20:
201
+ print(
202
+ "⚠️ 20x magnification is supported but 40x is recommended for best "
203
+ "cell detection accuracy (especially small immune cells)."
204
+ )
205
+
206
+ # Collect section directories
207
+ section_dirs = sorted(
208
+ p
209
+ for p in glob.glob(os.path.join(args.input_dir, args.pattern))
210
+ if os.path.isdir(p)
211
+ )
212
+
213
+ if not section_dirs:
214
+ # Maybe input_dir itself contains the image directly
215
+ candidates = [
216
+ os.path.join(args.input_dir, f"{args.input_filename}.tif"),
217
+ os.path.join(args.input_dir, f"{args.input_filename}.tiff"),
218
+ ]
219
+ if any(os.path.exists(c) for c in candidates):
220
+ section_dirs = [args.input_dir]
221
+ else:
222
+ raise SystemExit(
223
+ f"No section folders matching '{args.pattern}' found in {args.input_dir}"
224
+ )
225
+
226
+ print(f"Found {len(section_dirs)} section(s) | {args.magnification}x | "
227
+ f"patch={args.patch_size} overlap={args.overlap}")
228
+
229
+ # Set global args for worker processes
230
+ global _ARGS
231
+ _ARGS = {
232
+ "patch_size": args.patch_size,
233
+ "overlap": args.overlap,
234
+ "magnification": args.magnification,
235
+ "input_filename": args.input_filename,
236
+ }
237
+
238
+ if args.workers <= 1 or len(section_dirs) == 1:
239
+ results = [_process_section(d) for d in tqdm(section_dirs, desc="Sections")]
240
+ else:
241
+ with Pool(processes=min(args.workers, len(section_dirs))) as pool:
242
+ results = list(
243
+ tqdm(
244
+ pool.imap_unordered(_process_section, section_dirs),
245
+ total=len(section_dirs),
246
+ desc="Sections",
247
+ )
248
+ )
249
+
250
+ print("\n".join(results))
251
+
252
+
253
+ if __name__ == "__main__":
254
+ main()
post_processing.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # PostProcessing Pipeline
3
+ #
4
+ # Adapted from HoverNet
5
+ # HoverNet Network (https://doi.org/10.1016/j.media.2019.101563)
6
+ # Code Snippet adapted from HoverNet implementation (https://github.com/vqdang/hover_net)
7
+ #
8
+ # @ Fabian Hörst, fabian.hoerst@uk-essen.de
9
+ # Institute for Artifical Intelligence in Medicine,
10
+ # University Medicine Essen
11
+
12
+
13
+ import warnings
14
+ from typing import Tuple, Literal
15
+
16
+ import cv2
17
+ import numpy as np
18
+ from scipy.ndimage import measurements
19
+ from scipy.ndimage.morphology import binary_fill_holes
20
+ from skimage.segmentation import watershed
21
+ import torch
22
+
23
+ # import sys
24
+ # sys.path.append("/home01/k123a01/CellViTR/cell_segmentation/utils/")
25
+ from tools import get_bounding_box, remove_small_objects
26
+
27
+
28
+ def noop(*args, **kargs):
29
+ pass
30
+
31
+
32
+ warnings.warn = noop
33
+
34
+
35
+ class DetectionCellPostProcessor:
36
+ def __init__(
37
+ self,
38
+ nr_types: int = None,
39
+ magnification: Literal[20, 40] = 40,
40
+ gt: bool = False,
41
+ ) -> None:
42
+ """DetectionCellPostProcessor for postprocessing prediction maps and get detected cells
43
+
44
+ Args:
45
+ nr_types (int, optional): Number of cell types, including background (background = 0). Defaults to None.
46
+ magnification (Literal[20, 40], optional): Which magnification the data has. Defaults to 40.
47
+ gt (bool, optional): If this is gt data (used that we do not suppress tiny cells that may be noise in a prediction map).
48
+ Defaults to False.
49
+
50
+ Raises:
51
+ NotImplementedError: Unknown magnification
52
+ """
53
+ self.nr_types = nr_types
54
+ self.magnification = magnification
55
+ self.gt = gt
56
+
57
+ if magnification == 40:
58
+ self.object_size = 10
59
+ self.k_size = 21
60
+ elif magnification == 20:
61
+ self.object_size = 3 # 3 or 40, we used 5
62
+ self.k_size = 11 # 11 or 41, we used 13
63
+ else:
64
+ raise NotImplementedError("Unknown magnification")
65
+ if gt: # to not supress something in gt!
66
+ self.object_size = 100
67
+ self.k_size = 21
68
+
69
+ def post_process_cell_segmentation(
70
+ self,
71
+ pred_map: np.ndarray,
72
+ ) -> Tuple[np.ndarray, dict]:
73
+ """Post processing of one image tile
74
+
75
+ Args:
76
+ pred_map (np.ndarray): Combined output of tp, np and hv branches, in the same order. Shape: (H, W, 4)
77
+
78
+ Returns:
79
+ Tuple[np.ndarray, dict]:
80
+ np.ndarray: Instance map for one image. Each nuclei has own integer. Shape: (H, W)
81
+ dict: Instance dictionary. Main Key is the nuclei instance number (int), with a dict as value.
82
+ For each instance, the dictionary contains the keys: bbox (bounding box), centroid (centroid coordinates),
83
+ contour, type_prob (probability), type (nuclei type)
84
+ """
85
+ if self.nr_types is not None:
86
+ pred_type = pred_map[..., 1:2]
87
+ pred_inst = pred_map[..., 2:]
88
+ pred_type = pred_type.astype(np.int32)
89
+ # print('pred_type',pred_type)
90
+ else:
91
+ pred_inst = pred_map
92
+
93
+ pred_inst = np.squeeze(pred_inst)
94
+ pred_inst = self.__proc_np_hv(
95
+ pred_inst, object_size=self.object_size, ksize=self.k_size
96
+ )
97
+ # print('pred_inst',pred_inst)
98
+ inst_id_list = np.unique(pred_inst)[1:] # exlcude background
99
+ # print('inst_id_list',inst_id_list)
100
+ inst_info_dict = {}
101
+ for inst_id in inst_id_list:
102
+ inst_map = pred_inst == inst_id
103
+ rmin, rmax, cmin, cmax = get_bounding_box(inst_map)
104
+ inst_bbox = np.array([[rmin, cmin], [rmax, cmax]])
105
+ inst_map = inst_map[
106
+ inst_bbox[0][0] : inst_bbox[1][0], inst_bbox[0][1] : inst_bbox[1][1]
107
+ ]
108
+ inst_map = inst_map.astype(np.uint8)
109
+ inst_moment = cv2.moments(inst_map)
110
+ inst_contour = cv2.findContours(
111
+ inst_map, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
112
+ )
113
+ # * opencv protocol format may break
114
+ inst_contour = np.squeeze(inst_contour[0][0].astype("int32"))
115
+ # < 3 points dont make a contour, so skip, likely artifact too
116
+ # as the contours obtained via approximation => too small or sthg
117
+ if inst_contour.shape[0] < 3:
118
+ continue
119
+ if len(inst_contour.shape) != 2:
120
+ continue # ! check for trickery shape
121
+ inst_centroid = [
122
+ (inst_moment["m10"] / inst_moment["m00"]),
123
+ (inst_moment["m01"] / inst_moment["m00"]),
124
+ ]
125
+ inst_centroid = np.array(inst_centroid)
126
+ inst_contour[:, 0] += inst_bbox[0][1] # X
127
+ inst_contour[:, 1] += inst_bbox[0][0] # Y
128
+ inst_centroid[0] += inst_bbox[0][1] # X
129
+ inst_centroid[1] += inst_bbox[0][0] # Y
130
+ inst_info_dict[inst_id] = { # inst_id should start at 1
131
+ "bbox": inst_bbox,
132
+ "centroid": inst_centroid,
133
+ "contour": inst_contour,
134
+ "type_prob": None,
135
+ "type": None,
136
+ "all_type_prob": []
137
+ }
138
+
139
+ #### * Get class of each instance id, stored at index id-1 (inst_id = number of deteced nucleus)
140
+ for inst_id in list(inst_info_dict.keys()):
141
+ rmin, cmin, rmax, cmax = (inst_info_dict[inst_id]["bbox"]).flatten()
142
+ inst_map_crop = pred_inst[rmin:rmax, cmin:cmax]
143
+ inst_type_crop = pred_type[rmin:rmax, cmin:cmax]
144
+ inst_map_crop = inst_map_crop == inst_id
145
+ inst_type = inst_type_crop[inst_map_crop]
146
+ type_list, type_pixels = np.unique(inst_type, return_counts=True)
147
+ type_list = list(zip(type_list, type_pixels))
148
+
149
+ type_probs = {} # 각 인스턴스에 대한 cell type의 확률
150
+ total_pixels = np.sum(inst_map_crop) + 1.0e-6 # 0으로 나누지 않게 하기 위한 작은 값 추가
151
+ for cell_type, pixel_count in type_list:
152
+ type_probs[cell_type] = float(pixel_count / total_pixels)
153
+
154
+ all_type_prob = [type_probs.get(i, 0.0) for i in range(self.nr_types)] # 전체 cell type에 대한 확률
155
+ inst_info_dict[inst_id]["all_type_prob"] = all_type_prob
156
+ #print("inst all probs :", inst_info_dict[inst_id]["all_type_prob"])
157
+
158
+ type_list = sorted(type_list, key=lambda x: x[1], reverse=True)
159
+ inst_type = type_list[0][0]
160
+ if inst_type == 0: # ! pick the 2nd most dominant if exist
161
+ if len(type_list) > 1:
162
+ inst_type = type_list[1][0]
163
+ type_dict = {v[0]: v[1] for v in type_list}
164
+ type_prob = type_dict[inst_type] / (np.sum(inst_map_crop) + 1.0e-6)
165
+ inst_info_dict[inst_id]["type"] = int(inst_type)
166
+ inst_info_dict[inst_id]["type_prob"] = float(type_prob)
167
+
168
+ return pred_inst, inst_info_dict
169
+
170
+ def __proc_np_hv(
171
+ self, pred: np.ndarray, object_size: int = 10, ksize: int = 21
172
+ ) -> np.ndarray:
173
+ """Process Nuclei Prediction with XY Coordinate Map and generate instance map (each instance has unique integer)
174
+
175
+ Separate Instances (also overlapping ones) from binary nuclei map and hv map by using morphological operations and watershed
176
+
177
+ Args:
178
+ pred (np.ndarray): Prediction output, assuming. Shape: (H, W, 3)
179
+ * channel 0 contain probability map of nuclei
180
+ * channel 1 containing the regressed X-map
181
+ * channel 2 containing the regressed Y-map
182
+ object_size (int, optional): Smallest oject size for filtering. Defaults to 10
183
+ k_size (int, optional): Sobel Kernel size. Defaults to 21
184
+ Returns:
185
+ np.ndarray: Instance map for one image. Each nuclei has own integer. Shape: (H, W)
186
+ """
187
+ pred = np.array(pred, dtype=np.float32)
188
+
189
+ blb_raw = pred[..., 0]
190
+ h_dir_raw = pred[..., 1]
191
+ v_dir_raw = pred[..., 2]
192
+
193
+ # processing
194
+ blb = np.array(blb_raw >= 0.5, dtype=np.int32)
195
+
196
+ blb = measurements.label(blb)[0] # ndimage.label(blb)[0]
197
+ blb = remove_small_objects(blb, min_size=10) # 10
198
+ blb[blb > 0] = 1 # background is 0 already
199
+
200
+ h_dir = cv2.normalize(
201
+ h_dir_raw,
202
+ None,
203
+ alpha=0,
204
+ beta=1,
205
+ norm_type=cv2.NORM_MINMAX,
206
+ dtype=cv2.CV_32F,
207
+ )
208
+ v_dir = cv2.normalize(
209
+ v_dir_raw,
210
+ None,
211
+ alpha=0,
212
+ beta=1,
213
+ norm_type=cv2.NORM_MINMAX,
214
+ dtype=cv2.CV_32F,
215
+ )
216
+
217
+ # ksize = int((20 * scale_factor) + 1) # 21 vs 41
218
+ # obj_size = math.ceil(10 * (scale_factor**2)) #10 vs 40
219
+
220
+ sobelh = cv2.Sobel(h_dir, cv2.CV_64F, 1, 0, ksize=ksize)
221
+ sobelv = cv2.Sobel(v_dir, cv2.CV_64F, 0, 1, ksize=ksize)
222
+
223
+ sobelh = 1 - (
224
+ cv2.normalize(
225
+ sobelh,
226
+ None,
227
+ alpha=0,
228
+ beta=1,
229
+ norm_type=cv2.NORM_MINMAX,
230
+ dtype=cv2.CV_32F,
231
+ )
232
+ )
233
+ sobelv = 1 - (
234
+ cv2.normalize(
235
+ sobelv,
236
+ None,
237
+ alpha=0,
238
+ beta=1,
239
+ norm_type=cv2.NORM_MINMAX,
240
+ dtype=cv2.CV_32F,
241
+ )
242
+ )
243
+
244
+ overall = np.maximum(sobelh, sobelv)
245
+ overall = overall - (1 - blb)
246
+ overall[overall < 0] = 0
247
+
248
+ dist = (1.0 - overall) * blb
249
+ ## nuclei values form mountains so inverse to get basins
250
+ dist = -cv2.GaussianBlur(dist, (3, 3), 0)
251
+
252
+ overall = np.array(overall >= 0.4, dtype=np.int32)
253
+
254
+ marker = blb - overall
255
+ marker[marker < 0] = 0
256
+ marker = binary_fill_holes(marker).astype("uint8")
257
+ kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
258
+ marker = cv2.morphologyEx(marker, cv2.MORPH_OPEN, kernel)
259
+ marker = measurements.label(marker)[0]
260
+ marker = remove_small_objects(marker, min_size=object_size)
261
+
262
+ proced_pred = watershed(dist, markers=marker, mask=blb)
263
+
264
+ return proced_pred
265
+
266
+
267
+ def calculate_instances(
268
+ pred_types: torch.Tensor, pred_insts: torch.Tensor
269
+ ) -> list[dict]:
270
+ """Best used for GT
271
+
272
+ Args:
273
+ pred_types (torch.Tensor): Binary or type map ground-truth.
274
+ Shape must be (B, C, H, W) with C=1 for binary or num_nuclei_types for multi-class.
275
+ pred_insts (torch.Tensor): Ground-Truth instance map with shape (B, H, W)
276
+
277
+ Returns:
278
+ list[dict]: Dictionary with nuclei informations, output similar to post_process_cell_segmentation
279
+ """
280
+ type_preds = []
281
+ pred_types = pred_types.permute(0, 2, 3, 1)
282
+ for i in range(pred_types.shape[0]):
283
+ pred_type = torch.argmax(pred_types, dim=-1)[i].detach().cpu().numpy()
284
+ pred_inst = pred_insts[i].detach().cpu().numpy()
285
+ inst_id_list = np.unique(pred_inst)[1:] # exlcude background
286
+ inst_info_dict = {}
287
+ for inst_id in inst_id_list:
288
+ inst_map = pred_inst == inst_id
289
+ rmin, rmax, cmin, cmax = get_bounding_box(inst_map)
290
+ inst_bbox = np.array([[rmin, cmin], [rmax, cmax]])
291
+ inst_map = inst_map[
292
+ inst_bbox[0][0] : inst_bbox[1][0], inst_bbox[0][1] : inst_bbox[1][1]
293
+ ]
294
+ inst_map = inst_map.astype(np.uint8)
295
+ inst_moment = cv2.moments(inst_map)
296
+ inst_contour = cv2.findContours(
297
+ inst_map, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
298
+ )
299
+ # * opencv protocol format may break
300
+ inst_contour = np.squeeze(inst_contour[0][0].astype("int32"))
301
+ # < 3 points dont make a contour, so skip, likely artifact too
302
+ # as the contours obtained via approximation => too small or sthg
303
+ if inst_contour.shape[0] < 3:
304
+ continue
305
+ if len(inst_contour.shape) != 2:
306
+ continue # ! check for trickery shape
307
+ inst_centroid = [
308
+ (inst_moment["m10"] / inst_moment["m00"]),
309
+ (inst_moment["m01"] / inst_moment["m00"]),
310
+ ]
311
+ inst_centroid = np.array(inst_centroid)
312
+ inst_contour[:, 0] += inst_bbox[0][1] # X
313
+ inst_contour[:, 1] += inst_bbox[0][0] # Y
314
+ inst_centroid[0] += inst_bbox[0][1] # X
315
+ inst_centroid[1] += inst_bbox[0][0] # Y
316
+ inst_info_dict[inst_id] = { # inst_id should start at 1
317
+ "bbox": inst_bbox,
318
+ "centroid": inst_centroid,
319
+ "contour": inst_contour,
320
+ "type_prob": None,
321
+ "type": None,
322
+ }
323
+ #### * Get class of each instance id, stored at index id-1 (inst_id = number of deteced nucleus)
324
+ for inst_id in list(inst_info_dict.keys()):
325
+ rmin, cmin, rmax, cmax = (inst_info_dict[inst_id]["bbox"]).flatten()
326
+ inst_map_crop = pred_inst[rmin:rmax, cmin:cmax]
327
+ inst_type_crop = pred_type[rmin:rmax, cmin:cmax]
328
+ inst_map_crop = inst_map_crop == inst_id
329
+ inst_type = inst_type_crop[inst_map_crop]
330
+ type_list, type_pixels = np.unique(inst_type, return_counts=True)
331
+ type_list = list(zip(type_list, type_pixels))
332
+ type_list = sorted(type_list, key=lambda x: x[1], reverse=True)
333
+ inst_type = type_list[0][0]
334
+ if inst_type == 0: # ! pick the 2nd most dominant if exist
335
+ if len(type_list) > 1:
336
+ inst_type = type_list[1][0]
337
+ type_dict = {v[0]: v[1] for v in type_list}
338
+ type_prob = type_dict[inst_type] / (np.sum(inst_map_crop) + 1.0e-6)
339
+ inst_info_dict[inst_id]["type"] = int(inst_type)
340
+ inst_info_dict[inst_id]["type_prob"] = float(type_prob)
341
+ type_preds.append(inst_info_dict)
342
+
343
+ return type_preds
344
+
345
+
346
+
347
+
348
+
standard-ilc.tif ADDED

Git LFS Details

  • SHA256: b24fc1a67a87b3dbc60c92fe80b2b6b7d58d6190cde75aa803bab7156a1b2838
  • Pointer size: 134 Bytes
  • Size of remote file: 355 MB
tools.py ADDED
@@ -0,0 +1,400 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ # Helpful functions Pipeline
3
+ #
4
+ # Adapted from HoverNet
5
+ # HoverNet Network (https://doi.org/10.1016/j.media.2019.101563)
6
+ # Code Snippet adapted from HoverNet implementation (https://github.com/vqdang/hover_net)
7
+ #
8
+ # @ Fabian Hörst, fabian.hoerst@uk-essen.de
9
+ # Institute for Artifical Intelligence in Medicine,
10
+ # University Medicine Essen
11
+
12
+
13
+ import math
14
+ from typing import Tuple
15
+
16
+ import numpy as np
17
+ import scipy
18
+ from numba import njit, prange
19
+ from scipy import ndimage
20
+ from scipy.optimize import linear_sum_assignment
21
+ from skimage.draw import polygon
22
+
23
+
24
+ def get_bounding_box(img):
25
+ """Get bounding box coordinate information."""
26
+ rows = np.any(img, axis=1)
27
+ cols = np.any(img, axis=0)
28
+ rmin, rmax = np.where(rows)[0][[0, -1]]
29
+ cmin, cmax = np.where(cols)[0][[0, -1]]
30
+ # due to python indexing, need to add 1 to max
31
+ # else accessing will be 1px in the box, not out
32
+ rmax += 1
33
+ cmax += 1
34
+ return [rmin, rmax, cmin, cmax]
35
+
36
+
37
+ @njit
38
+ def cropping_center(x, crop_shape, batch=False):
39
+ """Crop an input image at the centre.
40
+
41
+ Args:
42
+ x: input array
43
+ crop_shape: dimensions of cropped array
44
+
45
+ Returns:
46
+ x: cropped array
47
+
48
+ """
49
+ orig_shape = x.shape
50
+ if not batch:
51
+ h0 = int((orig_shape[0] - crop_shape[0]) * 0.5)
52
+ w0 = int((orig_shape[1] - crop_shape[1]) * 0.5)
53
+ x = x[h0 : h0 + crop_shape[0], w0 : w0 + crop_shape[1], ...]
54
+ else:
55
+ h0 = int((orig_shape[1] - crop_shape[0]) * 0.5)
56
+ w0 = int((orig_shape[2] - crop_shape[1]) * 0.5)
57
+ x = x[:, h0 : h0 + crop_shape[0], w0 : w0 + crop_shape[1], ...]
58
+ return x
59
+
60
+
61
+ def remove_small_objects(pred, min_size=64, connectivity=1):
62
+ """Remove connected components smaller than the specified size.
63
+
64
+ This function is taken from skimage.morphology.remove_small_objects, but the warning
65
+ is removed when a single label is provided.
66
+
67
+ Args:
68
+ pred: input labelled array
69
+ min_size: minimum size of instance in output array
70
+ connectivity: The connectivity defining the neighborhood of a pixel.
71
+
72
+ Returns:
73
+ out: output array with instances removed under min_size
74
+
75
+ """
76
+ out = pred
77
+
78
+ if min_size == 0: # shortcut for efficiency
79
+ return out
80
+
81
+ if out.dtype == bool:
82
+ selem = ndimage.generate_binary_structure(pred.ndim, connectivity)
83
+ ccs = np.zeros_like(pred, dtype=np.int32)
84
+ ndimage.label(pred, selem, output=ccs)
85
+ else:
86
+ ccs = out
87
+
88
+ try:
89
+ component_sizes = np.bincount(ccs.ravel())
90
+ except ValueError:
91
+ raise ValueError(
92
+ "Negative value labels are not supported. Try "
93
+ "relabeling the input with `scipy.ndimage.label` or "
94
+ "`skimage.morphology.label`."
95
+ )
96
+
97
+ too_small = component_sizes < min_size
98
+ too_small_mask = too_small[ccs]
99
+ out[too_small_mask] = 0
100
+
101
+ return out
102
+
103
+
104
+ def pair_coordinates(
105
+ setA: np.ndarray, setB: np.ndarray, radius: float
106
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
107
+ """Use the Munkres or Kuhn-Munkres algorithm to find the most optimal
108
+ unique pairing (largest possible match) when pairing points in set B
109
+ against points in set A, using distance as cost function.
110
+
111
+ Args:
112
+ setA (np.ndarray): np.array (float32) of size Nx2 contains the of XY coordinate
113
+ of N different points
114
+ setB (np.ndarray): np.array (float32) of size Nx2 contains the of XY coordinate
115
+ of N different points
116
+ radius (float): valid area around a point in setA to consider
117
+ a given coordinate in setB a candidate for match
118
+
119
+ Returns:
120
+ Tuple[np.ndarray, np.ndarray, np.ndarray]:
121
+ pairing: pairing is an array of indices
122
+ where point at index pairing[0] in set A paired with point
123
+ in set B at index pairing[1]
124
+ unparedA: remaining point in set A unpaired
125
+ unparedB: remaining point in set B unpaired
126
+ """
127
+ # * Euclidean distance as the cost matrix
128
+ pair_distance = scipy.spatial.distance.cdist(setA, setB, metric="euclidean")
129
+
130
+ # * Munkres pairing with scipy library
131
+ # the algorithm return (row indices, matched column indices)
132
+ # if there is multiple same cost in a row, index of first occurence
133
+ # is return, thus the unique pairing is ensured
134
+ indicesA, paired_indicesB = linear_sum_assignment(pair_distance)
135
+
136
+ # extract the paired cost and remove instances
137
+ # outside of designated radius
138
+ pair_cost = pair_distance[indicesA, paired_indicesB]
139
+
140
+ pairedA = indicesA[pair_cost <= radius]
141
+ pairedB = paired_indicesB[pair_cost <= radius]
142
+
143
+ pairing = np.concatenate([pairedA[:, None], pairedB[:, None]], axis=-1)
144
+ unpairedA = np.delete(np.arange(setA.shape[0]), pairedA)
145
+ unpairedB = np.delete(np.arange(setB.shape[0]), pairedB)
146
+
147
+ return pairing, unpairedA, unpairedB
148
+
149
+
150
+ def fix_duplicates(inst_map: np.ndarray) -> np.ndarray:
151
+ """Re-label duplicated instances in an instance labelled mask.
152
+
153
+ Parameters
154
+ ----------
155
+ inst_map : np.ndarray
156
+ Instance labelled mask. Shape (H, W).
157
+
158
+ Returns
159
+ -------
160
+ np.ndarray:
161
+ The instance labelled mask without duplicated indices.
162
+ Shape (H, W).
163
+ """
164
+ current_max_id = np.amax(inst_map)
165
+ inst_list = list(np.unique(inst_map))
166
+ if 0 in inst_list:
167
+ inst_list.remove(0)
168
+
169
+ for inst_id in inst_list:
170
+ inst = np.array(inst_map == inst_id, np.uint8)
171
+ remapped_ids = ndimage.label(inst)[0]
172
+ remapped_ids[remapped_ids > 1] += current_max_id
173
+ inst_map[remapped_ids > 1] = remapped_ids[remapped_ids > 1]
174
+ current_max_id = np.amax(inst_map)
175
+
176
+ return inst_map
177
+
178
+
179
+ def polygons_to_label_coord(
180
+ coord: np.ndarray, shape: Tuple[int, int], labels: np.ndarray = None
181
+ ) -> np.ndarray:
182
+ """Render polygons to image given a shape.
183
+
184
+ Parameters
185
+ ----------
186
+ coord.shape : np.ndarray
187
+ Shape: (n_polys, n_rays)
188
+ shape : Tuple[int, int]
189
+ Shape of the output mask.
190
+ labels : np.ndarray, optional
191
+ Sorted indices of the centroids.
192
+
193
+ Returns
194
+ -------
195
+ np.ndarray:
196
+ Instance labelled mask. Shape: (H, W).
197
+ """
198
+ coord = np.asarray(coord)
199
+ if labels is None:
200
+ labels = np.arange(len(coord))
201
+
202
+ assert coord.ndim == 3 and coord.shape[1] == 2 and len(coord) == len(labels)
203
+
204
+ lbl = np.zeros(shape, np.int32)
205
+
206
+ for i, c in zip(labels, coord):
207
+ rr, cc = polygon(*c, shape)
208
+ lbl[rr, cc] = i + 1
209
+
210
+ return lbl
211
+
212
+
213
+ def ray_angles(n_rays: int = 32):
214
+ """Get linearly spaced angles for rays."""
215
+ return np.linspace(0, 2 * np.pi, n_rays, endpoint=False)
216
+
217
+
218
+ def dist_to_coord(
219
+ dist: np.ndarray, points: np.ndarray, scale_dist: Tuple[int, int] = (1, 1)
220
+ ) -> np.ndarray:
221
+ """Convert list of distances and centroids from polar to cartesian coordinates.
222
+
223
+ Parameters
224
+ ----------
225
+ dist : np.ndarray
226
+ The centerpoint pixels of the radial distance map. Shape (n_polys, n_rays).
227
+ points : np.ndarray
228
+ The centroids of the instances. Shape: (n_polys, 2).
229
+ scale_dist : Tuple[int, int], default=(1, 1)
230
+ Scaling factor.
231
+
232
+ Returns
233
+ -------
234
+ np.ndarray:
235
+ Cartesian cooridnates of the polygons. Shape (n_polys, 2, n_rays).
236
+ """
237
+ dist = np.asarray(dist)
238
+ points = np.asarray(points)
239
+ assert (
240
+ dist.ndim == 2
241
+ and points.ndim == 2
242
+ and len(dist) == len(points)
243
+ and points.shape[1] == 2
244
+ and len(scale_dist) == 2
245
+ )
246
+ n_rays = dist.shape[1]
247
+ phis = ray_angles(n_rays)
248
+ coord = (dist[:, np.newaxis] * np.array([np.sin(phis), np.cos(phis)])).astype(
249
+ np.float32
250
+ )
251
+ coord *= np.asarray(scale_dist).reshape(1, 2, 1)
252
+ coord += points[..., np.newaxis]
253
+ return coord
254
+
255
+
256
+ def polygons_to_label(
257
+ dist: np.ndarray,
258
+ points: np.ndarray,
259
+ shape: Tuple[int, int],
260
+ prob: np.ndarray = None,
261
+ thresh: float = -np.inf,
262
+ scale_dist: Tuple[int, int] = (1, 1),
263
+ ) -> np.ndarray:
264
+ """Convert distances and center points to instance labelled mask.
265
+
266
+ Parameters
267
+ ----------
268
+ dist : np.ndarray
269
+ The centerpoint pixels of the radial distance map. Shape (n_polys, n_rays).
270
+ points : np.ndarray
271
+ The centroids of the instances. Shape: (n_polys, 2).
272
+ shape : Tuple[int, int]:
273
+ Shape of the output mask.
274
+ prob : np.ndarray, optional
275
+ The centerpoint pixels of the regressed distance transform.
276
+ Shape: (n_polys, n_rays).
277
+ thresh : float, default=-np.inf
278
+ Threshold for the regressed distance transform.
279
+ scale_dist : Tuple[int, int], default=(1, 1)
280
+ Scaling factor.
281
+
282
+ Returns
283
+ -------
284
+ np.ndarray:
285
+ Instance labelled mask. Shape (H, W).
286
+ """
287
+ dist = np.asarray(dist)
288
+ points = np.asarray(points)
289
+ prob = np.inf * np.ones(len(points)) if prob is None else np.asarray(prob)
290
+
291
+ assert dist.ndim == 2 and points.ndim == 2 and len(dist) == len(points)
292
+ assert len(points) == len(prob) and points.shape[1] == 2 and prob.ndim == 1
293
+
294
+ ind = prob > thresh
295
+ points = points[ind]
296
+ dist = dist[ind]
297
+ prob = prob[ind]
298
+
299
+ ind = np.argsort(prob, kind="stable")
300
+ points = points[ind]
301
+ dist = dist[ind]
302
+
303
+ coord = dist_to_coord(dist, points, scale_dist=scale_dist)
304
+
305
+ return polygons_to_label_coord(coord, shape=shape, labels=ind)
306
+
307
+
308
+ @njit(cache=True, fastmath=True)
309
+ def intersection(boxA: np.ndarray, boxB: np.ndarray):
310
+ """Compute area of intersection of two boxes.
311
+
312
+ Parameters
313
+ ----------
314
+ boxA : np.ndarray
315
+ First boxes
316
+ boxB : np.ndarray
317
+ Second box
318
+
319
+ Returns
320
+ -------
321
+ float64:
322
+ Area of intersection
323
+ """
324
+ xA = max(boxA[..., 0], boxB[..., 0])
325
+ xB = min(boxA[..., 2], boxB[..., 2])
326
+ dx = xB - xA
327
+ if dx <= 0:
328
+ return 0.0
329
+
330
+ yA = max(boxA[..., 1], boxB[..., 1])
331
+ yB = min(boxA[..., 3], boxB[..., 3])
332
+ dy = yB - yA
333
+ if dy <= 0.0:
334
+ return 0.0
335
+
336
+ return dx * dy
337
+
338
+
339
+ @njit(parallel=True)
340
+ def get_bboxes(
341
+ dist: np.ndarray, points: np.ndarray
342
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, int]:
343
+ """Get bounding boxes from the non-zero pixels of the radial distance maps.
344
+
345
+ This is basically a translation from the stardist repo cpp code to python
346
+
347
+ NOTE: jit compiled and parallelized with numba.
348
+
349
+ Parameters
350
+ ----------
351
+ dist : np.ndarray
352
+ The non-zero values of the radial distance maps. Shape: (n_nonzero, n_rays).
353
+ points : np.ndarray
354
+ The yx-coordinates of the non-zero points. Shape (n_nonzero, 2).
355
+
356
+ Returns
357
+ -------
358
+ Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, int]:
359
+ Returns the x0, y0, x1, y1 bbox coordinates, bbox areas and the maximum
360
+ radial distance in the image.
361
+ """
362
+ n_polys = dist.shape[0]
363
+ n_rays = dist.shape[1]
364
+
365
+ bbox_x1 = np.zeros(n_polys)
366
+ bbox_x2 = np.zeros(n_polys)
367
+ bbox_y1 = np.zeros(n_polys)
368
+ bbox_y2 = np.zeros(n_polys)
369
+
370
+ areas = np.zeros(n_polys)
371
+ angle_pi = 2 * math.pi / n_rays
372
+ max_dist = 0
373
+
374
+ for i in prange(n_polys):
375
+ max_radius_outer = 0
376
+ py = points[i, 0]
377
+ px = points[i, 1]
378
+
379
+ for k in range(n_rays):
380
+ d = dist[i, k]
381
+ y = py + d * np.sin(angle_pi * k)
382
+ x = px + d * np.cos(angle_pi * k)
383
+
384
+ if k == 0:
385
+ bbox_x1[i] = x
386
+ bbox_x2[i] = x
387
+ bbox_y1[i] = y
388
+ bbox_y2[i] = y
389
+ else:
390
+ bbox_x1[i] = min(x, bbox_x1[i])
391
+ bbox_x2[i] = max(x, bbox_x2[i])
392
+ bbox_y1[i] = min(y, bbox_y1[i])
393
+ bbox_y2[i] = max(y, bbox_y2[i])
394
+
395
+ max_radius_outer = max(d, max_radius_outer)
396
+
397
+ areas[i] = (bbox_x2[i] - bbox_x1[i]) * (bbox_y2[i] - bbox_y1[i])
398
+ max_dist = max(max_dist, max_radius_outer)
399
+
400
+ return bbox_x1, bbox_y1, bbox_x2, bbox_y2, areas, max_dist