File size: 3,761 Bytes
c33081c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 | """Run a davidclara/siegfried-maps-segmentation model from Hugging Face on a map image.
Writes one binary PNG per class to --out-dir.
Example:
python inference.py \\
--hf-repo davidclara/siegfried-maps-segmentation \\
--model-name unetpp \\
--image map.png \\
--out-dir predictions/
"""
import argparse
import inspect
import json
from pathlib import Path
import numpy as np
import segmentation_models_pytorch as smp
import torch
from huggingface_hub import hf_hub_download
from PIL import Image
from safetensors.torch import load_file
Image.MAX_IMAGE_PIXELS = None
IMAGENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32)
IMAGENET_STD = np.array([0.229, 0.224, 0.225], dtype=np.float32)
SMP = {
"unet": smp.Unet,
"unetpp": smp.UnetPlusPlus,
"deeplabv3p": smp.DeepLabV3Plus,
"fpn": smp.FPN,
"pan": smp.PAN,
}
def build_model(model_cfg: dict) -> torch.nn.Module:
cfg = dict(model_cfg)
name = cfg.pop("name")
cfg["classes"] = cfg.pop("num_classes")
cfg["encoder_weights"] = None # weights come from safetensors
cls = SMP[name]
accepted = set(inspect.signature(cls).parameters)
return cls(**{k: v for k, v in cfg.items() if k in accepted})
def sliding_window_predict(model, img, patch, stride, n_classes, device):
_, H, W = img.shape
logits = np.zeros((n_classes, H, W), dtype=np.float32)
weights = np.zeros((H, W), dtype=np.float32)
rows = sorted({*range(0, max(H - patch, 0) + 1, stride), max(H - patch, 0)})
cols = sorted({*range(0, max(W - patch, 0) + 1, stride), max(W - patch, 0)})
with torch.no_grad():
for r in rows:
for c in cols:
tile = img[:, r : r + patch, c : c + patch]
ph, pw = patch - tile.shape[1], patch - tile.shape[2]
if ph or pw:
tile = np.pad(tile, ((0, 0), (0, ph), (0, pw)))
out = model(torch.from_numpy(tile).unsqueeze(0).to(device)).cpu().numpy()[0]
h, w = patch - ph, patch - pw
logits[:, r : r + h, c : c + w] += out[:, :h, :w]
weights[r : r + h, c : c + w] += 1.0
return logits / np.maximum(weights, 1e-8)
def main() -> None:
ap = argparse.ArgumentParser()
ap.add_argument("--hf-repo", required=True)
ap.add_argument("--model-name", required=True)
ap.add_argument("--image", required=True)
ap.add_argument("--out-dir", default="predictions")
args = ap.parse_args()
cfg_path = hf_hub_download(args.hf_repo, f"{args.model_name}/config.json")
wts_path = hf_hub_download(args.hf_repo, f"{args.model_name}/model.safetensors")
cfg = json.loads(Path(cfg_path).read_text())
device = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")
model = build_model(cfg["model"]).to(device).eval()
model.load_state_dict(load_file(wts_path))
classes = cfg["class_names"]
thrs = cfg.get("thresholds") or {}
thr = np.array([thrs.get(n, 0.5) for n in classes], dtype=np.float32).reshape(-1, 1, 1)
patch = cfg["patch_size"]
img = np.asarray(Image.open(args.image).convert("RGB"), dtype=np.float32) / 255.0
img = ((img - IMAGENET_MEAN) / IMAGENET_STD).transpose(2, 0, 1)
logits = sliding_window_predict(model, img, patch, patch // 2, len(classes), device)
probs = 1.0 / (1.0 + np.exp(-np.clip(logits, -88, 88)))
binary = (probs > thr).astype(np.uint8)
out_dir = Path(args.out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
for i, name in enumerate(classes):
Image.fromarray(binary[i] * 255, "L").save(out_dir / f"{name}.png")
print(f"Wrote {len(classes)} masks to {out_dir}/")
if __name__ == "__main__":
main()
|