HNE2Cell / inference.py
roobee79's picture
Upload 7 files
7747544 verified
"""
HNE2Cell — Step 3: Cell Detection & Classification Inference
Run the HNE2Cell model on extracted patches to detect and classify cells.
Outputs: per-patch cell masks (PNG) and centroid CSVs with cell type annotations.
Usage:
python inference.py \
--input_dir /path/to/patch_folders \
--output_dir /path/to/results \
--model_path ./HNE2cell_all_patch73_jit.pt \
--magnification 40 \
--batch_size 32
Cell Types (16 classes):
0: Background 4: B 8: DC 12: Epithelial
1: Malignant 5: Plasma 9: Fibroblast 13: Immune_Other
2: CD4T 6: Macrophage 10: Endothelial 14: Stromal_Other
3: CD8T 7: Myeloid 11: Pericyte 15: Dead
"""
import os
import argparse
import glob
import cv2
import numpy as np
import pandas as pd
import torch
from PIL import Image
from torch.cuda.amp import autocast
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
from tqdm import tqdm
from post_processing import DetectionCellPostProcessor
# ========================== Constants ======================================
CELL_TYPES = {
0: "Background",
1: "Malignant",
2: "CD4T",
3: "CD8T",
4: "B",
5: "Plasma",
6: "Macrophage",
7: "Myeloid",
8: "DC",
9: "Fibroblast",
10: "Endothelial",
11: "Pericyte",
12: "Epithelial",
13: "Immune_Other",
14: "Stromal_Other",
15: "Dead",
}
# RGBA colors for mask visualization
CELL_COLORS = {
0: [0, 0, 0, 0],
1: [255, 0, 0, 255],
2: [30, 144, 255, 255],
3: [65, 105, 225, 255],
4: [0, 0, 255, 255],
5: [100, 149, 237, 255],
6: [176, 224, 230, 255],
7: [70, 130, 180, 255],
8: [0, 191, 255, 255],
9: [34, 139, 34, 255],
10: [60, 179, 113, 255],
11: [50, 205, 50, 255],
12: [255, 140, 0, 255],
13: [176, 224, 230, 255],
14: [107, 142, 35, 255],
15: [128, 128, 128, 255],
}
# ImageNet-style normalization fitted to H&E data
TRANSFORM = transforms.Compose(
[
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.707223, 0.578729, 0.703617],
std=[0.211883, 0.230117, 0.177517],
),
]
)
# ========================== Dataset ========================================
class PatchDataset(Dataset):
def __init__(self, file_paths, transform=None):
self.file_paths = file_paths
self.transform = transform
def __len__(self):
return len(self.file_paths)
def __getitem__(self, idx):
fpath = self.file_paths[idx]
img = cv2.imread(fpath)
if img is None:
print(f"[WARN] Failed to load: {fpath}")
img = np.zeros((256, 256, 3), dtype=np.uint8)
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
if self.transform:
img = self.transform(Image.fromarray(img))
return fpath, img
# ========================== Inference ======================================
def process_batch(
batch,
model,
device,
mask_output_dir,
centroid_records,
magnification=40,
):
"""Run inference on one batch and save masks + centroid info."""
file_paths, images = batch
with torch.no_grad():
with autocast():
outputs = model(images.to(device, non_blocking=True))
for i, fpath in enumerate(file_paths):
slide_id = os.path.splitext(os.path.basename(fpath))[0]
# Extract per-sample predictions
cell_type_map = outputs["cell_type_map"][i].float().detach().cpu()
nuclei_binary_map = outputs["nuclei_binary_map"][i].float().detach().cpu()
hv_map = outputs["hv_map"][i].float().detach().cpu()
tissue_type_map = outputs["tissue_type_map"][i].float().detach().cpu()
# Build prediction map [H, W, 5]
pred_map = np.concatenate(
[
torch.argmax(tissue_type_map, dim=0)[..., None].numpy(),
torch.argmax(cell_type_map, dim=0)[..., None].numpy(),
torch.argmax(nuclei_binary_map, dim=0)[..., None].numpy(),
hv_map.permute(1, 2, 0).numpy(),
],
axis=-1,
)
# Post-processing
post_processor = DetectionCellPostProcessor(
nr_types=cell_type_map.shape[0],
magnification=magnification,
gt=False,
)
_, type_pred = post_processor.post_process_cell_segmentation(pred_map)
# Create mask image
mask = np.ones((256, 256, 3), dtype=np.uint8) * 255
for cell in type_pred.values():
ctype = cell["type"]
rgba = CELL_COLORS.get(ctype, [255, 255, 255, 255])
bgr = [rgba[2], rgba[1], rgba[0]]
cv2.fillPoly(mask, [cell["contour"]], bgr)
centroid_records.append(
{
"slide_id": slide_id,
"x": cell["centroid"][0],
"y": cell["centroid"][1],
"celltype": ctype,
"celltype_name": CELL_TYPES.get(ctype, "Unknown"),
}
)
# Save mask only if non-trivial
if not np.all(mask == 255):
cv2.imwrite(
os.path.join(mask_output_dir, f"{slide_id}_mask.png"), mask
)
def run_inference(
patch_folders: list[str],
model,
device,
output_dir: str,
magnification: int = 40,
batch_size: int = 32,
num_workers: int = 4,
):
"""Run inference over a list of patch folders."""
model.to(device).eval()
for folder in patch_folders:
folder_name = os.path.basename(folder)
png_files = sorted(glob.glob(os.path.join(folder, "*.png")))
if not png_files:
print(f"[SKIP] {folder}: no PNG patches found")
continue
mask_dir = os.path.join(output_dir, "mask_patches", f"{folder_name}")
centroid_path = os.path.join(output_dir, "centroid", f"{folder_name}_centroid.csv")
os.makedirs(mask_dir, exist_ok=True)
os.makedirs(os.path.dirname(centroid_path), exist_ok=True)
dataset = PatchDataset(png_files, transform=TRANSFORM)
loader = DataLoader(
dataset,
batch_size=batch_size,
num_workers=num_workers,
pin_memory=True,
shuffle=False,
prefetch_factor=2,
persistent_workers=True,
)
centroids = []
for batch in tqdm(loader, desc=f"Inference: {folder_name}"):
process_batch(
batch, model, device, mask_dir, centroids, magnification
)
df = pd.DataFrame(centroids)
df.to_csv(centroid_path, index=False)
print(f"[DONE] {folder_name}{centroid_path} ({len(df)} cells)")
# =============================== CLI =======================================
def main():
parser = argparse.ArgumentParser(description="HNE2Cell inference")
parser.add_argument(
"--input_dir",
type=str,
required=True,
help="Directory containing patch folders (each with *.png)",
)
parser.add_argument(
"--output_dir",
type=str,
required=True,
help="Output directory for masks and centroid CSVs",
)
parser.add_argument(
"--model_path",
type=str,
required=True,
help="Path to the TorchScript JIT model (.pt)",
)
parser.add_argument(
"--magnification",
type=int,
default=40,
choices=[20, 40],
help="Magnification of input patches. 40x recommended. (default: 40)",
)
parser.add_argument("--batch_size", type=int, default=32)
parser.add_argument("--num_workers", type=int, default=4)
parser.add_argument(
"--device",
type=str,
default="auto",
help="Device: 'cuda', 'cpu', or 'auto' (default: auto)",
)
args = parser.parse_args()
# Device
if args.device == "auto":
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
device = torch.device(args.device)
print(f"Using device: {device}")
# Load model
print(f"Loading model: {args.model_path}")
model = torch.jit.load(args.model_path, map_location=device)
model.eval()
if args.magnification == 20:
print(
"⚠️ Running at 20x. Results are usable but 40x is recommended "
"for best accuracy, especially for small immune cells."
)
# Collect patch folders
patch_folders = sorted(
p
for p in glob.glob(os.path.join(args.input_dir, "*"))
if os.path.isdir(p)
)
# Also check if input_dir itself contains patches
if not patch_folders and glob.glob(os.path.join(args.input_dir, "*.png")):
patch_folders = [args.input_dir]
print(f"Found {len(patch_folders)} patch folder(s)")
run_inference(
patch_folders,
model,
device,
args.output_dir,
magnification=args.magnification,
batch_size=args.batch_size,
num_workers=args.num_workers,
)
if __name__ == "__main__":
main()