| """ |
| HNE2Cell — Step 3: Cell Detection & Classification Inference |
| |
| Run the HNE2Cell model on extracted patches to detect and classify cells. |
| Outputs: per-patch cell masks (PNG) and centroid CSVs with cell type annotations. |
| |
| Usage: |
| python inference.py \ |
| --input_dir /path/to/patch_folders \ |
| --output_dir /path/to/results \ |
| --model_path ./HNE2cell_all_patch73_jit.pt \ |
| --magnification 40 \ |
| --batch_size 32 |
| |
| Cell Types (16 classes): |
| 0: Background 4: B 8: DC 12: Epithelial |
| 1: Malignant 5: Plasma 9: Fibroblast 13: Immune_Other |
| 2: CD4T 6: Macrophage 10: Endothelial 14: Stromal_Other |
| 3: CD8T 7: Myeloid 11: Pericyte 15: Dead |
| """ |
|
|
| import os |
| import argparse |
| import glob |
|
|
| import cv2 |
| import numpy as np |
| import pandas as pd |
| import torch |
| from PIL import Image |
| from torch.cuda.amp import autocast |
| from torch.utils.data import DataLoader, Dataset |
| from torchvision import transforms |
| from tqdm import tqdm |
|
|
| from post_processing import DetectionCellPostProcessor |
|
|
| |
|
|
| CELL_TYPES = { |
| 0: "Background", |
| 1: "Malignant", |
| 2: "CD4T", |
| 3: "CD8T", |
| 4: "B", |
| 5: "Plasma", |
| 6: "Macrophage", |
| 7: "Myeloid", |
| 8: "DC", |
| 9: "Fibroblast", |
| 10: "Endothelial", |
| 11: "Pericyte", |
| 12: "Epithelial", |
| 13: "Immune_Other", |
| 14: "Stromal_Other", |
| 15: "Dead", |
| } |
|
|
| |
| CELL_COLORS = { |
| 0: [0, 0, 0, 0], |
| 1: [255, 0, 0, 255], |
| 2: [30, 144, 255, 255], |
| 3: [65, 105, 225, 255], |
| 4: [0, 0, 255, 255], |
| 5: [100, 149, 237, 255], |
| 6: [176, 224, 230, 255], |
| 7: [70, 130, 180, 255], |
| 8: [0, 191, 255, 255], |
| 9: [34, 139, 34, 255], |
| 10: [60, 179, 113, 255], |
| 11: [50, 205, 50, 255], |
| 12: [255, 140, 0, 255], |
| 13: [176, 224, 230, 255], |
| 14: [107, 142, 35, 255], |
| 15: [128, 128, 128, 255], |
| } |
|
|
| |
| TRANSFORM = transforms.Compose( |
| [ |
| transforms.Resize((224, 224)), |
| transforms.ToTensor(), |
| transforms.Normalize( |
| mean=[0.707223, 0.578729, 0.703617], |
| std=[0.211883, 0.230117, 0.177517], |
| ), |
| ] |
| ) |
|
|
|
|
| |
|
|
|
|
| class PatchDataset(Dataset): |
| def __init__(self, file_paths, transform=None): |
| self.file_paths = file_paths |
| self.transform = transform |
|
|
| def __len__(self): |
| return len(self.file_paths) |
|
|
| def __getitem__(self, idx): |
| fpath = self.file_paths[idx] |
| img = cv2.imread(fpath) |
| if img is None: |
| print(f"[WARN] Failed to load: {fpath}") |
| img = np.zeros((256, 256, 3), dtype=np.uint8) |
| img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) |
| if self.transform: |
| img = self.transform(Image.fromarray(img)) |
| return fpath, img |
|
|
|
|
| |
|
|
|
|
| def process_batch( |
| batch, |
| model, |
| device, |
| mask_output_dir, |
| centroid_records, |
| magnification=40, |
| ): |
| """Run inference on one batch and save masks + centroid info.""" |
| file_paths, images = batch |
|
|
| with torch.no_grad(): |
| with autocast(): |
| outputs = model(images.to(device, non_blocking=True)) |
|
|
| for i, fpath in enumerate(file_paths): |
| slide_id = os.path.splitext(os.path.basename(fpath))[0] |
|
|
| |
| cell_type_map = outputs["cell_type_map"][i].float().detach().cpu() |
| nuclei_binary_map = outputs["nuclei_binary_map"][i].float().detach().cpu() |
| hv_map = outputs["hv_map"][i].float().detach().cpu() |
| tissue_type_map = outputs["tissue_type_map"][i].float().detach().cpu() |
|
|
| |
| pred_map = np.concatenate( |
| [ |
| torch.argmax(tissue_type_map, dim=0)[..., None].numpy(), |
| torch.argmax(cell_type_map, dim=0)[..., None].numpy(), |
| torch.argmax(nuclei_binary_map, dim=0)[..., None].numpy(), |
| hv_map.permute(1, 2, 0).numpy(), |
| ], |
| axis=-1, |
| ) |
|
|
| |
| post_processor = DetectionCellPostProcessor( |
| nr_types=cell_type_map.shape[0], |
| magnification=magnification, |
| gt=False, |
| ) |
| _, type_pred = post_processor.post_process_cell_segmentation(pred_map) |
|
|
| |
| mask = np.ones((256, 256, 3), dtype=np.uint8) * 255 |
| for cell in type_pred.values(): |
| ctype = cell["type"] |
| rgba = CELL_COLORS.get(ctype, [255, 255, 255, 255]) |
| bgr = [rgba[2], rgba[1], rgba[0]] |
| cv2.fillPoly(mask, [cell["contour"]], bgr) |
|
|
| centroid_records.append( |
| { |
| "slide_id": slide_id, |
| "x": cell["centroid"][0], |
| "y": cell["centroid"][1], |
| "celltype": ctype, |
| "celltype_name": CELL_TYPES.get(ctype, "Unknown"), |
| } |
| ) |
|
|
| |
| if not np.all(mask == 255): |
| cv2.imwrite( |
| os.path.join(mask_output_dir, f"{slide_id}_mask.png"), mask |
| ) |
|
|
|
|
| def run_inference( |
| patch_folders: list[str], |
| model, |
| device, |
| output_dir: str, |
| magnification: int = 40, |
| batch_size: int = 32, |
| num_workers: int = 4, |
| ): |
| """Run inference over a list of patch folders.""" |
| model.to(device).eval() |
|
|
| for folder in patch_folders: |
| folder_name = os.path.basename(folder) |
| png_files = sorted(glob.glob(os.path.join(folder, "*.png"))) |
|
|
| if not png_files: |
| print(f"[SKIP] {folder}: no PNG patches found") |
| continue |
|
|
| mask_dir = os.path.join(output_dir, "mask_patches", f"{folder_name}") |
| centroid_path = os.path.join(output_dir, "centroid", f"{folder_name}_centroid.csv") |
| os.makedirs(mask_dir, exist_ok=True) |
| os.makedirs(os.path.dirname(centroid_path), exist_ok=True) |
|
|
| dataset = PatchDataset(png_files, transform=TRANSFORM) |
| loader = DataLoader( |
| dataset, |
| batch_size=batch_size, |
| num_workers=num_workers, |
| pin_memory=True, |
| shuffle=False, |
| prefetch_factor=2, |
| persistent_workers=True, |
| ) |
|
|
| centroids = [] |
| for batch in tqdm(loader, desc=f"Inference: {folder_name}"): |
| process_batch( |
| batch, model, device, mask_dir, centroids, magnification |
| ) |
|
|
| df = pd.DataFrame(centroids) |
| df.to_csv(centroid_path, index=False) |
| print(f"[DONE] {folder_name} → {centroid_path} ({len(df)} cells)") |
|
|
|
|
| |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="HNE2Cell inference") |
| parser.add_argument( |
| "--input_dir", |
| type=str, |
| required=True, |
| help="Directory containing patch folders (each with *.png)", |
| ) |
| parser.add_argument( |
| "--output_dir", |
| type=str, |
| required=True, |
| help="Output directory for masks and centroid CSVs", |
| ) |
| parser.add_argument( |
| "--model_path", |
| type=str, |
| required=True, |
| help="Path to the TorchScript JIT model (.pt)", |
| ) |
| parser.add_argument( |
| "--magnification", |
| type=int, |
| default=40, |
| choices=[20, 40], |
| help="Magnification of input patches. 40x recommended. (default: 40)", |
| ) |
| parser.add_argument("--batch_size", type=int, default=32) |
| parser.add_argument("--num_workers", type=int, default=4) |
| parser.add_argument( |
| "--device", |
| type=str, |
| default="auto", |
| help="Device: 'cuda', 'cpu', or 'auto' (default: auto)", |
| ) |
|
|
| args = parser.parse_args() |
|
|
| |
| if args.device == "auto": |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| else: |
| device = torch.device(args.device) |
| print(f"Using device: {device}") |
|
|
| |
| print(f"Loading model: {args.model_path}") |
| model = torch.jit.load(args.model_path, map_location=device) |
| model.eval() |
|
|
| if args.magnification == 20: |
| print( |
| "⚠️ Running at 20x. Results are usable but 40x is recommended " |
| "for best accuracy, especially for small immune cells." |
| ) |
|
|
| |
| patch_folders = sorted( |
| p |
| for p in glob.glob(os.path.join(args.input_dir, "*")) |
| if os.path.isdir(p) |
| ) |
| |
| if not patch_folders and glob.glob(os.path.join(args.input_dir, "*.png")): |
| patch_folders = [args.input_dir] |
|
|
| print(f"Found {len(patch_folders)} patch folder(s)") |
|
|
| run_inference( |
| patch_folders, |
| model, |
| device, |
| args.output_dir, |
| magnification=args.magnification, |
| batch_size=args.batch_size, |
| num_workers=args.num_workers, |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|