"""Sapiens2 multi-task CPU: seg / normal / pointmap / pose at 0.4b/0.8b/1b plus 5B INT8 ONNX. 5B (seg, normal, pointmap) runs via INT8 ONNX from WeReCooking/sapiens2-onnx; pose-5b not shipped. Pose top-down: DETR finds people, sapiens2 estimates 308 keypoints per crop. Lazy-load with LRU cache (keeps 2 dense models + 1 pose model resident). Per-task API endpoint via Gradio's auto-API (curl-able with Bearer token). Also exposes a standalone ONNX CLI mode that does not need PyTorch or sapiens2: python app.py onnx seg 0.4b photo.jpg --output seg.png python app.py onnx pointmap 5b photo.jpg --output depth.png """ # Block mmpretrain: mmdet's reid modules try to import it via try/except ImportError, # but mmpretrain raises TypeError on import (transformers API drift) which escapes # the except and kills the process. import sys sys.modules["mmpretrain"] = None # --- ONNX CLI (standalone, no PyTorch/sapiens2 import) ---------------------- def _onnx_cli(): """Run a published sapiens2 ONNX model on a local image. Only needs numpy, onnxruntime, huggingface_hub, opencv-python-headless.""" import argparse import os import time from pathlib import Path import numpy as np import cv2 import onnxruntime as ort from huggingface_hub import hf_hub_download DEFAULT_REPO = "WeReCooking/sapiens2-onnx" PRECISIONS = {("seg", "0.4b"): "fp16"} # only seg-0.4b is fp16; rest fp32 or int8 for 5B INPUT_HW = (1024, 768) parser = argparse.ArgumentParser(prog="app.py onnx") parser.add_argument("task", choices=["seg", "normal", "pointmap", "pose"]) parser.add_argument("size", choices=["0.4b", "0.8b", "1b", "5b"]) parser.add_argument("image", help="Local image path") parser.add_argument("--cache-dir", default="./onnx_cache") parser.add_argument("--token", default=os.environ.get("HF_TOKEN")) parser.add_argument("--output", default=None, help="Save the visualization here") parser.add_argument("--repo", default=DEFAULT_REPO) args = parser.parse_args(sys.argv[2:]) precision = PRECISIONS.get((args.task, args.size), "int8" if args.size == "5b" else "fp32") filename = f"{args.task}/{args.task}_{args.size}_{precision}.onnx" print(f"[1/3] downloading {filename} from {args.repo}", flush=True) t0 = time.time() onnx_path = hf_hub_download(repo_id=args.repo, filename=filename, local_dir=args.cache_dir, token=args.token) hf_hub_download(repo_id=args.repo, filename=f"{filename}.data", local_dir=args.cache_dir, token=args.token) print(f" ready in {time.time()-t0:.1f}s", flush=True) img = cv2.imread(args.image, cv2.IMREAD_COLOR) if img is None: raise FileNotFoundError(args.image) H, W = INPUT_HW h0, w0 = img.shape[:2] scale = min(W / w0, H / h0) new_w, new_h = int(round(w0 * scale)), int(round(h0 * scale)) resized = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_LINEAR) canvas = np.zeros((H, W, 3), dtype=np.uint8) top = (H - new_h) // 2 left = (W - new_w) // 2 canvas[top:top + new_h, left:left + new_w] = resized mean = (123.675, 116.28, 103.53) std = (58.395, 57.12, 57.375) x = canvas.astype(np.float32) for c in range(3): x[:, :, c] = (x[:, :, c] - mean[c]) / std[c] x = x.transpose(2, 0, 1)[None] print(f"[2/3] ORT forward (input {x.shape} {x.dtype})", flush=True) sess = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"]) t0 = time.time() out = sess.run(None, {sess.get_inputs()[0].name: x}) print(f" forward {time.time()-t0:.1f}s, outputs={[o.shape for o in out]}", flush=True) print(f"[3/3] postprocess + preview", flush=True) if args.task == "pose": heatmaps = out[0][0] K, hH, hW = heatmaps.shape flat = heatmaps.reshape(K, -1) peak = flat.argmax(axis=1) ys, xs = np.unravel_index(peak, (hH, hW)) scores = flat.max(axis=1) inp_y = ys * (INPUT_HW[0] / hH) inp_x = xs * (INPUT_HW[1] / hW) scale_y = h0 / new_h scale_x = w0 / new_w img_y = (inp_y - top) * scale_y img_x = (inp_x - left) * scale_x n_visible = int((scores > 0.3).sum()) print(f" {n_visible}/{K} keypoints above 0.3 confidence (range {scores.min():.3f} to {scores.max():.3f})") if args.output: for i in range(K): if scores[i] < 0.3: continue cv2.circle(img, (int(img_x[i]), int(img_y[i])), 4, (0, 255, 0), -1) cv2.imwrite(args.output, img) print(f" saved {args.output}") return if args.task == "seg": logits = out[0][0] class_map = logits.argmax(axis=0).astype(np.int32) class_map_crop = class_map[top:top + new_h, left:left + new_w] class_map_full = cv2.resize(class_map_crop, (w0, h0), interpolation=cv2.INTER_NEAREST) classes = np.unique(class_map_full).tolist() print(f" classes detected: {classes[:15]}") if args.output: palette = (np.random.RandomState(42).rand(29, 3) * 255).astype(np.uint8) cv2.imwrite(args.output, palette[class_map_full]) print(f" saved {args.output}") return if args.task == "normal": normal_raw = out[0][0].transpose(1, 2, 0) norm = np.linalg.norm(normal_raw, axis=2, keepdims=True) normal_unit = normal_raw / np.maximum(norm, 1e-8) normal_crop = normal_unit[top:top + new_h, left:left + new_w] normal_full = cv2.resize(normal_crop, (w0, h0), interpolation=cv2.INTER_LINEAR) if args.output: rgb = (((normal_full + 1.0) / 2.0) * 255).clip(0, 255).astype(np.uint8) cv2.imwrite(args.output, rgb) print(f" saved {args.output}") return # pointmap pointmap_rel = out[0][0].transpose(1, 2, 0) s = out[1][0, 0] if len(out) > 1 else 1.0 pointmap_metric = pointmap_rel / max(float(s), 1e-8) z = pointmap_metric[..., 2] z_crop = z[top:top + new_h, left:left + new_w] z_full = cv2.resize(z_crop, (w0, h0), interpolation=cv2.INTER_LINEAR) zmin, zmax = float(z_full.min()), float(z_full.max()) print(f" Z range: [{zmin:.2f}, {zmax:.2f}] meters") if args.output: z_norm = ((z_full - zmin) / max(zmax - zmin, 1e-8) * 255).astype(np.uint8) cv2.imwrite(args.output, z_norm) print(f" saved {args.output}") if len(sys.argv) > 1 and sys.argv[1] == "onnx": _onnx_cli() sys.exit(0) # --- Gradio path ----------------------------------------------------------- import glob import os import time import traceback from pathlib import Path import gradio as gr import numpy as np from PIL import Image # --- Catalog ---------------------------------------------------------------- VARIANTS = { ("seg", "0.4b"): {"repo": "facebook/sapiens2-seg-0.4b", "filename": "sapiens2_0.4b_seg.safetensors", "config_glob": "**/sapiens2_0.4b_seg*shutterstock*1024x768*.py", "kind": "seg"}, ("seg", "0.8b"): {"repo": "facebook/sapiens2-seg-0.8b", "filename": "sapiens2_0.8b_seg.safetensors", "config_glob": "**/sapiens2_0.8b_seg*shutterstock*1024x768*.py", "kind": "seg"}, ("seg", "1b"): {"repo": "facebook/sapiens2-seg-1b", "filename": "sapiens2_1b_seg.safetensors", "config_glob": "**/sapiens2_1b_seg*shutterstock*1024x768*.py", "kind": "seg"}, ("normal", "0.4b"): {"repo": "facebook/sapiens2-normal-0.4b", "filename": "sapiens2_0.4b_normal.safetensors", "config_glob": "**/sapiens2_0.4b_normal*metasim*1024x768*.py", "kind": "normal"}, ("normal", "0.8b"): {"repo": "facebook/sapiens2-normal-0.8b", "filename": "sapiens2_0.8b_normal.safetensors", "config_glob": "**/sapiens2_0.8b_normal*metasim*1024x768*.py", "kind": "normal"}, ("normal", "1b"): {"repo": "facebook/sapiens2-normal-1b", "filename": "sapiens2_1b_normal.safetensors", "config_glob": "**/sapiens2_1b_normal*metasim*1024x768*.py", "kind": "normal"}, ("pointmap", "0.4b"): {"repo": "facebook/sapiens2-pointmap-0.4b", "filename": "sapiens2_0.4b_pointmap.safetensors", "config_glob": "**/sapiens2_0.4b_pointmap*render_people*1024x768*.py", "kind": "pointmap"}, ("pointmap", "0.8b"): {"repo": "facebook/sapiens2-pointmap-0.8b", "filename": "sapiens2_0.8b_pointmap.safetensors", "config_glob": "**/sapiens2_0.8b_pointmap*render_people*1024x768*.py", "kind": "pointmap"}, ("pointmap", "1b"): {"repo": "facebook/sapiens2-pointmap-1b", "filename": "sapiens2_1b_pointmap.safetensors", "config_glob": "**/sapiens2_1b_pointmap*render_people*1024x768*.py", "kind": "pointmap"}, ("pose", "0.4b"): {"repo": "facebook/sapiens2-pose-0.4b", "filename": "sapiens2_0.4b_pose.safetensors", "config_glob": "**/sapiens2_0.4b_keypoints308*shutterstock_goliath*1024x768*.py", "kind": "pose"}, ("pose", "0.8b"): {"repo": "facebook/sapiens2-pose-0.8b", "filename": "sapiens2_0.8b_pose.safetensors", "config_glob": "**/sapiens2_0.8b_keypoints308*shutterstock_goliath*1024x768*.py", "kind": "pose"}, ("pose", "1b"): {"repo": "facebook/sapiens2-pose-1b", "filename": "sapiens2_1b_pose.safetensors", "config_glob": "**/sapiens2_1b_keypoints308*shutterstock_goliath*1024x768*.py", "kind": "pose"}, # 5B variants run via prebuilt INT8 ONNX from WeReCooking/sapiens2-onnx. # fp32 5B PyTorch (~20 GB) won't fit in the free CPU Space's 16 GB; INT8 ONNX is ~5-6 GB. # pose-5b is intentionally absent — INT8 wasn't successfully built for it. ("seg", "5b"): {"onnx_repo": "WeReCooking/sapiens2-onnx", "onnx_filename": "seg/seg_5b_int8.onnx", "kind": "seg"}, ("normal", "5b"): {"onnx_repo": "WeReCooking/sapiens2-onnx", "onnx_filename": "normal/normal_5b_int8.onnx", "kind": "normal"}, ("pointmap", "5b"): {"onnx_repo": "WeReCooking/sapiens2-onnx", "onnx_filename": "pointmap/pointmap_5b_int8.onnx", "kind": "pointmap"}, } DENSE_KINDS = {"seg", "normal", "pointmap"} _MODELS: dict = {} # (task, size) -> dense model (LRU) _POSE_MODELS: dict = {} # (task, size) -> pose model (separate cache so DETR survives) _DETECTOR = None # tuple(processor, model) — lazily loaded once _POSE_METAINFO = None _ORT_SESSIONS: dict = {} # (task, "5b") -> onnxruntime InferenceSession _MAX_CACHED = 2 _DOME_CLASSES_29 = None _SAPIENS_PKG_ROOT = None def _sapiens_root() -> Path: """Return the directory containing the installed sapiens package.""" global _SAPIENS_PKG_ROOT if _SAPIENS_PKG_ROOT is None: import sapiens # imported lazily because it has side effects (mmdet etc.) _SAPIENS_PKG_ROOT = Path(sapiens.__file__).resolve().parent return _SAPIENS_PKG_ROOT def _find_config(pattern: str) -> str: # cfg_glob comes in as "**/sapiens2_..._1024x768*.py"; rglob applies the leading ** implicitly leaf = pattern.split("/")[-1] root = _sapiens_root() matches = list(root.rglob(leaf)) if not matches: raise FileNotFoundError(f"No config matching {leaf} under {root}") return str(matches[0]) def _get_dense_model(task: str, size: str): """Lazy-load + LRU-cache for seg/normal/pointmap.""" key = (task, size) if key in _MODELS: _MODELS[key] = _MODELS.pop(key) return _MODELS[key] spec = VARIANTS[key] from sapiens.dense.models import init_model if spec["kind"] == "normal": from sapiens.dense.models import NormalEstimator # noqa: F401 elif spec["kind"] == "pointmap": from sapiens.dense.models import PointmapEstimator # noqa: F401 config = _find_config(spec["config_glob"]) from huggingface_hub import hf_hub_download local_dir = f"/tmp/sapiens_models/{task}-{size}" os.makedirs(local_dir, exist_ok=True) ckpt = hf_hub_download(repo_id=spec["repo"], filename=spec["filename"], local_dir=local_dir) # cpu-basic has 16 GB. Loading a 1B dense (~6 GB fp32) on top of cached 0.8b/0.4b dense (~5 GB each) + a 1B pose + DETR OOMs. # So before init_model allocates a 1B's weights, evict ALL caches it would race with. import gc if size == "1b": _MODELS.clear() _POSE_MODELS.clear() _ORT_SESSIONS.clear() gc.collect() else: while len(_MODELS) >= _MAX_CACHED: oldest = next(iter(_MODELS)) del _MODELS[oldest] gc.collect() model = init_model(config, ckpt, device="cpu") _MODELS[key] = model return model def _get_pose_metainfo(): global _POSE_METAINFO if _POSE_METAINFO is None: from sapiens.pose.datasets import parse_pose_metainfo meta_cfg = _find_config("**/pose/configs/**/keypoints308.py") import importlib.util spec_obj = importlib.util.spec_from_file_location("keypoints308_meta", meta_cfg) mod = importlib.util.module_from_spec(spec_obj) spec_obj.loader.exec_module(mod) ds_info = getattr(mod, "dataset_info", None) if ds_info is None: raise RuntimeError(f"No dataset_info in {meta_cfg}") _POSE_METAINFO = parse_pose_metainfo(ds_info) return _POSE_METAINFO def _get_pose_model(size: str): key = ("pose", size) if key in _POSE_MODELS: return _POSE_MODELS[key] spec = VARIANTS[key] from sapiens.pose.models import init_model from sapiens.pose.datasets import UDPHeatmap config = _find_config(spec["config_glob"]) from huggingface_hub import hf_hub_download local_dir = f"/tmp/sapiens_models/pose-{size}" os.makedirs(local_dir, exist_ok=True) ckpt = hf_hub_download(repo_id=spec["repo"], filename=spec["filename"], local_dir=local_dir) # Same hard eviction as the dense 1B path: clear every other resident model before init_model allocates. import gc if size == "1b": _MODELS.clear() _POSE_MODELS.clear() _ORT_SESSIONS.clear() else: _POSE_MODELS.clear() # cap=1 gc.collect() model = init_model(config, ckpt, device="cpu") codec_cfg = dict(model.cfg.codec) assert codec_cfg.pop("type") == "UDPHeatmap" model.codec = UDPHeatmap(**codec_cfg) model.pose_metainfo = _get_pose_metainfo() _POSE_MODELS[key] = model return model def _get_detector(): global _DETECTOR if _DETECTOR is None: import torch # noqa: F401 from transformers import DetrImageProcessor, DetrForObjectDetection proc = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50") det = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50").eval() _DETECTOR = (proc, det) return _DETECTOR def _load_dome_classes(): global _DOME_CLASSES_29 if _DOME_CLASSES_29 is None: from sapiens.dense.src.datasets.seg.seg_utils import DOME_CLASSES_29 _DOME_CLASSES_29 = DOME_CLASSES_29 return _DOME_CLASSES_29 def _get_padding(data_samples): ds = data_samples[0] if isinstance(data_samples, list) and data_samples else data_samples if hasattr(ds, "padding_size"): return tuple(ds.padding_size) if hasattr(ds, "metainfo") and isinstance(ds.metainfo, dict): if "padding_size" in ds.metainfo: return tuple(ds.metainfo["padding_size"]) if "pad_shape" in ds.metainfo and "img_shape" in ds.metainfo: ph, pw = ds.metainfo["pad_shape"][:2] ih, iw = ds.metainfo["img_shape"][:2] return (0, pw - iw, 0, ph - ih) if isinstance(ds, dict): meta = ds.get("meta") or ds if "padding_size" in meta: return tuple(meta["padding_size"]) return (0, 0, 0, 0) # --- Per-task inference ----------------------------------------------------- def _infer_seg(image_bgr, model): import torch import torch.nn.functional as F import cv2 h0, w0 = image_bgr.shape[:2] data = model.pipeline(dict(img=image_bgr)) data = model.data_preprocessor(data) with torch.no_grad(): logits = model(data["inputs"]) logits = F.interpolate(logits, size=(h0, w0), mode="bilinear", align_corners=False) label_map = logits.argmax(dim=1).squeeze(0).cpu().numpy().astype(np.int32) classes = _load_dome_classes() palette = np.zeros((256, 3), dtype=np.uint8) for cid, meta in classes.items(): palette[cid] = meta["color"][::-1] color_mask = palette[label_map] overlay_bgr = cv2.addWeighted(image_bgr, 0.5, color_mask, 0.5, 0) overlay_rgb = cv2.cvtColor(overlay_bgr, cv2.COLOR_BGR2RGB) uniq = sorted(int(c) for c in np.unique(label_map)) labels = [classes[c]["name"].replace("_", " ") for c in uniq if c in classes] return Image.fromarray(overlay_rgb), f"classes: {', '.join(labels)}" def _infer_normal(image_bgr, model): import torch data = model.pipeline(dict(img=image_bgr)) data = model.data_preprocessor(data) inputs, data_samples = data["inputs"], data["data_samples"] if inputs.ndim == 3: inputs = inputs.unsqueeze(0) with torch.no_grad(): normal = model(inputs) normal = normal / normal.norm(dim=1, keepdim=True).clamp_min(1e-8) pl, pr, pt, pb = _get_padding(data_samples) normal = normal[:, :, pt:inputs.shape[2] - pb, pl:inputs.shape[3] - pr] normal_hwc = normal.squeeze(0).cpu().float().numpy().transpose(1, 2, 0) rgb = (((normal_hwc + 1.0) / 2.0) * 255.0).clip(0, 255).astype(np.uint8) return Image.fromarray(rgb), f"normal map {rgb.shape}" def _infer_pointmap(image_bgr, model): import torch data = model.pipeline(dict(img=image_bgr)) data = model.data_preprocessor(data) inputs, data_samples = data["inputs"], data["data_samples"] if inputs.ndim == 3: inputs = inputs.unsqueeze(0) with torch.no_grad(): out = model(inputs) if isinstance(out, tuple) and len(out) == 2: pointmap, scale = out pointmap = pointmap / scale.clamp_min(1e-8) else: pointmap = out pl, pr, pt, pb = _get_padding(data_samples) pointmap = pointmap[:, :, pt:inputs.shape[2] - pb, pl:inputs.shape[3] - pr] pmap_hwc = pointmap.squeeze(0).cpu().float().numpy().transpose(1, 2, 0) z = pmap_hwc[..., 2] z_min, z_max = float(z.min()), float(z.max()) z_norm = (z - z_min) / max(z_max - z_min, 1e-8) z_rgb = (z_norm * 255).astype(np.uint8) rgb = np.stack([z_rgb, z_rgb, z_rgb], axis=-1) return Image.fromarray(rgb), f"pointmap {pmap_hwc.shape} | Z [{z_min:.2f}, {z_max:.2f}]" # --- 5B INT8 ONNX path ------------------------------------------------------- def _get_ort_session(task: str): """Lazy-load + cache an ORT session for {task}_5b_int8.onnx. Each 5B session is 5-6 GB RAM. cpu-basic has 16 GB total, so keep at most one 5B session live and evict cached dense/pose PyTorch models that would push us OOM.""" key = (task, "5b") sess = _ORT_SESSIONS.get(key) if sess is not None: return sess import onnxruntime as ort from huggingface_hub import hf_hub_download spec = VARIANTS[key] cache_dir = os.environ.get("ONNX_5B_CACHE", "/app/onnx_5b") os.makedirs(cache_dir, exist_ok=True) fn = spec["onnx_filename"] onnx_path = hf_hub_download(repo_id=spec["onnx_repo"], filename=fn, local_dir=cache_dir) hf_hub_download(repo_id=spec["onnx_repo"], filename=fn + ".data", local_dir=cache_dir) # Evict any prior 5B ORT session and any 1b dense models — they together exceed 16 GB. import gc if _ORT_SESSIONS: _ORT_SESSIONS.clear() gc.collect() for k in list(_MODELS.keys()): if k[1] in ("1b", "0.8b"): del _MODELS[k] for k in list(_POSE_MODELS.keys()): if k[1] in ("1b", "0.8b"): del _POSE_MODELS[k] gc.collect() sess = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"]) _ORT_SESSIONS[key] = sess return sess def _infer_dense_5b(image_bgr, task: str): """5B inference: preprocess via the 0.4b PyTorch pipeline (cached), forward via ORT INT8.""" import torch import torch.nn.functional as F import cv2 # Use the 0.4b model's pipeline+preprocessor for image prep — it's already in cache for warm calls. proxy = _get_dense_model(task, "0.4b") data = proxy.pipeline(dict(img=image_bgr)) data = proxy.data_preprocessor(data) inputs, data_samples = data["inputs"], data["data_samples"] if inputs.ndim == 3: inputs = inputs.unsqueeze(0) sess = _get_ort_session(task) out = sess.run(None, {sess.get_inputs()[0].name: inputs.float().cpu().numpy()}) if task == "seg": logits = torch.from_numpy(out[0]) h0, w0 = image_bgr.shape[:2] logits = F.interpolate(logits, size=(h0, w0), mode="bilinear", align_corners=False) label_map = logits.argmax(dim=1).squeeze(0).numpy().astype(np.int32) classes = _load_dome_classes() palette = np.zeros((256, 3), dtype=np.uint8) for cid, meta in classes.items(): palette[cid] = meta["color"][::-1] color_mask = palette[label_map] overlay_bgr = cv2.addWeighted(image_bgr, 0.5, color_mask, 0.5, 0) overlay_rgb = cv2.cvtColor(overlay_bgr, cv2.COLOR_BGR2RGB) uniq = sorted(int(c) for c in np.unique(label_map)) labels = [classes[c]["name"].replace("_", " ") for c in uniq if c in classes] return Image.fromarray(overlay_rgb), f"classes: {', '.join(labels)}" if task == "normal": normal = torch.from_numpy(out[0]) normal = normal / normal.norm(dim=1, keepdim=True).clamp_min(1e-8) pl, pr, pt, pb = _get_padding(data_samples) normal = normal[:, :, pt:inputs.shape[2] - pb, pl:inputs.shape[3] - pr] normal_hwc = normal.squeeze(0).numpy().transpose(1, 2, 0) rgb = (((normal_hwc + 1.0) / 2.0) * 255.0).clip(0, 255).astype(np.uint8) return Image.fromarray(rgb), f"normal map {rgb.shape}" # pointmap — ONNX produces (pointmap [1,3,H,W], scale [1,1]); divide to recover metric depths. pointmap = torch.from_numpy(out[0]) if len(out) > 1: scale = torch.from_numpy(out[1]) pointmap = pointmap / scale.clamp_min(1e-8) pl, pr, pt, pb = _get_padding(data_samples) pointmap = pointmap[:, :, pt:inputs.shape[2] - pb, pl:inputs.shape[3] - pr] pmap_hwc = pointmap.squeeze(0).numpy().transpose(1, 2, 0) z = pmap_hwc[..., 2] z_min, z_max = float(z.min()), float(z.max()) z_norm = (z - z_min) / max(z_max - z_min, 1e-8) z_rgb = (z_norm * 255).astype(np.uint8) rgb = np.stack([z_rgb, z_rgb, z_rgb], axis=-1) return Image.fromarray(rgb), f"pointmap {pmap_hwc.shape} | Z [{z_min:.2f}, {z_max:.2f}]" # Inlined from the upstream Meta sample (pose keypoint render). # Draws skeleton links + colored keypoints; thickness/radius are picked by the caller. def visualize_keypoints( image: np.ndarray, keypoints, keypoints_visible, keypoint_scores, *, radius: int = 4, thickness: int = -1, color=(255, 0, 0), kpt_thr: float = 0.3, skeleton: list | None = None, kpt_color=None, link_color=None, show_kpt_idx: bool = False, ) -> np.ndarray: import cv2 img = image.copy() H, W = img.shape[:2] if skeleton is None: skeleton = [] if kpt_color is None: kpt_color = color if link_color is None: link_color = (0, 255, 0) def _as_color_list(c, n): if hasattr(c, "detach"): c = c.detach().cpu().numpy() if isinstance(c, np.ndarray): if c.ndim == 2 and c.shape[1] == 3: return [tuple(int(v) for v in row) for row in c.tolist()] if c.size == 3: return [tuple(int(v) for v in c.tolist())] * max(1, n) if isinstance(c, (list, tuple)): if n and len(c) == n and isinstance(c[0], (list, tuple, np.ndarray)): out = [] for cc in c: cc = np.asarray(cc).reshape(-1) out.append(tuple(int(v) for v in cc.tolist())) return out c_arr = np.asarray(c).reshape(-1) if c_arr.size == 3: return [tuple(int(v) for v in c_arr.tolist())] * max(1, n) return [(255, 0, 0)] * max(1, n) J = keypoints[0].shape[0] if keypoints else 0 kpt_colors = _as_color_list(kpt_color, J) link_colors = _as_color_list(link_color, len(skeleton)) def in_bounds(x, y): return 0 <= x < W and 0 <= y < H for kpts, vis, score in zip(keypoints, keypoints_visible, keypoint_scores): kpts = np.asarray(kpts, float) vis = np.asarray(vis).reshape(-1).astype(bool) score = np.asarray(score).reshape(-1) for lk, (i, j) in enumerate(skeleton): if i >= len(kpts) or j >= len(kpts): continue if not (vis[i] and vis[j]): continue if score[i] < kpt_thr or score[j] < kpt_thr: continue x1, y1 = map(int, np.round(kpts[i])) x2, y2 = map(int, np.round(kpts[j])) if not (in_bounds(x1, y1) and in_bounds(x2, y2)): continue cv2.line(img, (x1, y1), (x2, y2), link_colors[lk % len(link_colors)], thickness=max(1, thickness), lineType=cv2.LINE_AA) for j_idx, (xy, v, s) in enumerate(zip(kpts, vis, score)): if not v or s < kpt_thr: continue x, y = map(int, np.round(xy)) if not in_bounds(x, y): continue c = kpt_colors[min(j_idx, len(kpt_colors) - 1)] cv2.circle(img, (x, y), radius, c, thickness=-1, lineType=cv2.LINE_AA) if show_kpt_idx: cv2.putText(img, str(j_idx), (x + radius, y - radius), cv2.FONT_HERSHEY_SIMPLEX, 0.4, c, 1, cv2.LINE_AA) return img def _detect_persons(image_rgb: np.ndarray, threshold: float = 0.5): import torch proc, det = _get_detector() pil_img = Image.fromarray(image_rgb) inputs = proc(images=pil_img, return_tensors="pt") with torch.no_grad(): outputs = det(**inputs) target_sizes = torch.tensor([image_rgb.shape[:2]]) results = proc.post_process_object_detection( outputs, target_sizes=target_sizes, threshold=threshold )[0] person_mask = results["labels"] == 1 # COCO class 1 = person boxes = results["boxes"][person_mask].cpu().numpy() scores = results["scores"][person_mask].cpu().numpy().reshape(-1, 1) if len(boxes) == 0: h, w = image_rgb.shape[:2] return np.array([[0, 0, w - 1, h - 1, 1.0]], dtype=np.float32) return np.concatenate([boxes, scores], axis=1).astype(np.float32) def _infer_pose(image_bgr, model, kpt_thr: float = 0.3): import torch import cv2 image_rgb = cv2.cvtColor(image_bgr, cv2.COLOR_BGR2RGB) bboxes = _detect_persons(image_rgb) inputs_list, samples_list = [], [] for bbox in bboxes: data_info = dict(img=image_bgr, bbox=bbox[None, :4], bbox_score=np.ones(1, dtype=np.float32)) data = model.pipeline(data_info) data = model.data_preprocessor(data) inputs_list.append(data["inputs"]) samples_list.append(data["data_samples"]) inputs = torch.cat(inputs_list, dim=0) with torch.no_grad(): pred = model(inputs).cpu().numpy() keypoints, scores = [], [] for i, sample in enumerate(samples_list): kpts_i, scr_i = model.codec.decode(pred[i]) meta = sample["meta"] if isinstance(sample, dict) else sample.metainfo kpts_i = kpts_i / np.array(meta["input_size"]) * meta["bbox_scale"] + meta["bbox_center"] - 0.5 * meta["bbox_scale"] keypoints.append(kpts_i[0]) scores.append(scr_i[0]) pmeta = model.pose_metainfo vis_rgb = image_rgb.copy() # Scale render thickness so 308-keypoint dense pose stays visible on high-res input short_side = min(vis_rgb.shape[:2]) radius_px = max(3, short_side // 200) thick_px = max(2, short_side // 250) box_thick = max(2, short_side // 300) for bbox, kpts, scr in zip(bboxes, keypoints, scores): x1, y1, x2, y2 = map(int, bbox[:4]) cv2.rectangle(vis_rgb, (x1, y1), (x2, y2), (0, 255, 0), box_thick) vis_rgb = visualize_keypoints( image=vis_rgb, keypoints=[kpts], keypoints_visible=[np.ones(len(scr), dtype=bool)], keypoint_scores=[scr], radius=radius_px, thickness=thick_px, kpt_thr=kpt_thr, skeleton=pmeta["skeleton_links"], kpt_color=pmeta["keypoint_colors"], link_color=pmeta["skeleton_link_colors"], ) return Image.fromarray(vis_rgb), f"persons={len(bboxes)} | kpts/person={len(keypoints[0]) if keypoints else 0}" # --- Predict entry point ---------------------------------------------------- def predict(image: Image.Image, task: str, size: str): if image is None: return None, "No image provided" key = (task, size) if key not in VARIANTS: return None, f"Unknown variant {task}-{size}. Allowed: {sorted(VARIANTS.keys())}" t0 = time.time() try: import cv2 image_pil = image.convert("RGB") in_w, in_h = image_pil.size image_bgr = cv2.cvtColor(np.array(image_pil), cv2.COLOR_RGB2BGR) kind = VARIANTS[key]["kind"] if size == "5b": out_img, info = _infer_dense_5b(image_bgr, task) elif kind == "pose": model = _get_pose_model(size) out_img, info = _infer_pose(image_bgr, model) else: model = _get_dense_model(task, size) if kind == "seg": out_img, info = _infer_seg(image_bgr, model) elif kind == "normal": out_img, info = _infer_normal(image_bgr, model) elif kind == "pointmap": out_img, info = _infer_pointmap(image_bgr, model) else: return None, f"Unhandled kind: {kind}" elapsed = time.time() - t0 out_w, out_h = out_img.size return out_img, f"{task}-{size}: done in {elapsed:.1f}s | {in_w}×{in_h} → 1024×768 → {out_w}×{out_h} | {info}" except Exception as e: return None, f"{type(e).__name__}: {e}\n\n{traceback.format_exc()[:1500]}" def health(): return ( f"Service up | dense cache: {list(_MODELS.keys())} | pose cache: {list(_POSE_MODELS.keys())} | " f"detector_loaded={_DETECTOR is not None} | variants={len(VARIANTS)} " f"({sorted(set(t for t, _ in VARIANTS))} × {sorted(set(s for _, s in VARIANTS))})" ) DEMO_IMAGES = sorted(str(p) for p in Path("/app/assets/images").glob("*.jpg")) with gr.Blocks(title="Sapiens2 CPU", css=""" #img-in,#img-out{max-height:220px} #status-box textarea{max-height:60px!important;min-height:60px!important} #status-box{flex-grow:0!important} """) as demo: with gr.Row(equal_height=False): with gr.Column(scale=1): img_in = gr.Image(type="pil", label="Input", height=200, elem_id="img-in") with gr.Row(): task_in = gr.Dropdown(choices=["seg", "normal", "pointmap", "pose"], value="seg", label="Task", scale=1) size_in = gr.Dropdown(choices=["0.4b", "0.8b", "1b", "5b"], value="0.4b", label="Size", scale=1) run_btn = gr.Button("Predict - 1024×768 native", variant="primary") gr.Examples( examples=[[u] for u in DEMO_IMAGES], inputs=[img_in], examples_per_page=6, cache_examples=False, label="Meta demo images", ) with gr.Column(scale=1): img_out = gr.Image(type="pil", label="Output", height=200, elem_id="img-out") status = gr.Textbox(show_label=False, lines=2, max_lines=2, interactive=False, container=False, placeholder="Status will show here after Predict", elem_id="status-box") run_btn.click( fn=predict, inputs=[img_in, task_in, size_in], outputs=[img_out, status], api_name="predict" ) # Keep health endpoint accessible via API (no UI button — useless in browser) gr.Button(visible=False).click(fn=health, outputs=[gr.Textbox(visible=False)], api_name="health") demo.queue(default_concurrency_limit=1) if __name__ == "__main__": demo.launch()