|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
import cv2 |
|
|
import torch |
|
|
import numpy as np |
|
|
import gradio as gr |
|
|
import sys |
|
|
import shutil |
|
|
from datetime import datetime |
|
|
import glob |
|
|
import gc |
|
|
import time |
|
|
import open3d as o3d |
|
|
import open_clip |
|
|
from open_clip import tokenizer |
|
|
import trimesh |
|
|
import matplotlib.pyplot as plt |
|
|
import subprocess |
|
|
import tempfile |
|
|
import contextlib |
|
|
from huggingface_hub import hf_hub_download |
|
|
|
|
|
try: |
|
|
import gdown |
|
|
except Exception: |
|
|
gdown = None |
|
|
|
|
|
|
|
|
try: |
|
|
import gradio_client.utils as _gcu |
|
|
|
|
|
if hasattr(_gcu, "_json_schema_to_python_type"): |
|
|
_orig = _gcu._json_schema_to_python_type |
|
|
|
|
|
def _json_schema_to_python_type_patched(schema, defs=None): |
|
|
if isinstance(schema, bool): |
|
|
return "Any" |
|
|
return _orig(schema, defs) |
|
|
|
|
|
_gcu._json_schema_to_python_type = _json_schema_to_python_type_patched |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
os.environ.setdefault("MAX_JOBS", "1") |
|
|
REPO_ROOT = os.path.dirname(os.path.abspath(__file__)) |
|
|
sys.path.append(os.path.join(REPO_ROOT, "vggt")) |
|
|
MK_PATH = os.path.join(REPO_ROOT, "MaskClustering") |
|
|
DETECTRON2_ROOT = os.path.join(REPO_ROOT, "MaskClustering", "third_party", "detectron2") |
|
|
os.environ["PYTHONPATH"] = os.environ.get("PYTHONPATH", "") + ":" + os.path.join(REPO_ROOT, "MaskClustering", "third_party") |
|
|
|
|
|
|
|
|
try: |
|
|
import detectron2 |
|
|
except Exception: |
|
|
print("[runtime] detectron2 not found. Installing local detectron2 (editable, no build isolation)...") |
|
|
os.system("python -m pip install --no-build-isolation -e ./MaskClustering/third_party/detectron2") |
|
|
import importlib |
|
|
importlib.invalidate_caches() |
|
|
import detectron2 |
|
|
|
|
|
|
|
|
if os.path.isdir(DETECTRON2_ROOT) and DETECTRON2_ROOT not in sys.path: |
|
|
sys.path.append(DETECTRON2_ROOT) |
|
|
|
|
|
|
|
|
WORK_DIR = os.environ.get("ZOO3D_WORKDIR", os.path.join(tempfile.gettempdir(), "zoo3d")) |
|
|
os.makedirs(WORK_DIR, exist_ok=True) |
|
|
from visual_util import predictions_to_glb |
|
|
from vggt.models.vggt import VGGT |
|
|
from vggt.utils.load_fn import load_and_preprocess_images |
|
|
from vggt.utils.pose_enc import pose_encoding_to_extri_intri |
|
|
from vggt.utils.geometry import unproject_depth_map_to_point_map |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else 'cpu' |
|
|
|
|
|
print(f"Using device: {device}") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ZOO3D_ALLOW_CPU = os.environ.get("ZOO3D_ALLOW_CPU", "1") == "1" |
|
|
ZOO3D_CPU_DUMMY = os.environ.get("ZOO3D_CPU_DUMMY", "1") == "1" |
|
|
ZOO3D_SKIP_DOWNLOADS = os.environ.get("ZOO3D_SKIP_DOWNLOADS", "0") == "1" |
|
|
|
|
|
|
|
|
_VGGT_MODEL = None |
|
|
_METRIC3D_MODEL = None |
|
|
_CLIP_MODEL = None |
|
|
|
|
|
|
|
|
_MASK2FORMER_GDRIVE_FILE_ID = "10G7s6bVMwN__bcrR2fBal3goo69Y5Do4" |
|
|
|
|
|
|
|
|
def _ensure_mask2former_weights(dst_path: str) -> str: |
|
|
""" |
|
|
Ensure Mask2Former/CropFormer weights exist at dst_path. |
|
|
Priority: |
|
|
1) Use existing file (if present) |
|
|
2) Download from Google Drive (user-provided link / file id) |
|
|
3) Fallback: download from HF dataset (qqlu1992/Adobe_EntitySeg) |
|
|
""" |
|
|
if os.path.exists(dst_path) and os.path.getsize(dst_path) > 0: |
|
|
return dst_path |
|
|
|
|
|
os.makedirs(os.path.dirname(dst_path), exist_ok=True) |
|
|
|
|
|
|
|
|
override_path = os.environ.get("MASK2FORMER_WEIGHTS_PATH") |
|
|
if override_path and os.path.exists(override_path) and os.path.getsize(override_path) > 0: |
|
|
shutil.copyfile(override_path, dst_path) |
|
|
return dst_path |
|
|
|
|
|
|
|
|
if gdown is not None: |
|
|
url = f"https://drive.google.com/uc?id={_MASK2FORMER_GDRIVE_FILE_ID}" |
|
|
out = gdown.download(url, dst_path, quiet=False) |
|
|
if out is not None and os.path.exists(dst_path) and os.path.getsize(dst_path) > 0: |
|
|
return dst_path |
|
|
print("Warning: gdown download failed for Mask2Former weights; falling back to HF dataset...") |
|
|
else: |
|
|
print("Warning: gdown is not available; falling back to HF dataset for Mask2Former weights...") |
|
|
|
|
|
|
|
|
cached = hf_hub_download( |
|
|
repo_id="qqlu1992/Adobe_EntitySeg", |
|
|
repo_type="dataset", |
|
|
filename="CropFormer_model/Entity_Segmentation/Mask2Former_hornet_3x/Mask2Former_hornet_3x_576d0b.pth", |
|
|
) |
|
|
shutil.copyfile(cached, dst_path) |
|
|
return dst_path |
|
|
|
|
|
|
|
|
def _init_models(): |
|
|
""" |
|
|
Lazy-load heavy models so the UI can start quickly on HF Spaces. |
|
|
""" |
|
|
global _VGGT_MODEL, _METRIC3D_MODEL, _CLIP_MODEL |
|
|
|
|
|
if not torch.cuda.is_available(): |
|
|
|
|
|
if not ZOO3D_ALLOW_CPU: |
|
|
raise RuntimeError("CUDA недоступна. Для этого Space нужен GPU (CUDA).") |
|
|
|
|
|
if _CLIP_MODEL is None: |
|
|
print("[INFO] loading CLIP model (CPU)...") |
|
|
cm, _, _ = open_clip.create_model_and_transforms("ViT-H-14", pretrained="laion2b_s32b_b79k") |
|
|
cm.to("cpu") |
|
|
cm.eval() |
|
|
print("[INFO] finish loading CLIP model (CPU)...") |
|
|
globals()["_CLIP_MODEL"] = cm |
|
|
return None, None, _CLIP_MODEL |
|
|
|
|
|
if _VGGT_MODEL is None: |
|
|
print("Initializing and loading VGGT model...") |
|
|
|
|
|
try: |
|
|
m = VGGT.from_pretrained("facebook/VGGT-1B") |
|
|
except Exception: |
|
|
m = VGGT() |
|
|
_URL = "https://huggingface.co/facebook/VGGT-1B/resolve/main/model.pt" |
|
|
m.load_state_dict(torch.hub.load_state_dict_from_url(_URL)) |
|
|
m.eval() |
|
|
_VGGT_MODEL = m.to(device) |
|
|
|
|
|
if _METRIC3D_MODEL is None: |
|
|
print("Initializing and loading Metric3D model...") |
|
|
try: |
|
|
mm = torch.hub.load("yvanyin/metric3d", "metric3d_vit_small", pretrain=True, trust_repo=True) |
|
|
except TypeError: |
|
|
mm = torch.hub.load("yvanyin/metric3d", "metric3d_vit_small", pretrain=True) |
|
|
mm.to(device) |
|
|
mm.eval() |
|
|
_METRIC3D_MODEL = mm |
|
|
|
|
|
if _CLIP_MODEL is None: |
|
|
print("[INFO] loading CLIP model...") |
|
|
cm, _, _ = open_clip.create_model_and_transforms("ViT-H-14", pretrained="laion2b_s32b_b79k") |
|
|
cm.to(device) |
|
|
cm.eval() |
|
|
print("[INFO] finish loading CLIP model...") |
|
|
_CLIP_MODEL = cm |
|
|
|
|
|
return _VGGT_MODEL, _METRIC3D_MODEL, _CLIP_MODEL |
|
|
|
|
|
cropformer_name = "Mask2Former_hornet_3x_576d0b.pth" |
|
|
|
|
|
def check_weights(): |
|
|
if ZOO3D_SKIP_DOWNLOADS: |
|
|
print("[INFO] ZOO3D_SKIP_DOWNLOADS=1: skipping Mask2Former weights download.") |
|
|
return |
|
|
if not os.path.exists(os.path.join(MK_PATH, cropformer_name)): |
|
|
print(f"Downloading {cropformer_name}...") |
|
|
os.makedirs(MK_PATH, exist_ok=True) |
|
|
dst = os.path.join(MK_PATH, cropformer_name) |
|
|
_ensure_mask2former_weights(dst) |
|
|
print(f"Downloaded {cropformer_name}...") |
|
|
else: |
|
|
print(f"{cropformer_name} already exists...") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def extract_text_feature(descriptions, clip_model, target_path): |
|
|
text_tokens = tokenizer.tokenize(descriptions).to(device) |
|
|
with torch.no_grad(): |
|
|
text_features = clip_model.encode_text(text_tokens).float() |
|
|
text_features /= text_features.norm(dim=-1, keepdim=True) |
|
|
text_features = text_features.cpu().numpy() |
|
|
|
|
|
text_features_dict = {} |
|
|
for i, description in enumerate(descriptions): |
|
|
text_features_dict[description] = text_features[i] |
|
|
|
|
|
np.save(os.path.join(target_path, "text_features.npy"), text_features_dict) |
|
|
return text_features_dict |
|
|
|
|
|
|
|
|
clip_model = None |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def run_model(target_dir, model, metric3d_model=None) -> dict: |
|
|
""" |
|
|
Run the VGGT model on images in the 'target_dir/images' folder and return predictions. |
|
|
""" |
|
|
print(f"Processing images from {target_dir}") |
|
|
|
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
if device != "cuda": |
|
|
if not ZOO3D_ALLOW_CPU: |
|
|
raise RuntimeError("CUDA недоступна. Для этого Space нужен GPU (CUDA).") |
|
|
if not ZOO3D_CPU_DUMMY: |
|
|
raise RuntimeError( |
|
|
"CPU режим включен, но ZOO3D_CPU_DUMMY=0. " |
|
|
"Для отладки поставь ZOO3D_CPU_DUMMY=1 или включи GPU." |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
image_names = glob.glob(os.path.join(target_dir, "images", "*")) |
|
|
image_names = sorted(image_names) |
|
|
print(f"Found {len(image_names)} images") |
|
|
if len(image_names) == 0: |
|
|
raise ValueError("No images found. Check your upload.") |
|
|
|
|
|
|
|
|
cpu_images_u8 = None |
|
|
if device == "cpu": |
|
|
imgs = [] |
|
|
for p in image_names: |
|
|
im = cv2.imread(p, cv2.IMREAD_COLOR) |
|
|
if im is None: |
|
|
continue |
|
|
im = cv2.cvtColor(im, cv2.COLOR_BGR2RGB) |
|
|
imgs.append(im) |
|
|
if len(imgs) == 0: |
|
|
raise ValueError("No readable images found. Check your upload.") |
|
|
|
|
|
H, W = imgs[0].shape[:2] |
|
|
imgs2 = [] |
|
|
for im in imgs: |
|
|
if im.shape[:2] != (H, W): |
|
|
im = cv2.resize(im, (W, H)) |
|
|
imgs2.append(im) |
|
|
cpu_images_u8 = np.stack(imgs2, axis=0) |
|
|
print(f"CPU dummy: loaded images shape: {cpu_images_u8.shape}") |
|
|
|
|
|
images = load_and_preprocess_images(image_names) |
|
|
print(f"Preprocessed images shape: {tuple(images.shape)}") |
|
|
if device == "cuda": |
|
|
images = images.to(device) |
|
|
|
|
|
if device == "cpu": |
|
|
|
|
|
S, H, W = cpu_images_u8.shape[0], cpu_images_u8.shape[1], cpu_images_u8.shape[2] |
|
|
|
|
|
uu, vv = np.meshgrid(np.arange(W), np.arange(H)) |
|
|
x = (uu - (W / 2.0)) / float(max(W, 1)) |
|
|
y = -(vv - (H / 2.0)) / float(max(W, 1)) |
|
|
z = np.ones_like(x, dtype=np.float32) * 1.0 |
|
|
pts = np.stack([x, y, z], axis=-1).astype(np.float32) |
|
|
world_points_from_depth = np.repeat(pts[None, ...], S, axis=0) |
|
|
depth = np.ones((S, H, W, 1), dtype=np.float32) |
|
|
depth_conf = np.ones((S, H, W), dtype=np.float32) |
|
|
extrinsic = np.tile(np.array([[1, 0, 0, 0], |
|
|
[0, 1, 0, 0], |
|
|
[0, 0, 1, 0]], dtype=np.float32)[None, ...], (S, 1, 1)) |
|
|
intrinsic = np.tile(np.eye(3, dtype=np.float32)[None, ...], (S, 1, 1)) |
|
|
pose = np.tile(np.eye(4, dtype=np.float32)[None, ...], (S, 1, 1)) |
|
|
return { |
|
|
"images": cpu_images_u8, |
|
|
"extrinsic": extrinsic, |
|
|
"intrinsic": intrinsic, |
|
|
"pose": pose, |
|
|
"depth": depth, |
|
|
"depth_conf": depth_conf, |
|
|
"world_points_from_depth": world_points_from_depth, |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
model = model.to(device) |
|
|
model.eval() |
|
|
|
|
|
print("Running inference...") |
|
|
dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16 |
|
|
amp_ctx = torch.cuda.amp.autocast(dtype=dtype) if device == "cuda" else contextlib.nullcontext() |
|
|
|
|
|
with torch.no_grad(): |
|
|
with amp_ctx: |
|
|
predictions = model(images) |
|
|
|
|
|
scale_factor = torch.tensor(1.0, device=device) |
|
|
|
|
|
|
|
|
if metric3d_model is not None: |
|
|
print("Running Metric3D inference...") |
|
|
|
|
|
|
|
|
metric3d_input = images * 255.0 |
|
|
|
|
|
m_depths = [] |
|
|
|
|
|
for i in range(metric3d_input.shape[0]): |
|
|
img = metric3d_input[i:i+1] |
|
|
|
|
|
|
|
|
_, _, h, w = img.shape |
|
|
ph = ((h - 1) // 32 + 1) * 32 |
|
|
pw = ((w - 1) // 32 + 1) * 32 |
|
|
|
|
|
padding = (0, pw - w, 0, ph - h) |
|
|
if ph != h or pw != w: |
|
|
img = torch.nn.functional.pad(img, padding, mode='constant', value=0) |
|
|
|
|
|
with torch.no_grad(): |
|
|
pred_depth, confidence, _ = metric3d_model.inference({'input': img}) |
|
|
|
|
|
|
|
|
if ph != h or pw != w: |
|
|
pred_depth = pred_depth[:, :, :h, :w] |
|
|
|
|
|
m_depths.append(pred_depth) |
|
|
|
|
|
predictions["metric3d_depth"] = torch.cat(m_depths, dim=0) |
|
|
|
|
|
|
|
|
|
|
|
vggt_depth = predictions["depth"][0] |
|
|
metric_depth = predictions["metric3d_depth"] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if metric_depth.dim() == 4 and metric_depth.shape[1] == 1: |
|
|
metric_depth = metric_depth.permute(0, 2, 3, 1) |
|
|
elif metric_depth.dim() == 3: |
|
|
metric_depth = metric_depth.unsqueeze(-1) |
|
|
|
|
|
|
|
|
vggt_depth = vggt_depth.to(metric_depth.device).float() |
|
|
metric_depth = metric_depth.float() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print(f"Metric3D depth shape: {metric_depth.shape}") |
|
|
print(f"VGGT depth shape: {vggt_depth.shape}") |
|
|
valid_mask = (metric_depth > 1e-6) & (vggt_depth > 1e-6) |
|
|
|
|
|
if valid_mask.sum() > 0: |
|
|
print(f"Valid mask shape: {valid_mask.shape}") |
|
|
print(f"Metric depth shape: {metric_depth.shape}") |
|
|
print(f"VGGT depth shape: {vggt_depth.shape}") |
|
|
ratio = metric_depth[valid_mask] / vggt_depth[valid_mask] |
|
|
scale_factor = torch.median(ratio) |
|
|
print(f"Computed scale factor (VGGT / Metric3D): {scale_factor.item():.4f}") |
|
|
else: |
|
|
print("Warning: could not compute scale factor; falling back to 1.0") |
|
|
print("Converting pose encoding to extrinsic and intrinsic matrices...") |
|
|
extrinsic, intrinsic = pose_encoding_to_extri_intri(predictions["pose_enc"], images.shape[-2:]) |
|
|
extrinsic = extrinsic[0] |
|
|
add = torch.zeros_like(extrinsic[:, 2:]) |
|
|
add[..., -1] = 1 |
|
|
extrinsic = torch.cat([extrinsic, add], dim=-2) |
|
|
zero_extrinsic = extrinsic[0] |
|
|
for i, e in enumerate(extrinsic): |
|
|
extrinsic[i] = zero_extrinsic @ torch.linalg.inv(e) |
|
|
extrinsic[i, :3, 3] *= scale_factor |
|
|
extrinsic_inv = torch.linalg.inv(extrinsic) |
|
|
print(f"Extrinsic: {extrinsic.shape}") |
|
|
extrinsic_inv = extrinsic_inv[None, ..., :3, :] |
|
|
predictions["extrinsic"] = extrinsic_inv |
|
|
predictions["pose"] = extrinsic[None] |
|
|
print(f"Extrinsic: {extrinsic.shape} {extrinsic}") |
|
|
predictions["intrinsic"] = intrinsic |
|
|
|
|
|
|
|
|
for key in predictions.keys(): |
|
|
if isinstance(predictions[key], torch.Tensor): |
|
|
try: |
|
|
predictions[key] = predictions[key].cpu().numpy().squeeze(0) |
|
|
except ValueError: |
|
|
pass |
|
|
|
|
|
|
|
|
print("Computing world points from depth map...") |
|
|
predictions["depth"] = predictions["depth"] * float(scale_factor.item()) |
|
|
depth_map = predictions["depth"] |
|
|
world_points = unproject_depth_map_to_point_map(depth_map, predictions["extrinsic"], predictions["intrinsic"]) |
|
|
predictions["world_points_from_depth"] = world_points |
|
|
|
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
return predictions |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def handle_uploads(input_video, input_images): |
|
|
""" |
|
|
Create a new 'target_dir' + 'images' subfolder, and place user-uploaded |
|
|
images or extracted frames from video into it. Return (target_dir, image_paths). |
|
|
""" |
|
|
start_time = time.time() |
|
|
gc.collect() |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
|
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f") |
|
|
target_dir = os.path.join(WORK_DIR, "input", timestamp) |
|
|
target_dir_images = os.path.join(target_dir, "images") |
|
|
|
|
|
|
|
|
if os.path.exists(target_dir): |
|
|
shutil.rmtree(target_dir) |
|
|
os.makedirs(target_dir) |
|
|
os.makedirs(target_dir_images) |
|
|
|
|
|
image_paths = [] |
|
|
|
|
|
|
|
|
if input_images is not None: |
|
|
for file_data in input_images: |
|
|
if isinstance(file_data, dict) and "name" in file_data: |
|
|
file_path = file_data["name"] |
|
|
else: |
|
|
file_path = file_data |
|
|
dst_path = os.path.join(target_dir_images, os.path.basename(file_path)) |
|
|
shutil.copy(file_path, dst_path) |
|
|
image_paths.append(dst_path) |
|
|
|
|
|
|
|
|
if input_video is not None: |
|
|
if isinstance(input_video, dict) and "name" in input_video: |
|
|
video_path = input_video["name"] |
|
|
else: |
|
|
video_path = input_video |
|
|
|
|
|
vs = cv2.VideoCapture(video_path) |
|
|
fps = vs.get(cv2.CAP_PROP_FPS) |
|
|
frame_interval = int(fps * 1) |
|
|
|
|
|
count = 0 |
|
|
video_frame_num = 0 |
|
|
while True: |
|
|
gotit, frame = vs.read() |
|
|
if not gotit: |
|
|
break |
|
|
count += 1 |
|
|
if count % frame_interval == 0: |
|
|
image_path = os.path.join(target_dir_images, f"{video_frame_num:06}.jpg") |
|
|
cv2.imwrite(image_path, frame) |
|
|
image_paths.append(image_path) |
|
|
video_frame_num += 1 |
|
|
|
|
|
|
|
|
image_paths = sorted(image_paths) |
|
|
|
|
|
end_time = time.time() |
|
|
print(f"Files copied to {target_dir_images}; took {end_time - start_time:.3f} seconds") |
|
|
return target_dir, image_paths |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def update_gallery_on_upload(input_video, input_images): |
|
|
""" |
|
|
Whenever user uploads or changes files, immediately handle them |
|
|
and show in the gallery. Return (target_dir, image_paths). |
|
|
If nothing is uploaded, returns "None" and empty list. |
|
|
""" |
|
|
if not input_video and not input_images: |
|
|
return None, None, None, None |
|
|
target_dir, image_paths = handle_uploads(input_video, input_images) |
|
|
return None, target_dir, image_paths, "Upload complete. Click 'Detect Objects' to begin 3D processing." |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def reconstruct( |
|
|
target_dir, |
|
|
conf_thres=50.0, |
|
|
frame_filter="All", |
|
|
mask_black_bg=False, |
|
|
mask_white_bg=False, |
|
|
show_cam=True, |
|
|
mask_sky=False, |
|
|
prediction_mode="Depthmap and Camera Branch", |
|
|
text_labels="", |
|
|
): |
|
|
""" |
|
|
Perform reconstruction using the already-created target_dir/images. |
|
|
""" |
|
|
prediction_mode = "Depthmap and Camera Branch" |
|
|
if not os.path.isdir(target_dir) or target_dir == "None": |
|
|
return None, "No valid target directory found. Please upload first.", None, None |
|
|
|
|
|
start_time = time.time() |
|
|
gc.collect() |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
|
target_dir_images = os.path.join(target_dir, "images") |
|
|
all_files = sorted(os.listdir(target_dir_images)) if os.path.isdir(target_dir_images) else [] |
|
|
image_names = [f.split(".")[0] for f in all_files] |
|
|
all_files = [f"{i}: {filename}" for i, filename in enumerate(all_files)] |
|
|
frame_filter_choices = ["All"] + all_files |
|
|
|
|
|
print("Running run_model...") |
|
|
with torch.no_grad(): |
|
|
|
|
|
try: |
|
|
check_weights() |
|
|
except Exception as e: |
|
|
print(f"Warning: could not ensure Mask2Former weights at startup: {e}") |
|
|
vggt_model, metric3d_model, _ = _init_models() |
|
|
predictions = run_model(target_dir, vggt_model, metric3d_model=metric3d_model) |
|
|
|
|
|
|
|
|
|
|
|
prediction_save_path = os.path.join(target_dir, "predictions.npz") |
|
|
try: |
|
|
np.savez(prediction_save_path, **predictions) |
|
|
except Exception as e: |
|
|
print(f"Warning: could not save predictions to npz: {e}") |
|
|
|
|
|
depth_path = os.path.join(target_dir, "depth") |
|
|
pose_path = os.path.join(target_dir, "pose") |
|
|
intrinsic_path = os.path.join(target_dir, "intrinsic") |
|
|
os.makedirs(depth_path, exist_ok=True) |
|
|
os.makedirs(pose_path, exist_ok=True) |
|
|
os.makedirs(intrinsic_path, exist_ok=True) |
|
|
for i, d in enumerate(predictions["depth"]): |
|
|
print(d.shape) |
|
|
cv2.imwrite(os.path.join(depth_path, f"{image_names[i]}.png"), (d[..., 0] * 1000).astype(np.uint16)) |
|
|
intr = np.eye(4) |
|
|
intr[:3, :3] = np.mean(predictions["intrinsic"], axis=0) |
|
|
np.savetxt(os.path.join(intrinsic_path, "intrinsic_depth.txt"), intr) |
|
|
|
|
|
for i, p in enumerate(predictions["pose"]): |
|
|
np.savetxt(os.path.join(pose_path, f"{image_names[i]}.txt"), p) |
|
|
|
|
|
|
|
|
if frame_filter is None: |
|
|
frame_filter = "All" |
|
|
|
|
|
|
|
|
glbfile = os.path.join( |
|
|
target_dir, |
|
|
f"glbscene_{conf_thres}_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_maskb{mask_black_bg}_maskw{mask_white_bg}_cam{show_cam}_sky{mask_sky}_pred{prediction_mode.replace(' ', '_')}.glb", |
|
|
) |
|
|
|
|
|
|
|
|
glbscene, point_cloud_data = predictions_to_glb( |
|
|
predictions, |
|
|
conf_thres=conf_thres, |
|
|
filter_by_frames=frame_filter, |
|
|
mask_black_bg=mask_black_bg, |
|
|
mask_white_bg=mask_white_bg, |
|
|
show_cam=show_cam, |
|
|
mask_sky=mask_sky, |
|
|
target_dir=target_dir, |
|
|
prediction_mode=prediction_mode, |
|
|
) |
|
|
|
|
|
|
|
|
v = np.asarray(point_cloud_data.vertices) |
|
|
c = np.asarray(point_cloud_data.colors) / 255.0 |
|
|
if c.shape[1] == 4: |
|
|
c = c[:, :3] |
|
|
|
|
|
glbscene.export(file_obj=glbfile) |
|
|
pcd = o3d.geometry.PointCloud() |
|
|
pcd.points = o3d.utility.Vector3dVector(v) |
|
|
pcd.colors = o3d.utility.Vector3dVector(c) |
|
|
|
|
|
pcd = pcd.voxel_down_sample(voxel_size=0.01) |
|
|
o3d.io.write_point_cloud(os.path.join(target_dir, "point_cloud.ply"), pcd) |
|
|
|
|
|
|
|
|
|
|
|
del predictions |
|
|
gc.collect() |
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
end_time = time.time() |
|
|
print(f"Total time: {end_time - start_time:.2f} seconds (including IO)") |
|
|
log_msg = f"Reconstruction Success ({len(all_files)} frames). Waiting for visualization." |
|
|
|
|
|
|
|
|
root_input_dir = os.path.dirname(target_dir) |
|
|
seq_name = os.path.basename(target_dir) |
|
|
try: |
|
|
subprocess.run( |
|
|
[ |
|
|
sys.executable, |
|
|
os.path.join( |
|
|
MK_PATH, |
|
|
"third_party", |
|
|
"detectron2", |
|
|
"projects", |
|
|
"CropFormer", |
|
|
"demo_cropformer", |
|
|
"mask_predict.py", |
|
|
), |
|
|
"--config-file", |
|
|
os.path.join( |
|
|
MK_PATH, |
|
|
"third_party", |
|
|
"detectron2", |
|
|
"projects", |
|
|
"CropFormer", |
|
|
"configs", |
|
|
"entityv2", |
|
|
"entity_segmentation", |
|
|
"mask2former_hornet_3x.yaml", |
|
|
), |
|
|
"--root", |
|
|
root_input_dir, |
|
|
"--image_path_pattern", |
|
|
"images/*.jpg", |
|
|
"--dataset", |
|
|
"arkit_gt", |
|
|
"--seq_name_list", |
|
|
seq_name, |
|
|
"--opts", |
|
|
"MODEL.WEIGHTS", |
|
|
os.path.join(MK_PATH, cropformer_name), |
|
|
], |
|
|
check=True, |
|
|
env={ |
|
|
**os.environ, |
|
|
|
|
|
"PYTHONPATH": MK_PATH |
|
|
+ (os.pathsep + os.environ["PYTHONPATH"] if os.environ.get("PYTHONPATH") else ""), |
|
|
}, |
|
|
) |
|
|
|
|
|
subprocess.run( |
|
|
[ |
|
|
sys.executable, |
|
|
os.path.join(MK_PATH, "main.py"), |
|
|
"--config", |
|
|
"wild", |
|
|
"--root", |
|
|
root_input_dir, |
|
|
"--seq_name_list", |
|
|
seq_name, |
|
|
], |
|
|
check=True, |
|
|
) |
|
|
|
|
|
env = dict(os.environ) |
|
|
env["PYTHONPATH"] = MK_PATH + (os.pathsep + env["PYTHONPATH"] if env.get("PYTHONPATH") else "") |
|
|
subprocess.run( |
|
|
[ |
|
|
sys.executable, |
|
|
os.path.join(MK_PATH, "semantics", "get_open-voc_features.py"), |
|
|
"--config", |
|
|
"wild", |
|
|
"--root", |
|
|
root_input_dir, |
|
|
"--seq_name_list", |
|
|
seq_name, |
|
|
], |
|
|
env=env, |
|
|
check=True, |
|
|
) |
|
|
except Exception as e: |
|
|
print(f"Warning: external MaskClustering pipeline failed: {e}") |
|
|
|
|
|
return glbfile, log_msg, gr.Dropdown(choices=frame_filter_choices, value=frame_filter, interactive=True) |
|
|
|
|
|
def visualize_detections(target_dir, conf_thres, frame_filter="All", mask_black_bg=False, mask_white_bg=False, show_cam=True, mask_sky=False, prediction_mode="Depthmap and Camera Branch"): |
|
|
""" |
|
|
Generate a GLB scene with bounding boxes for detected objects. |
|
|
""" |
|
|
if not target_dir or not os.path.exists(target_dir): |
|
|
return None, "Target directory not found." |
|
|
|
|
|
ply_path = os.path.join(target_dir, "point_cloud.ply") |
|
|
npz_path = os.path.join(target_dir, "output", "object", "prediction.npz") |
|
|
|
|
|
|
|
|
if not os.path.exists(ply_path): |
|
|
return None, f"Point cloud not found at {ply_path}. Please run detection first." |
|
|
|
|
|
pcd = o3d.io.read_point_cloud(ply_path) |
|
|
points = np.asarray(pcd.points) |
|
|
colors = np.asarray(pcd.colors) |
|
|
|
|
|
if points.size == 0: |
|
|
return None, "Point cloud is empty." |
|
|
|
|
|
|
|
|
scene = trimesh.Scene() |
|
|
|
|
|
if colors.size == 0: |
|
|
t_colors = np.ones((len(points), 4), dtype=np.uint8) * 255 |
|
|
else: |
|
|
if colors.max() <= 1.0: |
|
|
t_colors = (colors * 255).astype(np.uint8) |
|
|
else: |
|
|
t_colors = colors.astype(np.uint8) |
|
|
if t_colors.shape[1] == 3: |
|
|
t_colors = np.hstack([t_colors, np.ones((len(t_colors), 1), dtype=np.uint8) * 255]) |
|
|
|
|
|
base_pc = trimesh.PointCloud(vertices=points, colors=t_colors) |
|
|
scene.add_geometry(base_pc) |
|
|
|
|
|
|
|
|
legend_md = "" |
|
|
if os.path.exists(npz_path): |
|
|
try: |
|
|
loaded = np.load(npz_path, allow_pickle=True) |
|
|
|
|
|
if 'pred_masks' in loaded: |
|
|
masks = loaded['pred_masks'].T |
|
|
labels = loaded['pred_classes'] |
|
|
confs = loaded['pred_score'] |
|
|
|
|
|
|
|
|
text_features_path = os.path.join(target_dir, "text_features.npy") |
|
|
label_to_name = {} |
|
|
if os.path.exists(text_features_path): |
|
|
try: |
|
|
text_features_dict = np.load(text_features_path, allow_pickle=True).item() |
|
|
feature_keys = list(text_features_dict.keys()) |
|
|
for i, name in enumerate(feature_keys): |
|
|
label_to_name[i] = name |
|
|
except Exception as e: |
|
|
print(f"Warning: Could not load text features for label mapping: {e}") |
|
|
|
|
|
|
|
|
if isinstance(confs, (list, tuple)): |
|
|
confs = np.array(confs) |
|
|
|
|
|
valid_indices = np.where(confs > conf_thres)[0] |
|
|
|
|
|
if len(valid_indices) > 0: |
|
|
legend_items = {} |
|
|
cmap = plt.get_cmap("tab10") |
|
|
|
|
|
detected_labels = np.unique(labels[valid_indices]) |
|
|
label_to_color = {label: cmap(i % 10) for i, label in enumerate(detected_labels)} |
|
|
|
|
|
for idx in valid_indices: |
|
|
mask = masks[idx] |
|
|
if hasattr(mask, "toarray"): |
|
|
mask = mask.toarray().flatten() |
|
|
mask = mask.astype(bool) |
|
|
|
|
|
|
|
|
if len(mask) != len(points): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
if len(mask) == len(points): |
|
|
obj_points = points[mask] |
|
|
if len(obj_points) >= 4: |
|
|
obj_pcd = trimesh.PointCloud(obj_points) |
|
|
try: |
|
|
bbox = obj_pcd.bounding_box_oriented |
|
|
except Exception: |
|
|
bbox = obj_pcd.bounding_box |
|
|
|
|
|
|
|
|
try: |
|
|
ext = np.asarray(bbox.extents).astype(float) |
|
|
if float(np.max(ext)) > 2.5: |
|
|
continue |
|
|
except Exception: |
|
|
pass |
|
|
|
|
|
|
|
|
|
|
|
verts = np.asarray(bbox.vertices) |
|
|
if verts.shape[0] != 8: |
|
|
continue |
|
|
T = np.asarray(bbox.transform) |
|
|
center = T[:3, 3] |
|
|
R = T[:3, :3] |
|
|
|
|
|
local = (verts - center) @ R |
|
|
|
|
|
signs = np.where(local >= 0.0, 1, -1).astype(int) |
|
|
sign_to_idx = {tuple(s): i for i, s in enumerate(signs)} |
|
|
|
|
|
edges_idx = set() |
|
|
for sx in (-1, 1): |
|
|
for sy in (-1, 1): |
|
|
for sz in (-1, 1): |
|
|
s = (sx, sy, sz) |
|
|
if s not in sign_to_idx: |
|
|
continue |
|
|
for axis in range(3): |
|
|
s2 = list(s) |
|
|
s2[axis] *= -1 |
|
|
s2 = tuple(s2) |
|
|
if s2 in sign_to_idx: |
|
|
i0 = sign_to_idx[s] |
|
|
i1 = sign_to_idx[s2] |
|
|
if i0 != i1: |
|
|
edges_idx.add(tuple(sorted((i0, i1)))) |
|
|
if not edges_idx: |
|
|
continue |
|
|
segments = np.array([[verts[i], verts[j]] for (i, j) in edges_idx], dtype=float) |
|
|
|
|
|
lbl_idx = labels[idx] |
|
|
lbl_name = label_to_name.get(lbl_idx, f"Class {lbl_idx}") |
|
|
color = label_to_color.get(lbl_idx, (1, 0, 0, 1)) |
|
|
|
|
|
color_u8 = (np.array(color) * 255).astype(np.uint8) |
|
|
|
|
|
radius = 0.015 |
|
|
for seg in segments: |
|
|
p1, p2 = seg[0], seg[1] |
|
|
v = p2 - p1 |
|
|
length = float(np.linalg.norm(v)) |
|
|
if length <= 1e-8: |
|
|
continue |
|
|
direction = v / length |
|
|
try: |
|
|
cyl = trimesh.creation.cylinder(radius=radius, height=length, sections=12) |
|
|
except Exception: |
|
|
continue |
|
|
|
|
|
try: |
|
|
align = trimesh.geometry.align_vectors([0, 0, 1], direction) |
|
|
cyl.apply_transform(align) |
|
|
except Exception: |
|
|
pass |
|
|
midpoint = (p1 + p2) / 2.0 |
|
|
cyl.apply_translation(midpoint) |
|
|
|
|
|
try: |
|
|
emissive = (color_u8[:3] / 255.0).tolist() |
|
|
mat = trimesh.visual.material.PBRMaterial( |
|
|
baseColorFactor=(0.0, 0.0, 0.0, 1.0), |
|
|
metallicFactor=0.0, |
|
|
roughnessFactor=1.0, |
|
|
emissiveFactor=emissive, |
|
|
doubleSided=True, |
|
|
) |
|
|
cyl.visual.material = mat |
|
|
except Exception: |
|
|
cyl.visual.face_colors = np.tile(color_u8[None, :], (len(cyl.faces), 1)) |
|
|
scene.add_geometry(cyl) |
|
|
legend_items[lbl_name] = color |
|
|
|
|
|
legend_md = "### Legend\n" |
|
|
for lbl_name, color in legend_items.items(): |
|
|
c_u8 = (np.array(color) * 255).astype(np.uint8) |
|
|
hex_c = "#{:02x}{:02x}{:02x}".format(c_u8[0], c_u8[1], c_u8[2]) |
|
|
legend_md += f"- <span style='color:{hex_c}'>■</span> {lbl_name}\n" |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error loading detections: {e}") |
|
|
legend_md = f"Error loading detections: {e}" |
|
|
|
|
|
|
|
|
out_path = os.path.join(target_dir, f"combined_viz_{conf_thres}.glb") |
|
|
scene.export(file_obj=out_path) |
|
|
|
|
|
return out_path, legend_md |
|
|
|
|
|
def detect_objects(text_labels, target_dir, conf_thres, *viz_args): |
|
|
""" |
|
|
Detect objects from text labels and return the detected objects. |
|
|
""" |
|
|
|
|
|
if not text_labels or not isinstance(text_labels, str) or len([l.strip() for l in text_labels.split(";") if l.strip()]) == 0: |
|
|
return None, "Please enter at least one text label (separated by ';')." |
|
|
|
|
|
|
|
|
if torch.cuda.is_available() or not ZOO3D_SKIP_DOWNLOADS: |
|
|
try: |
|
|
check_weights() |
|
|
except Exception as e: |
|
|
print(f"Warning: could not ensure Mask2Former weights: {e}") |
|
|
|
|
|
|
|
|
predictions_path = os.path.join(target_dir, "predictions.npz") |
|
|
if not os.path.exists(predictions_path): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("Predictions not found, running reconstruction first...") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
reconstruct(target_dir, 50.0, *viz_args, text_labels=text_labels) |
|
|
|
|
|
|
|
|
|
|
|
if text_labels: |
|
|
labels = [l.strip() for l in text_labels.split(";") if l.strip()] |
|
|
if labels: |
|
|
print(f"Extracting features for labels: {labels}") |
|
|
_, _, clip_model = _init_models() |
|
|
text_features = extract_text_feature(labels, clip_model, target_dir) |
|
|
print(f"Text features: {text_features}") |
|
|
try: |
|
|
env = dict(os.environ) |
|
|
env["PYTHONPATH"] = ( |
|
|
DETECTRON2_ROOT |
|
|
+ os.pathsep |
|
|
+ MK_PATH |
|
|
+ (os.pathsep + env["PYTHONPATH"] if env.get("PYTHONPATH") else "") |
|
|
) |
|
|
root_input_dir = os.path.dirname(target_dir) |
|
|
seq_name = os.path.basename(target_dir) |
|
|
subprocess.run( |
|
|
[ |
|
|
sys.executable, |
|
|
os.path.join(MK_PATH, "semantics", "wopen-voc_query.py"), |
|
|
"--config", |
|
|
"wild", |
|
|
"--root", |
|
|
root_input_dir, |
|
|
"--seq_name", |
|
|
seq_name, |
|
|
], |
|
|
env=env, |
|
|
check=True, |
|
|
) |
|
|
except Exception as e: |
|
|
print(f"Warning: open-voc query failed: {e}") |
|
|
|
|
|
return visualize_detections(target_dir, conf_thres, *viz_args) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def clear_fields(): |
|
|
""" |
|
|
Clears the 3D viewer, the stored target_dir, and empties the gallery. |
|
|
""" |
|
|
return None |
|
|
|
|
|
|
|
|
def update_log(): |
|
|
""" |
|
|
Display a quick log message while waiting. |
|
|
""" |
|
|
return "Loading and Reconstructing..." |
|
|
|
|
|
|
|
|
def update_visualization( |
|
|
target_dir, conf_thres, frame_filter, mask_black_bg, mask_white_bg, show_cam, mask_sky, prediction_mode, is_example |
|
|
): |
|
|
""" |
|
|
Reload saved predictions from npz, create (or reuse) the GLB for new parameters, |
|
|
and return it for the 3D viewer. If is_example == "True", skip. |
|
|
""" |
|
|
|
|
|
|
|
|
if is_example == "True": |
|
|
return None, "No reconstruction available. Please click the Reconstruct button first." |
|
|
|
|
|
if not target_dir or target_dir == "None" or not os.path.isdir(target_dir): |
|
|
return None, "No reconstruction available. Please click the Reconstruct button first." |
|
|
|
|
|
predictions_path = os.path.join(target_dir, "predictions.npz") |
|
|
if not os.path.exists(predictions_path): |
|
|
return None, f"No reconstruction available at {predictions_path}. Please run 'Reconstruct' first." |
|
|
|
|
|
key_list = [ |
|
|
"pose_enc", |
|
|
"depth", |
|
|
"depth_conf", |
|
|
"world_points", |
|
|
"world_points_conf", |
|
|
"images", |
|
|
"extrinsic", |
|
|
"intrinsic", |
|
|
"world_points_from_depth", |
|
|
] |
|
|
|
|
|
loaded = np.load(predictions_path, allow_pickle=True) |
|
|
predictions = {key: np.array(loaded[key]) for key in key_list} |
|
|
|
|
|
glbfile = os.path.join( |
|
|
target_dir, |
|
|
f"glbscene_{conf_thres}_{frame_filter.replace('.', '_').replace(':', '').replace(' ', '_')}_maskb{mask_black_bg}_maskw{mask_white_bg}_cam{show_cam}_sky{mask_sky}_pred{prediction_mode.replace(' ', '_')}.glb", |
|
|
) |
|
|
|
|
|
glbscene = predictions_to_glb( |
|
|
predictions, |
|
|
conf_thres=conf_thres, |
|
|
filter_by_frames=frame_filter, |
|
|
mask_black_bg=mask_black_bg, |
|
|
mask_white_bg=mask_white_bg, |
|
|
show_cam=show_cam, |
|
|
mask_sky=mask_sky, |
|
|
target_dir=target_dir, |
|
|
prediction_mode=prediction_mode, |
|
|
) |
|
|
glbscene.export(file_obj=glbfile) |
|
|
|
|
|
return glbfile, "Updating Visualization" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
great_wall_video = "examples/videos/great_wall.mp4" |
|
|
colosseum_video = "examples/videos/Colosseum.mp4" |
|
|
room_video = "examples/videos/room.mp4" |
|
|
kitchen_video = "examples/videos/kitchen.mp4" |
|
|
fern_video = "examples/videos/fern.mp4" |
|
|
single_cartoon_video = "examples/videos/single_cartoon.mp4" |
|
|
single_oil_painting_video = "examples/videos/single_oil_painting.mp4" |
|
|
pyramid_video = "examples/videos/pyramid.mp4" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
theme = gr.themes.Ocean() |
|
|
theme.set( |
|
|
checkbox_label_background_fill_selected="*button_primary_background_fill", |
|
|
checkbox_label_text_color_selected="*button_primary_text_color", |
|
|
) |
|
|
|
|
|
with gr.Blocks( |
|
|
theme=theme, |
|
|
css=""" |
|
|
.custom-log * { |
|
|
font-style: italic; |
|
|
font-size: 22px !important; |
|
|
background-image: linear-gradient(120deg, #0ea5e9 0%, #6ee7b7 60%, #34d399 100%); |
|
|
-webkit-background-clip: text; |
|
|
background-clip: text; |
|
|
font-weight: bold !important; |
|
|
color: transparent !important; |
|
|
text-align: center !important; |
|
|
} |
|
|
|
|
|
.example-log * { |
|
|
font-style: italic; |
|
|
font-size: 16px !important; |
|
|
background-image: linear-gradient(120deg, #0ea5e9 0%, #6ee7b7 60%, #34d399 100%); |
|
|
-webkit-background-clip: text; |
|
|
background-clip: text; |
|
|
color: transparent !important; |
|
|
} |
|
|
|
|
|
#my_radio .wrap { |
|
|
display: flex; |
|
|
flex-wrap: nowrap; |
|
|
justify-content: center; |
|
|
align-items: center; |
|
|
} |
|
|
|
|
|
#my_radio .wrap label { |
|
|
display: flex; |
|
|
width: 50%; |
|
|
justify-content: center; |
|
|
align-items: center; |
|
|
margin: 0; |
|
|
padding: 10px 0; |
|
|
box-sizing: border-box; |
|
|
} |
|
|
""", |
|
|
) as demo: |
|
|
|
|
|
is_example = gr.Textbox(label="is_example", visible=False, value="None") |
|
|
num_images = gr.Textbox(label="num_images", visible=False, value="None") |
|
|
|
|
|
gr.HTML( |
|
|
""" |
|
|
<h1>🦁 Zoo3D: Zero-Shot 3D Object Detection at Scene Level 🐼</h1> |
|
|
<p> |
|
|
<a href="https://github.com/col14m/zoo3d">GitHub Repository</a> |
|
|
</p> |
|
|
|
|
|
<div style="font-size: 16px; line-height: 1.5;"> |
|
|
<p>Upload a video or a set of images to create a 3D reconstruction and run open‑vocabulary 3D object detection from your text labels. The app builds a point cloud and draws colored wireframe bounding boxes for the detected objects.</p> |
|
|
|
|
|
<h3>Getting Started:</h3> |
|
|
<ol> |
|
|
<li><strong>Upload Your Data:</strong> Use "Upload Video" or "Upload Images". Videos are sampled at 1 frame/sec.</li> |
|
|
<li><strong>Enter Text Labels (Required):</strong> Provide one or more labels separated by semicolons, e.g. <code>chair; table; plant</code>.</li> |
|
|
<li><strong>Detect:</strong> Click <strong>"Detect Objects"</strong>. The app will reconstruct the scene (if needed) and then run detection.</li> |
|
|
<li><strong>Threshold (Optional):</strong> Tune the <em>Detection Threshold</em> (0–1). Higher = fewer, more confident detections.</li> |
|
|
<li><strong>Visualize & Download:</strong> A single 3D view shows the point cloud and colored wireframe boxes. A legend maps colors to labels. You can download the GLB.</li> |
|
|
</ol> |
|
|
<p><strong style="color: #0ea5e9;">Notes:</strong> <span style="color: #0ea5e9; font-weight: bold;">Reconstruction is triggered automatically on first run. If no labels are provided, you'll see an error: </span><code>Please enter at least one text label (separated by ';').</code></p> |
|
|
</div> |
|
|
""" |
|
|
) |
|
|
|
|
|
target_dir_output = gr.Textbox(label="Target Dir", visible=False, value="None") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=2): |
|
|
input_video = gr.Video(label="Upload Video", interactive=True) |
|
|
input_images = gr.File(file_count="multiple", label="Upload Images", interactive=True) |
|
|
|
|
|
def _safe_gallery(**kwargs): |
|
|
|
|
|
|
|
|
while True: |
|
|
try: |
|
|
return gr.Gallery(**kwargs) |
|
|
except TypeError as e: |
|
|
msg = str(e) |
|
|
|
|
|
bad = None |
|
|
import re |
|
|
m = re.search(r"unexpected keyword argument '([^']+)'", msg) |
|
|
if m: |
|
|
bad = m.group(1) |
|
|
if bad and bad in kwargs: |
|
|
kwargs.pop(bad) |
|
|
continue |
|
|
|
|
|
for k in ["show_download_button", "preview", "object_fit", "columns", "height"]: |
|
|
if k in kwargs: |
|
|
kwargs.pop(k) |
|
|
break |
|
|
else: |
|
|
raise |
|
|
|
|
|
image_gallery = _safe_gallery( |
|
|
label="Preview", |
|
|
columns=4, |
|
|
height="300px", |
|
|
show_download_button=True, |
|
|
object_fit="contain", |
|
|
preview=True, |
|
|
) |
|
|
|
|
|
with gr.Column(scale=4): |
|
|
text_labels = gr.Textbox(label="Text Labels (separated by ;)", placeholder="cat; dog; car") |
|
|
with gr.Column(): |
|
|
|
|
|
|
|
|
gr.Markdown("**3D Reconstruction & detection (Point Cloud and Bounding Boxes)**") |
|
|
log_output = gr.Markdown( |
|
|
"Please upload a video or images, then click Detect Objects.", elem_classes=["custom-log"] |
|
|
) |
|
|
reconstruction_output = gr.Model3D(height=520, zoom_speed=0.5, pan_speed=0.5) |
|
|
|
|
|
with gr.Row(): |
|
|
detect_btn = gr.Button("Detect Objects", scale=1, variant="primary") |
|
|
clear_btn = gr.ClearButton( |
|
|
[input_video, input_images, reconstruction_output, log_output, target_dir_output, image_gallery, text_labels], |
|
|
scale=1, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
prediction_mode = gr.Textbox(value="Depthmap and Camera Branch", visible=False) |
|
|
|
|
|
|
|
|
with gr.Row(): |
|
|
conf_thres = gr.Slider( |
|
|
minimum=0, |
|
|
maximum=100, |
|
|
value=50, |
|
|
step=0.1, |
|
|
label="Confidence Threshold (%)", |
|
|
visible=False, |
|
|
) |
|
|
frame_filter = gr.Dropdown( |
|
|
choices=["All"], |
|
|
value="All", |
|
|
label="Show Points from Frame", |
|
|
visible=False, |
|
|
) |
|
|
with gr.Column(): |
|
|
show_cam = gr.Checkbox(label="Show Camera", value=True, visible=False) |
|
|
mask_sky = gr.Checkbox(label="Filter Sky", value=False, visible=False) |
|
|
mask_black_bg = gr.Checkbox(label="Filter Black Background", value=False, visible=False) |
|
|
mask_white_bg = gr.Checkbox(label="Filter White Background", value=False, visible=False) |
|
|
|
|
|
|
|
|
detection_conf_thres = gr.Slider( |
|
|
minimum=0, |
|
|
maximum=1, |
|
|
value=0.6, |
|
|
step=0.01, |
|
|
label="Detection Threshold", |
|
|
) |
|
|
detection_legend = gr.Markdown("Legend will appear here") |
|
|
|
|
|
|
|
|
examples = [ |
|
|
] |
|
|
|
|
|
def example_pipeline( |
|
|
input_video, |
|
|
num_images_str, |
|
|
input_images, |
|
|
conf_thres, |
|
|
mask_black_bg, |
|
|
mask_white_bg, |
|
|
show_cam, |
|
|
mask_sky, |
|
|
prediction_mode, |
|
|
is_example_str, |
|
|
text_labels, |
|
|
): |
|
|
""" |
|
|
1) Copy example images to new target_dir |
|
|
2) Reconstruct (and Detect if labels present) |
|
|
3) Return model3D + logs + new_dir + updated dropdown + gallery |
|
|
We do NOT return is_example. It's just an input. |
|
|
""" |
|
|
target_dir, image_paths = handle_uploads(input_video, input_images) |
|
|
|
|
|
frame_filter = "All" |
|
|
|
|
|
detection_conf = 0.85 |
|
|
|
|
|
glbfile, legend_md = detect_objects( |
|
|
text_labels, |
|
|
target_dir, |
|
|
detection_conf, |
|
|
frame_filter, |
|
|
mask_black_bg, |
|
|
mask_white_bg, |
|
|
show_cam, |
|
|
mask_sky, |
|
|
prediction_mode |
|
|
) |
|
|
|
|
|
log_msg = "Example loaded and processed." |
|
|
|
|
|
return glbfile, log_msg + "\n\n" + legend_md, target_dir, gr.Dropdown(choices=["All"], value="All", interactive=True), image_paths |
|
|
|
|
|
detect_btn.click(fn=clear_fields, inputs=[], outputs=[]).then( |
|
|
fn=detect_objects, |
|
|
inputs=[ |
|
|
text_labels, |
|
|
target_dir_output, |
|
|
detection_conf_thres, |
|
|
frame_filter, |
|
|
mask_black_bg, |
|
|
mask_white_bg, |
|
|
show_cam, |
|
|
mask_sky, |
|
|
prediction_mode |
|
|
], |
|
|
outputs=[reconstruction_output, detection_legend] |
|
|
).then( |
|
|
fn=lambda: "False", inputs=[], outputs=[is_example] |
|
|
) |
|
|
|
|
|
detection_conf_thres.change( |
|
|
fn=visualize_detections, |
|
|
inputs=[ |
|
|
target_dir_output, |
|
|
detection_conf_thres, |
|
|
frame_filter, |
|
|
mask_black_bg, |
|
|
mask_white_bg, |
|
|
show_cam, |
|
|
mask_sky, |
|
|
prediction_mode |
|
|
], |
|
|
outputs=[reconstruction_output, detection_legend] |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
conf_thres.change( |
|
|
update_visualization, |
|
|
[ |
|
|
target_dir_output, |
|
|
conf_thres, |
|
|
frame_filter, |
|
|
mask_black_bg, |
|
|
mask_white_bg, |
|
|
show_cam, |
|
|
mask_sky, |
|
|
prediction_mode, |
|
|
is_example, |
|
|
], |
|
|
[reconstruction_output, log_output], |
|
|
) |
|
|
frame_filter.change( |
|
|
update_visualization, |
|
|
[ |
|
|
target_dir_output, |
|
|
conf_thres, |
|
|
frame_filter, |
|
|
mask_black_bg, |
|
|
mask_white_bg, |
|
|
show_cam, |
|
|
mask_sky, |
|
|
prediction_mode, |
|
|
is_example, |
|
|
], |
|
|
[reconstruction_output, log_output], |
|
|
) |
|
|
mask_black_bg.change( |
|
|
update_visualization, |
|
|
[ |
|
|
target_dir_output, |
|
|
conf_thres, |
|
|
frame_filter, |
|
|
mask_black_bg, |
|
|
mask_white_bg, |
|
|
show_cam, |
|
|
mask_sky, |
|
|
prediction_mode, |
|
|
is_example, |
|
|
], |
|
|
[reconstruction_output, log_output], |
|
|
) |
|
|
mask_white_bg.change( |
|
|
update_visualization, |
|
|
[ |
|
|
target_dir_output, |
|
|
conf_thres, |
|
|
frame_filter, |
|
|
mask_black_bg, |
|
|
mask_white_bg, |
|
|
show_cam, |
|
|
mask_sky, |
|
|
prediction_mode, |
|
|
is_example, |
|
|
], |
|
|
[reconstruction_output, log_output], |
|
|
) |
|
|
show_cam.change( |
|
|
update_visualization, |
|
|
[ |
|
|
target_dir_output, |
|
|
conf_thres, |
|
|
frame_filter, |
|
|
mask_black_bg, |
|
|
mask_white_bg, |
|
|
show_cam, |
|
|
mask_sky, |
|
|
prediction_mode, |
|
|
is_example, |
|
|
], |
|
|
[reconstruction_output, log_output], |
|
|
) |
|
|
prediction_mode.change( |
|
|
update_visualization, |
|
|
[ |
|
|
target_dir_output, |
|
|
conf_thres, |
|
|
frame_filter, |
|
|
mask_black_bg, |
|
|
mask_white_bg, |
|
|
show_cam, |
|
|
mask_sky, |
|
|
prediction_mode, |
|
|
is_example, |
|
|
], |
|
|
[reconstruction_output, log_output], |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
input_video.change( |
|
|
fn=update_gallery_on_upload, |
|
|
inputs=[input_video, input_images], |
|
|
outputs=[reconstruction_output, target_dir_output, image_gallery, log_output], |
|
|
) |
|
|
input_images.change( |
|
|
fn=update_gallery_on_upload, |
|
|
inputs=[input_video, input_images], |
|
|
outputs=[reconstruction_output, target_dir_output, image_gallery, log_output], |
|
|
) |
|
|
|
|
|
def main(): |
|
|
demo.queue(max_size=20).launch(show_error=True, share=False, show_api=False) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|