zoo3d / mvp.py
bulatko's picture
Runtime fallback: install local detectron2 if import fails (no build isolation)
dcb1ea4
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
import cv2
import torch
import numpy as np
import gradio as gr
import sys
import shutil
from datetime import datetime
import glob
import gc
import time
import open3d as o3d
import open_clip
from open_clip import tokenizer
import trimesh
import matplotlib.pyplot as plt
import subprocess
import tempfile
import contextlib
from huggingface_hub import hf_hub_download
try:
import gdown
except Exception:
gdown = None
# Defensive patch: some gradio_client versions crash on JSON schema with boolean additionalProperties.
try:
import gradio_client.utils as _gcu
if hasattr(_gcu, "_json_schema_to_python_type"):
_orig = _gcu._json_schema_to_python_type
def _json_schema_to_python_type_patched(schema, defs=None):
if isinstance(schema, bool):
return "Any"
return _orig(schema, defs)
_gcu._json_schema_to_python_type = _json_schema_to_python_type_patched
except Exception:
pass
os.environ.setdefault("MAX_JOBS", "1")
REPO_ROOT = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.join(REPO_ROOT, "vggt"))
MK_PATH = os.path.join(REPO_ROOT, "MaskClustering")
DETECTRON2_ROOT = os.path.join(REPO_ROOT, "MaskClustering", "third_party", "detectron2")
os.environ["PYTHONPATH"] = os.environ.get("PYTHONPATH", "") + ":" + os.path.join(REPO_ROOT, "MaskClustering", "third_party")
# Ensure local detectron2 is installed at runtime if postBuild didn't run
try:
import detectron2 # noqa: F401
except Exception:
print("[runtime] detectron2 not found. Installing local detectron2 (editable, no build isolation)...")
os.system("python -m pip install --no-build-isolation -e ./MaskClustering/third_party/detectron2")
import importlib
importlib.invalidate_caches()
import detectron2 # noqa: F401
# If detectron2 isn't installed as a package, allow importing from vendored source.
if os.path.isdir(DETECTRON2_ROOT) and DETECTRON2_ROOT not in sys.path:
sys.path.append(DETECTRON2_ROOT)
# Writable workdir (HF Spaces: prefer /tmp)
WORK_DIR = os.environ.get("ZOO3D_WORKDIR", os.path.join(tempfile.gettempdir(), "zoo3d"))
os.makedirs(WORK_DIR, exist_ok=True)
from visual_util import predictions_to_glb
from vggt.models.vggt import VGGT
from vggt.utils.load_fn import load_and_preprocess_images
from vggt.utils.pose_enc import pose_encoding_to_extri_intri
from vggt.utils.geometry import unproject_depth_map_to_point_map
device = "cuda" if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")
# CPU debug / compatibility knobs:
# - On CPU, VGGT-1B inference is usually impractical. For debugging, we fall back to a lightweight
# dummy pipeline that produces a minimal predictions dict compatible with `predictions_to_glb`.
ZOO3D_ALLOW_CPU = os.environ.get("ZOO3D_ALLOW_CPU", "1") == "1"
ZOO3D_CPU_DUMMY = os.environ.get("ZOO3D_CPU_DUMMY", "1") == "1"
ZOO3D_SKIP_DOWNLOADS = os.environ.get("ZOO3D_SKIP_DOWNLOADS", "0") == "1"
_VGGT_MODEL = None
_METRIC3D_MODEL = None
_CLIP_MODEL = None
_MASK2FORMER_GDRIVE_FILE_ID = "10G7s6bVMwN__bcrR2fBal3goo69Y5Do4"
def _ensure_mask2former_weights(dst_path: str) -> str:
"""
Ensure Mask2Former/CropFormer weights exist at dst_path.
Priority:
1) Use existing file (if present)
2) Download from Google Drive (user-provided link / file id)
3) Fallback: download from HF dataset (qqlu1992/Adobe_EntitySeg)
"""
if os.path.exists(dst_path) and os.path.getsize(dst_path) > 0:
return dst_path
os.makedirs(os.path.dirname(dst_path), exist_ok=True)
# Allow user override via local path
override_path = os.environ.get("MASK2FORMER_WEIGHTS_PATH")
if override_path and os.path.exists(override_path) and os.path.getsize(override_path) > 0:
shutil.copyfile(override_path, dst_path)
return dst_path
# 2) Google Drive
if gdown is not None:
url = f"https://drive.google.com/uc?id={_MASK2FORMER_GDRIVE_FILE_ID}"
out = gdown.download(url, dst_path, quiet=False)
if out is not None and os.path.exists(dst_path) and os.path.getsize(dst_path) > 0:
return dst_path
print("Warning: gdown download failed for Mask2Former weights; falling back to HF dataset...")
else:
print("Warning: gdown is not available; falling back to HF dataset for Mask2Former weights...")
# 3) HF fallback
cached = hf_hub_download(
repo_id="qqlu1992/Adobe_EntitySeg",
repo_type="dataset",
filename="CropFormer_model/Entity_Segmentation/Mask2Former_hornet_3x/Mask2Former_hornet_3x_576d0b.pth",
)
shutil.copyfile(cached, dst_path)
return dst_path
def _init_models():
"""
Lazy-load heavy models so the UI can start quickly on HF Spaces.
"""
global _VGGT_MODEL, _METRIC3D_MODEL, _CLIP_MODEL
if not torch.cuda.is_available():
# CPU-friendly mode for debugging: skip heavy models.
if not ZOO3D_ALLOW_CPU:
raise RuntimeError("CUDA недоступна. Для этого Space нужен GPU (CUDA).")
# We still can load CLIP on CPU if needed, but skip VGGT/Metric3D.
if _CLIP_MODEL is None:
print("[INFO] loading CLIP model (CPU)...")
cm, _, _ = open_clip.create_model_and_transforms("ViT-H-14", pretrained="laion2b_s32b_b79k")
cm.to("cpu")
cm.eval()
print("[INFO] finish loading CLIP model (CPU)...")
globals()["_CLIP_MODEL"] = cm
return None, None, _CLIP_MODEL
if _VGGT_MODEL is None:
print("Initializing and loading VGGT model...")
# Prefer Hugging Face weights for VGGT
try:
m = VGGT.from_pretrained("facebook/VGGT-1B")
except Exception:
m = VGGT()
_URL = "https://huggingface.co/facebook/VGGT-1B/resolve/main/model.pt"
m.load_state_dict(torch.hub.load_state_dict_from_url(_URL))
m.eval()
_VGGT_MODEL = m.to(device)
if _METRIC3D_MODEL is None:
print("Initializing and loading Metric3D model...")
try:
mm = torch.hub.load("yvanyin/metric3d", "metric3d_vit_small", pretrain=True, trust_repo=True)
except TypeError:
mm = torch.hub.load("yvanyin/metric3d", "metric3d_vit_small", pretrain=True)
mm.to(device)
mm.eval()
_METRIC3D_MODEL = mm
if _CLIP_MODEL is None:
print("[INFO] loading CLIP model...")
cm, _, _ = open_clip.create_model_and_transforms("ViT-H-14", pretrained="laion2b_s32b_b79k")
cm.to(device)
cm.eval()
print("[INFO] finish loading CLIP model...")
_CLIP_MODEL = cm
return _VGGT_MODEL, _METRIC3D_MODEL, _CLIP_MODEL
cropformer_name = "Mask2Former_hornet_3x_576d0b.pth"
def check_weights():
if ZOO3D_SKIP_DOWNLOADS:
print("[INFO] ZOO3D_SKIP_DOWNLOADS=1: skipping Mask2Former weights download.")
return
if not os.path.exists(os.path.join(MK_PATH, cropformer_name)):
print(f"Downloading {cropformer_name}...")
os.makedirs(MK_PATH, exist_ok=True)
dst = os.path.join(MK_PATH, cropformer_name)
_ensure_mask2former_weights(dst)
print(f"Downloaded {cropformer_name}...")
else:
print(f"{cropformer_name} already exists...")
#
# IMPORTANT (HF Spaces):
# Do NOT download large weights at import time (startup). We'll download lazily
# when running detection/reconstruction that actually needs them.
#
def extract_text_feature(descriptions, clip_model, target_path):
text_tokens = tokenizer.tokenize(descriptions).to(device)
with torch.no_grad():
text_features = clip_model.encode_text(text_tokens).float()
text_features /= text_features.norm(dim=-1, keepdim=True)
text_features = text_features.cpu().numpy()
text_features_dict = {}
for i, description in enumerate(descriptions):
text_features_dict[description] = text_features[i]
np.save(os.path.join(target_path, "text_features.npy"), text_features_dict)
return text_features_dict
clip_model = None
# -------------------------------------------------------------------------
# 1) Core model inference
# -------------------------------------------------------------------------
def run_model(target_dir, model, metric3d_model=None) -> dict:
"""
Run the VGGT model on images in the 'target_dir/images' folder and return predictions.
"""
print(f"Processing images from {target_dir}")
# Device selection
device = "cuda" if torch.cuda.is_available() else "cpu"
if device != "cuda":
if not ZOO3D_ALLOW_CPU:
raise RuntimeError("CUDA недоступна. Для этого Space нужен GPU (CUDA).")
if not ZOO3D_CPU_DUMMY:
raise RuntimeError(
"CPU режим включен, но ZOO3D_CPU_DUMMY=0. "
"Для отладки поставь ZOO3D_CPU_DUMMY=1 или включи GPU."
)
# Load and preprocess images (we need them for both GPU and CPU-dummy)
# Load and preprocess images
image_names = glob.glob(os.path.join(target_dir, "images", "*"))
image_names = sorted(image_names)
print(f"Found {len(image_names)} images")
if len(image_names) == 0:
raise ValueError("No images found. Check your upload.")
# For CPU dummy mode we want the original HxW for `predictions_to_glb` coloring.
cpu_images_u8 = None
if device == "cpu":
imgs = []
for p in image_names:
im = cv2.imread(p, cv2.IMREAD_COLOR)
if im is None:
continue
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB)
imgs.append(im)
if len(imgs) == 0:
raise ValueError("No readable images found. Check your upload.")
# Make all images same size for stacking
H, W = imgs[0].shape[:2]
imgs2 = []
for im in imgs:
if im.shape[:2] != (H, W):
im = cv2.resize(im, (W, H))
imgs2.append(im)
cpu_images_u8 = np.stack(imgs2, axis=0) # (S,H,W,3) uint8
print(f"CPU dummy: loaded images shape: {cpu_images_u8.shape}")
images = load_and_preprocess_images(image_names)
print(f"Preprocessed images shape: {tuple(images.shape)}")
if device == "cuda":
images = images.to(device)
if device == "cpu":
# Dummy predictions for CPU debugging: minimal keys needed by `predictions_to_glb`
S, H, W = cpu_images_u8.shape[0], cpu_images_u8.shape[1], cpu_images_u8.shape[2]
# Simple planar point cloud in camera space
uu, vv = np.meshgrid(np.arange(W), np.arange(H))
x = (uu - (W / 2.0)) / float(max(W, 1))
y = -(vv - (H / 2.0)) / float(max(W, 1))
z = np.ones_like(x, dtype=np.float32) * 1.0
pts = np.stack([x, y, z], axis=-1).astype(np.float32) # (H,W,3)
world_points_from_depth = np.repeat(pts[None, ...], S, axis=0) # (S,H,W,3)
depth = np.ones((S, H, W, 1), dtype=np.float32)
depth_conf = np.ones((S, H, W), dtype=np.float32)
extrinsic = np.tile(np.array([[1, 0, 0, 0],
[0, 1, 0, 0],
[0, 0, 1, 0]], dtype=np.float32)[None, ...], (S, 1, 1))
intrinsic = np.tile(np.eye(3, dtype=np.float32)[None, ...], (S, 1, 1))
pose = np.tile(np.eye(4, dtype=np.float32)[None, ...], (S, 1, 1))
return {
"images": cpu_images_u8,
"extrinsic": extrinsic,
"intrinsic": intrinsic,
"pose": pose,
"depth": depth,
"depth_conf": depth_conf,
"world_points_from_depth": world_points_from_depth,
}
# GPU inference
# Move model to device
model = model.to(device)
model.eval()
print("Running inference...")
dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16
amp_ctx = torch.cuda.amp.autocast(dtype=dtype) if device == "cuda" else contextlib.nullcontext()
with torch.no_grad():
with amp_ctx:
predictions = model(images)
scale_factor = torch.tensor(1.0, device=device)
# Metric3D inference
if metric3d_model is not None:
print("Running Metric3D inference...")
# images is (B, 3, H, W) in [0, 1]
# Metric3D usually expects [0, 255] if input is tensor via inference dict
metric3d_input = images * 255.0
m_depths = []
# Process one by one to avoid potential batch issues if inference doesn't support batch
for i in range(metric3d_input.shape[0]):
img = metric3d_input[i:i+1] # (1, 3, H, W)
# Pad image to be divisible by 32 (standard for HourGlass/UNet architectures)
_, _, h, w = img.shape
ph = ((h - 1) // 32 + 1) * 32
pw = ((w - 1) // 32 + 1) * 32
padding = (0, pw - w, 0, ph - h) # left, right, top, bottom
if ph != h or pw != w:
img = torch.nn.functional.pad(img, padding, mode='constant', value=0)
with torch.no_grad():
pred_depth, confidence, _ = metric3d_model.inference({'input': img})
# Crop back to original size
if ph != h or pw != w:
pred_depth = pred_depth[:, :, :h, :w]
m_depths.append(pred_depth)
predictions["metric3d_depth"] = torch.cat(m_depths, dim=0)
# Scale alignment: scale = median(Depths_VGGT / Depths_Metric3D)
# We need to make sure we use valid depths (e.g. > 0) to avoid numerical issues
vggt_depth = predictions["depth"][0] # (B, H, W, 1) or similar
metric_depth = predictions["metric3d_depth"] # (B, 1, H, W) presumably
# Ensure shapes match for broadcasting or direct division
# VGGT depth usually (B, H, W, 1)
# Metric3D depth usually (B, 1, H, W) or (B, H, W) depending on model output.
# Let's check shapes and align.
# Adjust Metric3D depth shape to match VGGT if needed
# Assuming VGGT is (B, H, W, 1) and Metric3D is (B, 1, H, W)
if metric_depth.dim() == 4 and metric_depth.shape[1] == 1:
metric_depth = metric_depth.permute(0, 2, 3, 1) # -> (B, H, W, 1)
elif metric_depth.dim() == 3:
metric_depth = metric_depth.unsqueeze(-1) # -> (B, H, W, 1)
# Move to same device/dtype
vggt_depth = vggt_depth.to(metric_depth.device).float()
metric_depth = metric_depth.float()
# Resize metric depth to match VGGT depth if they differ in spatial resolution
# vggt_depth: (B, H, W, 1) or (B, H, W)
# metric_depth: (B, H, W, 1) after permutation
# Mask for valid values to compute median
print(f"Metric3D depth shape: {metric_depth.shape}")
print(f"VGGT depth shape: {vggt_depth.shape}")
valid_mask = (metric_depth > 1e-6) & (vggt_depth > 1e-6)
if valid_mask.sum() > 0:
print(f"Valid mask shape: {valid_mask.shape}")
print(f"Metric depth shape: {metric_depth.shape}")
print(f"VGGT depth shape: {vggt_depth.shape}")
ratio = metric_depth[valid_mask] / vggt_depth[valid_mask]
scale_factor = torch.median(ratio)
print(f"Computed scale factor (VGGT / Metric3D): {scale_factor.item():.4f}")
else:
print("Warning: could not compute scale factor; falling back to 1.0")
print("Converting pose encoding to extrinsic and intrinsic matrices...")
extrinsic, intrinsic = pose_encoding_to_extri_intri(predictions["pose_enc"], images.shape[-2:])
extrinsic = extrinsic[0]
add = torch.zeros_like(extrinsic[:, 2:])
add[..., -1] = 1
extrinsic = torch.cat([extrinsic, add], dim=-2)
zero_extrinsic = extrinsic[0]
for i, e in enumerate(extrinsic):
extrinsic[i] = zero_extrinsic @ torch.linalg.inv(e)
extrinsic[i, :3, 3] *= scale_factor
extrinsic_inv = torch.linalg.inv(extrinsic)
print(f"Extrinsic: {extrinsic.shape}")
extrinsic_inv = extrinsic_inv[None, ..., :3, :]
predictions["extrinsic"] = extrinsic_inv
predictions["pose"] = extrinsic[None]
print(f"Extrinsic: {extrinsic.shape} {extrinsic}")
predictions["intrinsic"] = intrinsic
# Convert tensors to numpy
for key in predictions.keys():
if isinstance(predictions[key], torch.Tensor):
try:
predictions[key] = predictions[key].cpu().numpy().squeeze(0) # remove batch dimension
except ValueError:
pass
# Generate world points from depth map
print("Computing world points from depth map...")
predictions["depth"] = predictions["depth"] * float(scale_factor.item())
depth_map = predictions["depth"]
world_points = unproject_depth_map_to_point_map(depth_map, predictions["extrinsic"], predictions["intrinsic"])
predictions["world_points_from_depth"] = world_points
# Clean up
if torch.cuda.is_available():
torch.cuda.empty_cache()
return predictions
# -------------------------------------------------------------------------
# 2) Handle uploaded video/images --> produce target_dir + images
# -------------------------------------------------------------------------
def handle_uploads(input_video, input_images):
"""
Create a new 'target_dir' + 'images' subfolder, and place user-uploaded
images or extracted frames from video into it. Return (target_dir, image_paths).
"""
start_time = time.time()
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Create a unique folder name
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
target_dir = os.path.join(WORK_DIR, "input", timestamp)
target_dir_images = os.path.join(target_dir, "images")
# Clean up if somehow that folder already exists
if os.path.exists(target_dir):
shutil.rmtree(target_dir)
os.makedirs(target_dir)
os.makedirs(target_dir_images)
image_paths = []
# --- Handle images ---
if input_images is not None:
for file_data in input_images:
if isinstance(file_data, dict) and "name" in file_data:
file_path = file_data["name"]
else:
file_path = file_data
dst_path = os.path.join(target_dir_images, os.path.basename(file_path))
shutil.copy(file_path, dst_path)
image_paths.append(dst_path)
# --- Handle video ---
if input_video is not None:
if isinstance(input_video, dict) and "name" in input_video:
video_path = input_video["name"]
else:
video_path = input_video
vs = cv2.VideoCapture(video_path)
fps = vs.get(cv2.CAP_PROP_FPS)
frame_interval = int(fps * 1) # 1 frame/sec
count = 0
video_frame_num = 0
while True:
gotit, frame = vs.read()
if not gotit:
break
count += 1
if count % frame_interval == 0:
image_path = os.path.join(target_dir_images, f"{video_frame_num:06}.jpg")
cv2.imwrite(image_path, frame)
image_paths.append(image_path)
video_frame_num += 1
# Sort final images for gallery
image_paths = sorted(image_paths)
end_time = time.time()
print(f"Files copied to {target_dir_images}; took {end_time - start_time:.3f} seconds")
return target_dir, image_paths
# -------------------------------------------------------------------------
# 3) Update gallery on upload
# -------------------------------------------------------------------------
def update_gallery_on_upload(input_video, input_images):
"""
Whenever user uploads or changes files, immediately handle them
and show in the gallery. Return (target_dir, image_paths).
If nothing is uploaded, returns "None" and empty list.
"""
if not input_video and not input_images:
return None, None, None, None
target_dir, image_paths = handle_uploads(input_video, input_images)
return None, target_dir, image_paths, "Upload complete. Click 'Detect Objects' to begin 3D processing."
# -------------------------------------------------------------------------
# 4) Reconstruction: uses the target_dir plus any viz parameters
# -------------------------------------------------------------------------
def reconstruct(
target_dir,
conf_thres=50.0,
frame_filter="All",
mask_black_bg=False,
mask_white_bg=False,
show_cam=True,
mask_sky=False,
prediction_mode="Depthmap and Camera Branch",
text_labels="",
):
"""
Perform reconstruction using the already-created target_dir/images.
"""
prediction_mode = "Depthmap and Camera Branch" # Force prediction mode
if not os.path.isdir(target_dir) or target_dir == "None":
return None, "No valid target directory found. Please upload first.", None, None
start_time = time.time()
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
# Prepare frame_filter dropdown
target_dir_images = os.path.join(target_dir, "images")
all_files = sorted(os.listdir(target_dir_images)) if os.path.isdir(target_dir_images) else []
image_names = [f.split(".")[0] for f in all_files]
all_files = [f"{i}: {filename}" for i, filename in enumerate(all_files)]
frame_filter_choices = ["All"] + all_files
print("Running run_model...")
with torch.no_grad():
# Ensure CropFormer weights exist if downstream pipeline is enabled
try:
check_weights()
except Exception as e:
print(f"Warning: could not ensure Mask2Former weights at startup: {e}")
vggt_model, metric3d_model, _ = _init_models()
predictions = run_model(target_dir, vggt_model, metric3d_model=metric3d_model)
# Save predictions
prediction_save_path = os.path.join(target_dir, "predictions.npz")
try:
np.savez(prediction_save_path, **predictions)
except Exception as e:
print(f"Warning: could not save predictions to npz: {e}")
depth_path = os.path.join(target_dir, "depth")
pose_path = os.path.join(target_dir, "pose")
intrinsic_path = os.path.join(target_dir, "intrinsic")
os.makedirs(depth_path, exist_ok=True)
os.makedirs(pose_path, exist_ok=True)
os.makedirs(intrinsic_path, exist_ok=True)
for i, d in enumerate(predictions["depth"]):
print(d.shape)
cv2.imwrite(os.path.join(depth_path, f"{image_names[i]}.png"), (d[..., 0] * 1000).astype(np.uint16))
intr = np.eye(4)
intr[:3, :3] = np.mean(predictions["intrinsic"], axis=0)
np.savetxt(os.path.join(intrinsic_path, "intrinsic_depth.txt"), intr)
for i, p in enumerate(predictions["pose"]):
np.savetxt(os.path.join(pose_path, f"{image_names[i]}.txt"), p)
# Handle None frame_filter
if frame_filter is None:
frame_filter = "All"
# Build a GLB file name
glbfile = os.path.join(
target_dir,
f"glbscene_{conf_thres}_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_maskb{mask_black_bg}_maskw{mask_white_bg}_cam{show_cam}_sky{mask_sky}_pred{prediction_mode.replace(' ', '_')}.glb",
)
# Convert predictions to GLB
glbscene, point_cloud_data = predictions_to_glb(
predictions,
conf_thres=conf_thres,
filter_by_frames=frame_filter,
mask_black_bg=mask_black_bg,
mask_white_bg=mask_white_bg,
show_cam=show_cam,
mask_sky=mask_sky,
target_dir=target_dir,
prediction_mode=prediction_mode,
)
# Ensure colors are RGB (remove alpha if present) for Open3D
v = np.asarray(point_cloud_data.vertices)
c = np.asarray(point_cloud_data.colors) / 255.0
if c.shape[1] == 4:
c = c[:, :3]
glbscene.export(file_obj=glbfile)
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(v)
pcd.colors = o3d.utility.Vector3dVector(c)
pcd = pcd.voxel_down_sample(voxel_size=0.01)
o3d.io.write_point_cloud(os.path.join(target_dir, "point_cloud.ply"), pcd)
# Cleanup
del predictions
gc.collect()
torch.cuda.empty_cache()
end_time = time.time()
print(f"Total time: {end_time - start_time:.2f} seconds (including IO)")
log_msg = f"Reconstruction Success ({len(all_files)} frames). Waiting for visualization."
# External pipelines are fragile in Spaces (often require compiled ops).
# We try to run them, but do not fail the whole app if they error.
root_input_dir = os.path.dirname(target_dir)
seq_name = os.path.basename(target_dir)
try:
subprocess.run(
[
sys.executable,
os.path.join(
MK_PATH,
"third_party",
"detectron2",
"projects",
"CropFormer",
"demo_cropformer",
"mask_predict.py",
),
"--config-file",
os.path.join(
MK_PATH,
"third_party",
"detectron2",
"projects",
"CropFormer",
"configs",
"entityv2",
"entity_segmentation",
"mask2former_hornet_3x.yaml",
),
"--root",
root_input_dir,
"--image_path_pattern",
"images/*.jpg",
"--dataset",
"arkit_gt",
"--seq_name_list",
seq_name,
"--opts",
"MODEL.WEIGHTS",
os.path.join(MK_PATH, cropformer_name),
],
check=True,
env={
**os.environ,
# Use installed detectron2; avoid shadowing it with partial local tree
"PYTHONPATH": MK_PATH
+ (os.pathsep + os.environ["PYTHONPATH"] if os.environ.get("PYTHONPATH") else ""),
},
)
subprocess.run(
[
sys.executable,
os.path.join(MK_PATH, "main.py"),
"--config",
"wild",
"--root",
root_input_dir,
"--seq_name_list",
seq_name,
],
check=True,
)
env = dict(os.environ)
env["PYTHONPATH"] = MK_PATH + (os.pathsep + env["PYTHONPATH"] if env.get("PYTHONPATH") else "")
subprocess.run(
[
sys.executable,
os.path.join(MK_PATH, "semantics", "get_open-voc_features.py"),
"--config",
"wild",
"--root",
root_input_dir,
"--seq_name_list",
seq_name,
],
env=env,
check=True,
)
except Exception as e:
print(f"Warning: external MaskClustering pipeline failed: {e}")
return glbfile, log_msg, gr.Dropdown(choices=frame_filter_choices, value=frame_filter, interactive=True)
def visualize_detections(target_dir, conf_thres, frame_filter="All", mask_black_bg=False, mask_white_bg=False, show_cam=True, mask_sky=False, prediction_mode="Depthmap and Camera Branch"):
"""
Generate a GLB scene with bounding boxes for detected objects.
"""
if not target_dir or not os.path.exists(target_dir):
return None, "Target directory not found."
ply_path = os.path.join(target_dir, "point_cloud.ply")
npz_path = os.path.join(target_dir, "output", "object", "prediction.npz")
# 1. Загрузить point cloud как основу сцены
if not os.path.exists(ply_path):
return None, f"Point cloud not found at {ply_path}. Please run detection first."
pcd = o3d.io.read_point_cloud(ply_path)
points = np.asarray(pcd.points)
colors = np.asarray(pcd.colors)
if points.size == 0:
return None, "Point cloud is empty."
# Создаем базовую сцену из облака точек
scene = trimesh.Scene()
if colors.size == 0:
t_colors = np.ones((len(points), 4), dtype=np.uint8) * 255
else:
if colors.max() <= 1.0:
t_colors = (colors * 255).astype(np.uint8)
else:
t_colors = colors.astype(np.uint8)
if t_colors.shape[1] == 3:
t_colors = np.hstack([t_colors, np.ones((len(t_colors), 1), dtype=np.uint8) * 255])
base_pc = trimesh.PointCloud(vertices=points, colors=t_colors)
scene.add_geometry(base_pc)
# 2. Добавить боксы по результатам детекции, если они есть
legend_md = ""
if os.path.exists(npz_path):
try:
loaded = np.load(npz_path, allow_pickle=True)
# Check for detection keys
if 'pred_masks' in loaded:
masks = loaded['pred_masks'].T
labels = loaded['pred_classes']
confs = loaded['pred_score']
# Load text features to map labels to names
text_features_path = os.path.join(target_dir, "text_features.npy")
label_to_name = {}
if os.path.exists(text_features_path):
try:
text_features_dict = np.load(text_features_path, allow_pickle=True).item()
feature_keys = list(text_features_dict.keys())
for i, name in enumerate(feature_keys):
label_to_name[i] = name
except Exception as e:
print(f"Warning: Could not load text features for label mapping: {e}")
# Filter
if isinstance(confs, (list, tuple)):
confs = np.array(confs)
valid_indices = np.where(confs > conf_thres)[0]
if len(valid_indices) > 0:
legend_items = {}
cmap = plt.get_cmap("tab10")
detected_labels = np.unique(labels[valid_indices])
label_to_color = {label: cmap(i % 10) for i, label in enumerate(detected_labels)}
for idx in valid_indices:
mask = masks[idx]
if hasattr(mask, "toarray"):
mask = mask.toarray().flatten()
mask = mask.astype(bool)
# Verify mask size
if len(mask) != len(points):
# This is critical. If GLB points are filtered, masks might not match.
# If masks were generated on the FULL point cloud, we need the FULL point cloud to compute BBox.
# If we can't guarantee alignment, we skip or print warning.
# Ideally, detection pipeline should handle this alignment.
pass
# For now, let's assume they align or we skip.
# If alignment fails, we just don't add the box.
if len(mask) == len(points):
obj_points = points[mask]
if len(obj_points) >= 4:
obj_pcd = trimesh.PointCloud(obj_points)
try:
bbox = obj_pcd.bounding_box_oriented
except Exception:
bbox = obj_pcd.bounding_box
# Пропускаем нерелевантно большие боксы: если максимальная длина > 2.5 м
try:
ext = np.asarray(bbox.extents).astype(float)
if float(np.max(ext)) > 2.5:
continue
except Exception:
pass
# Строим только «каркас» бокса по 8 вершинам и трансформу:
# соединяем пары вершин, чьи локальные знаки отличаются ровно по одной оси
verts = np.asarray(bbox.vertices)
if verts.shape[0] != 8:
continue
T = np.asarray(bbox.transform)
center = T[:3, 3]
R = T[:3, :3]
# Локальные координаты (в осях бокса)
local = (verts - center) @ R
# Присваиваем каждой вершине тройку знаков (+/-1)
signs = np.where(local >= 0.0, 1, -1).astype(int)
sign_to_idx = {tuple(s): i for i, s in enumerate(signs)}
# Сгенерировать 12 рёбер: пары вершин, различающиеся знаком ровно по одной оси
edges_idx = set()
for sx in (-1, 1):
for sy in (-1, 1):
for sz in (-1, 1):
s = (sx, sy, sz)
if s not in sign_to_idx:
continue
for axis in range(3):
s2 = list(s)
s2[axis] *= -1
s2 = tuple(s2)
if s2 in sign_to_idx:
i0 = sign_to_idx[s]
i1 = sign_to_idx[s2]
if i0 != i1:
edges_idx.add(tuple(sorted((i0, i1))))
if not edges_idx:
continue
segments = np.array([[verts[i], verts[j]] for (i, j) in edges_idx], dtype=float)
lbl_idx = labels[idx]
lbl_name = label_to_name.get(lbl_idx, f"Class {lbl_idx}")
color = label_to_color.get(lbl_idx, (1, 0, 0, 1))
color_u8 = (np.array(color) * 255).astype(np.uint8)
# Постоянная толщина рамки: 3 см (0.03)
radius = 0.015
for seg in segments:
p1, p2 = seg[0], seg[1]
v = p2 - p1
length = float(np.linalg.norm(v))
if length <= 1e-8:
continue
direction = v / length
try:
cyl = trimesh.creation.cylinder(radius=radius, height=length, sections=12)
except Exception:
continue
# Повернуть ось Z к направлению ребра и перенести в середину
try:
align = trimesh.geometry.align_vectors([0, 0, 1], direction)
cyl.apply_transform(align)
except Exception:
pass
midpoint = (p1 + p2) / 2.0
cyl.apply_translation(midpoint)
# Материал без влияния освещения (эмуляция unlit через emissive)
try:
emissive = (color_u8[:3] / 255.0).tolist()
mat = trimesh.visual.material.PBRMaterial(
baseColorFactor=(0.0, 0.0, 0.0, 1.0),
metallicFactor=0.0,
roughnessFactor=1.0,
emissiveFactor=emissive,
doubleSided=True,
)
cyl.visual.material = mat
except Exception:
cyl.visual.face_colors = np.tile(color_u8[None, :], (len(cyl.faces), 1))
scene.add_geometry(cyl)
legend_items[lbl_name] = color
legend_md = "### Legend\n"
for lbl_name, color in legend_items.items():
c_u8 = (np.array(color) * 255).astype(np.uint8)
hex_c = "#{:02x}{:02x}{:02x}".format(c_u8[0], c_u8[1], c_u8[2])
legend_md += f"- <span style='color:{hex_c}'>■</span> {lbl_name}\n"
except Exception as e:
print(f"Error loading detections: {e}")
legend_md = f"Error loading detections: {e}"
# Export combined scene (облако + боксы)
out_path = os.path.join(target_dir, f"combined_viz_{conf_thres}.glb")
scene.export(file_obj=out_path)
return out_path, legend_md
def detect_objects(text_labels, target_dir, conf_thres, *viz_args):
"""
Detect objects from text labels and return the detected objects.
"""
# Require non-empty text labels
if not text_labels or not isinstance(text_labels, str) or len([l.strip() for l in text_labels.split(";") if l.strip()]) == 0:
return None, "Please enter at least one text label (separated by ';')."
# Ensure CropFormer weights exist (if detection pipeline uses them)
if torch.cuda.is_available() or not ZOO3D_SKIP_DOWNLOADS:
try:
check_weights()
except Exception as e:
print(f"Warning: could not ensure Mask2Former weights: {e}")
# 1. Run reconstruction first if needed (checking if predictions exist)
predictions_path = os.path.join(target_dir, "predictions.npz")
if not os.path.exists(predictions_path):
# We need to run reconstruction. But reconstruction needs inputs we might not have in this function scope easily
# unless we pass them or assume they are in target_dir.
# reconstruct function takes target_dir. Let's call it.
# However, reconstruct is heavy and takes many args.
# Let's assume for now user clicked Reconstruct or we call it with defaults/passed args if we merged them.
# Actually, if we want one button to do both, we should probably call `reconstruct` logic here.
# But `reconstruct` returns GLB path.
# Let's call run_model directly if predictions don't exist?
# Better: Reuse reconstruct function logic or call it.
# Simplify: If predictions don't exist, run standard reconstruction first
print("Predictions not found, running reconstruction first...")
# We need arguments for reconstruction.
# viz_args contains [frame_filter, mask_black_bg, mask_white_bg, show_cam, mask_sky, prediction_mode]
# conf_thres is passed separately.
# reconstruct signature: target_dir, conf_thres, frame_filter, mask_black_bg, mask_white_bg, show_cam, mask_sky, prediction_mode, text_labels
# viz_args order from click: frame_filter, mask_black_bg, mask_white_bg, show_cam, mask_sky, prediction_mode
reconstruct(target_dir, 50.0, *viz_args, text_labels=text_labels) # conf_thres 3.0 default for reconstruction points
# Extract text features if provided
if text_labels:
labels = [l.strip() for l in text_labels.split(";") if l.strip()]
if labels:
print(f"Extracting features for labels: {labels}")
_, _, clip_model = _init_models()
text_features = extract_text_feature(labels, clip_model, target_dir)
print(f"Text features: {text_features}")
try:
env = dict(os.environ)
env["PYTHONPATH"] = (
DETECTRON2_ROOT
+ os.pathsep
+ MK_PATH
+ (os.pathsep + env["PYTHONPATH"] if env.get("PYTHONPATH") else "")
)
root_input_dir = os.path.dirname(target_dir)
seq_name = os.path.basename(target_dir)
subprocess.run(
[
sys.executable,
os.path.join(MK_PATH, "semantics", "wopen-voc_query.py"),
"--config",
"wild",
"--root",
root_input_dir,
"--seq_name",
seq_name,
],
env=env,
check=True,
)
except Exception as e:
print(f"Warning: open-voc query failed: {e}")
return visualize_detections(target_dir, conf_thres, *viz_args)
# -------------------------------------------------------------------------
# 5) Helper functions for UI resets + re-visualization
# -------------------------------------------------------------------------
def clear_fields():
"""
Clears the 3D viewer, the stored target_dir, and empties the gallery.
"""
return None
def update_log():
"""
Display a quick log message while waiting.
"""
return "Loading and Reconstructing..."
def update_visualization(
target_dir, conf_thres, frame_filter, mask_black_bg, mask_white_bg, show_cam, mask_sky, prediction_mode, is_example
):
"""
Reload saved predictions from npz, create (or reuse) the GLB for new parameters,
and return it for the 3D viewer. If is_example == "True", skip.
"""
# If it's an example click, skip as requested
if is_example == "True":
return None, "No reconstruction available. Please click the Reconstruct button first."
if not target_dir or target_dir == "None" or not os.path.isdir(target_dir):
return None, "No reconstruction available. Please click the Reconstruct button first."
predictions_path = os.path.join(target_dir, "predictions.npz")
if not os.path.exists(predictions_path):
return None, f"No reconstruction available at {predictions_path}. Please run 'Reconstruct' first."
key_list = [
"pose_enc",
"depth",
"depth_conf",
"world_points",
"world_points_conf",
"images",
"extrinsic",
"intrinsic",
"world_points_from_depth",
]
loaded = np.load(predictions_path, allow_pickle=True)
predictions = {key: np.array(loaded[key]) for key in key_list}
glbfile = os.path.join(
target_dir,
f"glbscene_{conf_thres}_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_maskb{mask_black_bg}_maskw{mask_white_bg}_cam{show_cam}_sky{mask_sky}_pred{prediction_mode.replace(' ', '_')}.glb",
)
glbscene = predictions_to_glb(
predictions,
conf_thres=conf_thres,
filter_by_frames=frame_filter,
mask_black_bg=mask_black_bg,
mask_white_bg=mask_white_bg,
show_cam=show_cam,
mask_sky=mask_sky,
target_dir=target_dir,
prediction_mode=prediction_mode,
)
glbscene.export(file_obj=glbfile)
return glbfile, "Updating Visualization"
# -------------------------------------------------------------------------
# Example images
# -------------------------------------------------------------------------
great_wall_video = "examples/videos/great_wall.mp4"
colosseum_video = "examples/videos/Colosseum.mp4"
room_video = "examples/videos/room.mp4"
kitchen_video = "examples/videos/kitchen.mp4"
fern_video = "examples/videos/fern.mp4"
single_cartoon_video = "examples/videos/single_cartoon.mp4"
single_oil_painting_video = "examples/videos/single_oil_painting.mp4"
pyramid_video = "examples/videos/pyramid.mp4"
# -------------------------------------------------------------------------
# 6) Build Gradio UI
# -------------------------------------------------------------------------
theme = gr.themes.Ocean()
theme.set(
checkbox_label_background_fill_selected="*button_primary_background_fill",
checkbox_label_text_color_selected="*button_primary_text_color",
)
with gr.Blocks(
theme=theme,
css="""
.custom-log * {
font-style: italic;
font-size: 22px !important;
background-image: linear-gradient(120deg, #0ea5e9 0%, #6ee7b7 60%, #34d399 100%);
-webkit-background-clip: text;
background-clip: text;
font-weight: bold !important;
color: transparent !important;
text-align: center !important;
}
.example-log * {
font-style: italic;
font-size: 16px !important;
background-image: linear-gradient(120deg, #0ea5e9 0%, #6ee7b7 60%, #34d399 100%);
-webkit-background-clip: text;
background-clip: text;
color: transparent !important;
}
#my_radio .wrap {
display: flex;
flex-wrap: nowrap;
justify-content: center;
align-items: center;
}
#my_radio .wrap label {
display: flex;
width: 50%;
justify-content: center;
align-items: center;
margin: 0;
padding: 10px 0;
box-sizing: border-box;
}
""",
) as demo:
# Instead of gr.State, we use a hidden Textbox:
is_example = gr.Textbox(label="is_example", visible=False, value="None")
num_images = gr.Textbox(label="num_images", visible=False, value="None")
gr.HTML(
"""
<h1>🦁 Zoo3D: Zero-Shot 3D Object Detection at Scene Level 🐼</h1>
<p>
<a href="https://github.com/col14m/zoo3d">GitHub Repository</a>
</p>
<div style="font-size: 16px; line-height: 1.5;">
<p>Upload a video or a set of images to create a 3D reconstruction and run open‑vocabulary 3D object detection from your text labels. The app builds a point cloud and draws colored wireframe bounding boxes for the detected objects.</p>
<h3>Getting Started:</h3>
<ol>
<li><strong>Upload Your Data:</strong> Use "Upload Video" or "Upload Images". Videos are sampled at 1 frame/sec.</li>
<li><strong>Enter Text Labels (Required):</strong> Provide one or more labels separated by semicolons, e.g. <code>chair; table; plant</code>.</li>
<li><strong>Detect:</strong> Click <strong>"Detect Objects"</strong>. The app will reconstruct the scene (if needed) and then run detection.</li>
<li><strong>Threshold (Optional):</strong> Tune the <em>Detection Threshold</em> (0–1). Higher = fewer, more confident detections.</li>
<li><strong>Visualize & Download:</strong> A single 3D view shows the point cloud and colored wireframe boxes. A legend maps colors to labels. You can download the GLB.</li>
</ol>
<p><strong style="color: #0ea5e9;">Notes:</strong> <span style="color: #0ea5e9; font-weight: bold;">Reconstruction is triggered automatically on first run. If no labels are provided, you'll see an error: </span><code>Please enter at least one text label (separated by ';').</code></p>
</div>
"""
)
target_dir_output = gr.Textbox(label="Target Dir", visible=False, value="None")
with gr.Row():
with gr.Column(scale=2):
input_video = gr.Video(label="Upload Video", interactive=True)
input_images = gr.File(file_count="multiple", label="Upload Images", interactive=True)
def _safe_gallery(**kwargs):
# Gradio API differs between versions; HF Spaces may run Gradio 6.x.
# Retry by removing unsupported kwargs.
while True:
try:
return gr.Gallery(**kwargs)
except TypeError as e:
msg = str(e)
# Typical: "got an unexpected keyword argument 'show_download_button'"
bad = None
import re
m = re.search(r"unexpected keyword argument '([^']+)'", msg)
if m:
bad = m.group(1)
if bad and bad in kwargs:
kwargs.pop(bad)
continue
# Fallback: drop known version-sensitive args
for k in ["show_download_button", "preview", "object_fit", "columns", "height"]:
if k in kwargs:
kwargs.pop(k)
break
else:
raise
image_gallery = _safe_gallery(
label="Preview",
columns=4,
height="300px",
show_download_button=True,
object_fit="contain",
preview=True,
)
with gr.Column(scale=4):
text_labels = gr.Textbox(label="Text Labels (separated by ;)", placeholder="cat; dog; car")
with gr.Column():
gr.Markdown("**3D Reconstruction & detection (Point Cloud and Bounding Boxes)**")
log_output = gr.Markdown(
"Please upload a video or images, then click Detect Objects.", elem_classes=["custom-log"]
)
reconstruction_output = gr.Model3D(height=520, zoom_speed=0.5, pan_speed=0.5)
with gr.Row():
detect_btn = gr.Button("Detect Objects", scale=1, variant="primary")
clear_btn = gr.ClearButton(
[input_video, input_images, reconstruction_output, log_output, target_dir_output, image_gallery, text_labels],
scale=1,
)
# with gr.Row():
# prediction_mode = gr.Textbox(
# value="Depthmap and Camera Branch",
# visible=False,
# label="Prediction Mode"
# )
# We'll create a hidden component so the event handlers don't break
prediction_mode = gr.Textbox(value="Depthmap and Camera Branch", visible=False)
# Основные параметры визуализации реконструкции
with gr.Row():
conf_thres = gr.Slider(
minimum=0,
maximum=100,
value=50,
step=0.1,
label="Confidence Threshold (%)",
visible=False,
)
frame_filter = gr.Dropdown(
choices=["All"],
value="All",
label="Show Points from Frame",
visible=False,
)
with gr.Column():
show_cam = gr.Checkbox(label="Show Camera", value=True, visible=False)
mask_sky = gr.Checkbox(label="Filter Sky", value=False, visible=False)
mask_black_bg = gr.Checkbox(label="Filter Black Background", value=False, visible=False)
mask_white_bg = gr.Checkbox(label="Filter White Background", value=False, visible=False)
# Порог для детекции и легенда цветов боксов
detection_conf_thres = gr.Slider(
minimum=0,
maximum=1,
value=0.6,
step=0.01,
label="Detection Threshold",
)
detection_legend = gr.Markdown("Legend will appear here")
# ---------------------- Examples section ----------------------
examples = [
]
def example_pipeline(
input_video,
num_images_str,
input_images,
conf_thres,
mask_black_bg,
mask_white_bg,
show_cam,
mask_sky,
prediction_mode,
is_example_str,
text_labels,
):
"""
1) Copy example images to new target_dir
2) Reconstruct (and Detect if labels present)
3) Return model3D + logs + new_dir + updated dropdown + gallery
We do NOT return is_example. It's just an input.
"""
target_dir, image_paths = handle_uploads(input_video, input_images)
# Always use "All" for frame_filter in examples
frame_filter = "All"
detection_conf = 0.85
glbfile, legend_md = detect_objects(
text_labels,
target_dir,
detection_conf,
frame_filter,
mask_black_bg,
mask_white_bg,
show_cam,
mask_sky,
prediction_mode
)
log_msg = "Example loaded and processed."
return glbfile, log_msg + "\n\n" + legend_md, target_dir, gr.Dropdown(choices=["All"], value="All", interactive=True), image_paths
detect_btn.click(fn=clear_fields, inputs=[], outputs=[]).then(
fn=detect_objects,
inputs=[
text_labels,
target_dir_output,
detection_conf_thres,
frame_filter,
mask_black_bg,
mask_white_bg,
show_cam,
mask_sky,
prediction_mode
],
outputs=[reconstruction_output, detection_legend]
).then(
fn=lambda: "False", inputs=[], outputs=[is_example] # set is_example to "False"
)
detection_conf_thres.change(
fn=visualize_detections,
inputs=[
target_dir_output,
detection_conf_thres,
frame_filter,
mask_black_bg,
mask_white_bg,
show_cam,
mask_sky,
prediction_mode
],
outputs=[reconstruction_output, detection_legend]
)
# -------------------------------------------------------------------------
# Real-time Visualization Updates
# -------------------------------------------------------------------------
conf_thres.change(
update_visualization,
[
target_dir_output,
conf_thres,
frame_filter,
mask_black_bg,
mask_white_bg,
show_cam,
mask_sky,
prediction_mode,
is_example,
],
[reconstruction_output, log_output],
)
frame_filter.change(
update_visualization,
[
target_dir_output,
conf_thres,
frame_filter,
mask_black_bg,
mask_white_bg,
show_cam,
mask_sky,
prediction_mode,
is_example,
],
[reconstruction_output, log_output],
)
mask_black_bg.change(
update_visualization,
[
target_dir_output,
conf_thres,
frame_filter,
mask_black_bg,
mask_white_bg,
show_cam,
mask_sky,
prediction_mode,
is_example,
],
[reconstruction_output, log_output],
)
mask_white_bg.change(
update_visualization,
[
target_dir_output,
conf_thres,
frame_filter,
mask_black_bg,
mask_white_bg,
show_cam,
mask_sky,
prediction_mode,
is_example,
],
[reconstruction_output, log_output],
)
show_cam.change(
update_visualization,
[
target_dir_output,
conf_thres,
frame_filter,
mask_black_bg,
mask_white_bg,
show_cam,
mask_sky,
prediction_mode,
is_example,
],
[reconstruction_output, log_output],
)
prediction_mode.change(
update_visualization,
[
target_dir_output,
conf_thres,
frame_filter,
mask_black_bg,
mask_white_bg,
show_cam,
mask_sky,
prediction_mode,
is_example,
],
[reconstruction_output, log_output],
)
# # -------------------------------------------------------------------------
# # Auto-update gallery whenever user uploads or changes their files
# # -------------------------------------------------------------------------
input_video.change(
fn=update_gallery_on_upload,
inputs=[input_video, input_images],
outputs=[reconstruction_output, target_dir_output, image_gallery, log_output],
)
input_images.change(
fn=update_gallery_on_upload,
inputs=[input_video, input_images],
outputs=[reconstruction_output, target_dir_output, image_gallery, log_output],
)
def main():
demo.queue(max_size=20).launch(show_error=True, share=False, show_api=False)
if __name__ == "__main__":
main()