zoo3d / mvp_complete.py
drozdgk's picture
init
4eeefd1
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
import cv2
import torch
import numpy as np
import gradio as gr
import sys
import shutil
from datetime import datetime
import glob
import gc
import time
import open3d as o3d
import open_clip
from open_clip import tokenizer
import trimesh
import matplotlib.pyplot as plt
MK_PATH = "/home/jovyan/users/bulat/workspace/3drec/Indoor/MaskClustering"
sys.path.append("vggt/")
sys.path.append(MK_PATH)
# Preload CropFormer model once on script import
try:
from exts.cropformer_runner import preload_cropformer_model, make_cropformer_dir
make_cropformer_dir(MK_PATH)
preload_cropformer_model(
config_file=os.path.join(MK_PATH, "third_party/detectron2/projects/CropFormer/configs/entityv2/entity_segmentation/mask2former_hornet_3x.yaml"),
opts=[
"MODEL.WEIGHTS",
os.path.join(MK_PATH, "Mask2Former_hornet_3x_576d0b.pth"),
],
)
except Exception as e:
print(f"[Warning] Could not preload CropFormer model: {e}")
from exts.ov_features import load as load_ov_features, main as main_ov_features
load_ov_features(MK_PATH)
from visual_util import predictions_to_glb
from vggt.models.vggt import VGGT
from vggt.utils.load_fn import load_and_preprocess_images
from vggt.utils.pose_enc import pose_encoding_to_extri_intri
from vggt.utils.geometry import unproject_depth_map_to_point_map
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Initializing and loading VGGT model...")
# model = VGGT.from_pretrained("facebook/VGGT-1B") # another way to load the model
model = VGGT()
_URL = "https://huggingface.co/facebook/VGGT-1B/resolve/main/model.pt"
model.load_state_dict(torch.hub.load_state_dict_from_url(_URL))
model.eval()
model = model.to(device)
print("Initializing and loading Metric3D model...")
try:
metric3d_model = torch.hub.load('yvanyin/metric3d', 'metric3d_vit_small', pretrain=True, trust_repo=True)
except TypeError:
metric3d_model = torch.hub.load('yvanyin/metric3d', 'metric3d_vit_small', pretrain=True)
metric3d_model.to(device)
metric3d_model.eval()
def load_clip():
print(f'[INFO] loading CLIP model...')
model, _, preprocess = open_clip.create_model_and_transforms("ViT-H-14", pretrained="laion2b_s32b_b79k")
model.cuda()
model.eval()
print(f'[INFO]', ' finish loading CLIP model...')
return model, preprocess
def extract_text_feature(descriptions, clip_model, target_path):
text_tokens = tokenizer.tokenize(descriptions).cuda()
with torch.no_grad():
text_features = clip_model.encode_text(text_tokens).float()
text_features /= text_features.norm(dim=-1, keepdim=True)
text_features = text_features.cpu().numpy()
text_features_dict = {}
for i, description in enumerate(descriptions):
text_features_dict[description] = text_features[i]
np.save(os.path.join(target_path, "text_features.npy"), text_features_dict)
return text_features_dict
clip_model, clip_preprocess = load_clip()
# -------------------------------------------------------------------------
# 1) Core model inference
# -------------------------------------------------------------------------
def run_model(target_dir, model, metric3d_model=None) -> dict:
"""
Run the VGGT model on images in the 'target_dir/images' folder and return predictions.
"""
print(f"Processing images from {target_dir}")
# Device check
device = "cuda" if torch.cuda.is_available() else "cpu"
if not torch.cuda.is_available():
raise ValueError("CUDA is not available. Check your environment.")
# Move model to device
model = model.to(device)
model.eval()
# Load and preprocess images
image_names = glob.glob(os.path.join(target_dir, "images", "*"))
image_names = sorted(image_names)
print(f"Found {len(image_names)} images")
if len(image_names) == 0:
raise ValueError("No images found. Check your upload.")
images = load_and_preprocess_images(image_names).to(device)
print(f"Preprocessed images shape: {images.shape}")
# Run inference
print("Running inference...")
dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16
with torch.no_grad():
with torch.cuda.amp.autocast(dtype=dtype):
predictions = model(images)
# Metric3D inference
if metric3d_model is not None:
print("Running Metric3D inference...")
# images is (B, 3, H, W) in [0, 1]
# Metric3D usually expects [0, 255] if input is tensor via inference dict
metric3d_input = images * 255.0
m_depths = []
# Process one by one to avoid potential batch issues if inference doesn't support batch
for i in range(metric3d_input.shape[0]):
img = metric3d_input[i:i+1] # (1, 3, H, W)
# Pad image to be divisible by 32 (standard for HourGlass/UNet architectures)
_, _, h, w = img.shape
ph = ((h - 1) // 32 + 1) * 32
pw = ((w - 1) // 32 + 1) * 32
padding = (0, pw - w, 0, ph - h) # left, right, top, bottom
if ph != h or pw != w:
img = torch.nn.functional.pad(img, padding, mode='constant', value=0)
with torch.no_grad():
pred_depth, confidence, _ = metric3d_model.inference({'input': img})
# Crop back to original size
if ph != h or pw != w:
pred_depth = pred_depth[:, :, :h, :w]
m_depths.append(pred_depth)
predictions["metric3d_depth"] = torch.cat(m_depths, dim=0)
# Scale alignment: scale = median(Depths_VGGT / Depths_Metric3D)
# We need to make sure we use valid depths (e.g. > 0) to avoid numerical issues
vggt_depth = predictions["depth"] # (B, H, W, 1) or similar
metric_depth = predictions["metric3d_depth"] # (B, 1, H, W) presumably
# Ensure shapes match for broadcasting or direct division
# VGGT depth usually (B, H, W, 1)
# Metric3D depth usually (B, 1, H, W) or (B, H, W) depending on model output.
# Let's check shapes and align.
# Adjust Metric3D depth shape to match VGGT if needed
# Assuming VGGT is (B, H, W, 1) and Metric3D is (B, 1, H, W)
if metric_depth.dim() == 4 and metric_depth.shape[1] == 1:
metric_depth = metric_depth.permute(0, 2, 3, 1) # -> (B, H, W, 1)
elif metric_depth.dim() == 3:
metric_depth = metric_depth.unsqueeze(-1) # -> (B, H, W, 1)
# Move to same device/dtype
vggt_depth = vggt_depth.to(metric_depth.device).float()[0]
metric_depth = metric_depth.float()
# Resize metric depth to match VGGT depth if they differ in spatial resolution
# vggt_depth: (B, H, W, 1) or (B, H, W)
# metric_depth: (B, H, W, 1) after permutation
target_h, target_w = vggt_depth.shape[1], vggt_depth.shape[2]
# Mask for valid values to compute median
print(f"Metric3D depth shape: {metric_depth.shape}")
print(f"VGGT depth shape: {vggt_depth.shape}")
valid_mask = (metric_depth > 1e-6) & (vggt_depth > 1e-6)
if valid_mask.sum() > 0:
ratio = metric_depth[valid_mask] / vggt_depth[valid_mask]
scale_factor = torch.median(ratio)
print(f"Computed scale factor (VGGT / Metric3D): {scale_factor.item():.4f}")
print("Converting pose encoding to extrinsic and intrinsic matrices...")
extrinsic, intrinsic = pose_encoding_to_extri_intri(predictions["pose_enc"], images.shape[-2:])
extrinsic = extrinsic[0]
add = torch.zeros_like(extrinsic[:, 2:])
add[..., -1] = 1
extrinsic = torch.cat([extrinsic, add], dim=-2)
zero_extrinsic = extrinsic[0]
for i, e in enumerate(extrinsic):
extrinsic[i] = zero_extrinsic @ torch.linalg.inv(e)
extrinsic[i, :3, 3] *= scale_factor
extrinsic_inv = torch.linalg.inv(extrinsic)
print(f"Extrinsic: {extrinsic.shape}")
extrinsic_inv = extrinsic_inv[None, ..., :3, :]
predictions["extrinsic"] = extrinsic_inv
predictions["pose"] = extrinsic[None]
print(f"Extrinsic: {extrinsic.shape} {extrinsic}")
predictions["intrinsic"] = intrinsic
# Convert tensors to numpy
for key in predictions.keys():
if isinstance(predictions[key], torch.Tensor):
try:
predictions[key] = predictions[key].cpu().numpy().squeeze(0) # remove batch dimension
except ValueError:
pass
# Generate world points from depth map
print("Computing world points from depth map...")
predictions["depth"] = predictions["depth"] * scale_factor.item()
depth_map = predictions["depth"]
world_points = unproject_depth_map_to_point_map(depth_map, predictions["extrinsic"], predictions["intrinsic"])
predictions["world_points_from_depth"] = world_points
# Clean up
torch.cuda.empty_cache()
return predictions
# -------------------------------------------------------------------------
# 2) Handle uploaded video/images --> produce target_dir + images
# -------------------------------------------------------------------------
def handle_uploads(input_video, input_images):
"""
Create a new 'target_dir' + 'images' subfolder, and place user-uploaded
images or extracted frames from video into it. Return (target_dir, image_paths).
"""
start_time = time.time()
gc.collect()
torch.cuda.empty_cache()
# Create a unique folder name
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
target_dir = f"temp/input/{timestamp}"
target_dir_images = os.path.join(target_dir, "images")
# Clean up if somehow that folder already exists
if os.path.exists(target_dir):
shutil.rmtree(target_dir)
os.makedirs(target_dir)
os.makedirs(target_dir_images)
image_paths = []
# --- Handle images ---
if input_images is not None:
for file_data in input_images:
if isinstance(file_data, dict) and "name" in file_data:
file_path = file_data["name"]
else:
file_path = file_data
dst_path = os.path.join(target_dir_images, os.path.basename(file_path))
shutil.copy(file_path, dst_path)
image_paths.append(dst_path)
# --- Handle video ---
if input_video is not None:
if isinstance(input_video, dict) and "name" in input_video:
video_path = input_video["name"]
else:
video_path = input_video
vs = cv2.VideoCapture(video_path)
fps = vs.get(cv2.CAP_PROP_FPS)
frame_interval = int(fps * 1) # 1 frame/sec
count = 0
video_frame_num = 0
while True:
gotit, frame = vs.read()
if not gotit:
break
count += 1
if count % frame_interval == 0:
image_path = os.path.join(target_dir_images, f"{video_frame_num:06}.jpg")
cv2.imwrite(image_path, frame)
image_paths.append(image_path)
video_frame_num += 1
# Sort final images for gallery
image_paths = sorted(image_paths)
end_time = time.time()
print(f"Files copied to {target_dir_images}; took {end_time - start_time:.3f} seconds")
return target_dir, image_paths
# -------------------------------------------------------------------------
# 3) Update gallery on upload
# -------------------------------------------------------------------------
def update_gallery_on_upload(input_video, input_images):
"""
Whenever user uploads or changes files, immediately handle them
and show in the gallery. Return (target_dir, image_paths).
If nothing is uploaded, returns "None" and empty list.
"""
if not input_video and not input_images:
return None, None, None, None
target_dir, image_paths = handle_uploads(input_video, input_images)
return None, target_dir, image_paths, "Upload complete. Click 'Detect Objects' to begin 3D processing."
# -------------------------------------------------------------------------
# 4) Reconstruction: uses the target_dir plus any viz parameters
# -------------------------------------------------------------------------
def reconstruct(
target_dir,
conf_thres=50.0,
frame_filter="All",
mask_black_bg=False,
mask_white_bg=False,
show_cam=True,
mask_sky=False,
prediction_mode="Depthmap and Camera Branch",
text_labels="",
):
"""
Perform reconstruction using the already-created target_dir/images.
"""
prediction_mode = "Depthmap and Camera Branch" # Force prediction mode
if not os.path.isdir(target_dir) or target_dir == "None":
return None, "No valid target directory found. Please upload first.", None, None
start_time = time.time()
gc.collect()
torch.cuda.empty_cache()
# Prepare frame_filter dropdown
target_dir_images = os.path.join(target_dir, "images")
all_files = sorted(os.listdir(target_dir_images)) if os.path.isdir(target_dir_images) else []
image_names = [f.split(".")[0] for f in all_files]
all_files = [f"{i}: {filename}" for i, filename in enumerate(all_files)]
frame_filter_choices = ["All"] + all_files
print("Running run_model...")
with torch.no_grad():
predictions = run_model(target_dir, model, metric3d_model=metric3d_model)
# Save predictions
prediction_save_path = os.path.join(target_dir, "predictions.npz")
try:
np.savez(prediction_save_path, **predictions)
except Exception as e:
print(f"Warning: could not save predictions to npz: {e}")
depth_path = os.path.join(target_dir, "depth")
pose_path = os.path.join(target_dir, "pose")
intrinsic_path = os.path.join(target_dir, "intrinsic")
os.makedirs(depth_path, exist_ok=True)
os.makedirs(pose_path, exist_ok=True)
os.makedirs(intrinsic_path, exist_ok=True)
for i, d in enumerate(predictions["depth"]):
print(d.shape)
cv2.imwrite(os.path.join(depth_path, f"{image_names[i]}.png"), (d[..., 0] * 1000).astype(np.uint16))
intr = np.eye(4)
intr[:3, :3] = np.mean(predictions["intrinsic"], axis=0)
np.savetxt(os.path.join(intrinsic_path, "intrinsic_depth.txt"), intr)
for i, p in enumerate(predictions["pose"]):
np.savetxt(os.path.join(pose_path, f"{image_names[i]}.txt"), p)
# Handle None frame_filter
if frame_filter is None:
frame_filter = "All"
# Build a GLB file name
glbfile = os.path.join(
target_dir,
f"glbscene_{conf_thres}_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_maskb{mask_black_bg}_maskw{mask_white_bg}_cam{show_cam}_sky{mask_sky}_pred{prediction_mode.replace(' ', '_')}.glb",
)
# Convert predictions to GLB
glbscene, point_cloud_data = predictions_to_glb(
predictions,
conf_thres=conf_thres,
filter_by_frames=frame_filter,
mask_black_bg=mask_black_bg,
mask_white_bg=mask_white_bg,
show_cam=show_cam,
mask_sky=mask_sky,
target_dir=target_dir,
prediction_mode=prediction_mode,
)
# Ensure colors are RGB (remove alpha if present) for Open3D
v = np.asarray(point_cloud_data.vertices)
c = np.asarray(point_cloud_data.colors) / 255.0
if c.shape[1] == 4:
c = c[:, :3]
glbscene.export(file_obj=glbfile)
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(v)
pcd.colors = o3d.utility.Vector3dVector(c)
pcd = pcd.voxel_down_sample(voxel_size=0.01)
o3d.io.write_point_cloud(os.path.join(target_dir, "point_cloud.ply"), pcd)
# Cleanup
del predictions
gc.collect()
torch.cuda.empty_cache()
end_time = time.time()
print(f"Total time: {end_time - start_time:.2f} seconds (including IO)")
log_msg = f"Reconstruction Success ({len(all_files)} frames). Waiting for visualization."
# Run CropFormer mask prediction via Python API (no system call)
try:
from exts.cropformer_runner import run_cropformer_mask_predict
except ImportError:
from .exts.cropformer_runner import run_cropformer_mask_predict # if used as a module
run_cropformer_mask_predict(
config_file=os.path.join(MK_PATH, "third_party/detectron2/projects/CropFormer/configs/entityv2/entity_segmentation/mask2former_hornet_3x.yaml"),
root="/home/jovyan/users/bulat/workspace/3drec/vggt/temp/input/",
image_path_pattern="images/*.jpg",
dataset="arkit_gt",
seq_name_list=os.path.basename(target_dir),
confidence_threshold=0.5,
opts=[
"MODEL.WEIGHTS",
os.path.join(MK_PATH, "Mask2Former_hornet_3x_576d0b.pth"),
],
)
os.system(f"python /home/jovyan/users/bulat/workspace/3drec/Indoor/MaskClustering/main.py --config wild --root /home/jovyan/users/bulat/workspace/3drec/vggt/temp/input --seq_name_list {os.path.basename(target_dir)}")
main_ov_features(clip_model, clip_preprocess, os.path.basename(target_dir), "/home/jovyan/users/bulat/workspace/3drec/vggt/temp/input")
return glbfile, log_msg, gr.Dropdown(choices=frame_filter_choices, value=frame_filter, interactive=True)
def visualize_detections(target_dir, conf_thres, frame_filter="All", mask_black_bg=False, mask_white_bg=False, show_cam=True, mask_sky=False, prediction_mode="Depthmap and Camera Branch"):
"""
Generate a GLB scene with bounding boxes for detected objects.
"""
if not target_dir or not os.path.exists(target_dir):
return None, "Target directory not found."
ply_path = os.path.join(target_dir, "point_cloud.ply")
npz_path = os.path.join(target_dir, "output", "object", "prediction.npz")
# 1. Загрузить point cloud как основу сцены
if not os.path.exists(ply_path):
return None, f"Point cloud not found at {ply_path}. Please run detection first."
pcd = o3d.io.read_point_cloud(ply_path)
points = np.asarray(pcd.points)
colors = np.asarray(pcd.colors)
if points.size == 0:
return None, "Point cloud is empty."
# Создаем базовую сцену из облака точек
scene = trimesh.Scene()
if colors.size == 0:
t_colors = np.ones((len(points), 4), dtype=np.uint8) * 255
else:
if colors.max() <= 1.0:
t_colors = (colors * 255).astype(np.uint8)
else:
t_colors = colors.astype(np.uint8)
if t_colors.shape[1] == 3:
t_colors = np.hstack([t_colors, np.ones((len(t_colors), 1), dtype=np.uint8) * 255])
base_pc = trimesh.PointCloud(vertices=points, colors=t_colors)
scene.add_geometry(base_pc)
# 2. Добавить боксы по результатам детекции, если они есть
legend_md = ""
if os.path.exists(npz_path):
try:
loaded = np.load(npz_path, allow_pickle=True)
# Check for detection keys
if 'pred_masks' in loaded:
masks = loaded['pred_masks'].T
labels = loaded['pred_classes']
confs = loaded['pred_score']
# Load text features to map labels to names
text_features_path = os.path.join(target_dir, "text_features.npy")
label_to_name = {}
if os.path.exists(text_features_path):
try:
text_features_dict = np.load(text_features_path, allow_pickle=True).item()
feature_keys = list(text_features_dict.keys())
for i, name in enumerate(feature_keys):
label_to_name[i] = name
except Exception as e:
print(f"Warning: Could not load text features for label mapping: {e}")
# Filter
if isinstance(confs, (list, tuple)):
confs = np.array(confs)
valid_indices = np.where(confs > conf_thres)[0]
if len(valid_indices) > 0:
legend_items = {}
cmap = plt.get_cmap("tab10")
detected_labels = np.unique(labels[valid_indices])
label_to_color = {label: cmap(i % 10) for i, label in enumerate(detected_labels)}
for idx in valid_indices:
mask = masks[idx]
if hasattr(mask, "toarray"):
mask = mask.toarray().flatten()
mask = mask.astype(bool)
# Verify mask size
if len(mask) != len(points):
# This is critical. If GLB points are filtered, masks might not match.
# If masks were generated on the FULL point cloud, we need the FULL point cloud to compute BBox.
# If we can't guarantee alignment, we skip or print warning.
# Ideally, detection pipeline should handle this alignment.
pass
# For now, let's assume they align or we skip.
# If alignment fails, we just don't add the box.
if len(mask) == len(points):
obj_points = points[mask]
if len(obj_points) >= 4:
obj_pcd = trimesh.PointCloud(obj_points)
try:
bbox = obj_pcd.bounding_box_oriented
except Exception:
bbox = obj_pcd.bounding_box
# Строим только «каркас» бокса по 8 вершинам и трансформу:
# соединяем пары вершин, чьи локальные знаки отличаются ровно по одной оси
verts = np.asarray(bbox.vertices)
if verts.shape[0] != 8:
continue
T = np.asarray(bbox.transform)
center = T[:3, 3]
R = T[:3, :3]
# Локальные координаты (в осях бокса)
local = (verts - center) @ R
# Присваиваем каждой вершине тройку знаков (+/-1)
signs = np.where(local >= 0.0, 1, -1).astype(int)
sign_to_idx = {tuple(s): i for i, s in enumerate(signs)}
# Сгенерировать 12 рёбер: пары вершин, различающиеся знаком ровно по одной оси
edges_idx = set()
for sx in (-1, 1):
for sy in (-1, 1):
for sz in (-1, 1):
s = (sx, sy, sz)
if s not in sign_to_idx:
continue
for axis in range(3):
s2 = list(s)
s2[axis] *= -1
s2 = tuple(s2)
if s2 in sign_to_idx:
i0 = sign_to_idx[s]
i1 = sign_to_idx[s2]
if i0 != i1:
edges_idx.add(tuple(sorted((i0, i1))))
if not edges_idx:
continue
segments = np.array([[verts[i], verts[j]] for (i, j) in edges_idx], dtype=float)
lbl_idx = labels[idx]
lbl_name = label_to_name.get(lbl_idx, f"Class {lbl_idx}")
color = label_to_color.get(lbl_idx, (1, 0, 0, 1))
color_u8 = (np.array(color) * 255).astype(np.uint8)
# Постоянная толщина рамки: 3 см (0.03)
radius = 0.015
for seg in segments:
p1, p2 = seg[0], seg[1]
v = p2 - p1
length = float(np.linalg.norm(v))
if length <= 1e-8:
continue
direction = v / length
try:
cyl = trimesh.creation.cylinder(radius=radius, height=length, sections=12)
except Exception:
continue
# Повернуть ось Z к направлению ребра и перенести в середину
try:
align = trimesh.geometry.align_vectors([0, 0, 1], direction)
cyl.apply_transform(align)
except Exception:
pass
midpoint = (p1 + p2) / 2.0
cyl.apply_translation(midpoint)
# Материал без влияния освещения (эмуляция unlit через emissive)
try:
emissive = (color_u8[:3] / 255.0).tolist()
mat = trimesh.visual.material.PBRMaterial(
baseColorFactor=(0.0, 0.0, 0.0, 1.0),
metallicFactor=0.0,
roughnessFactor=1.0,
emissiveFactor=emissive,
doubleSided=True,
)
cyl.visual.material = mat
except Exception:
cyl.visual.face_colors = np.tile(color_u8[None, :], (len(cyl.faces), 1))
scene.add_geometry(cyl)
legend_items[lbl_name] = color
legend_md = "### Legend\n"
for lbl_name, color in legend_items.items():
c_u8 = (np.array(color) * 255).astype(np.uint8)
hex_c = "#{:02x}{:02x}{:02x}".format(c_u8[0], c_u8[1], c_u8[2])
legend_md += f"- <span style='color:{hex_c}'>■</span> {lbl_name}\n"
except Exception as e:
print(f"Error loading detections: {e}")
legend_md = f"Error loading detections: {e}"
# Export combined scene (облако + боксы)
out_path = os.path.join(target_dir, f"combined_viz_{conf_thres}.glb")
scene.export(file_obj=out_path)
return out_path, legend_md
def detect_objects(text_labels, target_dir, conf_thres, *viz_args):
"""
Detect objects from text labels and return the detected objects.
"""
# Require non-empty text labels
if not text_labels or not isinstance(text_labels, str) or len([l.strip() for l in text_labels.split(";") if l.strip()]) == 0:
return None, "Please enter at least one text label (separated by ';')."
# 1. Run reconstruction first if needed (checking if predictions exist)
predictions_path = os.path.join(target_dir, "predictions.npz")
if not os.path.exists(predictions_path):
# We need to run reconstruction. But reconstruction needs inputs we might not have in this function scope easily
# unless we pass them or assume they are in target_dir.
# reconstruct function takes target_dir. Let's call it.
# However, reconstruct is heavy and takes many args.
# Let's assume for now user clicked Reconstruct or we call it with defaults/passed args if we merged them.
# Actually, if we want one button to do both, we should probably call `reconstruct` logic here.
# But `reconstruct` returns GLB path.
# Let's call run_model directly if predictions don't exist?
# Better: Reuse reconstruct function logic or call it.
# Simplify: If predictions don't exist, run standard reconstruction first
print("Predictions not found, running reconstruction first...")
# We need arguments for reconstruction.
# viz_args contains [frame_filter, mask_black_bg, mask_white_bg, show_cam, mask_sky, prediction_mode]
# conf_thres is passed separately.
# reconstruct signature: target_dir, conf_thres, frame_filter, mask_black_bg, mask_white_bg, show_cam, mask_sky, prediction_mode, text_labels
# viz_args order from click: frame_filter, mask_black_bg, mask_white_bg, show_cam, mask_sky, prediction_mode
reconstruct(target_dir, 50.0, *viz_args, text_labels=text_labels) # conf_thres 3.0 default for reconstruction points
# Extract text features if provided
if text_labels:
labels = [l.strip() for l in text_labels.split(";") if l.strip()]
if labels:
print(f"Extracting features for labels: {labels}")
text_features = extract_text_feature(labels, clip_model, target_dir)
print(f"Text features: {text_features}")
os.system(f"PYTHONPATH=/home/jovyan/users/bulat/workspace/3drec/Indoor/MaskClustering python /home/jovyan/users/bulat/workspace/3drec/Indoor/MaskClustering/semantics/wopen-voc_query.py --config wild\
--root /home/jovyan/users/bulat/workspace/3drec/vggt/temp/input --seq_name {os.path.basename(target_dir)}")
return visualize_detections(target_dir, conf_thres, *viz_args)
# -------------------------------------------------------------------------
# 5) Helper functions for UI resets + re-visualization
# -------------------------------------------------------------------------
def clear_fields():
"""
Clears the 3D viewer, the stored target_dir, and empties the gallery.
"""
return None
def update_log():
"""
Display a quick log message while waiting.
"""
return "Loading and Reconstructing..."
def update_visualization(
target_dir, conf_thres, frame_filter, mask_black_bg, mask_white_bg, show_cam, mask_sky, prediction_mode, is_example
):
"""
Reload saved predictions from npz, create (or reuse) the GLB for new parameters,
and return it for the 3D viewer. If is_example == "True", skip.
"""
# If it's an example click, skip as requested
if is_example == "True":
return None, "No reconstruction available. Please click the Reconstruct button first."
if not target_dir or target_dir == "None" or not os.path.isdir(target_dir):
return None, "No reconstruction available. Please click the Reconstruct button first."
predictions_path = os.path.join(target_dir, "predictions.npz")
if not os.path.exists(predictions_path):
return None, f"No reconstruction available at {predictions_path}. Please run 'Reconstruct' first."
key_list = [
"pose_enc",
"depth",
"depth_conf",
"world_points",
"world_points_conf",
"images",
"extrinsic",
"intrinsic",
"world_points_from_depth",
]
loaded = np.load(predictions_path, allow_pickle=True)
predictions = {key: np.array(loaded[key]) for key in key_list}
glbfile = os.path.join(
target_dir,
f"glbscene_{conf_thres}_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_maskb{mask_black_bg}_maskw{mask_white_bg}_cam{show_cam}_sky{mask_sky}_pred{prediction_mode.replace(' ', '_')}.glb",
)
if not os.path.exists(glbfile):
glbscene = predictions_to_glb(
predictions,
conf_thres=conf_thres,
filter_by_frames=frame_filter,
mask_black_bg=mask_black_bg,
mask_white_bg=mask_white_bg,
show_cam=show_cam,
mask_sky=mask_sky,
target_dir=target_dir,
prediction_mode=prediction_mode,
)
glbscene.export(file_obj=glbfile)
return glbfile, "Updating Visualization"
# -------------------------------------------------------------------------
# Example images
# -------------------------------------------------------------------------
great_wall_video = "examples/videos/great_wall.mp4"
colosseum_video = "examples/videos/Colosseum.mp4"
room_video = "examples/videos/room.mp4"
kitchen_video = "examples/videos/kitchen.mp4"
fern_video = "examples/videos/fern.mp4"
single_cartoon_video = "examples/videos/single_cartoon.mp4"
single_oil_painting_video = "examples/videos/single_oil_painting.mp4"
pyramid_video = "examples/videos/pyramid.mp4"
# -------------------------------------------------------------------------
# 6) Build Gradio UI
# -------------------------------------------------------------------------
theme = gr.themes.Ocean()
theme.set(
checkbox_label_background_fill_selected="*button_primary_background_fill",
checkbox_label_text_color_selected="*button_primary_text_color",
)
with gr.Blocks(
theme=theme,
css="""
.custom-log * {
font-style: italic;
font-size: 22px !important;
background-image: linear-gradient(120deg, #0ea5e9 0%, #6ee7b7 60%, #34d399 100%);
-webkit-background-clip: text;
background-clip: text;
font-weight: bold !important;
color: transparent !important;
text-align: center !important;
}
.example-log * {
font-style: italic;
font-size: 16px !important;
background-image: linear-gradient(120deg, #0ea5e9 0%, #6ee7b7 60%, #34d399 100%);
-webkit-background-clip: text;
background-clip: text;
color: transparent !important;
}
#my_radio .wrap {
display: flex;
flex-wrap: nowrap;
justify-content: center;
align-items: center;
}
#my_radio .wrap label {
display: flex;
width: 50%;
justify-content: center;
align-items: center;
margin: 0;
padding: 10px 0;
box-sizing: border-box;
}
""",
) as demo:
# Instead of gr.State, we use a hidden Textbox:
is_example = gr.Textbox(label="is_example", visible=False, value="None")
num_images = gr.Textbox(label="num_images", visible=False, value="None")
gr.HTML(
"""
<h1>🦁 Zoo3D: Zero-Shot 3D Object Detection at Scene Level 🐼</h1>
<p>
<a href="https://github.com/col14m/zoo3d">GitHub Repository</a>
</p>
<div style="font-size: 16px; line-height: 1.5;">
<p>Upload a video or a set of images to create a 3D reconstruction and run open‑vocabulary 3D object detection from your text labels. The app builds a point cloud and draws colored wireframe bounding boxes for the detected objects.</p>
<h3>Getting Started:</h3>
<ol>
<li><strong>Upload Your Data:</strong> Use "Upload Video" or "Upload Images". Videos are sampled at 1 frame/sec.</li>
<li><strong>Enter Text Labels (Required):</strong> Provide one or more labels separated by semicolons, e.g. <code>chair; table; plant</code>.</li>
<li><strong>Detect:</strong> Click <strong>"Detect Objects"</strong>. The app will reconstruct the scene (if needed) and then run detection.</li>
<li><strong>Threshold (Optional):</strong> Tune the <em>Detection Cosine Similarity Threshold</em> (0–1). Higher = fewer, more confident detections.</li>
<li><strong>Visualize & Download:</strong> A single 3D view shows the point cloud and colored wireframe boxes. A legend maps colors to labels. You can download the GLB.</li>
</ol>
<p><strong style="color: #0ea5e9;">Notes:</strong> <span style="color: #0ea5e9; font-weight: bold;">Reconstruction is triggered automatically on first run. If no labels are provided, you'll see an error: </span><code>Please enter at least one text label (separated by ';').</code></p>
</div>
"""
)
target_dir_output = gr.Textbox(label="Target Dir", visible=False, value="None")
with gr.Row():
with gr.Column(scale=2):
input_video = gr.Video(label="Upload Video", interactive=True)
input_images = gr.File(file_count="multiple", label="Upload Images", interactive=True)
image_gallery = gr.Gallery(
label="Preview",
columns=4,
height="300px",
show_download_button=True,
object_fit="contain",
preview=True,
)
with gr.Column(scale=4):
text_labels = gr.Textbox(label="Text Labels (separated by ;)", placeholder="cat; dog; car")
with gr.Column():
gr.Markdown("**3D Reconstruction (Point Cloud and Camera Poses)**")
log_output = gr.Markdown(
"Please upload a video or images, then click Detect Objects.", elem_classes=["custom-log"]
)
reconstruction_output = gr.Model3D(height=520, zoom_speed=0.5, pan_speed=0.5)
with gr.Row():
detect_btn = gr.Button("Detect Objects", scale=1, variant="primary")
clear_btn = gr.ClearButton(
[input_video, input_images, reconstruction_output, log_output, target_dir_output, image_gallery, text_labels],
scale=1,
)
# with gr.Row():
# prediction_mode = gr.Textbox(
# value="Depthmap and Camera Branch",
# visible=False,
# label="Prediction Mode"
# )
# We'll create a hidden component so the event handlers don't break
prediction_mode = gr.Textbox(value="Depthmap and Camera Branch", visible=False)
# Основные параметры визуализации реконструкции
with gr.Row():
conf_thres = gr.Slider(
minimum=0,
maximum=100,
value=50,
step=0.1,
label="Confidence Threshold (%)",
visible=False,
)
frame_filter = gr.Dropdown(
choices=["All"],
value="All",
label="Show Points from Frame",
visible=False,
)
with gr.Column():
show_cam = gr.Checkbox(label="Show Camera", value=True, visible=False)
mask_sky = gr.Checkbox(label="Filter Sky", value=False, visible=False)
mask_black_bg = gr.Checkbox(label="Filter Black Background", value=False, visible=False)
mask_white_bg = gr.Checkbox(label="Filter White Background", value=False, visible=False)
# Порог для детекции и легенда цветов боксов
detection_conf_thres = gr.Slider(
minimum=0,
maximum=1,
value=0.6,
step=0.01,
label="Detection Cosine Similarity Threshold",
)
detection_legend = gr.Markdown("Legend will appear here")
# ---------------------- Examples section ----------------------
examples = [
]
def example_pipeline(
input_video,
num_images_str,
input_images,
conf_thres,
mask_black_bg,
mask_white_bg,
show_cam,
mask_sky,
prediction_mode,
is_example_str,
text_labels,
):
"""
1) Copy example images to new target_dir
2) Reconstruct (and Detect if labels present)
3) Return model3D + logs + new_dir + updated dropdown + gallery
We do NOT return is_example. It's just an input.
"""
target_dir, image_paths = handle_uploads(input_video, input_images)
# Always use "All" for frame_filter in examples
frame_filter = "All"
# We use detect_objects logic here to handle both reconstruction and detection if needed.
# But detect_objects signature is (text_labels, target_dir, conf_thres, *viz_args)
# where viz_args are frame_filter, mask_black_bg, mask_white_bg, show_cam, mask_sky, prediction_mode
# BUT wait, detect_objects calls reconstruct ONLY if predictions don't exist.
# Here we just uploaded new files, so predictions definitely don't exist.
# So we can call detect_objects directly.
# Note: detect_objects uses detection_conf_thres (default 0.85).
# But here we only have conf_thres input from examples (which is for reconstruction point cloud filtering).
# We should probably use a default for detection conf thres or add it to examples.
# Let's use a hardcoded default for detection in examples for now, e.g. 0.5 or 0.85
detection_conf = 0.85
glbfile, legend_md = detect_objects(
text_labels,
target_dir,
detection_conf,
frame_filter,
mask_black_bg,
mask_white_bg,
show_cam,
mask_sky,
prediction_mode
)
log_msg = "Example loaded and processed."
return glbfile, log_msg + "\n\n" + legend_md, target_dir, gr.Dropdown(choices=["All"], value="All", interactive=True), image_paths
detect_btn.click(fn=clear_fields, inputs=[], outputs=[]).then(
fn=detect_objects,
inputs=[
text_labels,
target_dir_output,
detection_conf_thres,
frame_filter,
mask_black_bg,
mask_white_bg,
show_cam,
mask_sky,
prediction_mode
],
outputs=[reconstruction_output, detection_legend]
).then(
fn=lambda: "False", inputs=[], outputs=[is_example] # set is_example to "False"
)
detection_conf_thres.change(
fn=visualize_detections,
inputs=[
target_dir_output,
detection_conf_thres,
frame_filter,
mask_black_bg,
mask_white_bg,
show_cam,
mask_sky,
prediction_mode
],
outputs=[reconstruction_output, detection_legend]
)
# -------------------------------------------------------------------------
# Real-time Visualization Updates
# -------------------------------------------------------------------------
conf_thres.change(
update_visualization,
[
target_dir_output,
conf_thres,
frame_filter,
mask_black_bg,
mask_white_bg,
show_cam,
mask_sky,
prediction_mode,
is_example,
],
[reconstruction_output, log_output],
)
frame_filter.change(
update_visualization,
[
target_dir_output,
conf_thres,
frame_filter,
mask_black_bg,
mask_white_bg,
show_cam,
mask_sky,
prediction_mode,
is_example,
],
[reconstruction_output, log_output],
)
mask_black_bg.change(
update_visualization,
[
target_dir_output,
conf_thres,
frame_filter,
mask_black_bg,
mask_white_bg,
show_cam,
mask_sky,
prediction_mode,
is_example,
],
[reconstruction_output, log_output],
)
mask_white_bg.change(
update_visualization,
[
target_dir_output,
conf_thres,
frame_filter,
mask_black_bg,
mask_white_bg,
show_cam,
mask_sky,
prediction_mode,
is_example,
],
[reconstruction_output, log_output],
)
show_cam.change(
update_visualization,
[
target_dir_output,
conf_thres,
frame_filter,
mask_black_bg,
mask_white_bg,
show_cam,
mask_sky,
prediction_mode,
is_example,
],
[reconstruction_output, log_output],
)
prediction_mode.change(
update_visualization,
[
target_dir_output,
conf_thres,
frame_filter,
mask_black_bg,
mask_white_bg,
show_cam,
mask_sky,
prediction_mode,
is_example,
],
[reconstruction_output, log_output],
)
# # -------------------------------------------------------------------------
# # Auto-update gallery whenever user uploads or changes their files
# # -------------------------------------------------------------------------
input_video.change(
fn=update_gallery_on_upload,
inputs=[input_video, input_images],
outputs=[reconstruction_output, target_dir_output, image_gallery, log_output],
)
input_images.change(
fn=update_gallery_on_upload,
inputs=[input_video, input_images],
outputs=[reconstruction_output, target_dir_output, image_gallery, log_output],
)
demo.queue(max_size=20).launch(show_error=True, share=True)