"""Run a davidclara/siegfried-maps-segmentation model from Hugging Face on a map image. Writes one binary PNG per class to --out-dir. Example: python inference.py \\ --hf-repo davidclara/siegfried-maps-segmentation \\ --model-name unetpp \\ --image map.png \\ --out-dir predictions/ """ import argparse import inspect import json from pathlib import Path import numpy as np import segmentation_models_pytorch as smp import torch from huggingface_hub import hf_hub_download from PIL import Image from safetensors.torch import load_file Image.MAX_IMAGE_PIXELS = None IMAGENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32) IMAGENET_STD = np.array([0.229, 0.224, 0.225], dtype=np.float32) SMP = { "unet": smp.Unet, "unetpp": smp.UnetPlusPlus, "deeplabv3p": smp.DeepLabV3Plus, "fpn": smp.FPN, "pan": smp.PAN, } def build_model(model_cfg: dict) -> torch.nn.Module: cfg = dict(model_cfg) name = cfg.pop("name") cfg["classes"] = cfg.pop("num_classes") cfg["encoder_weights"] = None # weights come from safetensors cls = SMP[name] accepted = set(inspect.signature(cls).parameters) return cls(**{k: v for k, v in cfg.items() if k in accepted}) def sliding_window_predict(model, img, patch, stride, n_classes, device): _, H, W = img.shape logits = np.zeros((n_classes, H, W), dtype=np.float32) weights = np.zeros((H, W), dtype=np.float32) rows = sorted({*range(0, max(H - patch, 0) + 1, stride), max(H - patch, 0)}) cols = sorted({*range(0, max(W - patch, 0) + 1, stride), max(W - patch, 0)}) with torch.no_grad(): for r in rows: for c in cols: tile = img[:, r : r + patch, c : c + patch] ph, pw = patch - tile.shape[1], patch - tile.shape[2] if ph or pw: tile = np.pad(tile, ((0, 0), (0, ph), (0, pw))) out = model(torch.from_numpy(tile).unsqueeze(0).to(device)).cpu().numpy()[0] h, w = patch - ph, patch - pw logits[:, r : r + h, c : c + w] += out[:, :h, :w] weights[r : r + h, c : c + w] += 1.0 return logits / np.maximum(weights, 1e-8) def main() -> None: ap = argparse.ArgumentParser() ap.add_argument("--hf-repo", required=True) ap.add_argument("--model-name", required=True) ap.add_argument("--image", required=True) ap.add_argument("--out-dir", default="predictions") args = ap.parse_args() cfg_path = hf_hub_download(args.hf_repo, f"{args.model_name}/config.json") wts_path = hf_hub_download(args.hf_repo, f"{args.model_name}/model.safetensors") cfg = json.loads(Path(cfg_path).read_text()) device = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu") model = build_model(cfg["model"]).to(device).eval() model.load_state_dict(load_file(wts_path)) classes = cfg["class_names"] thrs = cfg.get("thresholds") or {} thr = np.array([thrs.get(n, 0.5) for n in classes], dtype=np.float32).reshape(-1, 1, 1) patch = cfg["patch_size"] img = np.asarray(Image.open(args.image).convert("RGB"), dtype=np.float32) / 255.0 img = ((img - IMAGENET_MEAN) / IMAGENET_STD).transpose(2, 0, 1) logits = sliding_window_predict(model, img, patch, patch // 2, len(classes), device) probs = 1.0 / (1.0 + np.exp(-np.clip(logits, -88, 88))) binary = (probs > thr).astype(np.uint8) out_dir = Path(args.out_dir) out_dir.mkdir(parents=True, exist_ok=True) for i, name in enumerate(classes): Image.fromarray(binary[i] * 255, "L").save(out_dir / f"{name}.png") print(f"Wrote {len(classes)} masks to {out_dir}/") if __name__ == "__main__": main()