davidclara commited on
Commit
c33081c
·
verified ·
1 Parent(s): 4a9bcd8

Upload inference.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. inference.py +102 -0
inference.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Run a davidclara/siegfried-maps-segmentation model from Hugging Face on a map image.
2
+
3
+ Writes one binary PNG per class to --out-dir.
4
+
5
+ Example:
6
+ python inference.py \\
7
+ --hf-repo davidclara/siegfried-maps-segmentation \\
8
+ --model-name unetpp \\
9
+ --image map.png \\
10
+ --out-dir predictions/
11
+ """
12
+
13
+ import argparse
14
+ import inspect
15
+ import json
16
+ from pathlib import Path
17
+
18
+ import numpy as np
19
+ import segmentation_models_pytorch as smp
20
+ import torch
21
+ from huggingface_hub import hf_hub_download
22
+ from PIL import Image
23
+ from safetensors.torch import load_file
24
+
25
+ Image.MAX_IMAGE_PIXELS = None
26
+ IMAGENET_MEAN = np.array([0.485, 0.456, 0.406], dtype=np.float32)
27
+ IMAGENET_STD = np.array([0.229, 0.224, 0.225], dtype=np.float32)
28
+ SMP = {
29
+ "unet": smp.Unet,
30
+ "unetpp": smp.UnetPlusPlus,
31
+ "deeplabv3p": smp.DeepLabV3Plus,
32
+ "fpn": smp.FPN,
33
+ "pan": smp.PAN,
34
+ }
35
+
36
+
37
+ def build_model(model_cfg: dict) -> torch.nn.Module:
38
+ cfg = dict(model_cfg)
39
+ name = cfg.pop("name")
40
+ cfg["classes"] = cfg.pop("num_classes")
41
+ cfg["encoder_weights"] = None # weights come from safetensors
42
+ cls = SMP[name]
43
+ accepted = set(inspect.signature(cls).parameters)
44
+ return cls(**{k: v for k, v in cfg.items() if k in accepted})
45
+
46
+
47
+ def sliding_window_predict(model, img, patch, stride, n_classes, device):
48
+ _, H, W = img.shape
49
+ logits = np.zeros((n_classes, H, W), dtype=np.float32)
50
+ weights = np.zeros((H, W), dtype=np.float32)
51
+ rows = sorted({*range(0, max(H - patch, 0) + 1, stride), max(H - patch, 0)})
52
+ cols = sorted({*range(0, max(W - patch, 0) + 1, stride), max(W - patch, 0)})
53
+ with torch.no_grad():
54
+ for r in rows:
55
+ for c in cols:
56
+ tile = img[:, r : r + patch, c : c + patch]
57
+ ph, pw = patch - tile.shape[1], patch - tile.shape[2]
58
+ if ph or pw:
59
+ tile = np.pad(tile, ((0, 0), (0, ph), (0, pw)))
60
+ out = model(torch.from_numpy(tile).unsqueeze(0).to(device)).cpu().numpy()[0]
61
+ h, w = patch - ph, patch - pw
62
+ logits[:, r : r + h, c : c + w] += out[:, :h, :w]
63
+ weights[r : r + h, c : c + w] += 1.0
64
+ return logits / np.maximum(weights, 1e-8)
65
+
66
+
67
+ def main() -> None:
68
+ ap = argparse.ArgumentParser()
69
+ ap.add_argument("--hf-repo", required=True)
70
+ ap.add_argument("--model-name", required=True)
71
+ ap.add_argument("--image", required=True)
72
+ ap.add_argument("--out-dir", default="predictions")
73
+ args = ap.parse_args()
74
+
75
+ cfg_path = hf_hub_download(args.hf_repo, f"{args.model_name}/config.json")
76
+ wts_path = hf_hub_download(args.hf_repo, f"{args.model_name}/model.safetensors")
77
+ cfg = json.loads(Path(cfg_path).read_text())
78
+
79
+ device = "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu")
80
+ model = build_model(cfg["model"]).to(device).eval()
81
+ model.load_state_dict(load_file(wts_path))
82
+
83
+ classes = cfg["class_names"]
84
+ thrs = cfg.get("thresholds") or {}
85
+ thr = np.array([thrs.get(n, 0.5) for n in classes], dtype=np.float32).reshape(-1, 1, 1)
86
+ patch = cfg["patch_size"]
87
+
88
+ img = np.asarray(Image.open(args.image).convert("RGB"), dtype=np.float32) / 255.0
89
+ img = ((img - IMAGENET_MEAN) / IMAGENET_STD).transpose(2, 0, 1)
90
+ logits = sliding_window_predict(model, img, patch, patch // 2, len(classes), device)
91
+ probs = 1.0 / (1.0 + np.exp(-np.clip(logits, -88, 88)))
92
+ binary = (probs > thr).astype(np.uint8)
93
+
94
+ out_dir = Path(args.out_dir)
95
+ out_dir.mkdir(parents=True, exist_ok=True)
96
+ for i, name in enumerate(classes):
97
+ Image.fromarray(binary[i] * 255, "L").save(out_dir / f"{name}.png")
98
+ print(f"Wrote {len(classes)} masks to {out_dir}/")
99
+
100
+
101
+ if __name__ == "__main__":
102
+ main()