import gc import os import sys import tempfile import urllib.request import zipfile from functools import lru_cache from pathlib import Path # Set env vars before importing torch. def _ensure_positive_int_env(name: str, default: int) -> None: value = os.getenv(name, "").strip() if not value.isdigit() or int(value) < 1: os.environ[name] = str(default) _ensure_positive_int_env("OMP_NUM_THREADS", 1) os.environ.setdefault("HF_HUB_DISABLE_TELEMETRY", "1") import torch DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Must be set before importing pyrender / renderer modules. if DEVICE.type == "cuda": os.environ["PYOPENGL_PLATFORM"] = "egl" else: os.environ["PYOPENGL_PLATFORM"] = "osmesa" import cv2 import gradio as gr import numpy as np from gradio import Error as GradioError from huggingface_hub import snapshot_download, whoami HF_TOKEN = os.getenv("HF_TOKEN") HF_REPO_ID = os.getenv("SAM3D_HF_REPO_ID", "facebook/sam-3d-body-dinov3") MAX_IMAGE_SIDE = int(os.getenv("MAX_IMAGE_SIDE", "1024")) SRC_CACHE_DIR = Path(tempfile.gettempdir()) / "sam_3d_body_src" SRC_ROOT = SRC_CACHE_DIR / "sam-3d-body-main" SRC_ZIP = SRC_CACHE_DIR / "sam-3d-body.zip" print("Using device:", DEVICE) if not HF_TOKEN: raise GradioError( "Missing HF_TOKEN. Add a Hugging Face user access token in " "Space Settings -> Repository secrets under the key HF_TOKEN." ) try: me = whoami(token=HF_TOKEN) print( "Authenticated on Hugging Face as:", me.get("name") or me.get("fullname") or "unknown", ) except Exception as exc: raise GradioError(f"HF_TOKEN is present but invalid or unusable: {exc}") from exc def _ensure_repo_on_path() -> None: try: import sam_3d_body # noqa: F401 return except Exception: pass if not SRC_ROOT.exists(): SRC_CACHE_DIR.mkdir(parents=True, exist_ok=True) url = "https://codeload.github.com/facebookresearch/sam-3d-body/zip/refs/heads/main" urllib.request.urlretrieve(url, SRC_ZIP) with zipfile.ZipFile(SRC_ZIP, "r") as zf: zf.extractall(SRC_CACHE_DIR) src_root_str = str(SRC_ROOT) if src_root_str not in sys.path: sys.path.insert(0, src_root_str) import sam_3d_body # noqa: F401 def _resize_longest_side(img: np.ndarray, max_side: int = MAX_IMAGE_SIDE) -> np.ndarray: h, w = img.shape[:2] scale = min(1.0, max_side / float(max(h, w))) if scale >= 1.0: return img new_w = max(1, int(round(w * scale))) new_h = max(1, int(round(h * scale))) return cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_AREA) def _patch_estimator_for_cpu_only() -> None: if DEVICE.type != "cpu": return _ensure_repo_on_path() import sam_3d_body.sam_3d_body_estimator as estimator_mod if getattr(estimator_mod, "_cpu_safe_patch_applied", False): return original_recursive_to = estimator_mod.recursive_to def recursive_to_safe(data, device): if device == "cuda": device = "cpu" return original_recursive_to(data, device) estimator_mod.recursive_to = recursive_to_safe estimator_mod._cpu_safe_patch_applied = True print("Applied CPU-only batch transfer patch.") def _to_numpy(value): if isinstance(value, np.ndarray): return value if isinstance(value, torch.Tensor): return value.detach().cpu().numpy() return np.asarray(value) def _to_scalar_float(value) -> float: arr = _to_numpy(value) if np.isscalar(arr): return float(arr) arr = np.asarray(arr).reshape(-1) return float(arr[0]) def _normalize_person_output(person_output: dict) -> dict: normalized = {} for key, value in person_output.items(): if isinstance(value, (torch.Tensor, np.ndarray, list, tuple)): try: normalized[key] = _to_numpy(value) except Exception: normalized[key] = value else: normalized[key] = value # Force key arrays into stable shapes/dtypes used by rendering. if "pred_vertices" in normalized: normalized["pred_vertices"] = np.asarray(normalized["pred_vertices"], dtype=np.float32) if "pred_cam_t" in normalized: normalized["pred_cam_t"] = np.asarray(normalized["pred_cam_t"], dtype=np.float32).reshape(-1)[:3] if "pred_keypoints_2d" in normalized: normalized["pred_keypoints_2d"] = np.asarray(normalized["pred_keypoints_2d"], dtype=np.float32) if "bbox" in normalized: normalized["bbox"] = np.asarray(normalized["bbox"], dtype=np.float32).reshape(-1)[:4] if "lhand_bbox" in normalized: normalized["lhand_bbox"] = np.asarray(normalized["lhand_bbox"], dtype=np.float32).reshape(-1)[:4] if "rhand_bbox" in normalized: normalized["rhand_bbox"] = np.asarray(normalized["rhand_bbox"], dtype=np.float32).reshape(-1)[:4] return normalized def _draw_bbox(img: np.ndarray, bbox, color) -> None: if bbox is None: return x1, y1, x2, y2 = [int(v) for v in np.asarray(bbox).tolist()] cv2.rectangle(img, (x1, y1), (x2, y2), color, 2) def _draw_keypoints(img: np.ndarray, keypoints_2d: np.ndarray, color=(0, 0, 255)) -> None: if keypoints_2d is None or len(keypoints_2d) == 0: return pts = np.asarray(keypoints_2d) for pt in pts: if len(pt) < 2: continue x, y = int(pt[0]), int(pt[1]) cv2.circle(img, (x, y), 3, color, -1) def render_result_panorama(img_bgr: np.ndarray, outputs, faces) -> np.ndarray: _ensure_repo_on_path() from sam_3d_body.visualization.renderer import Renderer people = [_normalize_person_output(person) for person in outputs] faces_np = np.asarray(_to_numpy(faces), dtype=np.int32) if not people: raise ValueError("No people to render.") img_orig = img_bgr.copy() img_kpts = img_bgr.copy() # Sort farthest to closest, matching upstream visualization logic. all_depths = np.stack([p["pred_cam_t"] for p in people], axis=0)[:, 2] people_sorted = [people[idx] for idx in np.argsort(-all_depths)] for person in people_sorted: _draw_keypoints(img_kpts, person.get("pred_keypoints_2d")) _draw_bbox(img_kpts, person.get("bbox"), (0, 255, 0)) if "lhand_bbox" in person: _draw_bbox(img_kpts, person.get("lhand_bbox"), (255, 0, 0)) if "rhand_bbox" in person: _draw_bbox(img_kpts, person.get("rhand_bbox"), (0, 0, 255)) all_pred_vertices = [] all_faces = [] for pid, person in enumerate(people_sorted): verts = np.asarray(person["pred_vertices"], dtype=np.float32) cam_t = np.asarray(person["pred_cam_t"], dtype=np.float32).reshape(1, 3) all_pred_vertices.append(verts + cam_t) all_faces.append(faces_np + verts.shape[0] * pid) all_pred_vertices = np.concatenate(all_pred_vertices, axis=0) all_faces = np.concatenate(all_faces, axis=0) tail = min(all_pred_vertices.shape[0], 2 * 18439) fake_pred_cam_t = ( np.max(all_pred_vertices[-tail:], axis=0) + np.min(all_pred_vertices[-tail:], axis=0) ) / 2.0 fake_pred_cam_t = fake_pred_cam_t.astype(np.float32) all_pred_vertices = all_pred_vertices - fake_pred_cam_t[None, :] focal_length = _to_scalar_float(people_sorted[0]["focal_length"]) renderer = Renderer(focal_length=focal_length, faces=all_faces) light_blue = (0.65098039, 0.74117647, 0.85882353) img_mesh = renderer( all_pred_vertices, fake_pred_cam_t, img_bgr.copy(), mesh_base_color=light_blue, scene_bg_color=(1, 1, 1), ) img_mesh = np.clip(img_mesh * 255.0, 0, 255).astype(np.uint8) white_img = np.ones_like(img_bgr, dtype=np.uint8) * 255 img_mesh_side = renderer( all_pred_vertices, fake_pred_cam_t, white_img, mesh_base_color=light_blue, scene_bg_color=(1, 1, 1), side_view=True, ) img_mesh_side = np.clip(img_mesh_side * 255.0, 0, 255).astype(np.uint8) return np.concatenate([img_orig, img_kpts, img_mesh, img_mesh_side], axis=1) @lru_cache(maxsize=1) def get_estimator(): _ensure_repo_on_path() _patch_estimator_for_cpu_only() from sam_3d_body import SAM3DBodyEstimator, load_sam_3d_body print("HF_TOKEN present:", bool(HF_TOKEN)) print("Target repo:", HF_REPO_ID) snapshot_dir = snapshot_download( repo_id=HF_REPO_ID, token=HF_TOKEN, allow_patterns=[ "model.ckpt", "model_config.yaml", "assets/mhr_model.pt", ], ) checkpoint_path = os.path.join(snapshot_dir, "model.ckpt") mhr_path = os.path.join(snapshot_dir, "assets", "mhr_model.pt") model, model_cfg = load_sam_3d_body( checkpoint_path=checkpoint_path, device=str(DEVICE), mhr_path=mhr_path, ) try: print("Model parameter device:", next(model.parameters()).device) except StopIteration: print("Model parameter device: unavailable") if hasattr(model, "image_mean"): print("image_mean device:", model.image_mean.device) if hasattr(model, "image_std"): print("image_std device:", model.image_std.device) estimator = SAM3DBodyEstimator( sam_3d_body_model=model, model_cfg=model_cfg, human_detector=None, human_segmentor=None, fov_estimator=None, ) return estimator def run_inference(image: np.ndarray): if image is None: raise gr.Error("Please upload an image.") estimator = get_estimator() img_rgb = _resize_longest_side(image.astype(np.uint8)) img_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR) try: outputs = estimator.process_one_image(img_rgb) except Exception as exc: raise gr.Error(f"Inference failed: {exc}") from exc if not outputs: raise gr.Error( "No result was produced. Use an image with one clearly visible full-body person." ) try: print("Num outputs:", len(outputs)) print("Output keys:", list(outputs[0].keys()) if outputs else []) except Exception: pass try: rendered_bgr = render_result_panorama(img_bgr, outputs, estimator.faces) rendered = cv2.cvtColor(rendered_bgr, cv2.COLOR_BGR2RGB) status = ( f"Done. Reconstructed people: {len(outputs)} | " f"Processed size: {img_rgb.shape[1]}x{img_rgb.shape[0]} | " f"Rendered size: {rendered.shape[1]}x{rendered.shape[0]} | " f"Device: {DEVICE.type.upper()}" ) except Exception as vis_exc: rendered = img_rgb.copy() status = ( f"Inference succeeded for {len(outputs)} person(s), " f"but visualization failed: {type(vis_exc).__name__}: {vis_exc}" ) print("Visualization failed:", repr(vis_exc)) del outputs gc.collect() if DEVICE.type == "cuda": torch.cuda.empty_cache() return rendered, status DESCRIPTION = """ # SAM 3D Body — Gradio demo Upload a photo and run full-image 3D body reconstruction. Notes: - Automatically uses GPU when available, otherwise CPU. - Detector, segmentor, and FOV estimator are disabled to keep the app lean. - Best results come from one clearly visible full-body person. - The Space secret `HF_TOKEN` must be set after access to the gated model repo is approved. - Optional env var: `SAM3D_HF_REPO_ID=facebook/sam-3d-body-vith` for the smaller checkpoint. """ with gr.Blocks(title="SAM 3D Body") as demo: gr.Markdown(DESCRIPTION) with gr.Row(): with gr.Column(scale=1): input_image = gr.Image(label="Input image", type="numpy", image_mode="RGB") with gr.Row(): run_btn = gr.Button("Run", variant="primary") gr.ClearButton([input_image], value="Clear") with gr.Column(scale=2): output_image = gr.Image(label="Result", type="numpy") status_box = gr.Textbox(label="Status", interactive=False) run_btn.click( fn=run_inference, inputs=input_image, outputs=[output_image, status_box], ) demo.queue(default_concurrency_limit=1) if __name__ == "__main__": demo.launch(ssr_mode=False)