Upload inference.py with huggingface_hub
Browse files- inference.py +102 -0
inference.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Run a davidclara/siegfried-maps-segmentation model from Hugging Face on a map image.
|
| 2 |
+
|
| 3 |
+
Writes one binary PNG per class to --out-dir.
|
| 4 |
+
|
| 5 |
+
Example:
|
| 6 |
+
python inference.py \\
|
| 7 |
+
--hf-repo davidclara/siegfried-maps-segmentation \\
|
| 8 |
+
--model-name unetpp \\
|
| 9 |
+
--image map.png \\
|
| 10 |
+
--out-dir predictions/
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import argparse
|
| 14 |
+
import inspect
|
| 15 |
+
import json
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
|
| 18 |
+
import numpy as np
|
| 19 |
+
import segmentation_models_pytorch as smp
|
| 20 |
+
import torch
|
| 21 |
+
from huggingface_hub import hf_hub_download
|
| 22 |
+
from PIL import Image
|
| 23 |
+
from safetensors.torch import load_file
|
| 24 |
+
|
| 25 |
+
Image.MAX_IMAGE_PIXELS = None
|
| 26 |
+
IMAGENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32)
|
| 27 |
+
IMAGENET_STD = np.array([0.229, 0.224, 0.225], dtype=np.float32)
|
| 28 |
+
SMP = {
|
| 29 |
+
"unet": smp.Unet,
|
| 30 |
+
"unetpp": smp.UnetPlusPlus,
|
| 31 |
+
"deeplabv3p": smp.DeepLabV3Plus,
|
| 32 |
+
"fpn": smp.FPN,
|
| 33 |
+
"pan": smp.PAN,
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def build_model(model_cfg: dict) -> torch.nn.Module:
|
| 38 |
+
cfg = dict(model_cfg)
|
| 39 |
+
name = cfg.pop("name")
|
| 40 |
+
cfg["classes"] = cfg.pop("num_classes")
|
| 41 |
+
cfg["encoder_weights"] = None # weights come from safetensors
|
| 42 |
+
cls = SMP[name]
|
| 43 |
+
accepted = set(inspect.signature(cls).parameters)
|
| 44 |
+
return cls(**{k: v for k, v in cfg.items() if k in accepted})
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def sliding_window_predict(model, img, patch, stride, n_classes, device):
|
| 48 |
+
_, H, W = img.shape
|
| 49 |
+
logits = np.zeros((n_classes, H, W), dtype=np.float32)
|
| 50 |
+
weights = np.zeros((H, W), dtype=np.float32)
|
| 51 |
+
rows = sorted({*range(0, max(H - patch, 0) + 1, stride), max(H - patch, 0)})
|
| 52 |
+
cols = sorted({*range(0, max(W - patch, 0) + 1, stride), max(W - patch, 0)})
|
| 53 |
+
with torch.no_grad():
|
| 54 |
+
for r in rows:
|
| 55 |
+
for c in cols:
|
| 56 |
+
tile = img[:, r : r + patch, c : c + patch]
|
| 57 |
+
ph, pw = patch - tile.shape[1], patch - tile.shape[2]
|
| 58 |
+
if ph or pw:
|
| 59 |
+
tile = np.pad(tile, ((0, 0), (0, ph), (0, pw)))
|
| 60 |
+
out = model(torch.from_numpy(tile).unsqueeze(0).to(device)).cpu().numpy()[0]
|
| 61 |
+
h, w = patch - ph, patch - pw
|
| 62 |
+
logits[:, r : r + h, c : c + w] += out[:, :h, :w]
|
| 63 |
+
weights[r : r + h, c : c + w] += 1.0
|
| 64 |
+
return logits / np.maximum(weights, 1e-8)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def main() -> None:
|
| 68 |
+
ap = argparse.ArgumentParser()
|
| 69 |
+
ap.add_argument("--hf-repo", required=True)
|
| 70 |
+
ap.add_argument("--model-name", required=True)
|
| 71 |
+
ap.add_argument("--image", required=True)
|
| 72 |
+
ap.add_argument("--out-dir", default="predictions")
|
| 73 |
+
args = ap.parse_args()
|
| 74 |
+
|
| 75 |
+
cfg_path = hf_hub_download(args.hf_repo, f"{args.model_name}/config.json")
|
| 76 |
+
wts_path = hf_hub_download(args.hf_repo, f"{args.model_name}/model.safetensors")
|
| 77 |
+
cfg = json.loads(Path(cfg_path).read_text())
|
| 78 |
+
|
| 79 |
+
device = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")
|
| 80 |
+
model = build_model(cfg["model"]).to(device).eval()
|
| 81 |
+
model.load_state_dict(load_file(wts_path))
|
| 82 |
+
|
| 83 |
+
classes = cfg["class_names"]
|
| 84 |
+
thrs = cfg.get("thresholds") or {}
|
| 85 |
+
thr = np.array([thrs.get(n, 0.5) for n in classes], dtype=np.float32).reshape(-1, 1, 1)
|
| 86 |
+
patch = cfg["patch_size"]
|
| 87 |
+
|
| 88 |
+
img = np.asarray(Image.open(args.image).convert("RGB"), dtype=np.float32) / 255.0
|
| 89 |
+
img = ((img - IMAGENET_MEAN) / IMAGENET_STD).transpose(2, 0, 1)
|
| 90 |
+
logits = sliding_window_predict(model, img, patch, patch // 2, len(classes), device)
|
| 91 |
+
probs = 1.0 / (1.0 + np.exp(-np.clip(logits, -88, 88)))
|
| 92 |
+
binary = (probs > thr).astype(np.uint8)
|
| 93 |
+
|
| 94 |
+
out_dir = Path(args.out_dir)
|
| 95 |
+
out_dir.mkdir(parents=True, exist_ok=True)
|
| 96 |
+
for i, name in enumerate(classes):
|
| 97 |
+
Image.fromarray(binary[i] * 255, "L").save(out_dir / f"{name}.png")
|
| 98 |
+
print(f"Wrote {len(classes)} masks to {out_dir}/")
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
if __name__ == "__main__":
|
| 102 |
+
main()
|