| """Run a davidclara/siegfried-maps-segmentation model from Hugging Face on a map image. |
| |
| Writes one binary PNG per class to --out-dir. |
| |
| Example: |
| python inference.py \\ |
| --hf-repo davidclara/siegfried-maps-segmentation \\ |
| --model-name unetpp \\ |
| --image map.png \\ |
| --out-dir predictions/ |
| """ |
|
|
| import argparse |
| import inspect |
| import json |
| from pathlib import Path |
|
|
| import numpy as np |
| import segmentation_models_pytorch as smp |
| import torch |
| from huggingface_hub import hf_hub_download |
| from PIL import Image |
| from safetensors.torch import load_file |
|
|
| Image.MAX_IMAGE_PIXELS = None |
| IMAGENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32) |
| IMAGENET_STD = np.array([0.229, 0.224, 0.225], dtype=np.float32) |
| SMP = { |
| "unet": smp.Unet, |
| "unetpp": smp.UnetPlusPlus, |
| "deeplabv3p": smp.DeepLabV3Plus, |
| "fpn": smp.FPN, |
| "pan": smp.PAN, |
| } |
|
|
|
|
| def build_model(model_cfg: dict) -> torch.nn.Module: |
| cfg = dict(model_cfg) |
| name = cfg.pop("name") |
| cfg["classes"] = cfg.pop("num_classes") |
| cfg["encoder_weights"] = None |
| cls = SMP[name] |
| accepted = set(inspect.signature(cls).parameters) |
| return cls(**{k: v for k, v in cfg.items() if k in accepted}) |
|
|
|
|
| def sliding_window_predict(model, img, patch, stride, n_classes, device): |
| _, H, W = img.shape |
| logits = np.zeros((n_classes, H, W), dtype=np.float32) |
| weights = np.zeros((H, W), dtype=np.float32) |
| rows = sorted({*range(0, max(H - patch, 0) + 1, stride), max(H - patch, 0)}) |
| cols = sorted({*range(0, max(W - patch, 0) + 1, stride), max(W - patch, 0)}) |
| with torch.no_grad(): |
| for r in rows: |
| for c in cols: |
| tile = img[:, r : r + patch, c : c + patch] |
| ph, pw = patch - tile.shape[1], patch - tile.shape[2] |
| if ph or pw: |
| tile = np.pad(tile, ((0, 0), (0, ph), (0, pw))) |
| out = model(torch.from_numpy(tile).unsqueeze(0).to(device)).cpu().numpy()[0] |
| h, w = patch - ph, patch - pw |
| logits[:, r : r + h, c : c + w] += out[:, :h, :w] |
| weights[r : r + h, c : c + w] += 1.0 |
| return logits / np.maximum(weights, 1e-8) |
|
|
|
|
| def main() -> None: |
| ap = argparse.ArgumentParser() |
| ap.add_argument("--hf-repo", required=True) |
| ap.add_argument("--model-name", required=True) |
| ap.add_argument("--image", required=True) |
| ap.add_argument("--out-dir", default="predictions") |
| args = ap.parse_args() |
|
|
| cfg_path = hf_hub_download(args.hf_repo, f"{args.model_name}/config.json") |
| wts_path = hf_hub_download(args.hf_repo, f"{args.model_name}/model.safetensors") |
| cfg = json.loads(Path(cfg_path).read_text()) |
|
|
| device = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu") |
| model = build_model(cfg["model"]).to(device).eval() |
| model.load_state_dict(load_file(wts_path)) |
|
|
| classes = cfg["class_names"] |
| thrs = cfg.get("thresholds") or {} |
| thr = np.array([thrs.get(n, 0.5) for n in classes], dtype=np.float32).reshape(-1, 1, 1) |
| patch = cfg["patch_size"] |
|
|
| img = np.asarray(Image.open(args.image).convert("RGB"), dtype=np.float32) / 255.0 |
| img = ((img - IMAGENET_MEAN) / IMAGENET_STD).transpose(2, 0, 1) |
| logits = sliding_window_predict(model, img, patch, patch // 2, len(classes), device) |
| probs = 1.0 / (1.0 + np.exp(-np.clip(logits, -88, 88))) |
| binary = (probs > thr).astype(np.uint8) |
|
|
| out_dir = Path(args.out_dir) |
| out_dir.mkdir(parents=True, exist_ok=True) |
| for i, name in enumerate(classes): |
| Image.fromarray(binary[i] * 255, "L").save(out_dir / f"{name}.png") |
| print(f"Wrote {len(classes)} masks to {out_dir}/") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|