| """Standalone inference helpers for MALUNet on CVC-ClinicDB. |
| |
| `load_model` accepts either a local checkpoint path or an "<owner>/<repo>" |
| reference to a Hugging Face model repository (it downloads `best.pth`). |
| |
| CLI: |
| python infer.py --weights ./best.pth --image polyp.png --out mask.png |
| python infer.py --weights jane-l/malunet-cvc --image polyp.png --out mask.png |
| """ |
| import argparse |
| import io |
| import os |
| from pathlib import Path |
| from typing import Tuple, Union |
|
|
| import numpy as np |
| import torch |
| from PIL import Image |
|
|
| from models.malunet import MALUNet |
|
|
| DEFAULT_MODEL_CONFIG = { |
| "num_classes": 1, |
| "input_channels": 3, |
| "c_list": [8, 16, 24, 32, 48, 64], |
| "split_att": "fc", |
| "bridge": True, |
| } |
| INPUT_SIZE = 256 |
| NORM_MEAN = 109.0 |
| NORM_STD = 75.0 |
|
|
|
|
| def _build(): |
| return MALUNet( |
| num_classes=DEFAULT_MODEL_CONFIG["num_classes"], |
| input_channels=DEFAULT_MODEL_CONFIG["input_channels"], |
| c_list=DEFAULT_MODEL_CONFIG["c_list"], |
| split_att=DEFAULT_MODEL_CONFIG["split_att"], |
| bridge=DEFAULT_MODEL_CONFIG["bridge"], |
| ) |
|
|
|
|
| def _is_hf_repo_id(s: str) -> bool: |
| if os.path.exists(s): |
| return False |
| return "/" in s and not s.endswith(".pth") and not s.endswith(".pt") |
|
|
|
|
| def _strip_module_prefix(state_dict): |
| return {k[7:] if k.startswith("module.") else k: v for k, v in state_dict.items()} |
|
|
|
|
| def load_model(weights: str, device: Union[str, torch.device, None] = None) -> torch.nn.Module: |
| if device is None: |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| elif isinstance(device, str): |
| device = torch.device(device) |
|
|
| if _is_hf_repo_id(weights): |
| from huggingface_hub import hf_hub_download |
|
|
| weights = hf_hub_download(repo_id=weights, filename="best.pth") |
|
|
| state = torch.load(weights, map_location="cpu") |
| if isinstance(state, dict) and "model_state_dict" in state: |
| state = state["model_state_dict"] |
| state = _strip_module_prefix(state) |
|
|
| model = _build() |
| model.load_state_dict(state, strict=True) |
| model.to(device).eval() |
| return model |
|
|
|
|
| def _preprocess(img: Image.Image) -> Tuple[torch.Tensor, Tuple[int, int]]: |
| """RGB PIL image -> normalized (1,3,H,W) tensor. Returns the original (H,W).""" |
| img = img.convert("RGB") |
| orig_size = img.size[::-1] |
| arr = np.asarray(img, dtype=np.float32) |
| arr = (arr - NORM_MEAN) / NORM_STD |
| lo, hi = arr.min(), arr.max() |
| if hi > lo: |
| arr = (arr - lo) / (hi - lo) * 255.0 |
| else: |
| arr = np.zeros_like(arr) |
| img_resized = Image.fromarray(arr.astype(np.uint8)).resize( |
| (INPUT_SIZE, INPUT_SIZE), Image.BILINEAR |
| ) |
| t = torch.from_numpy(np.asarray(img_resized, dtype=np.float32)).permute(2, 0, 1).unsqueeze(0) |
| return t, orig_size |
|
|
|
|
| @torch.no_grad() |
| def predict_mask( |
| model: torch.nn.Module, |
| image: Union[str, Path, Image.Image, bytes], |
| threshold: float = 0.5, |
| return_prob: bool = False, |
| ) -> np.ndarray: |
| """Returns a uint8 mask resized back to the original image resolution.""" |
| if isinstance(image, (str, Path)): |
| img = Image.open(image) |
| elif isinstance(image, bytes): |
| img = Image.open(io.BytesIO(image)) |
| elif isinstance(image, Image.Image): |
| img = image |
| else: |
| raise TypeError(f"unsupported image type: {type(image)}") |
|
|
| device = next(model.parameters()).device |
| t, (h, w) = _preprocess(img) |
| t = t.to(device).float() |
| out = model(t) |
| prob = out[0, 0].cpu().numpy() |
| prob_full = np.array( |
| Image.fromarray((prob * 255).astype(np.uint8)).resize((w, h), Image.BILINEAR), |
| dtype=np.float32, |
| ) / 255.0 |
| if return_prob: |
| return prob_full |
| return (prob_full >= threshold).astype(np.uint8) * 255 |
|
|
|
|
| def overlay(image: Image.Image, mask: np.ndarray, alpha: float = 0.45) -> Image.Image: |
| base = image.convert("RGB") |
| bw, bh = base.size |
| if mask.shape != (bh, bw): |
| mask = np.array(Image.fromarray(mask).resize((bw, bh), Image.NEAREST)) |
| color = np.zeros((bh, bw, 3), dtype=np.uint8) |
| color[..., 0] = mask |
| base_arr = np.asarray(base, dtype=np.float32) |
| mask_bool = mask > 0 |
| blended = base_arr.copy() |
| blended[mask_bool] = (1 - alpha) * base_arr[mask_bool] + alpha * color[mask_bool] |
| return Image.fromarray(blended.astype(np.uint8)) |
|
|
|
|
| def main(): |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--weights", required=True, help="Local .pth path OR <owner>/<repo> on HF") |
| ap.add_argument("--image", required=True) |
| ap.add_argument("--out", default="mask.png") |
| ap.add_argument("--overlay-out", default=None, help="optional overlay PNG path") |
| ap.add_argument("--threshold", type=float, default=0.5) |
| args = ap.parse_args() |
|
|
| model = load_model(args.weights) |
| img = Image.open(args.image) |
| mask = predict_mask(model, img, threshold=args.threshold) |
| Image.fromarray(mask).save(args.out) |
| print(f"wrote {args.out}") |
| if args.overlay_out: |
| overlay(img, mask).save(args.overlay_out) |
| print(f"wrote {args.overlay_out}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|