| import gc |
| import os |
| import sys |
| import tempfile |
| import urllib.request |
| import zipfile |
| from functools import lru_cache |
| from pathlib import Path |
|
|
| |
| def _ensure_positive_int_env(name: str, default: int) -> None: |
| value = os.getenv(name, "").strip() |
| if not value.isdigit() or int(value) < 1: |
| os.environ[name] = str(default) |
|
|
|
|
| _ensure_positive_int_env("OMP_NUM_THREADS", 1) |
| os.environ.setdefault("HF_HUB_DISABLE_TELEMETRY", "1") |
|
|
| import torch |
|
|
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| |
| if DEVICE.type == "cuda": |
| os.environ["PYOPENGL_PLATFORM"] = "egl" |
| else: |
| os.environ["PYOPENGL_PLATFORM"] = "osmesa" |
|
|
| import cv2 |
| import gradio as gr |
| import numpy as np |
| from gradio import Error as GradioError |
| from huggingface_hub import snapshot_download, whoami |
|
|
| HF_TOKEN = os.getenv("HF_TOKEN") |
| HF_REPO_ID = os.getenv("SAM3D_HF_REPO_ID", "facebook/sam-3d-body-dinov3") |
| MAX_IMAGE_SIDE = int(os.getenv("MAX_IMAGE_SIDE", "1024")) |
|
|
| SRC_CACHE_DIR = Path(tempfile.gettempdir()) / "sam_3d_body_src" |
| SRC_ROOT = SRC_CACHE_DIR / "sam-3d-body-main" |
| SRC_ZIP = SRC_CACHE_DIR / "sam-3d-body.zip" |
|
|
| print("Using device:", DEVICE) |
|
|
| if not HF_TOKEN: |
| raise GradioError( |
| "Missing HF_TOKEN. Add a Hugging Face user access token in " |
| "Space Settings -> Repository secrets under the key HF_TOKEN." |
| ) |
|
|
| try: |
| me = whoami(token=HF_TOKEN) |
| print( |
| "Authenticated on Hugging Face as:", |
| me.get("name") or me.get("fullname") or "unknown", |
| ) |
| except Exception as exc: |
| raise GradioError(f"HF_TOKEN is present but invalid or unusable: {exc}") from exc |
|
|
|
|
| def _ensure_repo_on_path() -> None: |
| try: |
| import sam_3d_body |
| return |
| except Exception: |
| pass |
|
|
| if not SRC_ROOT.exists(): |
| SRC_CACHE_DIR.mkdir(parents=True, exist_ok=True) |
| url = "https://codeload.github.com/facebookresearch/sam-3d-body/zip/refs/heads/main" |
| urllib.request.urlretrieve(url, SRC_ZIP) |
| with zipfile.ZipFile(SRC_ZIP, "r") as zf: |
| zf.extractall(SRC_CACHE_DIR) |
|
|
| src_root_str = str(SRC_ROOT) |
| if src_root_str not in sys.path: |
| sys.path.insert(0, src_root_str) |
|
|
| import sam_3d_body |
|
|
|
|
| def _resize_longest_side(img: np.ndarray, max_side: int = MAX_IMAGE_SIDE) -> np.ndarray: |
| h, w = img.shape[:2] |
| scale = min(1.0, max_side / float(max(h, w))) |
| if scale >= 1.0: |
| return img |
| new_w = max(1, int(round(w * scale))) |
| new_h = max(1, int(round(h * scale))) |
| return cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_AREA) |
|
|
|
|
| def _patch_estimator_for_cpu_only() -> None: |
| if DEVICE.type != "cpu": |
| return |
|
|
| _ensure_repo_on_path() |
| import sam_3d_body.sam_3d_body_estimator as estimator_mod |
|
|
| if getattr(estimator_mod, "_cpu_safe_patch_applied", False): |
| return |
|
|
| original_recursive_to = estimator_mod.recursive_to |
|
|
| def recursive_to_safe(data, device): |
| if device == "cuda": |
| device = "cpu" |
| return original_recursive_to(data, device) |
|
|
| estimator_mod.recursive_to = recursive_to_safe |
| estimator_mod._cpu_safe_patch_applied = True |
| print("Applied CPU-only batch transfer patch.") |
|
|
|
|
| def _to_numpy(value): |
| if isinstance(value, np.ndarray): |
| return value |
| if isinstance(value, torch.Tensor): |
| return value.detach().cpu().numpy() |
| return np.asarray(value) |
|
|
|
|
| def _to_scalar_float(value) -> float: |
| arr = _to_numpy(value) |
| if np.isscalar(arr): |
| return float(arr) |
| arr = np.asarray(arr).reshape(-1) |
| return float(arr[0]) |
|
|
|
|
| def _normalize_person_output(person_output: dict) -> dict: |
| normalized = {} |
| for key, value in person_output.items(): |
| if isinstance(value, (torch.Tensor, np.ndarray, list, tuple)): |
| try: |
| normalized[key] = _to_numpy(value) |
| except Exception: |
| normalized[key] = value |
| else: |
| normalized[key] = value |
|
|
| |
| if "pred_vertices" in normalized: |
| normalized["pred_vertices"] = np.asarray(normalized["pred_vertices"], dtype=np.float32) |
| if "pred_cam_t" in normalized: |
| normalized["pred_cam_t"] = np.asarray(normalized["pred_cam_t"], dtype=np.float32).reshape(-1)[:3] |
| if "pred_keypoints_2d" in normalized: |
| normalized["pred_keypoints_2d"] = np.asarray(normalized["pred_keypoints_2d"], dtype=np.float32) |
| if "bbox" in normalized: |
| normalized["bbox"] = np.asarray(normalized["bbox"], dtype=np.float32).reshape(-1)[:4] |
| if "lhand_bbox" in normalized: |
| normalized["lhand_bbox"] = np.asarray(normalized["lhand_bbox"], dtype=np.float32).reshape(-1)[:4] |
| if "rhand_bbox" in normalized: |
| normalized["rhand_bbox"] = np.asarray(normalized["rhand_bbox"], dtype=np.float32).reshape(-1)[:4] |
|
|
| return normalized |
|
|
|
|
| def _draw_bbox(img: np.ndarray, bbox, color) -> None: |
| if bbox is None: |
| return |
| x1, y1, x2, y2 = [int(v) for v in np.asarray(bbox).tolist()] |
| cv2.rectangle(img, (x1, y1), (x2, y2), color, 2) |
|
|
|
|
| def _draw_keypoints(img: np.ndarray, keypoints_2d: np.ndarray, color=(0, 0, 255)) -> None: |
| if keypoints_2d is None or len(keypoints_2d) == 0: |
| return |
| pts = np.asarray(keypoints_2d) |
| for pt in pts: |
| if len(pt) < 2: |
| continue |
| x, y = int(pt[0]), int(pt[1]) |
| cv2.circle(img, (x, y), 3, color, -1) |
|
|
|
|
| def render_result_panorama(img_bgr: np.ndarray, outputs, faces) -> np.ndarray: |
| _ensure_repo_on_path() |
| from sam_3d_body.visualization.renderer import Renderer |
|
|
| people = [_normalize_person_output(person) for person in outputs] |
| faces_np = np.asarray(_to_numpy(faces), dtype=np.int32) |
|
|
| if not people: |
| raise ValueError("No people to render.") |
|
|
| img_orig = img_bgr.copy() |
| img_kpts = img_bgr.copy() |
|
|
| |
| all_depths = np.stack([p["pred_cam_t"] for p in people], axis=0)[:, 2] |
| people_sorted = [people[idx] for idx in np.argsort(-all_depths)] |
|
|
| for person in people_sorted: |
| _draw_keypoints(img_kpts, person.get("pred_keypoints_2d")) |
| _draw_bbox(img_kpts, person.get("bbox"), (0, 255, 0)) |
| if "lhand_bbox" in person: |
| _draw_bbox(img_kpts, person.get("lhand_bbox"), (255, 0, 0)) |
| if "rhand_bbox" in person: |
| _draw_bbox(img_kpts, person.get("rhand_bbox"), (0, 0, 255)) |
|
|
| all_pred_vertices = [] |
| all_faces = [] |
| for pid, person in enumerate(people_sorted): |
| verts = np.asarray(person["pred_vertices"], dtype=np.float32) |
| cam_t = np.asarray(person["pred_cam_t"], dtype=np.float32).reshape(1, 3) |
| all_pred_vertices.append(verts + cam_t) |
| all_faces.append(faces_np + verts.shape[0] * pid) |
|
|
| all_pred_vertices = np.concatenate(all_pred_vertices, axis=0) |
| all_faces = np.concatenate(all_faces, axis=0) |
|
|
| tail = min(all_pred_vertices.shape[0], 2 * 18439) |
| fake_pred_cam_t = ( |
| np.max(all_pred_vertices[-tail:], axis=0) + np.min(all_pred_vertices[-tail:], axis=0) |
| ) / 2.0 |
| fake_pred_cam_t = fake_pred_cam_t.astype(np.float32) |
| all_pred_vertices = all_pred_vertices - fake_pred_cam_t[None, :] |
|
|
| focal_length = _to_scalar_float(people_sorted[0]["focal_length"]) |
| renderer = Renderer(focal_length=focal_length, faces=all_faces) |
|
|
| light_blue = (0.65098039, 0.74117647, 0.85882353) |
|
|
| img_mesh = renderer( |
| all_pred_vertices, |
| fake_pred_cam_t, |
| img_bgr.copy(), |
| mesh_base_color=light_blue, |
| scene_bg_color=(1, 1, 1), |
| ) |
| img_mesh = np.clip(img_mesh * 255.0, 0, 255).astype(np.uint8) |
|
|
| white_img = np.ones_like(img_bgr, dtype=np.uint8) * 255 |
| img_mesh_side = renderer( |
| all_pred_vertices, |
| fake_pred_cam_t, |
| white_img, |
| mesh_base_color=light_blue, |
| scene_bg_color=(1, 1, 1), |
| side_view=True, |
| ) |
| img_mesh_side = np.clip(img_mesh_side * 255.0, 0, 255).astype(np.uint8) |
|
|
| return np.concatenate([img_orig, img_kpts, img_mesh, img_mesh_side], axis=1) |
|
|
|
|
| @lru_cache(maxsize=1) |
| def get_estimator(): |
| _ensure_repo_on_path() |
| _patch_estimator_for_cpu_only() |
|
|
| from sam_3d_body import SAM3DBodyEstimator, load_sam_3d_body |
|
|
| print("HF_TOKEN present:", bool(HF_TOKEN)) |
| print("Target repo:", HF_REPO_ID) |
|
|
| snapshot_dir = snapshot_download( |
| repo_id=HF_REPO_ID, |
| token=HF_TOKEN, |
| allow_patterns=[ |
| "model.ckpt", |
| "model_config.yaml", |
| "assets/mhr_model.pt", |
| ], |
| ) |
|
|
| checkpoint_path = os.path.join(snapshot_dir, "model.ckpt") |
| mhr_path = os.path.join(snapshot_dir, "assets", "mhr_model.pt") |
|
|
| model, model_cfg = load_sam_3d_body( |
| checkpoint_path=checkpoint_path, |
| device=str(DEVICE), |
| mhr_path=mhr_path, |
| ) |
|
|
| try: |
| print("Model parameter device:", next(model.parameters()).device) |
| except StopIteration: |
| print("Model parameter device: unavailable") |
|
|
| if hasattr(model, "image_mean"): |
| print("image_mean device:", model.image_mean.device) |
| if hasattr(model, "image_std"): |
| print("image_std device:", model.image_std.device) |
|
|
| estimator = SAM3DBodyEstimator( |
| sam_3d_body_model=model, |
| model_cfg=model_cfg, |
| human_detector=None, |
| human_segmentor=None, |
| fov_estimator=None, |
| ) |
| return estimator |
|
|
|
|
| def run_inference(image: np.ndarray): |
| if image is None: |
| raise gr.Error("Please upload an image.") |
|
|
| estimator = get_estimator() |
|
|
| img_rgb = _resize_longest_side(image.astype(np.uint8)) |
| img_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR) |
|
|
| try: |
| outputs = estimator.process_one_image(img_rgb) |
| except Exception as exc: |
| raise gr.Error(f"Inference failed: {exc}") from exc |
|
|
| if not outputs: |
| raise gr.Error( |
| "No result was produced. Use an image with one clearly visible full-body person." |
| ) |
|
|
| try: |
| print("Num outputs:", len(outputs)) |
| print("Output keys:", list(outputs[0].keys()) if outputs else []) |
| except Exception: |
| pass |
|
|
| try: |
| rendered_bgr = render_result_panorama(img_bgr, outputs, estimator.faces) |
| rendered = cv2.cvtColor(rendered_bgr, cv2.COLOR_BGR2RGB) |
| status = ( |
| f"Done. Reconstructed people: {len(outputs)} | " |
| f"Processed size: {img_rgb.shape[1]}x{img_rgb.shape[0]} | " |
| f"Rendered size: {rendered.shape[1]}x{rendered.shape[0]} | " |
| f"Device: {DEVICE.type.upper()}" |
| ) |
| except Exception as vis_exc: |
| rendered = img_rgb.copy() |
| status = ( |
| f"Inference succeeded for {len(outputs)} person(s), " |
| f"but visualization failed: {type(vis_exc).__name__}: {vis_exc}" |
| ) |
| print("Visualization failed:", repr(vis_exc)) |
|
|
| del outputs |
| gc.collect() |
| if DEVICE.type == "cuda": |
| torch.cuda.empty_cache() |
|
|
| return rendered, status |
|
|
|
|
| DESCRIPTION = """ |
| # SAM 3D Body — Gradio demo |
| |
| Upload a photo and run full-image 3D body reconstruction. |
| |
| Notes: |
| - Automatically uses GPU when available, otherwise CPU. |
| - Detector, segmentor, and FOV estimator are disabled to keep the app lean. |
| - Best results come from one clearly visible full-body person. |
| - The Space secret `HF_TOKEN` must be set after access to the gated model repo is approved. |
| - Optional env var: `SAM3D_HF_REPO_ID=facebook/sam-3d-body-vith` for the smaller checkpoint. |
| """ |
|
|
|
|
| with gr.Blocks(title="SAM 3D Body") as demo: |
| gr.Markdown(DESCRIPTION) |
|
|
| with gr.Row(): |
| with gr.Column(scale=1): |
| input_image = gr.Image(label="Input image", type="numpy", image_mode="RGB") |
| with gr.Row(): |
| run_btn = gr.Button("Run", variant="primary") |
| gr.ClearButton([input_image], value="Clear") |
| with gr.Column(scale=2): |
| output_image = gr.Image(label="Result", type="numpy") |
| status_box = gr.Textbox(label="Status", interactive=False) |
|
|
| run_btn.click( |
| fn=run_inference, |
| inputs=input_image, |
| outputs=[output_image, status_box], |
| ) |
|
|
| demo.queue(default_concurrency_limit=1) |
|
|
| if __name__ == "__main__": |
| demo.launch(ssr_mode=False) |