3da / app.py
Rausda6's picture
Update app.py
4a0f727 verified
import gc
import os
import sys
import tempfile
import urllib.request
import zipfile
from functools import lru_cache
from pathlib import Path
# Set env vars before importing torch.
def _ensure_positive_int_env(name: str, default: int) -> None:
value = os.getenv(name, "").strip()
if not value.isdigit() or int(value) < 1:
os.environ[name] = str(default)
_ensure_positive_int_env("OMP_NUM_THREADS", 1)
os.environ.setdefault("HF_HUB_DISABLE_TELEMETRY", "1")
import torch
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Must be set before importing pyrender / renderer modules.
if DEVICE.type == "cuda":
os.environ["PYOPENGL_PLATFORM"] = "egl"
else:
os.environ["PYOPENGL_PLATFORM"] = "osmesa"
import cv2
import gradio as gr
import numpy as np
from gradio import Error as GradioError
from huggingface_hub import snapshot_download, whoami
HF_TOKEN = os.getenv("HF_TOKEN")
HF_REPO_ID = os.getenv("SAM3D_HF_REPO_ID", "facebook/sam-3d-body-dinov3")
MAX_IMAGE_SIDE = int(os.getenv("MAX_IMAGE_SIDE", "1024"))
SRC_CACHE_DIR = Path(tempfile.gettempdir()) / "sam_3d_body_src"
SRC_ROOT = SRC_CACHE_DIR / "sam-3d-body-main"
SRC_ZIP = SRC_CACHE_DIR / "sam-3d-body.zip"
print("Using device:", DEVICE)
if not HF_TOKEN:
raise GradioError(
"Missing HF_TOKEN. Add a Hugging Face user access token in "
"Space Settings -> Repository secrets under the key HF_TOKEN."
)
try:
me = whoami(token=HF_TOKEN)
print(
"Authenticated on Hugging Face as:",
me.get("name") or me.get("fullname") or "unknown",
)
except Exception as exc:
raise GradioError(f"HF_TOKEN is present but invalid or unusable: {exc}") from exc
def _ensure_repo_on_path() -> None:
try:
import sam_3d_body # noqa: F401
return
except Exception:
pass
if not SRC_ROOT.exists():
SRC_CACHE_DIR.mkdir(parents=True, exist_ok=True)
url = "https://codeload.github.com/facebookresearch/sam-3d-body/zip/refs/heads/main"
urllib.request.urlretrieve(url, SRC_ZIP)
with zipfile.ZipFile(SRC_ZIP, "r") as zf:
zf.extractall(SRC_CACHE_DIR)
src_root_str = str(SRC_ROOT)
if src_root_str not in sys.path:
sys.path.insert(0, src_root_str)
import sam_3d_body # noqa: F401
def _resize_longest_side(img: np.ndarray, max_side: int = MAX_IMAGE_SIDE) -> np.ndarray:
h, w = img.shape[:2]
scale = min(1.0, max_side / float(max(h, w)))
if scale >= 1.0:
return img
new_w = max(1, int(round(w * scale)))
new_h = max(1, int(round(h * scale)))
return cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_AREA)
def _patch_estimator_for_cpu_only() -> None:
if DEVICE.type != "cpu":
return
_ensure_repo_on_path()
import sam_3d_body.sam_3d_body_estimator as estimator_mod
if getattr(estimator_mod, "_cpu_safe_patch_applied", False):
return
original_recursive_to = estimator_mod.recursive_to
def recursive_to_safe(data, device):
if device == "cuda":
device = "cpu"
return original_recursive_to(data, device)
estimator_mod.recursive_to = recursive_to_safe
estimator_mod._cpu_safe_patch_applied = True
print("Applied CPU-only batch transfer patch.")
def _to_numpy(value):
if isinstance(value, np.ndarray):
return value
if isinstance(value, torch.Tensor):
return value.detach().cpu().numpy()
return np.asarray(value)
def _to_scalar_float(value) -> float:
arr = _to_numpy(value)
if np.isscalar(arr):
return float(arr)
arr = np.asarray(arr).reshape(-1)
return float(arr[0])
def _normalize_person_output(person_output: dict) -> dict:
normalized = {}
for key, value in person_output.items():
if isinstance(value, (torch.Tensor, np.ndarray, list, tuple)):
try:
normalized[key] = _to_numpy(value)
except Exception:
normalized[key] = value
else:
normalized[key] = value
# Force key arrays into stable shapes/dtypes used by rendering.
if "pred_vertices" in normalized:
normalized["pred_vertices"] = np.asarray(normalized["pred_vertices"], dtype=np.float32)
if "pred_cam_t" in normalized:
normalized["pred_cam_t"] = np.asarray(normalized["pred_cam_t"], dtype=np.float32).reshape(-1)[:3]
if "pred_keypoints_2d" in normalized:
normalized["pred_keypoints_2d"] = np.asarray(normalized["pred_keypoints_2d"], dtype=np.float32)
if "bbox" in normalized:
normalized["bbox"] = np.asarray(normalized["bbox"], dtype=np.float32).reshape(-1)[:4]
if "lhand_bbox" in normalized:
normalized["lhand_bbox"] = np.asarray(normalized["lhand_bbox"], dtype=np.float32).reshape(-1)[:4]
if "rhand_bbox" in normalized:
normalized["rhand_bbox"] = np.asarray(normalized["rhand_bbox"], dtype=np.float32).reshape(-1)[:4]
return normalized
def _draw_bbox(img: np.ndarray, bbox, color) -> None:
if bbox is None:
return
x1, y1, x2, y2 = [int(v) for v in np.asarray(bbox).tolist()]
cv2.rectangle(img, (x1, y1), (x2, y2), color, 2)
def _draw_keypoints(img: np.ndarray, keypoints_2d: np.ndarray, color=(0, 0, 255)) -> None:
if keypoints_2d is None or len(keypoints_2d) == 0:
return
pts = np.asarray(keypoints_2d)
for pt in pts:
if len(pt) < 2:
continue
x, y = int(pt[0]), int(pt[1])
cv2.circle(img, (x, y), 3, color, -1)
def render_result_panorama(img_bgr: np.ndarray, outputs, faces) -> np.ndarray:
_ensure_repo_on_path()
from sam_3d_body.visualization.renderer import Renderer
people = [_normalize_person_output(person) for person in outputs]
faces_np = np.asarray(_to_numpy(faces), dtype=np.int32)
if not people:
raise ValueError("No people to render.")
img_orig = img_bgr.copy()
img_kpts = img_bgr.copy()
# Sort farthest to closest, matching upstream visualization logic.
all_depths = np.stack([p["pred_cam_t"] for p in people], axis=0)[:, 2]
people_sorted = [people[idx] for idx in np.argsort(-all_depths)]
for person in people_sorted:
_draw_keypoints(img_kpts, person.get("pred_keypoints_2d"))
_draw_bbox(img_kpts, person.get("bbox"), (0, 255, 0))
if "lhand_bbox" in person:
_draw_bbox(img_kpts, person.get("lhand_bbox"), (255, 0, 0))
if "rhand_bbox" in person:
_draw_bbox(img_kpts, person.get("rhand_bbox"), (0, 0, 255))
all_pred_vertices = []
all_faces = []
for pid, person in enumerate(people_sorted):
verts = np.asarray(person["pred_vertices"], dtype=np.float32)
cam_t = np.asarray(person["pred_cam_t"], dtype=np.float32).reshape(1, 3)
all_pred_vertices.append(verts + cam_t)
all_faces.append(faces_np + verts.shape[0] * pid)
all_pred_vertices = np.concatenate(all_pred_vertices, axis=0)
all_faces = np.concatenate(all_faces, axis=0)
tail = min(all_pred_vertices.shape[0], 2 * 18439)
fake_pred_cam_t = (
np.max(all_pred_vertices[-tail:], axis=0) + np.min(all_pred_vertices[-tail:], axis=0)
) / 2.0
fake_pred_cam_t = fake_pred_cam_t.astype(np.float32)
all_pred_vertices = all_pred_vertices - fake_pred_cam_t[None, :]
focal_length = _to_scalar_float(people_sorted[0]["focal_length"])
renderer = Renderer(focal_length=focal_length, faces=all_faces)
light_blue = (0.65098039, 0.74117647, 0.85882353)
img_mesh = renderer(
all_pred_vertices,
fake_pred_cam_t,
img_bgr.copy(),
mesh_base_color=light_blue,
scene_bg_color=(1, 1, 1),
)
img_mesh = np.clip(img_mesh * 255.0, 0, 255).astype(np.uint8)
white_img = np.ones_like(img_bgr, dtype=np.uint8) * 255
img_mesh_side = renderer(
all_pred_vertices,
fake_pred_cam_t,
white_img,
mesh_base_color=light_blue,
scene_bg_color=(1, 1, 1),
side_view=True,
)
img_mesh_side = np.clip(img_mesh_side * 255.0, 0, 255).astype(np.uint8)
return np.concatenate([img_orig, img_kpts, img_mesh, img_mesh_side], axis=1)
@lru_cache(maxsize=1)
def get_estimator():
_ensure_repo_on_path()
_patch_estimator_for_cpu_only()
from sam_3d_body import SAM3DBodyEstimator, load_sam_3d_body
print("HF_TOKEN present:", bool(HF_TOKEN))
print("Target repo:", HF_REPO_ID)
snapshot_dir = snapshot_download(
repo_id=HF_REPO_ID,
token=HF_TOKEN,
allow_patterns=[
"model.ckpt",
"model_config.yaml",
"assets/mhr_model.pt",
],
)
checkpoint_path = os.path.join(snapshot_dir, "model.ckpt")
mhr_path = os.path.join(snapshot_dir, "assets", "mhr_model.pt")
model, model_cfg = load_sam_3d_body(
checkpoint_path=checkpoint_path,
device=str(DEVICE),
mhr_path=mhr_path,
)
try:
print("Model parameter device:", next(model.parameters()).device)
except StopIteration:
print("Model parameter device: unavailable")
if hasattr(model, "image_mean"):
print("image_mean device:", model.image_mean.device)
if hasattr(model, "image_std"):
print("image_std device:", model.image_std.device)
estimator = SAM3DBodyEstimator(
sam_3d_body_model=model,
model_cfg=model_cfg,
human_detector=None,
human_segmentor=None,
fov_estimator=None,
)
return estimator
def run_inference(image: np.ndarray):
if image is None:
raise gr.Error("Please upload an image.")
estimator = get_estimator()
img_rgb = _resize_longest_side(image.astype(np.uint8))
img_bgr = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2BGR)
try:
outputs = estimator.process_one_image(img_rgb)
except Exception as exc:
raise gr.Error(f"Inference failed: {exc}") from exc
if not outputs:
raise gr.Error(
"No result was produced. Use an image with one clearly visible full-body person."
)
try:
print("Num outputs:", len(outputs))
print("Output keys:", list(outputs[0].keys()) if outputs else [])
except Exception:
pass
try:
rendered_bgr = render_result_panorama(img_bgr, outputs, estimator.faces)
rendered = cv2.cvtColor(rendered_bgr, cv2.COLOR_BGR2RGB)
status = (
f"Done. Reconstructed people: {len(outputs)} | "
f"Processed size: {img_rgb.shape[1]}x{img_rgb.shape[0]} | "
f"Rendered size: {rendered.shape[1]}x{rendered.shape[0]} | "
f"Device: {DEVICE.type.upper()}"
)
except Exception as vis_exc:
rendered = img_rgb.copy()
status = (
f"Inference succeeded for {len(outputs)} person(s), "
f"but visualization failed: {type(vis_exc).__name__}: {vis_exc}"
)
print("Visualization failed:", repr(vis_exc))
del outputs
gc.collect()
if DEVICE.type == "cuda":
torch.cuda.empty_cache()
return rendered, status
DESCRIPTION = """
# SAM 3D Body — Gradio demo
Upload a photo and run full-image 3D body reconstruction.
Notes:
- Automatically uses GPU when available, otherwise CPU.
- Detector, segmentor, and FOV estimator are disabled to keep the app lean.
- Best results come from one clearly visible full-body person.
- The Space secret `HF_TOKEN` must be set after access to the gated model repo is approved.
- Optional env var: `SAM3D_HF_REPO_ID=facebook/sam-3d-body-vith` for the smaller checkpoint.
"""
with gr.Blocks(title="SAM 3D Body") as demo:
gr.Markdown(DESCRIPTION)
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(label="Input image", type="numpy", image_mode="RGB")
with gr.Row():
run_btn = gr.Button("Run", variant="primary")
gr.ClearButton([input_image], value="Clear")
with gr.Column(scale=2):
output_image = gr.Image(label="Result", type="numpy")
status_box = gr.Textbox(label="Status", interactive=False)
run_btn.click(
fn=run_inference,
inputs=input_image,
outputs=[output_image, status_box],
)
demo.queue(default_concurrency_limit=1)
if __name__ == "__main__":
demo.launch(ssr_mode=False)