""" HNE2Cell — Step 3: Cell Detection & Classification Inference Run the HNE2Cell model on extracted patches to detect and classify cells. Outputs: per-patch cell masks (PNG) and centroid CSVs with cell type annotations. Usage: python inference.py \ --input_dir /path/to/patch_folders \ --output_dir /path/to/results \ --model_path ./HNE2cell_all_patch73_jit.pt \ --magnification 40 \ --batch_size 32 Cell Types (16 classes): 0: Background 4: B 8: DC 12: Epithelial 1: Malignant 5: Plasma 9: Fibroblast 13: Immune_Other 2: CD4T 6: Macrophage 10: Endothelial 14: Stromal_Other 3: CD8T 7: Myeloid 11: Pericyte 15: Dead """ import os import argparse import glob import cv2 import numpy as np import pandas as pd import torch from PIL import Image from torch.cuda.amp import autocast from torch.utils.data import DataLoader, Dataset from torchvision import transforms from tqdm import tqdm from post_processing import DetectionCellPostProcessor # ========================== Constants ====================================== CELL_TYPES = { 0: "Background", 1: "Malignant", 2: "CD4T", 3: "CD8T", 4: "B", 5: "Plasma", 6: "Macrophage", 7: "Myeloid", 8: "DC", 9: "Fibroblast", 10: "Endothelial", 11: "Pericyte", 12: "Epithelial", 13: "Immune_Other", 14: "Stromal_Other", 15: "Dead", } # RGBA colors for mask visualization CELL_COLORS = { 0: [0, 0, 0, 0], 1: [255, 0, 0, 255], 2: [30, 144, 255, 255], 3: [65, 105, 225, 255], 4: [0, 0, 255, 255], 5: [100, 149, 237, 255], 6: [176, 224, 230, 255], 7: [70, 130, 180, 255], 8: [0, 191, 255, 255], 9: [34, 139, 34, 255], 10: [60, 179, 113, 255], 11: [50, 205, 50, 255], 12: [255, 140, 0, 255], 13: [176, 224, 230, 255], 14: [107, 142, 35, 255], 15: [128, 128, 128, 255], } # ImageNet-style normalization fitted to H&E data TRANSFORM = transforms.Compose( [ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize( mean=[0.707223, 0.578729, 0.703617], std=[0.211883, 0.230117, 0.177517], ), ] ) # ========================== Dataset ======================================== class PatchDataset(Dataset): def __init__(self, file_paths, transform=None): self.file_paths = file_paths self.transform = transform def __len__(self): return len(self.file_paths) def __getitem__(self, idx): fpath = self.file_paths[idx] img = cv2.imread(fpath) if img is None: print(f"[WARN] Failed to load: {fpath}") img = np.zeros((256, 256, 3), dtype=np.uint8) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) if self.transform: img = self.transform(Image.fromarray(img)) return fpath, img # ========================== Inference ====================================== def process_batch( batch, model, device, mask_output_dir, centroid_records, magnification=40, ): """Run inference on one batch and save masks + centroid info.""" file_paths, images = batch with torch.no_grad(): with autocast(): outputs = model(images.to(device, non_blocking=True)) for i, fpath in enumerate(file_paths): slide_id = os.path.splitext(os.path.basename(fpath))[0] # Extract per-sample predictions cell_type_map = outputs["cell_type_map"][i].float().detach().cpu() nuclei_binary_map = outputs["nuclei_binary_map"][i].float().detach().cpu() hv_map = outputs["hv_map"][i].float().detach().cpu() tissue_type_map = outputs["tissue_type_map"][i].float().detach().cpu() # Build prediction map [H, W, 5] pred_map = np.concatenate( [ torch.argmax(tissue_type_map, dim=0)[..., None].numpy(), torch.argmax(cell_type_map, dim=0)[..., None].numpy(), torch.argmax(nuclei_binary_map, dim=0)[..., None].numpy(), hv_map.permute(1, 2, 0).numpy(), ], axis=-1, ) # Post-processing post_processor = DetectionCellPostProcessor( nr_types=cell_type_map.shape[0], magnification=magnification, gt=False, ) _, type_pred = post_processor.post_process_cell_segmentation(pred_map) # Create mask image mask = np.ones((256, 256, 3), dtype=np.uint8) * 255 for cell in type_pred.values(): ctype = cell["type"] rgba = CELL_COLORS.get(ctype, [255, 255, 255, 255]) bgr = [rgba[2], rgba[1], rgba[0]] cv2.fillPoly(mask, [cell["contour"]], bgr) centroid_records.append( { "slide_id": slide_id, "x": cell["centroid"][0], "y": cell["centroid"][1], "celltype": ctype, "celltype_name": CELL_TYPES.get(ctype, "Unknown"), } ) # Save mask only if non-trivial if not np.all(mask == 255): cv2.imwrite( os.path.join(mask_output_dir, f"{slide_id}_mask.png"), mask ) def run_inference( patch_folders: list[str], model, device, output_dir: str, magnification: int = 40, batch_size: int = 32, num_workers: int = 4, ): """Run inference over a list of patch folders.""" model.to(device).eval() for folder in patch_folders: folder_name = os.path.basename(folder) png_files = sorted(glob.glob(os.path.join(folder, "*.png"))) if not png_files: print(f"[SKIP] {folder}: no PNG patches found") continue mask_dir = os.path.join(output_dir, "mask_patches", f"{folder_name}") centroid_path = os.path.join(output_dir, "centroid", f"{folder_name}_centroid.csv") os.makedirs(mask_dir, exist_ok=True) os.makedirs(os.path.dirname(centroid_path), exist_ok=True) dataset = PatchDataset(png_files, transform=TRANSFORM) loader = DataLoader( dataset, batch_size=batch_size, num_workers=num_workers, pin_memory=True, shuffle=False, prefetch_factor=2, persistent_workers=True, ) centroids = [] for batch in tqdm(loader, desc=f"Inference: {folder_name}"): process_batch( batch, model, device, mask_dir, centroids, magnification ) df = pd.DataFrame(centroids) df.to_csv(centroid_path, index=False) print(f"[DONE] {folder_name} → {centroid_path} ({len(df)} cells)") # =============================== CLI ======================================= def main(): parser = argparse.ArgumentParser(description="HNE2Cell inference") parser.add_argument( "--input_dir", type=str, required=True, help="Directory containing patch folders (each with *.png)", ) parser.add_argument( "--output_dir", type=str, required=True, help="Output directory for masks and centroid CSVs", ) parser.add_argument( "--model_path", type=str, required=True, help="Path to the TorchScript JIT model (.pt)", ) parser.add_argument( "--magnification", type=int, default=40, choices=[20, 40], help="Magnification of input patches. 40x recommended. (default: 40)", ) parser.add_argument("--batch_size", type=int, default=32) parser.add_argument("--num_workers", type=int, default=4) parser.add_argument( "--device", type=str, default="auto", help="Device: 'cuda', 'cpu', or 'auto' (default: auto)", ) args = parser.parse_args() # Device if args.device == "auto": device = torch.device("cuda" if torch.cuda.is_available() else "cpu") else: device = torch.device(args.device) print(f"Using device: {device}") # Load model print(f"Loading model: {args.model_path}") model = torch.jit.load(args.model_path, map_location=device) model.eval() if args.magnification == 20: print( "⚠️ Running at 20x. Results are usable but 40x is recommended " "for best accuracy, especially for small immune cells." ) # Collect patch folders patch_folders = sorted( p for p in glob.glob(os.path.join(args.input_dir, "*")) if os.path.isdir(p) ) # Also check if input_dir itself contains patches if not patch_folders and glob.glob(os.path.join(args.input_dir, "*.png")): patch_folders = [args.input_dir] print(f"Found {len(patch_folders)} patch folder(s)") run_inference( patch_folders, model, device, args.output_dir, magnification=args.magnification, batch_size=args.batch_size, num_workers=args.num_workers, ) if __name__ == "__main__": main()