"""Prediction pipeline for retinal segmentation. Usage: # Single image python -m src.predict --checkpoint best_model.pth --input image.png --output output/ # Directory of images python -m src.predict --checkpoint best_model.pth --input images/ --output output/ # With TTA and custom threshold python -m src.predict --checkpoint best_model.pth --input images/ --output output/ --tta --threshold 0.45 """ import argparse import os from pathlib import Path import albumentations as A import matplotlib.pyplot as plt import numpy as np import pandas as pd import torch from albumentations.pytorch import ToTensorV2 from PIL import Image from scipy import ndimage from scipy.ndimage import distance_transform_edt from skimage.measure import label as sk_label from skimage.measure import regionprops from torch.amp import autocast from src.config import Config from src.model import build_model MASK_COLORS = { "nv": (0.7, 0.0, 1.0), # purple (matches app.py) "vo": (0.0, 0.5, 1.0), # blue "retina": (0.0, 0.8, 0.0), # green } def load_model(checkpoint_path, config, device): """Load model from checkpoint, overriding architecture config from saved state.""" ckpt = torch.load(checkpoint_path, map_location=device, weights_only=False) # Override architecture fields from checkpoint if available if "config" in ckpt: saved = ckpt["config"] config.image_size = tuple(saved.get("image_size", config.image_size)) config.encoder_name = saved.get("encoder_name", config.encoder_name) config.decoder_attention = saved.get("decoder_attention", config.decoder_attention) config.num_classes = saved.get("num_classes", config.num_classes) config.mask_names = tuple(saved.get("mask_names", config.mask_names)) model = build_model(config) model.load_state_dict(ckpt["model_state_dict"]) model.to(device) model.eval() return model MAX_INPUT_SIZE = 1024 # images larger than this are downscaled before inference def get_preprocess(config): """Validation-style preprocessing: resize + normalize.""" return A.Compose( [ A.Resize(config.image_size[0], config.image_size[1]), A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), ToTensorV2(), ] ) def resize_to_max(image_np, max_side=MAX_INPUT_SIZE): """Downscale image so its longest side <= max_side, preserving aspect ratio. Returns: resized_np: downscaled uint8 image scale: float, resized/original (same for both axes) """ h, w = image_np.shape[:2] if h <= max_side and w <= max_side: return image_np, 1.0 scale = max_side / max(h, w) new_h, new_w = int(round(h * scale)), int(round(w * scale)) resized = np.array(Image.fromarray(image_np).resize((new_w, new_h), Image.LANCZOS)) print(f" Resized {w}x{h} -> {new_w}x{new_h} (scale={scale:.4f})") return resized, scale def predict_single(model, image_np, preprocess, device, config, tta=False, threshold=0.5): """Run inference on a single image. Args: model: trained model in eval mode image_np: HxWx3 uint8 numpy array (RGB) preprocess: albumentations transform device: torch device config: Config object tta: if True, average predictions over flips threshold: binarization threshold Returns: masks_prob: [num_classes, H, W] float32 probabilities (original resolution) masks_binary: [num_classes, H, W] uint8 binary masks (original resolution) """ orig_h, orig_w = image_np.shape[:2] def _infer(img_np): t = preprocess(image=img_np)["image"].unsqueeze(0).to(device) with autocast(device_type=device.type, enabled=(device.type == "cuda")): logits = model(t) return logits.squeeze(0).detach().cpu() logits = _infer(image_np) if tta: # Horizontal flip l_hflip = _infer(image_np[:, ::-1].copy()) l_hflip = torch.flip(l_hflip, dims=[2]) # Vertical flip l_vflip = _infer(image_np[::-1, :].copy()) l_vflip = torch.flip(l_vflip, dims=[1]) # Both flips l_hvflip = _infer(image_np[::-1, ::-1].copy()) l_hvflip = torch.flip(l_hvflip, dims=[1, 2]) logits = (logits + l_hflip + l_vflip + l_hvflip) / 4.0 probs = torch.sigmoid(logits) # Resize probabilities back to original resolution probs_np = probs.numpy() masks_prob = np.zeros((config.num_classes, orig_h, orig_w), dtype=np.float32) for i in range(config.num_classes): resized = np.array(Image.fromarray(probs_np[i]).resize((orig_w, orig_h), Image.BILINEAR)) masks_prob[i] = resized masks_binary = (masks_prob > threshold).astype(np.uint8) return masks_prob, masks_binary def predict_tiled( model, image_np, preprocess, device, config, tta=False, threshold=0.5, tile_size=512, overlap=128, ): """Tiled inference for large images with overlap blending. Splits the image into overlapping tiles, runs inference on each, then stitches predictions back using a linear blend in the overlap zones. """ orig_h, orig_w = image_np.shape[:2] num_classes = config.num_classes stride = tile_size - overlap acc = np.zeros((num_classes, orig_h, orig_w), dtype=np.float64) weight = np.zeros((orig_h, orig_w), dtype=np.float64) # 1-D linear ramp for blending: 0→1 over overlap, 1 in center, 1→0 over overlap def make_blend_1d(size): w = np.ones(size, dtype=np.float64) ramp = np.linspace(0, 1, overlap, endpoint=False) w[:overlap] = ramp w[size - overlap :] = ramp[::-1] return w blend_h = make_blend_1d(tile_size) blend_w = make_blend_1d(tile_size) blend_2d = np.outer(blend_h, blend_w) # (tile_size, tile_size) # Build tile grid (top-left corners) ys = list(range(0, orig_h - tile_size, stride)) + [orig_h - tile_size] xs = list(range(0, orig_w - tile_size, stride)) + [orig_w - tile_size] ys = sorted(set(max(0, y) for y in ys)) xs = sorted(set(max(0, x) for x in xs)) total = len(ys) * len(xs) print(f" Tiled inference: {orig_h}x{orig_w} -> {len(ys)}x{len(xs)} = {total} tiles") def _infer_tile(tile_np): t = preprocess(image=tile_np)["image"].unsqueeze(0).to(device) with autocast(device_type=device.type, enabled=(device.type == "cuda")): logits = model(t) return logits.squeeze(0).detach().cpu().numpy() # (C, tile_size, tile_size) count = 0 for y in ys: for x in xs: tile = image_np[y : y + tile_size, x : x + tile_size] # Pad if tile is smaller than expected (edge case) th, tw = tile.shape[:2] if th < tile_size or tw < tile_size: padded = np.zeros((tile_size, tile_size, 3), dtype=np.uint8) padded[:th, :tw] = tile tile = padded logits_tile = _infer_tile(tile) if tta: l_hflip = _infer_tile(tile[:, ::-1].copy()) l_hflip = l_hflip[:, :, ::-1] l_vflip = _infer_tile(tile[::-1, :].copy()) l_vflip = l_vflip[:, ::-1, :] l_hvflip = _infer_tile(tile[::-1, ::-1].copy()) l_hvflip = l_hvflip[:, ::-1, ::-1] logits_tile = (logits_tile + l_hflip + l_vflip + l_hvflip) / 4.0 # Accumulate with blend weights actual_h = min(tile_size, orig_h - y) actual_w = min(tile_size, orig_w - x) b = blend_2d[:actual_h, :actual_w] acc[:, y : y + actual_h, x : x + actual_w] += logits_tile[:, :actual_h, :actual_w] * b weight[y : y + actual_h, x : x + actual_w] += b count += 1 if count % 50 == 0 or count == total: print(f" {count}/{total} tiles done") # Normalize by accumulated weights, then sigmoid to get probabilities weight = np.maximum(weight, 1e-8) masks_logits = (acc / weight).astype(np.float32) masks_prob = (1.0 / (1.0 + np.exp(-masks_logits))).astype(np.float32) masks_binary = (masks_prob > threshold).astype(np.uint8) return masks_prob, masks_binary # ── Post-processing ─────────────────────────────────────────────────────────── def postprocess_mask(mask: np.ndarray) -> np.ndarray: """Fill holes then keep only the largest connected component.""" filled = ndimage.binary_fill_holes(mask).astype(np.uint8) labeled, n = ndimage.label(filled) if n == 0: return filled largest = int(np.argmax(ndimage.sum(filled, labeled, range(1, n + 1)))) + 1 return (labeled == largest).astype(np.uint8) def postprocess_vo(mask: np.ndarray, close_radius: int = 15) -> np.ndarray: """Aggressive VO post-processing: close gaps, fill holes, keep largest component.""" struct = ndimage.generate_binary_structure(2, 1) struct = ndimage.iterate_structure(struct, close_radius) closed = ndimage.binary_closing(mask.astype(bool), structure=struct) filled = ndimage.binary_fill_holes(closed).astype(np.uint8) labeled, n = ndimage.label(filled) if n == 0: return filled largest = int(np.argmax(ndimage.sum(filled, labeled, range(1, n + 1)))) + 1 return (labeled == largest).astype(np.uint8) def postprocess_nv( nv_mask: np.ndarray, vo_mask: np.ndarray, vessel_mask: np.ndarray | None = None, outside_px: int = 520, inside_px: int = 260, min_area: int = 150, max_eccentricity: float = 0.985, vessel_suppression: bool = True, boundary_masking: bool = True, ) -> np.ndarray: """Post-process NV mask to reduce false positives from normal vessels. Three stages: A. VO-boundary spatial masking — zero out NV far from the VO edge B. Vessel mask suppression — zero out NV overlapping known vessels C. Morphological filtering — remove elongated/tiny connected components """ result = nv_mask.copy() # A. VO-boundary spatial masking if boundary_masking: vo_bool = vo_mask.astype(bool) if vo_bool.any(): # Distance from each non-VO pixel to nearest VO pixel dist_outside = distance_transform_edt(~vo_bool) # Distance from each VO pixel to nearest non-VO pixel (VO interior depth) dist_inside = distance_transform_edt(vo_bool) # Boundary zone = within outside_px of VO edge (outside) and within inside_px (inside) boundary_zone = (dist_outside <= outside_px) & (dist_inside <= inside_px) result = result & boundary_zone.astype(np.uint8) # B. Vessel mask suppression if vessel_suppression and vessel_mask is not None: if vessel_mask.shape != result.shape: vessel_mask = np.array( Image.fromarray(vessel_mask).resize( (result.shape[1], result.shape[0]), Image.NEAREST ) ) result = result & (~vessel_mask.astype(bool)).astype(np.uint8) # C. Morphological component filtering if result.any(): labeled = sk_label(result, connectivity=2) for region in regionprops(labeled): if region.area < min_area or region.eccentricity > max_eccentricity: result[labeled == region.label] = 0 return result def postprocess_all( masks_binary: np.ndarray, mask_names: tuple, vessel_mask: np.ndarray | None = None, config=None, ) -> np.ndarray: """Apply class-specific post-processing to all masks. Order matters: VO is cleaned first so NV boundary masking uses a clean VO. Args: masks_binary: [num_classes, H, W] uint8 binary masks mask_names: tuple of class names, e.g. ("nv", "vo", "retina") vessel_mask: optional [H, W] uint8 binary vessel mask config: Config object (uses defaults if None) """ from src.config import Config if config is None: config = Config() result = masks_binary.copy() names = list(mask_names) # 1. VO post-processing (must be first — NV needs clean VO) if "vo" in names: result[names.index("vo")] = postprocess_vo(result[names.index("vo")]) # 2. Retina post-processing if "retina" in names: result[names.index("retina")] = postprocess_mask(result[names.index("retina")]) # 3. NV post-processing (uses cleaned VO mask) if "nv" in names and "vo" in names: nv_idx = names.index("nv") vo_idx = names.index("vo") result[nv_idx] = postprocess_nv( result[nv_idx], result[vo_idx], vessel_mask=vessel_mask, outside_px=config.nv_outside_px, inside_px=config.nv_inside_px, min_area=config.nv_min_component_area, max_eccentricity=config.nv_max_eccentricity, vessel_suppression=config.nv_vessel_suppression, boundary_masking=config.nv_boundary_masking, ) return result # ── Vessel mask loading ─────────────────────────────────────────────────────── _manifest_cache: dict[str, pd.DataFrame] = {} def load_vessel_mask( image_stem: str, manifest_path: str, vessel_mask_root: str = "data/Training data", vessel_mask_fallback: str = "data/vessels mask", ) -> np.ndarray | None: """Load a ground-truth vessel mask by image stem, if available. Tries manifest vessel_mask_path first, then falls back to the loose vessel mask folder (data/vessels mask/) by stem name. Returns [H, W] uint8 binary mask, or None if not found. """ if manifest_path not in _manifest_cache: try: _manifest_cache[manifest_path] = pd.read_csv(manifest_path) except FileNotFoundError: return None df = _manifest_cache[manifest_path] rows = df[df["stem"] == image_stem] if rows.empty: return None # Try 1: manifest vessel_mask_path column row = rows.iloc[0] vessel_path = row.get("vessel_mask_path", "") if vessel_path and not (isinstance(vessel_path, float) and np.isnan(vessel_path)): full_path = Path(vessel_mask_root) / Path(str(vessel_path).replace("\\", "/")) if full_path.exists(): mask = np.array(Image.open(str(full_path)).convert("L")) return (mask > 127).astype(np.uint8) # Try 2: fallback folder by stem name (.jpg then .png) fallback_dir = Path(vessel_mask_fallback) if fallback_dir.is_dir(): for ext in (".jpg", ".png", ".JPG", ".PNG"): fallback_path = fallback_dir / f"{image_stem}{ext}" if fallback_path.exists(): mask = np.array(Image.open(str(fallback_path)).convert("L")) return (mask > 127).astype(np.uint8) return None def save_masks(masks_binary, mask_names, output_dir, stem): """Save individual binary masks as PNGs.""" for i, name in enumerate(mask_names): mask_img = Image.fromarray(masks_binary[i] * 255) mask_img.save(os.path.join(output_dir, f"{stem}_{name}.png")) def save_overlay_large( image_np, masks_binary, masks_prob, mask_names, output_dir, stem, max_side=4096 ): """Save 4-panel overlay for large images using PIL (matches save_overlay layout).""" from PIL import ImageDraw orig_h, orig_w = image_np.shape[:2] # Downscale each panel so longest side <= max_side / 2 (4 panels fit in ~2x width) panel_max = max_side // 2 scale = min(panel_max / orig_w, panel_max / orig_h, 1.0) pw = int(orig_w * scale) ph = int(orig_h * scale) base = Image.fromarray(image_np).resize((pw, ph), Image.LANCZOS) mask_colors_rgba = { "nv": (178, 0, 255), "vo": (0, 128, 255), "retina": (0, 204, 0), } title_h = 30 # pixels for title bar panel_names = ["Input"] + list(mask_names) n_panels = len(panel_names) canvas_w = pw * n_panels canvas_h = ph + title_h canvas = Image.new("RGB", (canvas_w, canvas_h), (0, 0, 0)) draw = ImageDraw.Draw(canvas) # Panel 0: Input (unmodified) canvas.paste(base, (0, title_h)) draw.text((pw // 2, title_h // 2), "Input", fill=(255, 255, 255), anchor="mm") # Panels 1–3: one mask each for i, name in enumerate(mask_names): panel = base.copy().convert("RGBA") color = mask_colors_rgba.get(name, (255, 255, 0)) mask_small = np.array( Image.fromarray(masks_binary[i].astype(np.uint8) * 255).resize((pw, ph), Image.NEAREST) ) color_f = tuple(c / 255.0 for c in color[:3]) base_np = np.array(base.convert("RGB")).astype(np.float32) / 255.0 alpha = (mask_small > 0).astype(np.float32) * 0.55 blended = base_np.copy() for c, cv in enumerate(color_f): blended[..., c] = base_np[..., c] * (1 - alpha) + cv * alpha blended_uint8 = (np.clip(blended, 0, 1) * 255).astype(np.uint8) panel = Image.fromarray(blended_uint8) x_offset = (i + 1) * pw canvas.paste(panel.convert("RGB"), (x_offset, title_h)) draw.text((x_offset + pw // 2, title_h // 2), name, fill=tuple(color), anchor="mm") out_path = os.path.join(output_dir, f"{stem}_overlay.png") canvas.save(out_path) print(" Overlay saved -> " + out_path + " (" + str(canvas_w) + "x" + str(canvas_h) + ")") def save_overlay(image_np, masks_binary, masks_prob, mask_names, output_dir, stem): """Save a visualization overlay with original image and colored masks.""" fig, axes = plt.subplots(1, 1 + len(mask_names), figsize=(5 * (1 + len(mask_names)), 5)) # Original image axes[0].imshow(image_np) axes[0].set_title("Input") axes[0].axis("off") # Individual mask predictions for i, name in enumerate(mask_names): color = MASK_COLORS.get(name, (1, 1, 0)) alpha = masks_binary[i].astype(np.float32) * 0.55 base = image_np.astype(np.float32) / 255.0 blended = base.copy() for c, cv in enumerate(color): blended[..., c] = base[..., c] * (1 - alpha) + cv * alpha blended = np.clip(blended, 0, 1) axes[i + 1].imshow(blended) axes[i + 1].set_title(f"{name}") axes[i + 1].axis("off") plt.tight_layout() fig.savefig(os.path.join(output_dir, f"{stem}_overlay.png"), dpi=150, bbox_inches="tight") plt.close(fig) def predict_directory(model, input_dir, output_dir, config, device, tta=False, threshold=0.5): """Run prediction on all images in a directory.""" preprocess = get_preprocess(config) input_path = Path(input_dir) extensions = {".png", ".jpg", ".jpeg", ".tif", ".tiff", ".bmp"} image_files = sorted( f for f in input_path.iterdir() if f.suffix.lower() in extensions and f.is_file() ) if not image_files: print(f"No images found in {input_dir}") return os.makedirs(output_dir, exist_ok=True) mask_dir = os.path.join(output_dir, "masks") overlay_dir = os.path.join(output_dir, "overlays") os.makedirs(mask_dir, exist_ok=True) os.makedirs(overlay_dir, exist_ok=True) Image.MAX_IMAGE_PIXELS = None print(f"Predicting {len(image_files)} images...") for i, img_path in enumerate(image_files): image_np = np.array(Image.open(img_path).convert("RGB")) orig_h, orig_w = image_np.shape[:2] masks_prob, masks_binary = predict_single( model, image_np, preprocess, device, config, tta=tta, threshold=threshold ) stem = img_path.stem save_masks(masks_binary, config.mask_names, mask_dir, stem) if orig_h > MAX_INPUT_SIZE or orig_w > MAX_INPUT_SIZE: save_overlay_large( image_np, masks_binary, masks_prob, config.mask_names, overlay_dir, stem ) else: save_overlay(image_np, masks_binary, masks_prob, config.mask_names, overlay_dir, stem) print(f" [{i + 1}/{len(image_files)}] {img_path.name}") print(f"Done. Masks saved to {mask_dir}, overlays to {overlay_dir}") def main(): parser = argparse.ArgumentParser(description="Retinal segmentation prediction") parser.add_argument("--checkpoint", required=True, help="Path to best_model.pth") parser.add_argument("--input", required=True, help="Path to image or directory") parser.add_argument("--output", default="predictions", help="Output directory") parser.add_argument("--tta", action="store_true", help="Enable test-time augmentation") parser.add_argument("--threshold", type=float, default=0.5, help="Binarization threshold") parser.add_argument("--device", default=None, help="Device (auto-detected if not set)") parser.add_argument( "--no-attention", action="store_true", help="Disable decoder attention (for checkpoints trained without scSE)", ) args = parser.parse_args() config = Config() if args.no_attention: config.decoder_attention = None if args.device: device = torch.device(args.device) else: device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Device: {device}") model = load_model(args.checkpoint, config, device) input_path = Path(args.input) if input_path.is_file(): preprocess = get_preprocess(config) os.makedirs(args.output, exist_ok=True) Image.MAX_IMAGE_PIXELS = None image_np = np.array(Image.open(input_path).convert("RGB")) orig_h, orig_w = image_np.shape[:2] masks_prob, masks_binary = predict_single( model, image_np, preprocess, device, config, tta=args.tta, threshold=args.threshold ) stem = input_path.stem mask_dir = os.path.join(args.output, "masks") overlay_dir = os.path.join(args.output, "overlays") os.makedirs(mask_dir, exist_ok=True) os.makedirs(overlay_dir, exist_ok=True) save_masks(masks_binary, config.mask_names, mask_dir, stem) if orig_h > MAX_INPUT_SIZE or orig_w > MAX_INPUT_SIZE: save_overlay_large( image_np, masks_binary, masks_prob, config.mask_names, overlay_dir, stem ) else: save_overlay(image_np, masks_binary, masks_prob, config.mask_names, overlay_dir, stem) print(f"Saved to {args.output}") elif input_path.is_dir(): predict_directory( model, args.input, args.output, config, device, tta=args.tta, threshold=args.threshold ) else: print(f"Error: {args.input} is not a valid file or directory") if __name__ == "__main__": main()