Upload 7 files
Browse files- .gitattributes +1 -0
- HNE2cell_all_patch73_jit.pt +3 -0
- inference.py +316 -0
- normalize.py +277 -0
- patchify.py +254 -0
- post_processing.py +348 -0
- standard-ilc.tif +3 -0
- tools.py +400 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
standard-ilc.tif filter=lfs diff=lfs merge=lfs -text
|
HNE2cell_all_patch73_jit.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:cd428608461bf295e4abfa646fb834c2d24acedca0155d43fdb817da42936593
|
| 3 |
+
size 5138307858
|
inference.py
ADDED
|
@@ -0,0 +1,316 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
HNE2Cell — Step 3: Cell Detection & Classification Inference
|
| 3 |
+
|
| 4 |
+
Run the HNE2Cell model on extracted patches to detect and classify cells.
|
| 5 |
+
Outputs: per-patch cell masks (PNG) and centroid CSVs with cell type annotations.
|
| 6 |
+
|
| 7 |
+
Usage:
|
| 8 |
+
python inference.py \
|
| 9 |
+
--input_dir /path/to/patch_folders \
|
| 10 |
+
--output_dir /path/to/results \
|
| 11 |
+
--model_path ./HNE2cell_all_patch73_jit.pt \
|
| 12 |
+
--magnification 40 \
|
| 13 |
+
--batch_size 32
|
| 14 |
+
|
| 15 |
+
Cell Types (16 classes):
|
| 16 |
+
0: Background 4: B 8: DC 12: Epithelial
|
| 17 |
+
1: Malignant 5: Plasma 9: Fibroblast 13: Immune_Other
|
| 18 |
+
2: CD4T 6: Macrophage 10: Endothelial 14: Stromal_Other
|
| 19 |
+
3: CD8T 7: Myeloid 11: Pericyte 15: Dead
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
import os
|
| 23 |
+
import argparse
|
| 24 |
+
import glob
|
| 25 |
+
|
| 26 |
+
import cv2
|
| 27 |
+
import numpy as np
|
| 28 |
+
import pandas as pd
|
| 29 |
+
import torch
|
| 30 |
+
from PIL import Image
|
| 31 |
+
from torch.cuda.amp import autocast
|
| 32 |
+
from torch.utils.data import DataLoader, Dataset
|
| 33 |
+
from torchvision import transforms
|
| 34 |
+
from tqdm import tqdm
|
| 35 |
+
|
| 36 |
+
from post_processing import DetectionCellPostProcessor
|
| 37 |
+
|
| 38 |
+
# ========================== Constants ======================================
|
| 39 |
+
|
| 40 |
+
CELL_TYPES = {
|
| 41 |
+
0: "Background",
|
| 42 |
+
1: "Malignant",
|
| 43 |
+
2: "CD4T",
|
| 44 |
+
3: "CD8T",
|
| 45 |
+
4: "B",
|
| 46 |
+
5: "Plasma",
|
| 47 |
+
6: "Macrophage",
|
| 48 |
+
7: "Myeloid",
|
| 49 |
+
8: "DC",
|
| 50 |
+
9: "Fibroblast",
|
| 51 |
+
10: "Endothelial",
|
| 52 |
+
11: "Pericyte",
|
| 53 |
+
12: "Epithelial",
|
| 54 |
+
13: "Immune_Other",
|
| 55 |
+
14: "Stromal_Other",
|
| 56 |
+
15: "Dead",
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
# RGBA colors for mask visualization
|
| 60 |
+
CELL_COLORS = {
|
| 61 |
+
0: [0, 0, 0, 0],
|
| 62 |
+
1: [255, 0, 0, 255],
|
| 63 |
+
2: [30, 144, 255, 255],
|
| 64 |
+
3: [65, 105, 225, 255],
|
| 65 |
+
4: [0, 0, 255, 255],
|
| 66 |
+
5: [100, 149, 237, 255],
|
| 67 |
+
6: [176, 224, 230, 255],
|
| 68 |
+
7: [70, 130, 180, 255],
|
| 69 |
+
8: [0, 191, 255, 255],
|
| 70 |
+
9: [34, 139, 34, 255],
|
| 71 |
+
10: [60, 179, 113, 255],
|
| 72 |
+
11: [50, 205, 50, 255],
|
| 73 |
+
12: [255, 140, 0, 255],
|
| 74 |
+
13: [176, 224, 230, 255],
|
| 75 |
+
14: [107, 142, 35, 255],
|
| 76 |
+
15: [128, 128, 128, 255],
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
# ImageNet-style normalization fitted to H&E data
|
| 80 |
+
TRANSFORM = transforms.Compose(
|
| 81 |
+
[
|
| 82 |
+
transforms.Resize((224, 224)),
|
| 83 |
+
transforms.ToTensor(),
|
| 84 |
+
transforms.Normalize(
|
| 85 |
+
mean=[0.707223, 0.578729, 0.703617],
|
| 86 |
+
std=[0.211883, 0.230117, 0.177517],
|
| 87 |
+
),
|
| 88 |
+
]
|
| 89 |
+
)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
# ========================== Dataset ========================================
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
class PatchDataset(Dataset):
|
| 96 |
+
def __init__(self, file_paths, transform=None):
|
| 97 |
+
self.file_paths = file_paths
|
| 98 |
+
self.transform = transform
|
| 99 |
+
|
| 100 |
+
def __len__(self):
|
| 101 |
+
return len(self.file_paths)
|
| 102 |
+
|
| 103 |
+
def __getitem__(self, idx):
|
| 104 |
+
fpath = self.file_paths[idx]
|
| 105 |
+
img = cv2.imread(fpath)
|
| 106 |
+
if img is None:
|
| 107 |
+
print(f"[WARN] Failed to load: {fpath}")
|
| 108 |
+
img = np.zeros((256, 256, 3), dtype=np.uint8)
|
| 109 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
| 110 |
+
if self.transform:
|
| 111 |
+
img = self.transform(Image.fromarray(img))
|
| 112 |
+
return fpath, img
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
# ========================== Inference ======================================
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def process_batch(
|
| 119 |
+
batch,
|
| 120 |
+
model,
|
| 121 |
+
device,
|
| 122 |
+
mask_output_dir,
|
| 123 |
+
centroid_records,
|
| 124 |
+
magnification=40,
|
| 125 |
+
):
|
| 126 |
+
"""Run inference on one batch and save masks + centroid info."""
|
| 127 |
+
file_paths, images = batch
|
| 128 |
+
|
| 129 |
+
with torch.no_grad():
|
| 130 |
+
with autocast():
|
| 131 |
+
outputs = model(images.to(device, non_blocking=True))
|
| 132 |
+
|
| 133 |
+
for i, fpath in enumerate(file_paths):
|
| 134 |
+
slide_id = os.path.splitext(os.path.basename(fpath))[0]
|
| 135 |
+
|
| 136 |
+
# Extract per-sample predictions
|
| 137 |
+
cell_type_map = outputs["cell_type_map"][i].float().detach().cpu()
|
| 138 |
+
nuclei_binary_map = outputs["nuclei_binary_map"][i].float().detach().cpu()
|
| 139 |
+
hv_map = outputs["hv_map"][i].float().detach().cpu()
|
| 140 |
+
tissue_type_map = outputs["tissue_type_map"][i].float().detach().cpu()
|
| 141 |
+
|
| 142 |
+
# Build prediction map [H, W, 5]
|
| 143 |
+
pred_map = np.concatenate(
|
| 144 |
+
[
|
| 145 |
+
torch.argmax(tissue_type_map, dim=0)[..., None].numpy(),
|
| 146 |
+
torch.argmax(cell_type_map, dim=0)[..., None].numpy(),
|
| 147 |
+
torch.argmax(nuclei_binary_map, dim=0)[..., None].numpy(),
|
| 148 |
+
hv_map.permute(1, 2, 0).numpy(),
|
| 149 |
+
],
|
| 150 |
+
axis=-1,
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
# Post-processing
|
| 154 |
+
post_processor = DetectionCellPostProcessor(
|
| 155 |
+
nr_types=cell_type_map.shape[0],
|
| 156 |
+
magnification=magnification,
|
| 157 |
+
gt=False,
|
| 158 |
+
)
|
| 159 |
+
_, type_pred = post_processor.post_process_cell_segmentation(pred_map)
|
| 160 |
+
|
| 161 |
+
# Create mask image
|
| 162 |
+
mask = np.ones((256, 256, 3), dtype=np.uint8) * 255
|
| 163 |
+
for cell in type_pred.values():
|
| 164 |
+
ctype = cell["type"]
|
| 165 |
+
rgba = CELL_COLORS.get(ctype, [255, 255, 255, 255])
|
| 166 |
+
bgr = [rgba[2], rgba[1], rgba[0]]
|
| 167 |
+
cv2.fillPoly(mask, [cell["contour"]], bgr)
|
| 168 |
+
|
| 169 |
+
centroid_records.append(
|
| 170 |
+
{
|
| 171 |
+
"slide_id": slide_id,
|
| 172 |
+
"x": cell["centroid"][0],
|
| 173 |
+
"y": cell["centroid"][1],
|
| 174 |
+
"celltype": ctype,
|
| 175 |
+
"celltype_name": CELL_TYPES.get(ctype, "Unknown"),
|
| 176 |
+
}
|
| 177 |
+
)
|
| 178 |
+
|
| 179 |
+
# Save mask only if non-trivial
|
| 180 |
+
if not np.all(mask == 255):
|
| 181 |
+
cv2.imwrite(
|
| 182 |
+
os.path.join(mask_output_dir, f"{slide_id}_mask.png"), mask
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
def run_inference(
|
| 187 |
+
patch_folders: list[str],
|
| 188 |
+
model,
|
| 189 |
+
device,
|
| 190 |
+
output_dir: str,
|
| 191 |
+
magnification: int = 40,
|
| 192 |
+
batch_size: int = 32,
|
| 193 |
+
num_workers: int = 4,
|
| 194 |
+
):
|
| 195 |
+
"""Run inference over a list of patch folders."""
|
| 196 |
+
model.to(device).eval()
|
| 197 |
+
|
| 198 |
+
for folder in patch_folders:
|
| 199 |
+
folder_name = os.path.basename(folder)
|
| 200 |
+
png_files = sorted(glob.glob(os.path.join(folder, "*.png")))
|
| 201 |
+
|
| 202 |
+
if not png_files:
|
| 203 |
+
print(f"[SKIP] {folder}: no PNG patches found")
|
| 204 |
+
continue
|
| 205 |
+
|
| 206 |
+
mask_dir = os.path.join(output_dir, "mask_patches", f"{folder_name}")
|
| 207 |
+
centroid_path = os.path.join(output_dir, "centroid", f"{folder_name}_centroid.csv")
|
| 208 |
+
os.makedirs(mask_dir, exist_ok=True)
|
| 209 |
+
os.makedirs(os.path.dirname(centroid_path), exist_ok=True)
|
| 210 |
+
|
| 211 |
+
dataset = PatchDataset(png_files, transform=TRANSFORM)
|
| 212 |
+
loader = DataLoader(
|
| 213 |
+
dataset,
|
| 214 |
+
batch_size=batch_size,
|
| 215 |
+
num_workers=num_workers,
|
| 216 |
+
pin_memory=True,
|
| 217 |
+
shuffle=False,
|
| 218 |
+
prefetch_factor=2,
|
| 219 |
+
persistent_workers=True,
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
centroids = []
|
| 223 |
+
for batch in tqdm(loader, desc=f"Inference: {folder_name}"):
|
| 224 |
+
process_batch(
|
| 225 |
+
batch, model, device, mask_dir, centroids, magnification
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
df = pd.DataFrame(centroids)
|
| 229 |
+
df.to_csv(centroid_path, index=False)
|
| 230 |
+
print(f"[DONE] {folder_name} → {centroid_path} ({len(df)} cells)")
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
# =============================== CLI =======================================
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
def main():
|
| 237 |
+
parser = argparse.ArgumentParser(description="HNE2Cell inference")
|
| 238 |
+
parser.add_argument(
|
| 239 |
+
"--input_dir",
|
| 240 |
+
type=str,
|
| 241 |
+
required=True,
|
| 242 |
+
help="Directory containing patch folders (each with *.png)",
|
| 243 |
+
)
|
| 244 |
+
parser.add_argument(
|
| 245 |
+
"--output_dir",
|
| 246 |
+
type=str,
|
| 247 |
+
required=True,
|
| 248 |
+
help="Output directory for masks and centroid CSVs",
|
| 249 |
+
)
|
| 250 |
+
parser.add_argument(
|
| 251 |
+
"--model_path",
|
| 252 |
+
type=str,
|
| 253 |
+
required=True,
|
| 254 |
+
help="Path to the TorchScript JIT model (.pt)",
|
| 255 |
+
)
|
| 256 |
+
parser.add_argument(
|
| 257 |
+
"--magnification",
|
| 258 |
+
type=int,
|
| 259 |
+
default=40,
|
| 260 |
+
choices=[20, 40],
|
| 261 |
+
help="Magnification of input patches. 40x recommended. (default: 40)",
|
| 262 |
+
)
|
| 263 |
+
parser.add_argument("--batch_size", type=int, default=32)
|
| 264 |
+
parser.add_argument("--num_workers", type=int, default=4)
|
| 265 |
+
parser.add_argument(
|
| 266 |
+
"--device",
|
| 267 |
+
type=str,
|
| 268 |
+
default="auto",
|
| 269 |
+
help="Device: 'cuda', 'cpu', or 'auto' (default: auto)",
|
| 270 |
+
)
|
| 271 |
+
|
| 272 |
+
args = parser.parse_args()
|
| 273 |
+
|
| 274 |
+
# Device
|
| 275 |
+
if args.device == "auto":
|
| 276 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 277 |
+
else:
|
| 278 |
+
device = torch.device(args.device)
|
| 279 |
+
print(f"Using device: {device}")
|
| 280 |
+
|
| 281 |
+
# Load model
|
| 282 |
+
print(f"Loading model: {args.model_path}")
|
| 283 |
+
model = torch.jit.load(args.model_path, map_location=device)
|
| 284 |
+
model.eval()
|
| 285 |
+
|
| 286 |
+
if args.magnification == 20:
|
| 287 |
+
print(
|
| 288 |
+
"⚠️ Running at 20x. Results are usable but 40x is recommended "
|
| 289 |
+
"for best accuracy, especially for small immune cells."
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
# Collect patch folders
|
| 293 |
+
patch_folders = sorted(
|
| 294 |
+
p
|
| 295 |
+
for p in glob.glob(os.path.join(args.input_dir, "*"))
|
| 296 |
+
if os.path.isdir(p)
|
| 297 |
+
)
|
| 298 |
+
# Also check if input_dir itself contains patches
|
| 299 |
+
if not patch_folders and glob.glob(os.path.join(args.input_dir, "*.png")):
|
| 300 |
+
patch_folders = [args.input_dir]
|
| 301 |
+
|
| 302 |
+
print(f"Found {len(patch_folders)} patch folder(s)")
|
| 303 |
+
|
| 304 |
+
run_inference(
|
| 305 |
+
patch_folders,
|
| 306 |
+
model,
|
| 307 |
+
device,
|
| 308 |
+
args.output_dir,
|
| 309 |
+
magnification=args.magnification,
|
| 310 |
+
batch_size=args.batch_size,
|
| 311 |
+
num_workers=args.num_workers,
|
| 312 |
+
)
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
if __name__ == "__main__":
|
| 316 |
+
main()
|
normalize.py
ADDED
|
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
HNE2Cell — Step 1: Reinhard Color Normalization
|
| 3 |
+
|
| 4 |
+
Normalize H&E stained whole-slide images (WSI) to a reference color distribution
|
| 5 |
+
using the Reinhard method in LAB color space.
|
| 6 |
+
|
| 7 |
+
Supported input formats: .svs, .tif, .tiff, .ndpi
|
| 8 |
+
Output: Aligned-hne.tif (full-resolution normalized), Aligned-hne.jpg (4x downsampled preview)
|
| 9 |
+
|
| 10 |
+
Usage:
|
| 11 |
+
python normalize.py \
|
| 12 |
+
--input_dir /path/to/slides \
|
| 13 |
+
--target /path/to/standard-ilc.tif \
|
| 14 |
+
--patch_size 128 \
|
| 15 |
+
--saturation_threshold 0.1
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
import os
|
| 19 |
+
import argparse
|
| 20 |
+
import glob
|
| 21 |
+
|
| 22 |
+
import numpy as np
|
| 23 |
+
import tifffile as tiff
|
| 24 |
+
from PIL import Image
|
| 25 |
+
from skimage import color
|
| 26 |
+
|
| 27 |
+
Image.MAX_IMAGE_PIXELS = None
|
| 28 |
+
os.environ["OPENCV_IO_MAX_IMAGE_PIXELS"] = str(pow(2, 40))
|
| 29 |
+
|
| 30 |
+
# ---------------------------------------------------------------------------
|
| 31 |
+
# Optional: openslide (only needed for .svs / .ndpi)
|
| 32 |
+
# ---------------------------------------------------------------------------
|
| 33 |
+
try:
|
| 34 |
+
import openslide
|
| 35 |
+
|
| 36 |
+
OPENSLIDE_AVAILABLE = True
|
| 37 |
+
except ImportError:
|
| 38 |
+
OPENSLIDE_AVAILABLE = False
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
# ============================= I/O helpers =================================
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def load_image(image_path: str, level: int = 0) -> np.ndarray:
|
| 45 |
+
"""Load a whole-slide image as an RGB numpy array.
|
| 46 |
+
|
| 47 |
+
Supports .svs/.ndpi (via OpenSlide) and .tif/.tiff (via tifffile).
|
| 48 |
+
"""
|
| 49 |
+
ext = os.path.splitext(image_path)[1].lower()
|
| 50 |
+
|
| 51 |
+
if ext in (".svs", ".ndpi"):
|
| 52 |
+
if not OPENSLIDE_AVAILABLE:
|
| 53 |
+
raise ImportError(
|
| 54 |
+
"openslide-python is required to read .svs/.ndpi files. "
|
| 55 |
+
"Install it with: pip install openslide-python"
|
| 56 |
+
)
|
| 57 |
+
slide = openslide.OpenSlide(image_path)
|
| 58 |
+
image = slide.read_region((0, 0), level, slide.level_dimensions[level])
|
| 59 |
+
image = image.convert("RGB")
|
| 60 |
+
slide.close()
|
| 61 |
+
return np.array(image)
|
| 62 |
+
|
| 63 |
+
if ext in (".tif", ".tiff"):
|
| 64 |
+
image = tiff.imread(image_path)
|
| 65 |
+
if image.ndim == 2:
|
| 66 |
+
image = np.stack((image,) * 3, axis=-1)
|
| 67 |
+
elif image.ndim == 4 and image.shape[0] == 1:
|
| 68 |
+
image = image[0]
|
| 69 |
+
# Ensure RGB uint8
|
| 70 |
+
if image.dtype != np.uint8:
|
| 71 |
+
image = np.clip(image, 0, 255).astype(np.uint8)
|
| 72 |
+
return image
|
| 73 |
+
|
| 74 |
+
raise ValueError(f"Unsupported file format: {ext}")
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
# ======================== Saturation filtering =============================
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def calculate_saturation(patch: Image.Image) -> float:
|
| 81 |
+
hsv = patch.convert("HSV")
|
| 82 |
+
return np.mean(np.array(hsv)[:, :, 1] / 255.0)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def extract_high_saturation_patches(
|
| 86 |
+
image: np.ndarray, patch_size: int, saturation_threshold: float
|
| 87 |
+
) -> list:
|
| 88 |
+
"""Return list of ((x0, y0), patch_array) for patches above the saturation threshold."""
|
| 89 |
+
pil_img = Image.fromarray(image)
|
| 90 |
+
width, height = pil_img.size
|
| 91 |
+
|
| 92 |
+
patches = []
|
| 93 |
+
for i in range(width // patch_size):
|
| 94 |
+
for j in range(height // patch_size):
|
| 95 |
+
x0, y0 = i * patch_size, j * patch_size
|
| 96 |
+
patch = pil_img.crop((x0, y0, x0 + patch_size, y0 + patch_size))
|
| 97 |
+
if calculate_saturation(patch) >= saturation_threshold:
|
| 98 |
+
patches.append(((x0, y0), np.array(patch)))
|
| 99 |
+
return patches
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def reconstruct_from_patches(
|
| 103 |
+
width: int, height: int, patch_size: int, patches: list
|
| 104 |
+
) -> np.ndarray:
|
| 105 |
+
"""Place high-saturation patches back into a blank canvas (background = black)."""
|
| 106 |
+
canvas = np.zeros((height, width, 3), dtype=np.uint8)
|
| 107 |
+
for (x0, y0), arr in patches:
|
| 108 |
+
if arr.shape == (patch_size, patch_size, 3):
|
| 109 |
+
canvas[y0 : y0 + patch_size, x0 : x0 + patch_size, :] = arr
|
| 110 |
+
return canvas
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
# =================== Reinhard color normalization ==========================
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def _color_convert_chunked(image, func, chunk_size=16384):
|
| 117 |
+
"""Apply color conversion function in spatial chunks to limit memory."""
|
| 118 |
+
h, w, _ = image.shape
|
| 119 |
+
out = np.zeros_like(image, dtype=np.float32)
|
| 120 |
+
for i in range(0, h, chunk_size):
|
| 121 |
+
for j in range(0, w, chunk_size):
|
| 122 |
+
out[i : min(i + chunk_size, h), j : min(j + chunk_size, w), :] = func(
|
| 123 |
+
image[i : min(i + chunk_size, h), j : min(j + chunk_size, w), :]
|
| 124 |
+
)
|
| 125 |
+
return out
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def reinhard_normalize(source: np.ndarray, target: np.ndarray) -> np.ndarray:
|
| 129 |
+
"""Reinhard color normalization in LAB space.
|
| 130 |
+
|
| 131 |
+
Only non-zero (tissue) pixels are used for statistics.
|
| 132 |
+
Returns float64 image in [0, 1] range.
|
| 133 |
+
"""
|
| 134 |
+
src_lab = _color_convert_chunked(source, color.rgb2lab)
|
| 135 |
+
tgt_lab = color.rgb2lab(target)
|
| 136 |
+
|
| 137 |
+
for ch in range(3):
|
| 138 |
+
src_ch = src_lab[:, :, ch]
|
| 139 |
+
tgt_ch = tgt_lab[:, :, ch]
|
| 140 |
+
|
| 141 |
+
src_vals = src_ch[src_ch != 0]
|
| 142 |
+
tgt_vals = tgt_ch[tgt_ch != 0]
|
| 143 |
+
|
| 144 |
+
if len(src_vals) == 0 or len(tgt_vals) == 0:
|
| 145 |
+
continue
|
| 146 |
+
|
| 147 |
+
src_mean, src_std = src_vals.mean(), src_vals.std()
|
| 148 |
+
tgt_mean, tgt_std = tgt_vals.mean(), tgt_vals.std()
|
| 149 |
+
|
| 150 |
+
if src_std < 1e-6:
|
| 151 |
+
continue
|
| 152 |
+
|
| 153 |
+
src_lab[:, :, ch] = np.where(
|
| 154 |
+
src_ch != 0,
|
| 155 |
+
(src_ch - src_mean) * (tgt_std / src_std) + tgt_mean,
|
| 156 |
+
0,
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
return _color_convert_chunked(src_lab, color.lab2rgb)
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
# ============================= Main pipeline ===============================
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def normalize_slide(
|
| 166 |
+
slide_path: str,
|
| 167 |
+
target_image: np.ndarray,
|
| 168 |
+
patch_size: int = 128,
|
| 169 |
+
saturation_threshold: float = 0.1,
|
| 170 |
+
output_dir: str | None = None,
|
| 171 |
+
skip_existing: bool = True,
|
| 172 |
+
):
|
| 173 |
+
"""Full normalization pipeline for a single slide."""
|
| 174 |
+
|
| 175 |
+
if output_dir is None:
|
| 176 |
+
output_dir = os.path.dirname(slide_path)
|
| 177 |
+
|
| 178 |
+
output_tif = os.path.join(output_dir, "Aligned-hne.tif")
|
| 179 |
+
output_jpg = os.path.join(output_dir, "Aligned-hne.jpg")
|
| 180 |
+
|
| 181 |
+
if skip_existing and os.path.exists(output_tif):
|
| 182 |
+
print(f"[SKIP] {slide_path} — Aligned-hne.tif already exists.")
|
| 183 |
+
return
|
| 184 |
+
|
| 185 |
+
print(f"[LOAD] {slide_path}")
|
| 186 |
+
raw = load_image(slide_path)
|
| 187 |
+
h, w = raw.shape[:2]
|
| 188 |
+
|
| 189 |
+
# 1. Saturation-based tissue detection
|
| 190 |
+
patches = extract_high_saturation_patches(
|
| 191 |
+
raw, patch_size, saturation_threshold
|
| 192 |
+
)
|
| 193 |
+
reconstructed = reconstruct_from_patches(w, h, patch_size, patches)
|
| 194 |
+
|
| 195 |
+
# 2. (Optional) save intermediate reconstruction
|
| 196 |
+
recon_path = os.path.join(output_dir, "recon.tif")
|
| 197 |
+
bigtiff = reconstructed.nbytes > 4 * 1024**3
|
| 198 |
+
tiff.imwrite(recon_path, reconstructed, bigtiff=bigtiff)
|
| 199 |
+
|
| 200 |
+
# 3. Reinhard normalization
|
| 201 |
+
normalized = reinhard_normalize(reconstructed, target_image)
|
| 202 |
+
normalized_u8 = (normalized * 255).astype(np.uint8)
|
| 203 |
+
|
| 204 |
+
# 4. Save outputs
|
| 205 |
+
tiff.imwrite(output_tif, normalized_u8, bigtiff=bigtiff)
|
| 206 |
+
|
| 207 |
+
resized = Image.fromarray(normalized_u8).resize(
|
| 208 |
+
(w // 4, h // 4), Image.LANCZOS
|
| 209 |
+
)
|
| 210 |
+
resized.save(output_jpg, quality=90)
|
| 211 |
+
|
| 212 |
+
print(f"[DONE] {slide_path} → {output_tif}")
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
# =============================== CLI =======================================
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
def main():
|
| 219 |
+
parser = argparse.ArgumentParser(
|
| 220 |
+
description="Reinhard color normalization for H&E WSIs"
|
| 221 |
+
)
|
| 222 |
+
parser.add_argument(
|
| 223 |
+
"--input_dir",
|
| 224 |
+
type=str,
|
| 225 |
+
required=True,
|
| 226 |
+
help="Root directory to search for slide files (.svs, .tif, .tiff, .ndpi)",
|
| 227 |
+
)
|
| 228 |
+
parser.add_argument(
|
| 229 |
+
"--target",
|
| 230 |
+
type=str,
|
| 231 |
+
required=True,
|
| 232 |
+
help="Path to the reference/target image (.tif)",
|
| 233 |
+
)
|
| 234 |
+
parser.add_argument("--patch_size", type=int, default=128)
|
| 235 |
+
parser.add_argument("--saturation_threshold", type=float, default=0.1)
|
| 236 |
+
parser.add_argument(
|
| 237 |
+
"--output_dir",
|
| 238 |
+
type=str,
|
| 239 |
+
default=None,
|
| 240 |
+
help="If set, all outputs go here. Otherwise, outputs are saved next to each slide.",
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
args = parser.parse_args()
|
| 244 |
+
|
| 245 |
+
# Load target image once
|
| 246 |
+
target_image = load_image(args.target)
|
| 247 |
+
|
| 248 |
+
# Collect slides
|
| 249 |
+
extensions = ("*.svs", "*.tif", "*.tiff", "*.ndpi")
|
| 250 |
+
slides = []
|
| 251 |
+
for ext in extensions:
|
| 252 |
+
slides.extend(glob.glob(os.path.join(args.input_dir, "**", ext), recursive=True))
|
| 253 |
+
|
| 254 |
+
# Exclude files that are already outputs
|
| 255 |
+
slides = [
|
| 256 |
+
s
|
| 257 |
+
for s in slides
|
| 258 |
+
if os.path.basename(s) not in ("Aligned-hne.tif", "Aligned-hne.tiff", "recon.tif")
|
| 259 |
+
]
|
| 260 |
+
|
| 261 |
+
print(f"Found {len(slides)} slide(s) in {args.input_dir}")
|
| 262 |
+
|
| 263 |
+
for slide_path in slides:
|
| 264 |
+
try:
|
| 265 |
+
normalize_slide(
|
| 266 |
+
slide_path,
|
| 267 |
+
target_image,
|
| 268 |
+
patch_size=args.patch_size,
|
| 269 |
+
saturation_threshold=args.saturation_threshold,
|
| 270 |
+
output_dir=args.output_dir,
|
| 271 |
+
)
|
| 272 |
+
except Exception as e:
|
| 273 |
+
print(f"[ERROR] {slide_path}: {e}")
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
if __name__ == "__main__":
|
| 277 |
+
main()
|
patchify.py
ADDED
|
@@ -0,0 +1,254 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
HNE2Cell — Step 2: Patch Extraction
|
| 3 |
+
|
| 4 |
+
Extract overlapping patches from color-normalized H&E images for cell detection.
|
| 5 |
+
Supports both 20x and 40x magnification (40x recommended for best results).
|
| 6 |
+
|
| 7 |
+
Usage:
|
| 8 |
+
# 40x (recommended)
|
| 9 |
+
python patchify.py \
|
| 10 |
+
--input_dir /path/to/slides \
|
| 11 |
+
--patch_size 256 \
|
| 12 |
+
--overlap 64 \
|
| 13 |
+
--magnification 40 \
|
| 14 |
+
--workers 8
|
| 15 |
+
|
| 16 |
+
# 20x (supported but 40x preferred)
|
| 17 |
+
python patchify.py \
|
| 18 |
+
--input_dir /path/to/slides \
|
| 19 |
+
--patch_size 256 \
|
| 20 |
+
--overlap 64 \
|
| 21 |
+
--magnification 20 \
|
| 22 |
+
--workers 8
|
| 23 |
+
|
| 24 |
+
Notes:
|
| 25 |
+
- 40x magnification is recommended for optimal cell detection accuracy.
|
| 26 |
+
- 20x is supported and functional, but fine-grained cell boundaries
|
| 27 |
+
(especially small immune cells) may be less precise.
|
| 28 |
+
- Input: Aligned-hne.tif (output of normalize.py)
|
| 29 |
+
- Output: <section>/patches_<mag>x_p<patch>_o<overlap>/<name>_<x>_<y>.png
|
| 30 |
+
"""
|
| 31 |
+
|
| 32 |
+
import os
|
| 33 |
+
import argparse
|
| 34 |
+
import glob
|
| 35 |
+
import time
|
| 36 |
+
from multiprocessing import Pool
|
| 37 |
+
|
| 38 |
+
import numpy as np
|
| 39 |
+
from PIL import Image
|
| 40 |
+
from tqdm import tqdm
|
| 41 |
+
|
| 42 |
+
Image.MAX_IMAGE_PIXELS = None
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
# ========================== Utility functions ==============================
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def black_to_white(pil_img: Image.Image) -> Image.Image:
|
| 49 |
+
"""Replace pure-black (0,0,0) pixels with white — avoids dark-border artifacts."""
|
| 50 |
+
arr = np.array(pil_img)
|
| 51 |
+
if arr.ndim == 3 and arr.shape[2] >= 3:
|
| 52 |
+
mask = (arr[..., :3] == 0).all(axis=-1)
|
| 53 |
+
arr[mask] = 255
|
| 54 |
+
return Image.fromarray(arr)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def make_start_positions(length: int, patch_size: int, stride: int) -> list[int]:
|
| 58 |
+
"""Generate start positions so the last patch always reaches the edge."""
|
| 59 |
+
if length < patch_size:
|
| 60 |
+
return [0]
|
| 61 |
+
starts = list(range(0, length - patch_size + 1, stride))
|
| 62 |
+
last = length - patch_size
|
| 63 |
+
if starts[-1] != last:
|
| 64 |
+
starts.append(last)
|
| 65 |
+
return starts
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
# ========================== Core patching ==================================
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def extract_patches(
|
| 72 |
+
image_path: str,
|
| 73 |
+
output_dir: str,
|
| 74 |
+
patch_size: int = 256,
|
| 75 |
+
overlap: int = 64,
|
| 76 |
+
prefix: str = "patch",
|
| 77 |
+
) -> int:
|
| 78 |
+
"""Crop overlapping patches from a single image and save as PNG.
|
| 79 |
+
|
| 80 |
+
Returns the number of patches saved.
|
| 81 |
+
"""
|
| 82 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 83 |
+
|
| 84 |
+
stride = patch_size - overlap
|
| 85 |
+
assert stride > 0, f"overlap ({overlap}) must be < patch_size ({patch_size})"
|
| 86 |
+
|
| 87 |
+
img = Image.open(image_path).convert("RGB")
|
| 88 |
+
width, height = img.size
|
| 89 |
+
|
| 90 |
+
xs = make_start_positions(width, patch_size, stride)
|
| 91 |
+
ys = make_start_positions(height, patch_size, stride)
|
| 92 |
+
|
| 93 |
+
count = 0
|
| 94 |
+
with tqdm(total=len(xs) * len(ys), desc=prefix, unit="patch", leave=False) as pbar:
|
| 95 |
+
for x0 in xs:
|
| 96 |
+
for y0 in ys:
|
| 97 |
+
patch = img.crop((x0, y0, x0 + patch_size, y0 + patch_size))
|
| 98 |
+
patch = black_to_white(patch)
|
| 99 |
+
patch.save(
|
| 100 |
+
os.path.join(output_dir, f"{prefix}_{x0}_{y0}.png"),
|
| 101 |
+
format="PNG",
|
| 102 |
+
)
|
| 103 |
+
count += 1
|
| 104 |
+
pbar.update(1)
|
| 105 |
+
return count
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
# =================== Per-section processing (for Pool) =====================
|
| 109 |
+
|
| 110 |
+
# These will be set once in main() before the pool is created
|
| 111 |
+
_ARGS = {}
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def _process_section(section_dir: str) -> str:
|
| 115 |
+
"""Process a single section directory. Designed for multiprocessing.Pool."""
|
| 116 |
+
|
| 117 |
+
patch_size = _ARGS["patch_size"]
|
| 118 |
+
overlap = _ARGS["overlap"]
|
| 119 |
+
magnification = _ARGS["magnification"]
|
| 120 |
+
input_filename = _ARGS["input_filename"]
|
| 121 |
+
|
| 122 |
+
# Locate input file
|
| 123 |
+
candidates = [
|
| 124 |
+
os.path.join(section_dir, f"{input_filename}.tif"),
|
| 125 |
+
os.path.join(section_dir, f"{input_filename}.tiff"),
|
| 126 |
+
]
|
| 127 |
+
image_path = next((p for p in candidates if os.path.exists(p)), None)
|
| 128 |
+
|
| 129 |
+
if image_path is None:
|
| 130 |
+
return f"[SKIP] {section_dir}: {input_filename}.tif not found"
|
| 131 |
+
|
| 132 |
+
stride = patch_size - overlap
|
| 133 |
+
out_dir = os.path.join(
|
| 134 |
+
section_dir,
|
| 135 |
+
f"patches_{magnification}x_p{patch_size}_o{overlap}",
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
section_name = os.path.basename(section_dir)
|
| 139 |
+
|
| 140 |
+
t0 = time.time()
|
| 141 |
+
n = extract_patches(
|
| 142 |
+
image_path=image_path,
|
| 143 |
+
output_dir=out_dir,
|
| 144 |
+
patch_size=patch_size,
|
| 145 |
+
overlap=overlap,
|
| 146 |
+
prefix=section_name,
|
| 147 |
+
)
|
| 148 |
+
dt = time.time() - t0
|
| 149 |
+
|
| 150 |
+
return (
|
| 151 |
+
f"[OK] {section_name} | {magnification}x | "
|
| 152 |
+
f"stride={stride} | {n} patches | {dt:.1f}s → {out_dir}"
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
# =============================== CLI =======================================
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def main():
|
| 160 |
+
parser = argparse.ArgumentParser(
|
| 161 |
+
description="Extract overlapping patches from normalized H&E images"
|
| 162 |
+
)
|
| 163 |
+
parser.add_argument(
|
| 164 |
+
"--input_dir",
|
| 165 |
+
type=str,
|
| 166 |
+
required=True,
|
| 167 |
+
help="Root directory containing section folders with Aligned-hne.tif files",
|
| 168 |
+
)
|
| 169 |
+
parser.add_argument(
|
| 170 |
+
"--input_filename",
|
| 171 |
+
type=str,
|
| 172 |
+
default="Aligned-hne",
|
| 173 |
+
help="Base filename of the normalized image (default: Aligned-hne)",
|
| 174 |
+
)
|
| 175 |
+
parser.add_argument(
|
| 176 |
+
"--patch_size", type=int, default=256, help="Patch size in pixels (default: 256)"
|
| 177 |
+
)
|
| 178 |
+
parser.add_argument(
|
| 179 |
+
"--overlap", type=int, default=64, help="Overlap in pixels (default: 64)"
|
| 180 |
+
)
|
| 181 |
+
parser.add_argument(
|
| 182 |
+
"--magnification",
|
| 183 |
+
type=int,
|
| 184 |
+
default=40,
|
| 185 |
+
choices=[20, 40],
|
| 186 |
+
help="Slide magnification. 40x recommended; 20x supported. (default: 40)",
|
| 187 |
+
)
|
| 188 |
+
parser.add_argument(
|
| 189 |
+
"--pattern",
|
| 190 |
+
type=str,
|
| 191 |
+
default="*",
|
| 192 |
+
help="Glob pattern to match section folders (default: '*')",
|
| 193 |
+
)
|
| 194 |
+
parser.add_argument(
|
| 195 |
+
"--workers", type=int, default=8, help="Number of parallel workers (default: 8)"
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
args = parser.parse_args()
|
| 199 |
+
|
| 200 |
+
if args.magnification == 20:
|
| 201 |
+
print(
|
| 202 |
+
"⚠️ 20x magnification is supported but 40x is recommended for best "
|
| 203 |
+
"cell detection accuracy (especially small immune cells)."
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
# Collect section directories
|
| 207 |
+
section_dirs = sorted(
|
| 208 |
+
p
|
| 209 |
+
for p in glob.glob(os.path.join(args.input_dir, args.pattern))
|
| 210 |
+
if os.path.isdir(p)
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
if not section_dirs:
|
| 214 |
+
# Maybe input_dir itself contains the image directly
|
| 215 |
+
candidates = [
|
| 216 |
+
os.path.join(args.input_dir, f"{args.input_filename}.tif"),
|
| 217 |
+
os.path.join(args.input_dir, f"{args.input_filename}.tiff"),
|
| 218 |
+
]
|
| 219 |
+
if any(os.path.exists(c) for c in candidates):
|
| 220 |
+
section_dirs = [args.input_dir]
|
| 221 |
+
else:
|
| 222 |
+
raise SystemExit(
|
| 223 |
+
f"No section folders matching '{args.pattern}' found in {args.input_dir}"
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
print(f"Found {len(section_dirs)} section(s) | {args.magnification}x | "
|
| 227 |
+
f"patch={args.patch_size} overlap={args.overlap}")
|
| 228 |
+
|
| 229 |
+
# Set global args for worker processes
|
| 230 |
+
global _ARGS
|
| 231 |
+
_ARGS = {
|
| 232 |
+
"patch_size": args.patch_size,
|
| 233 |
+
"overlap": args.overlap,
|
| 234 |
+
"magnification": args.magnification,
|
| 235 |
+
"input_filename": args.input_filename,
|
| 236 |
+
}
|
| 237 |
+
|
| 238 |
+
if args.workers <= 1 or len(section_dirs) == 1:
|
| 239 |
+
results = [_process_section(d) for d in tqdm(section_dirs, desc="Sections")]
|
| 240 |
+
else:
|
| 241 |
+
with Pool(processes=min(args.workers, len(section_dirs))) as pool:
|
| 242 |
+
results = list(
|
| 243 |
+
tqdm(
|
| 244 |
+
pool.imap_unordered(_process_section, section_dirs),
|
| 245 |
+
total=len(section_dirs),
|
| 246 |
+
desc="Sections",
|
| 247 |
+
)
|
| 248 |
+
)
|
| 249 |
+
|
| 250 |
+
print("\n".join(results))
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
if __name__ == "__main__":
|
| 254 |
+
main()
|
post_processing.py
ADDED
|
@@ -0,0 +1,348 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# PostProcessing Pipeline
|
| 3 |
+
#
|
| 4 |
+
# Adapted from HoverNet
|
| 5 |
+
# HoverNet Network (https://doi.org/10.1016/j.media.2019.101563)
|
| 6 |
+
# Code Snippet adapted from HoverNet implementation (https://github.com/vqdang/hover_net)
|
| 7 |
+
#
|
| 8 |
+
# @ Fabian Hörst, fabian.hoerst@uk-essen.de
|
| 9 |
+
# Institute for Artifical Intelligence in Medicine,
|
| 10 |
+
# University Medicine Essen
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
import warnings
|
| 14 |
+
from typing import Tuple, Literal
|
| 15 |
+
|
| 16 |
+
import cv2
|
| 17 |
+
import numpy as np
|
| 18 |
+
from scipy.ndimage import measurements
|
| 19 |
+
from scipy.ndimage.morphology import binary_fill_holes
|
| 20 |
+
from skimage.segmentation import watershed
|
| 21 |
+
import torch
|
| 22 |
+
|
| 23 |
+
# import sys
|
| 24 |
+
# sys.path.append("/home01/k123a01/CellViTR/cell_segmentation/utils/")
|
| 25 |
+
from tools import get_bounding_box, remove_small_objects
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def noop(*args, **kargs):
|
| 29 |
+
pass
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
warnings.warn = noop
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class DetectionCellPostProcessor:
|
| 36 |
+
def __init__(
|
| 37 |
+
self,
|
| 38 |
+
nr_types: int = None,
|
| 39 |
+
magnification: Literal[20, 40] = 40,
|
| 40 |
+
gt: bool = False,
|
| 41 |
+
) -> None:
|
| 42 |
+
"""DetectionCellPostProcessor for postprocessing prediction maps and get detected cells
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
nr_types (int, optional): Number of cell types, including background (background = 0). Defaults to None.
|
| 46 |
+
magnification (Literal[20, 40], optional): Which magnification the data has. Defaults to 40.
|
| 47 |
+
gt (bool, optional): If this is gt data (used that we do not suppress tiny cells that may be noise in a prediction map).
|
| 48 |
+
Defaults to False.
|
| 49 |
+
|
| 50 |
+
Raises:
|
| 51 |
+
NotImplementedError: Unknown magnification
|
| 52 |
+
"""
|
| 53 |
+
self.nr_types = nr_types
|
| 54 |
+
self.magnification = magnification
|
| 55 |
+
self.gt = gt
|
| 56 |
+
|
| 57 |
+
if magnification == 40:
|
| 58 |
+
self.object_size = 10
|
| 59 |
+
self.k_size = 21
|
| 60 |
+
elif magnification == 20:
|
| 61 |
+
self.object_size = 3 # 3 or 40, we used 5
|
| 62 |
+
self.k_size = 11 # 11 or 41, we used 13
|
| 63 |
+
else:
|
| 64 |
+
raise NotImplementedError("Unknown magnification")
|
| 65 |
+
if gt: # to not supress something in gt!
|
| 66 |
+
self.object_size = 100
|
| 67 |
+
self.k_size = 21
|
| 68 |
+
|
| 69 |
+
def post_process_cell_segmentation(
|
| 70 |
+
self,
|
| 71 |
+
pred_map: np.ndarray,
|
| 72 |
+
) -> Tuple[np.ndarray, dict]:
|
| 73 |
+
"""Post processing of one image tile
|
| 74 |
+
|
| 75 |
+
Args:
|
| 76 |
+
pred_map (np.ndarray): Combined output of tp, np and hv branches, in the same order. Shape: (H, W, 4)
|
| 77 |
+
|
| 78 |
+
Returns:
|
| 79 |
+
Tuple[np.ndarray, dict]:
|
| 80 |
+
np.ndarray: Instance map for one image. Each nuclei has own integer. Shape: (H, W)
|
| 81 |
+
dict: Instance dictionary. Main Key is the nuclei instance number (int), with a dict as value.
|
| 82 |
+
For each instance, the dictionary contains the keys: bbox (bounding box), centroid (centroid coordinates),
|
| 83 |
+
contour, type_prob (probability), type (nuclei type)
|
| 84 |
+
"""
|
| 85 |
+
if self.nr_types is not None:
|
| 86 |
+
pred_type = pred_map[..., 1:2]
|
| 87 |
+
pred_inst = pred_map[..., 2:]
|
| 88 |
+
pred_type = pred_type.astype(np.int32)
|
| 89 |
+
# print('pred_type',pred_type)
|
| 90 |
+
else:
|
| 91 |
+
pred_inst = pred_map
|
| 92 |
+
|
| 93 |
+
pred_inst = np.squeeze(pred_inst)
|
| 94 |
+
pred_inst = self.__proc_np_hv(
|
| 95 |
+
pred_inst, object_size=self.object_size, ksize=self.k_size
|
| 96 |
+
)
|
| 97 |
+
# print('pred_inst',pred_inst)
|
| 98 |
+
inst_id_list = np.unique(pred_inst)[1:] # exlcude background
|
| 99 |
+
# print('inst_id_list',inst_id_list)
|
| 100 |
+
inst_info_dict = {}
|
| 101 |
+
for inst_id in inst_id_list:
|
| 102 |
+
inst_map = pred_inst == inst_id
|
| 103 |
+
rmin, rmax, cmin, cmax = get_bounding_box(inst_map)
|
| 104 |
+
inst_bbox = np.array([[rmin, cmin], [rmax, cmax]])
|
| 105 |
+
inst_map = inst_map[
|
| 106 |
+
inst_bbox[0][0] : inst_bbox[1][0], inst_bbox[0][1] : inst_bbox[1][1]
|
| 107 |
+
]
|
| 108 |
+
inst_map = inst_map.astype(np.uint8)
|
| 109 |
+
inst_moment = cv2.moments(inst_map)
|
| 110 |
+
inst_contour = cv2.findContours(
|
| 111 |
+
inst_map, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
|
| 112 |
+
)
|
| 113 |
+
# * opencv protocol format may break
|
| 114 |
+
inst_contour = np.squeeze(inst_contour[0][0].astype("int32"))
|
| 115 |
+
# < 3 points dont make a contour, so skip, likely artifact too
|
| 116 |
+
# as the contours obtained via approximation => too small or sthg
|
| 117 |
+
if inst_contour.shape[0] < 3:
|
| 118 |
+
continue
|
| 119 |
+
if len(inst_contour.shape) != 2:
|
| 120 |
+
continue # ! check for trickery shape
|
| 121 |
+
inst_centroid = [
|
| 122 |
+
(inst_moment["m10"] / inst_moment["m00"]),
|
| 123 |
+
(inst_moment["m01"] / inst_moment["m00"]),
|
| 124 |
+
]
|
| 125 |
+
inst_centroid = np.array(inst_centroid)
|
| 126 |
+
inst_contour[:, 0] += inst_bbox[0][1] # X
|
| 127 |
+
inst_contour[:, 1] += inst_bbox[0][0] # Y
|
| 128 |
+
inst_centroid[0] += inst_bbox[0][1] # X
|
| 129 |
+
inst_centroid[1] += inst_bbox[0][0] # Y
|
| 130 |
+
inst_info_dict[inst_id] = { # inst_id should start at 1
|
| 131 |
+
"bbox": inst_bbox,
|
| 132 |
+
"centroid": inst_centroid,
|
| 133 |
+
"contour": inst_contour,
|
| 134 |
+
"type_prob": None,
|
| 135 |
+
"type": None,
|
| 136 |
+
"all_type_prob": []
|
| 137 |
+
}
|
| 138 |
+
|
| 139 |
+
#### * Get class of each instance id, stored at index id-1 (inst_id = number of deteced nucleus)
|
| 140 |
+
for inst_id in list(inst_info_dict.keys()):
|
| 141 |
+
rmin, cmin, rmax, cmax = (inst_info_dict[inst_id]["bbox"]).flatten()
|
| 142 |
+
inst_map_crop = pred_inst[rmin:rmax, cmin:cmax]
|
| 143 |
+
inst_type_crop = pred_type[rmin:rmax, cmin:cmax]
|
| 144 |
+
inst_map_crop = inst_map_crop == inst_id
|
| 145 |
+
inst_type = inst_type_crop[inst_map_crop]
|
| 146 |
+
type_list, type_pixels = np.unique(inst_type, return_counts=True)
|
| 147 |
+
type_list = list(zip(type_list, type_pixels))
|
| 148 |
+
|
| 149 |
+
type_probs = {} # 각 인스턴스에 대한 cell type의 확률
|
| 150 |
+
total_pixels = np.sum(inst_map_crop) + 1.0e-6 # 0으로 나누지 않게 하기 위한 작은 값 추가
|
| 151 |
+
for cell_type, pixel_count in type_list:
|
| 152 |
+
type_probs[cell_type] = float(pixel_count / total_pixels)
|
| 153 |
+
|
| 154 |
+
all_type_prob = [type_probs.get(i, 0.0) for i in range(self.nr_types)] # 전체 cell type에 대한 확률
|
| 155 |
+
inst_info_dict[inst_id]["all_type_prob"] = all_type_prob
|
| 156 |
+
#print("inst all probs :", inst_info_dict[inst_id]["all_type_prob"])
|
| 157 |
+
|
| 158 |
+
type_list = sorted(type_list, key=lambda x: x[1], reverse=True)
|
| 159 |
+
inst_type = type_list[0][0]
|
| 160 |
+
if inst_type == 0: # ! pick the 2nd most dominant if exist
|
| 161 |
+
if len(type_list) > 1:
|
| 162 |
+
inst_type = type_list[1][0]
|
| 163 |
+
type_dict = {v[0]: v[1] for v in type_list}
|
| 164 |
+
type_prob = type_dict[inst_type] / (np.sum(inst_map_crop) + 1.0e-6)
|
| 165 |
+
inst_info_dict[inst_id]["type"] = int(inst_type)
|
| 166 |
+
inst_info_dict[inst_id]["type_prob"] = float(type_prob)
|
| 167 |
+
|
| 168 |
+
return pred_inst, inst_info_dict
|
| 169 |
+
|
| 170 |
+
def __proc_np_hv(
|
| 171 |
+
self, pred: np.ndarray, object_size: int = 10, ksize: int = 21
|
| 172 |
+
) -> np.ndarray:
|
| 173 |
+
"""Process Nuclei Prediction with XY Coordinate Map and generate instance map (each instance has unique integer)
|
| 174 |
+
|
| 175 |
+
Separate Instances (also overlapping ones) from binary nuclei map and hv map by using morphological operations and watershed
|
| 176 |
+
|
| 177 |
+
Args:
|
| 178 |
+
pred (np.ndarray): Prediction output, assuming. Shape: (H, W, 3)
|
| 179 |
+
* channel 0 contain probability map of nuclei
|
| 180 |
+
* channel 1 containing the regressed X-map
|
| 181 |
+
* channel 2 containing the regressed Y-map
|
| 182 |
+
object_size (int, optional): Smallest oject size for filtering. Defaults to 10
|
| 183 |
+
k_size (int, optional): Sobel Kernel size. Defaults to 21
|
| 184 |
+
Returns:
|
| 185 |
+
np.ndarray: Instance map for one image. Each nuclei has own integer. Shape: (H, W)
|
| 186 |
+
"""
|
| 187 |
+
pred = np.array(pred, dtype=np.float32)
|
| 188 |
+
|
| 189 |
+
blb_raw = pred[..., 0]
|
| 190 |
+
h_dir_raw = pred[..., 1]
|
| 191 |
+
v_dir_raw = pred[..., 2]
|
| 192 |
+
|
| 193 |
+
# processing
|
| 194 |
+
blb = np.array(blb_raw >= 0.5, dtype=np.int32)
|
| 195 |
+
|
| 196 |
+
blb = measurements.label(blb)[0] # ndimage.label(blb)[0]
|
| 197 |
+
blb = remove_small_objects(blb, min_size=10) # 10
|
| 198 |
+
blb[blb > 0] = 1 # background is 0 already
|
| 199 |
+
|
| 200 |
+
h_dir = cv2.normalize(
|
| 201 |
+
h_dir_raw,
|
| 202 |
+
None,
|
| 203 |
+
alpha=0,
|
| 204 |
+
beta=1,
|
| 205 |
+
norm_type=cv2.NORM_MINMAX,
|
| 206 |
+
dtype=cv2.CV_32F,
|
| 207 |
+
)
|
| 208 |
+
v_dir = cv2.normalize(
|
| 209 |
+
v_dir_raw,
|
| 210 |
+
None,
|
| 211 |
+
alpha=0,
|
| 212 |
+
beta=1,
|
| 213 |
+
norm_type=cv2.NORM_MINMAX,
|
| 214 |
+
dtype=cv2.CV_32F,
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
# ksize = int((20 * scale_factor) + 1) # 21 vs 41
|
| 218 |
+
# obj_size = math.ceil(10 * (scale_factor**2)) #10 vs 40
|
| 219 |
+
|
| 220 |
+
sobelh = cv2.Sobel(h_dir, cv2.CV_64F, 1, 0, ksize=ksize)
|
| 221 |
+
sobelv = cv2.Sobel(v_dir, cv2.CV_64F, 0, 1, ksize=ksize)
|
| 222 |
+
|
| 223 |
+
sobelh = 1 - (
|
| 224 |
+
cv2.normalize(
|
| 225 |
+
sobelh,
|
| 226 |
+
None,
|
| 227 |
+
alpha=0,
|
| 228 |
+
beta=1,
|
| 229 |
+
norm_type=cv2.NORM_MINMAX,
|
| 230 |
+
dtype=cv2.CV_32F,
|
| 231 |
+
)
|
| 232 |
+
)
|
| 233 |
+
sobelv = 1 - (
|
| 234 |
+
cv2.normalize(
|
| 235 |
+
sobelv,
|
| 236 |
+
None,
|
| 237 |
+
alpha=0,
|
| 238 |
+
beta=1,
|
| 239 |
+
norm_type=cv2.NORM_MINMAX,
|
| 240 |
+
dtype=cv2.CV_32F,
|
| 241 |
+
)
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
overall = np.maximum(sobelh, sobelv)
|
| 245 |
+
overall = overall - (1 - blb)
|
| 246 |
+
overall[overall < 0] = 0
|
| 247 |
+
|
| 248 |
+
dist = (1.0 - overall) * blb
|
| 249 |
+
## nuclei values form mountains so inverse to get basins
|
| 250 |
+
dist = -cv2.GaussianBlur(dist, (3, 3), 0)
|
| 251 |
+
|
| 252 |
+
overall = np.array(overall >= 0.4, dtype=np.int32)
|
| 253 |
+
|
| 254 |
+
marker = blb - overall
|
| 255 |
+
marker[marker < 0] = 0
|
| 256 |
+
marker = binary_fill_holes(marker).astype("uint8")
|
| 257 |
+
kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5))
|
| 258 |
+
marker = cv2.morphologyEx(marker, cv2.MORPH_OPEN, kernel)
|
| 259 |
+
marker = measurements.label(marker)[0]
|
| 260 |
+
marker = remove_small_objects(marker, min_size=object_size)
|
| 261 |
+
|
| 262 |
+
proced_pred = watershed(dist, markers=marker, mask=blb)
|
| 263 |
+
|
| 264 |
+
return proced_pred
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
def calculate_instances(
|
| 268 |
+
pred_types: torch.Tensor, pred_insts: torch.Tensor
|
| 269 |
+
) -> list[dict]:
|
| 270 |
+
"""Best used for GT
|
| 271 |
+
|
| 272 |
+
Args:
|
| 273 |
+
pred_types (torch.Tensor): Binary or type map ground-truth.
|
| 274 |
+
Shape must be (B, C, H, W) with C=1 for binary or num_nuclei_types for multi-class.
|
| 275 |
+
pred_insts (torch.Tensor): Ground-Truth instance map with shape (B, H, W)
|
| 276 |
+
|
| 277 |
+
Returns:
|
| 278 |
+
list[dict]: Dictionary with nuclei informations, output similar to post_process_cell_segmentation
|
| 279 |
+
"""
|
| 280 |
+
type_preds = []
|
| 281 |
+
pred_types = pred_types.permute(0, 2, 3, 1)
|
| 282 |
+
for i in range(pred_types.shape[0]):
|
| 283 |
+
pred_type = torch.argmax(pred_types, dim=-1)[i].detach().cpu().numpy()
|
| 284 |
+
pred_inst = pred_insts[i].detach().cpu().numpy()
|
| 285 |
+
inst_id_list = np.unique(pred_inst)[1:] # exlcude background
|
| 286 |
+
inst_info_dict = {}
|
| 287 |
+
for inst_id in inst_id_list:
|
| 288 |
+
inst_map = pred_inst == inst_id
|
| 289 |
+
rmin, rmax, cmin, cmax = get_bounding_box(inst_map)
|
| 290 |
+
inst_bbox = np.array([[rmin, cmin], [rmax, cmax]])
|
| 291 |
+
inst_map = inst_map[
|
| 292 |
+
inst_bbox[0][0] : inst_bbox[1][0], inst_bbox[0][1] : inst_bbox[1][1]
|
| 293 |
+
]
|
| 294 |
+
inst_map = inst_map.astype(np.uint8)
|
| 295 |
+
inst_moment = cv2.moments(inst_map)
|
| 296 |
+
inst_contour = cv2.findContours(
|
| 297 |
+
inst_map, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
|
| 298 |
+
)
|
| 299 |
+
# * opencv protocol format may break
|
| 300 |
+
inst_contour = np.squeeze(inst_contour[0][0].astype("int32"))
|
| 301 |
+
# < 3 points dont make a contour, so skip, likely artifact too
|
| 302 |
+
# as the contours obtained via approximation => too small or sthg
|
| 303 |
+
if inst_contour.shape[0] < 3:
|
| 304 |
+
continue
|
| 305 |
+
if len(inst_contour.shape) != 2:
|
| 306 |
+
continue # ! check for trickery shape
|
| 307 |
+
inst_centroid = [
|
| 308 |
+
(inst_moment["m10"] / inst_moment["m00"]),
|
| 309 |
+
(inst_moment["m01"] / inst_moment["m00"]),
|
| 310 |
+
]
|
| 311 |
+
inst_centroid = np.array(inst_centroid)
|
| 312 |
+
inst_contour[:, 0] += inst_bbox[0][1] # X
|
| 313 |
+
inst_contour[:, 1] += inst_bbox[0][0] # Y
|
| 314 |
+
inst_centroid[0] += inst_bbox[0][1] # X
|
| 315 |
+
inst_centroid[1] += inst_bbox[0][0] # Y
|
| 316 |
+
inst_info_dict[inst_id] = { # inst_id should start at 1
|
| 317 |
+
"bbox": inst_bbox,
|
| 318 |
+
"centroid": inst_centroid,
|
| 319 |
+
"contour": inst_contour,
|
| 320 |
+
"type_prob": None,
|
| 321 |
+
"type": None,
|
| 322 |
+
}
|
| 323 |
+
#### * Get class of each instance id, stored at index id-1 (inst_id = number of deteced nucleus)
|
| 324 |
+
for inst_id in list(inst_info_dict.keys()):
|
| 325 |
+
rmin, cmin, rmax, cmax = (inst_info_dict[inst_id]["bbox"]).flatten()
|
| 326 |
+
inst_map_crop = pred_inst[rmin:rmax, cmin:cmax]
|
| 327 |
+
inst_type_crop = pred_type[rmin:rmax, cmin:cmax]
|
| 328 |
+
inst_map_crop = inst_map_crop == inst_id
|
| 329 |
+
inst_type = inst_type_crop[inst_map_crop]
|
| 330 |
+
type_list, type_pixels = np.unique(inst_type, return_counts=True)
|
| 331 |
+
type_list = list(zip(type_list, type_pixels))
|
| 332 |
+
type_list = sorted(type_list, key=lambda x: x[1], reverse=True)
|
| 333 |
+
inst_type = type_list[0][0]
|
| 334 |
+
if inst_type == 0: # ! pick the 2nd most dominant if exist
|
| 335 |
+
if len(type_list) > 1:
|
| 336 |
+
inst_type = type_list[1][0]
|
| 337 |
+
type_dict = {v[0]: v[1] for v in type_list}
|
| 338 |
+
type_prob = type_dict[inst_type] / (np.sum(inst_map_crop) + 1.0e-6)
|
| 339 |
+
inst_info_dict[inst_id]["type"] = int(inst_type)
|
| 340 |
+
inst_info_dict[inst_id]["type_prob"] = float(type_prob)
|
| 341 |
+
type_preds.append(inst_info_dict)
|
| 342 |
+
|
| 343 |
+
return type_preds
|
| 344 |
+
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
|
| 348 |
+
|
standard-ilc.tif
ADDED
|
|
Git LFS Details
|
tools.py
ADDED
|
@@ -0,0 +1,400 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
# Helpful functions Pipeline
|
| 3 |
+
#
|
| 4 |
+
# Adapted from HoverNet
|
| 5 |
+
# HoverNet Network (https://doi.org/10.1016/j.media.2019.101563)
|
| 6 |
+
# Code Snippet adapted from HoverNet implementation (https://github.com/vqdang/hover_net)
|
| 7 |
+
#
|
| 8 |
+
# @ Fabian Hörst, fabian.hoerst@uk-essen.de
|
| 9 |
+
# Institute for Artifical Intelligence in Medicine,
|
| 10 |
+
# University Medicine Essen
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
import math
|
| 14 |
+
from typing import Tuple
|
| 15 |
+
|
| 16 |
+
import numpy as np
|
| 17 |
+
import scipy
|
| 18 |
+
from numba import njit, prange
|
| 19 |
+
from scipy import ndimage
|
| 20 |
+
from scipy.optimize import linear_sum_assignment
|
| 21 |
+
from skimage.draw import polygon
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def get_bounding_box(img):
|
| 25 |
+
"""Get bounding box coordinate information."""
|
| 26 |
+
rows = np.any(img, axis=1)
|
| 27 |
+
cols = np.any(img, axis=0)
|
| 28 |
+
rmin, rmax = np.where(rows)[0][[0, -1]]
|
| 29 |
+
cmin, cmax = np.where(cols)[0][[0, -1]]
|
| 30 |
+
# due to python indexing, need to add 1 to max
|
| 31 |
+
# else accessing will be 1px in the box, not out
|
| 32 |
+
rmax += 1
|
| 33 |
+
cmax += 1
|
| 34 |
+
return [rmin, rmax, cmin, cmax]
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
@njit
|
| 38 |
+
def cropping_center(x, crop_shape, batch=False):
|
| 39 |
+
"""Crop an input image at the centre.
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
x: input array
|
| 43 |
+
crop_shape: dimensions of cropped array
|
| 44 |
+
|
| 45 |
+
Returns:
|
| 46 |
+
x: cropped array
|
| 47 |
+
|
| 48 |
+
"""
|
| 49 |
+
orig_shape = x.shape
|
| 50 |
+
if not batch:
|
| 51 |
+
h0 = int((orig_shape[0] - crop_shape[0]) * 0.5)
|
| 52 |
+
w0 = int((orig_shape[1] - crop_shape[1]) * 0.5)
|
| 53 |
+
x = x[h0 : h0 + crop_shape[0], w0 : w0 + crop_shape[1], ...]
|
| 54 |
+
else:
|
| 55 |
+
h0 = int((orig_shape[1] - crop_shape[0]) * 0.5)
|
| 56 |
+
w0 = int((orig_shape[2] - crop_shape[1]) * 0.5)
|
| 57 |
+
x = x[:, h0 : h0 + crop_shape[0], w0 : w0 + crop_shape[1], ...]
|
| 58 |
+
return x
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def remove_small_objects(pred, min_size=64, connectivity=1):
|
| 62 |
+
"""Remove connected components smaller than the specified size.
|
| 63 |
+
|
| 64 |
+
This function is taken from skimage.morphology.remove_small_objects, but the warning
|
| 65 |
+
is removed when a single label is provided.
|
| 66 |
+
|
| 67 |
+
Args:
|
| 68 |
+
pred: input labelled array
|
| 69 |
+
min_size: minimum size of instance in output array
|
| 70 |
+
connectivity: The connectivity defining the neighborhood of a pixel.
|
| 71 |
+
|
| 72 |
+
Returns:
|
| 73 |
+
out: output array with instances removed under min_size
|
| 74 |
+
|
| 75 |
+
"""
|
| 76 |
+
out = pred
|
| 77 |
+
|
| 78 |
+
if min_size == 0: # shortcut for efficiency
|
| 79 |
+
return out
|
| 80 |
+
|
| 81 |
+
if out.dtype == bool:
|
| 82 |
+
selem = ndimage.generate_binary_structure(pred.ndim, connectivity)
|
| 83 |
+
ccs = np.zeros_like(pred, dtype=np.int32)
|
| 84 |
+
ndimage.label(pred, selem, output=ccs)
|
| 85 |
+
else:
|
| 86 |
+
ccs = out
|
| 87 |
+
|
| 88 |
+
try:
|
| 89 |
+
component_sizes = np.bincount(ccs.ravel())
|
| 90 |
+
except ValueError:
|
| 91 |
+
raise ValueError(
|
| 92 |
+
"Negative value labels are not supported. Try "
|
| 93 |
+
"relabeling the input with `scipy.ndimage.label` or "
|
| 94 |
+
"`skimage.morphology.label`."
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
too_small = component_sizes < min_size
|
| 98 |
+
too_small_mask = too_small[ccs]
|
| 99 |
+
out[too_small_mask] = 0
|
| 100 |
+
|
| 101 |
+
return out
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def pair_coordinates(
|
| 105 |
+
setA: np.ndarray, setB: np.ndarray, radius: float
|
| 106 |
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
| 107 |
+
"""Use the Munkres or Kuhn-Munkres algorithm to find the most optimal
|
| 108 |
+
unique pairing (largest possible match) when pairing points in set B
|
| 109 |
+
against points in set A, using distance as cost function.
|
| 110 |
+
|
| 111 |
+
Args:
|
| 112 |
+
setA (np.ndarray): np.array (float32) of size Nx2 contains the of XY coordinate
|
| 113 |
+
of N different points
|
| 114 |
+
setB (np.ndarray): np.array (float32) of size Nx2 contains the of XY coordinate
|
| 115 |
+
of N different points
|
| 116 |
+
radius (float): valid area around a point in setA to consider
|
| 117 |
+
a given coordinate in setB a candidate for match
|
| 118 |
+
|
| 119 |
+
Returns:
|
| 120 |
+
Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
| 121 |
+
pairing: pairing is an array of indices
|
| 122 |
+
where point at index pairing[0] in set A paired with point
|
| 123 |
+
in set B at index pairing[1]
|
| 124 |
+
unparedA: remaining point in set A unpaired
|
| 125 |
+
unparedB: remaining point in set B unpaired
|
| 126 |
+
"""
|
| 127 |
+
# * Euclidean distance as the cost matrix
|
| 128 |
+
pair_distance = scipy.spatial.distance.cdist(setA, setB, metric="euclidean")
|
| 129 |
+
|
| 130 |
+
# * Munkres pairing with scipy library
|
| 131 |
+
# the algorithm return (row indices, matched column indices)
|
| 132 |
+
# if there is multiple same cost in a row, index of first occurence
|
| 133 |
+
# is return, thus the unique pairing is ensured
|
| 134 |
+
indicesA, paired_indicesB = linear_sum_assignment(pair_distance)
|
| 135 |
+
|
| 136 |
+
# extract the paired cost and remove instances
|
| 137 |
+
# outside of designated radius
|
| 138 |
+
pair_cost = pair_distance[indicesA, paired_indicesB]
|
| 139 |
+
|
| 140 |
+
pairedA = indicesA[pair_cost <= radius]
|
| 141 |
+
pairedB = paired_indicesB[pair_cost <= radius]
|
| 142 |
+
|
| 143 |
+
pairing = np.concatenate([pairedA[:, None], pairedB[:, None]], axis=-1)
|
| 144 |
+
unpairedA = np.delete(np.arange(setA.shape[0]), pairedA)
|
| 145 |
+
unpairedB = np.delete(np.arange(setB.shape[0]), pairedB)
|
| 146 |
+
|
| 147 |
+
return pairing, unpairedA, unpairedB
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def fix_duplicates(inst_map: np.ndarray) -> np.ndarray:
|
| 151 |
+
"""Re-label duplicated instances in an instance labelled mask.
|
| 152 |
+
|
| 153 |
+
Parameters
|
| 154 |
+
----------
|
| 155 |
+
inst_map : np.ndarray
|
| 156 |
+
Instance labelled mask. Shape (H, W).
|
| 157 |
+
|
| 158 |
+
Returns
|
| 159 |
+
-------
|
| 160 |
+
np.ndarray:
|
| 161 |
+
The instance labelled mask without duplicated indices.
|
| 162 |
+
Shape (H, W).
|
| 163 |
+
"""
|
| 164 |
+
current_max_id = np.amax(inst_map)
|
| 165 |
+
inst_list = list(np.unique(inst_map))
|
| 166 |
+
if 0 in inst_list:
|
| 167 |
+
inst_list.remove(0)
|
| 168 |
+
|
| 169 |
+
for inst_id in inst_list:
|
| 170 |
+
inst = np.array(inst_map == inst_id, np.uint8)
|
| 171 |
+
remapped_ids = ndimage.label(inst)[0]
|
| 172 |
+
remapped_ids[remapped_ids > 1] += current_max_id
|
| 173 |
+
inst_map[remapped_ids > 1] = remapped_ids[remapped_ids > 1]
|
| 174 |
+
current_max_id = np.amax(inst_map)
|
| 175 |
+
|
| 176 |
+
return inst_map
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
def polygons_to_label_coord(
|
| 180 |
+
coord: np.ndarray, shape: Tuple[int, int], labels: np.ndarray = None
|
| 181 |
+
) -> np.ndarray:
|
| 182 |
+
"""Render polygons to image given a shape.
|
| 183 |
+
|
| 184 |
+
Parameters
|
| 185 |
+
----------
|
| 186 |
+
coord.shape : np.ndarray
|
| 187 |
+
Shape: (n_polys, n_rays)
|
| 188 |
+
shape : Tuple[int, int]
|
| 189 |
+
Shape of the output mask.
|
| 190 |
+
labels : np.ndarray, optional
|
| 191 |
+
Sorted indices of the centroids.
|
| 192 |
+
|
| 193 |
+
Returns
|
| 194 |
+
-------
|
| 195 |
+
np.ndarray:
|
| 196 |
+
Instance labelled mask. Shape: (H, W).
|
| 197 |
+
"""
|
| 198 |
+
coord = np.asarray(coord)
|
| 199 |
+
if labels is None:
|
| 200 |
+
labels = np.arange(len(coord))
|
| 201 |
+
|
| 202 |
+
assert coord.ndim == 3 and coord.shape[1] == 2 and len(coord) == len(labels)
|
| 203 |
+
|
| 204 |
+
lbl = np.zeros(shape, np.int32)
|
| 205 |
+
|
| 206 |
+
for i, c in zip(labels, coord):
|
| 207 |
+
rr, cc = polygon(*c, shape)
|
| 208 |
+
lbl[rr, cc] = i + 1
|
| 209 |
+
|
| 210 |
+
return lbl
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def ray_angles(n_rays: int = 32):
|
| 214 |
+
"""Get linearly spaced angles for rays."""
|
| 215 |
+
return np.linspace(0, 2 * np.pi, n_rays, endpoint=False)
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
def dist_to_coord(
|
| 219 |
+
dist: np.ndarray, points: np.ndarray, scale_dist: Tuple[int, int] = (1, 1)
|
| 220 |
+
) -> np.ndarray:
|
| 221 |
+
"""Convert list of distances and centroids from polar to cartesian coordinates.
|
| 222 |
+
|
| 223 |
+
Parameters
|
| 224 |
+
----------
|
| 225 |
+
dist : np.ndarray
|
| 226 |
+
The centerpoint pixels of the radial distance map. Shape (n_polys, n_rays).
|
| 227 |
+
points : np.ndarray
|
| 228 |
+
The centroids of the instances. Shape: (n_polys, 2).
|
| 229 |
+
scale_dist : Tuple[int, int], default=(1, 1)
|
| 230 |
+
Scaling factor.
|
| 231 |
+
|
| 232 |
+
Returns
|
| 233 |
+
-------
|
| 234 |
+
np.ndarray:
|
| 235 |
+
Cartesian cooridnates of the polygons. Shape (n_polys, 2, n_rays).
|
| 236 |
+
"""
|
| 237 |
+
dist = np.asarray(dist)
|
| 238 |
+
points = np.asarray(points)
|
| 239 |
+
assert (
|
| 240 |
+
dist.ndim == 2
|
| 241 |
+
and points.ndim == 2
|
| 242 |
+
and len(dist) == len(points)
|
| 243 |
+
and points.shape[1] == 2
|
| 244 |
+
and len(scale_dist) == 2
|
| 245 |
+
)
|
| 246 |
+
n_rays = dist.shape[1]
|
| 247 |
+
phis = ray_angles(n_rays)
|
| 248 |
+
coord = (dist[:, np.newaxis] * np.array([np.sin(phis), np.cos(phis)])).astype(
|
| 249 |
+
np.float32
|
| 250 |
+
)
|
| 251 |
+
coord *= np.asarray(scale_dist).reshape(1, 2, 1)
|
| 252 |
+
coord += points[..., np.newaxis]
|
| 253 |
+
return coord
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
def polygons_to_label(
|
| 257 |
+
dist: np.ndarray,
|
| 258 |
+
points: np.ndarray,
|
| 259 |
+
shape: Tuple[int, int],
|
| 260 |
+
prob: np.ndarray = None,
|
| 261 |
+
thresh: float = -np.inf,
|
| 262 |
+
scale_dist: Tuple[int, int] = (1, 1),
|
| 263 |
+
) -> np.ndarray:
|
| 264 |
+
"""Convert distances and center points to instance labelled mask.
|
| 265 |
+
|
| 266 |
+
Parameters
|
| 267 |
+
----------
|
| 268 |
+
dist : np.ndarray
|
| 269 |
+
The centerpoint pixels of the radial distance map. Shape (n_polys, n_rays).
|
| 270 |
+
points : np.ndarray
|
| 271 |
+
The centroids of the instances. Shape: (n_polys, 2).
|
| 272 |
+
shape : Tuple[int, int]:
|
| 273 |
+
Shape of the output mask.
|
| 274 |
+
prob : np.ndarray, optional
|
| 275 |
+
The centerpoint pixels of the regressed distance transform.
|
| 276 |
+
Shape: (n_polys, n_rays).
|
| 277 |
+
thresh : float, default=-np.inf
|
| 278 |
+
Threshold for the regressed distance transform.
|
| 279 |
+
scale_dist : Tuple[int, int], default=(1, 1)
|
| 280 |
+
Scaling factor.
|
| 281 |
+
|
| 282 |
+
Returns
|
| 283 |
+
-------
|
| 284 |
+
np.ndarray:
|
| 285 |
+
Instance labelled mask. Shape (H, W).
|
| 286 |
+
"""
|
| 287 |
+
dist = np.asarray(dist)
|
| 288 |
+
points = np.asarray(points)
|
| 289 |
+
prob = np.inf * np.ones(len(points)) if prob is None else np.asarray(prob)
|
| 290 |
+
|
| 291 |
+
assert dist.ndim == 2 and points.ndim == 2 and len(dist) == len(points)
|
| 292 |
+
assert len(points) == len(prob) and points.shape[1] == 2 and prob.ndim == 1
|
| 293 |
+
|
| 294 |
+
ind = prob > thresh
|
| 295 |
+
points = points[ind]
|
| 296 |
+
dist = dist[ind]
|
| 297 |
+
prob = prob[ind]
|
| 298 |
+
|
| 299 |
+
ind = np.argsort(prob, kind="stable")
|
| 300 |
+
points = points[ind]
|
| 301 |
+
dist = dist[ind]
|
| 302 |
+
|
| 303 |
+
coord = dist_to_coord(dist, points, scale_dist=scale_dist)
|
| 304 |
+
|
| 305 |
+
return polygons_to_label_coord(coord, shape=shape, labels=ind)
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
@njit(cache=True, fastmath=True)
|
| 309 |
+
def intersection(boxA: np.ndarray, boxB: np.ndarray):
|
| 310 |
+
"""Compute area of intersection of two boxes.
|
| 311 |
+
|
| 312 |
+
Parameters
|
| 313 |
+
----------
|
| 314 |
+
boxA : np.ndarray
|
| 315 |
+
First boxes
|
| 316 |
+
boxB : np.ndarray
|
| 317 |
+
Second box
|
| 318 |
+
|
| 319 |
+
Returns
|
| 320 |
+
-------
|
| 321 |
+
float64:
|
| 322 |
+
Area of intersection
|
| 323 |
+
"""
|
| 324 |
+
xA = max(boxA[..., 0], boxB[..., 0])
|
| 325 |
+
xB = min(boxA[..., 2], boxB[..., 2])
|
| 326 |
+
dx = xB - xA
|
| 327 |
+
if dx <= 0:
|
| 328 |
+
return 0.0
|
| 329 |
+
|
| 330 |
+
yA = max(boxA[..., 1], boxB[..., 1])
|
| 331 |
+
yB = min(boxA[..., 3], boxB[..., 3])
|
| 332 |
+
dy = yB - yA
|
| 333 |
+
if dy <= 0.0:
|
| 334 |
+
return 0.0
|
| 335 |
+
|
| 336 |
+
return dx * dy
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
@njit(parallel=True)
|
| 340 |
+
def get_bboxes(
|
| 341 |
+
dist: np.ndarray, points: np.ndarray
|
| 342 |
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, int]:
|
| 343 |
+
"""Get bounding boxes from the non-zero pixels of the radial distance maps.
|
| 344 |
+
|
| 345 |
+
This is basically a translation from the stardist repo cpp code to python
|
| 346 |
+
|
| 347 |
+
NOTE: jit compiled and parallelized with numba.
|
| 348 |
+
|
| 349 |
+
Parameters
|
| 350 |
+
----------
|
| 351 |
+
dist : np.ndarray
|
| 352 |
+
The non-zero values of the radial distance maps. Shape: (n_nonzero, n_rays).
|
| 353 |
+
points : np.ndarray
|
| 354 |
+
The yx-coordinates of the non-zero points. Shape (n_nonzero, 2).
|
| 355 |
+
|
| 356 |
+
Returns
|
| 357 |
+
-------
|
| 358 |
+
Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray, int]:
|
| 359 |
+
Returns the x0, y0, x1, y1 bbox coordinates, bbox areas and the maximum
|
| 360 |
+
radial distance in the image.
|
| 361 |
+
"""
|
| 362 |
+
n_polys = dist.shape[0]
|
| 363 |
+
n_rays = dist.shape[1]
|
| 364 |
+
|
| 365 |
+
bbox_x1 = np.zeros(n_polys)
|
| 366 |
+
bbox_x2 = np.zeros(n_polys)
|
| 367 |
+
bbox_y1 = np.zeros(n_polys)
|
| 368 |
+
bbox_y2 = np.zeros(n_polys)
|
| 369 |
+
|
| 370 |
+
areas = np.zeros(n_polys)
|
| 371 |
+
angle_pi = 2 * math.pi / n_rays
|
| 372 |
+
max_dist = 0
|
| 373 |
+
|
| 374 |
+
for i in prange(n_polys):
|
| 375 |
+
max_radius_outer = 0
|
| 376 |
+
py = points[i, 0]
|
| 377 |
+
px = points[i, 1]
|
| 378 |
+
|
| 379 |
+
for k in range(n_rays):
|
| 380 |
+
d = dist[i, k]
|
| 381 |
+
y = py + d * np.sin(angle_pi * k)
|
| 382 |
+
x = px + d * np.cos(angle_pi * k)
|
| 383 |
+
|
| 384 |
+
if k == 0:
|
| 385 |
+
bbox_x1[i] = x
|
| 386 |
+
bbox_x2[i] = x
|
| 387 |
+
bbox_y1[i] = y
|
| 388 |
+
bbox_y2[i] = y
|
| 389 |
+
else:
|
| 390 |
+
bbox_x1[i] = min(x, bbox_x1[i])
|
| 391 |
+
bbox_x2[i] = max(x, bbox_x2[i])
|
| 392 |
+
bbox_y1[i] = min(y, bbox_y1[i])
|
| 393 |
+
bbox_y2[i] = max(y, bbox_y2[i])
|
| 394 |
+
|
| 395 |
+
max_radius_outer = max(d, max_radius_outer)
|
| 396 |
+
|
| 397 |
+
areas[i] = (bbox_x2[i] - bbox_x1[i]) * (bbox_y2[i] - bbox_y1[i])
|
| 398 |
+
max_dist = max(max_dist, max_radius_outer)
|
| 399 |
+
|
| 400 |
+
return bbox_x1, bbox_y1, bbox_x2, bbox_y2, areas, max_dist
|