Spaces:
Configuration error
Configuration error
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import json | |
| import os | |
| import cv2 | |
| import torch | |
| import numpy as np | |
| from scipy.spatial.transform import Rotation | |
| import gradio as gr | |
| import shutil | |
| from datetime import datetime | |
| import glob | |
| import gc | |
| import time | |
| import zipfile | |
| import functools | |
| from typing import Any, Dict, Optional | |
| from stream3r.models.stream3r import STream3R | |
| from stream3r.stream_session import StreamSession | |
| from stream3r.models.components.utils.load_fn import load_and_preprocess_images | |
| from stream3r.models.components.utils.pose_enc import pose_encoding_to_extri_intri | |
| from stream3r.models.components.utils.geometry import unproject_depth_map_to_point_map | |
| from stream3r.utils.visual_utils import predictions_to_glb | |
| import warnings | |
| warnings.filterwarnings("ignore", message="Can't initialize NVML") | |
| import spaces | |
| device = "cuda" | |
| model = STream3R.from_pretrained("yslan/STream3R") | |
| def handle_zip(zip_file): | |
| outdir = "uploads" | |
| os.makedirs(outdir, exist_ok=True) | |
| with zipfile.ZipFile(zip_file.name, "r") as zf: | |
| zf.extractall(outdir) | |
| return f"Extracted {len(os.listdir(outdir))} files" | |
| # --- add near your imports/helpers --- | |
| ALLOWED_IMG_EXT = {".png", ".jpg", ".jpeg", ".bmp", ".webp"} | |
| def _is_within_dir(base_dir: str, path: str) -> bool: | |
| # Prevent zip-slip: ensure extracted path stays inside base_dir | |
| base_dir = os.path.abspath(base_dir) | |
| path = os.path.abspath(path) | |
| return os.path.commonpath([base_dir]) == os.path.commonpath([base_dir, path]) | |
| def extract_images_from_zip(zip_path: str, outdir: str) -> list[str]: | |
| """ | |
| Extracts only image files from a zip into outdir. | |
| Returns list of extracted file paths. | |
| """ | |
| os.makedirs(outdir, exist_ok=True) | |
| extracted = [] | |
| with zipfile.ZipFile(zip_path, "r") as zf: | |
| for member in zf.infolist(): | |
| # Skip directories and non-image files | |
| name = member.filename | |
| if name.endswith("/"): | |
| continue | |
| ext = os.path.splitext(name)[1].lower() | |
| if ext not in ALLOWED_IMG_EXT: | |
| continue | |
| # Construct final path safely | |
| base_name = os.path.basename(name) | |
| name_root, name_ext = os.path.splitext(base_name) | |
| dest_path = os.path.join(outdir, base_name) | |
| counter = 1 | |
| while os.path.exists(dest_path): | |
| dest_path = os.path.join(outdir, f"{name_root}_{counter}{name_ext}") | |
| counter += 1 | |
| # Zip-slip guard (in case filename has ../ etc.) | |
| if not _is_within_dir(outdir, dest_path): | |
| continue | |
| with zf.open(member) as src, open(dest_path, "wb") as dst: | |
| shutil.copyfileobj(src, dst) | |
| extracted.append(dest_path) | |
| return extracted | |
| def extract_session_state(zip_path: str, extract_root: str) -> str: | |
| """Extract a previously saved session archive into *extract_root*. | |
| Returns the directory that contains the restored session data. | |
| """ | |
| if os.path.exists(extract_root): | |
| shutil.rmtree(extract_root) | |
| os.makedirs(extract_root, exist_ok=True) | |
| with zipfile.ZipFile(zip_path, "r") as zf: | |
| zf.extractall(extract_root) | |
| entries = [os.path.join(extract_root, entry) for entry in os.listdir(extract_root)] | |
| dirs = [entry for entry in entries if os.path.isdir(entry)] | |
| files = [entry for entry in entries if os.path.isfile(entry)] | |
| if len(dirs) == 1 and not files: | |
| return dirs[0] | |
| return extract_root | |
| def package_session_state(target_dir: str) -> str: | |
| """Create a zip archive containing the entire session directory.""" | |
| if not os.path.isdir(target_dir): | |
| raise ValueError(f"Target directory does not exist: {target_dir}") | |
| os.makedirs("demo_cache", exist_ok=True) | |
| archive_name = f"{os.path.basename(os.path.normpath(target_dir))}_session.zip" | |
| archive_path = os.path.join("demo_cache", archive_name) | |
| if os.path.exists(archive_path): | |
| os.remove(archive_path) | |
| with zipfile.ZipFile(archive_path, "w", compression=zipfile.ZIP_DEFLATED) as zf: | |
| for root, _, files in os.walk(target_dir): | |
| for fname in files: | |
| file_path = os.path.join(root, fname) | |
| if os.path.abspath(file_path) == os.path.abspath(archive_path): | |
| continue | |
| arcname = os.path.join(os.path.basename(target_dir), os.path.relpath(file_path, target_dir)) | |
| zf.write(file_path, arcname) | |
| return archive_path | |
| def _copy_with_unique_name(src_path: str, dst_dir: str) -> str: | |
| """Copy *src_path* into *dst_dir*, avoiding filename collisions.""" | |
| base_name = os.path.basename(src_path) | |
| name, ext = os.path.splitext(base_name) | |
| candidate = base_name | |
| counter = 1 | |
| dest_path = os.path.join(dst_dir, candidate) | |
| while os.path.exists(dest_path): | |
| candidate = f"{name}_{counter}{ext}" | |
| dest_path = os.path.join(dst_dir, candidate) | |
| counter += 1 | |
| shutil.copy(src_path, dest_path) | |
| return dest_path | |
| def _resolve_path(file_data) -> Optional[str]: | |
| if file_data is None: | |
| return None | |
| if isinstance(file_data, dict) and "name" in file_data: | |
| return file_data["name"] | |
| if hasattr(file_data, "name"): | |
| return getattr(file_data, "name") | |
| return str(file_data) | |
| STREAM_SESSION_CONFIG_PATH = os.path.join(os.path.dirname(__file__), "configs", "stream_session.json") | |
| def load_stream_session_config() -> Dict[str, Any]: | |
| try: | |
| with open(STREAM_SESSION_CONFIG_PATH, "r", encoding="utf-8") as handle: | |
| data = json.load(handle) | |
| if isinstance(data, dict): | |
| return data | |
| except (OSError, json.JSONDecodeError): | |
| pass | |
| return {} | |
| def load_session_settings(target_dir: str) -> Dict[str, Any]: | |
| settings_path = os.path.join(target_dir, "session_settings.json") | |
| if not os.path.exists(settings_path): | |
| return {} | |
| try: | |
| with open(settings_path, "r", encoding="utf-8") as handle: | |
| data = json.load(handle) | |
| if isinstance(data, dict): | |
| return data | |
| except (json.JSONDecodeError, OSError) as exc: | |
| print(f"Failed to load session settings from {settings_path}: {exc}") | |
| return {} | |
| def build_frame_filter_choices(target_dir_images: str) -> list[str]: | |
| if not os.path.isdir(target_dir_images): | |
| return ["All"] | |
| files = sorted(os.listdir(target_dir_images)) | |
| return ["All"] + [f"{idx}: {name}" for idx, name in enumerate(files)] | |
| def sanitize_frame_filter_label(label: Optional[str]) -> str: | |
| if not label: | |
| return "All" | |
| return label.replace('.', '_').replace(':', '').replace(' ', '_') | |
| def select_top_k_frames(predictions: Dict[str, np.ndarray], images_dir: str, top_k: int) -> list[Dict[str, Any]]: | |
| if top_k <= 0: | |
| return [] | |
| if not os.path.isdir(images_dir): | |
| return [] | |
| image_files = sorted( | |
| [fname for fname in os.listdir(images_dir) if not fname.startswith('.')] | |
| ) | |
| extrinsics = predictions.get("extrinsic") | |
| if extrinsics is None: | |
| return [] | |
| num_frames = extrinsics.shape[0] | |
| if num_frames == 0: | |
| return [] | |
| top_k = min(top_k, num_frames) | |
| def _camera_position(extr: np.ndarray) -> np.ndarray: | |
| R = extr[:, :3] | |
| t = extr[:, 3] | |
| return (-R.T @ t).astype(np.float64) | |
| positions = np.array([_camera_position(extrinsics[i]) for i in range(num_frames)]) | |
| forward_vectors = np.array([extrinsics[i][2, :3] for i in range(num_frames)]) | |
| forward_norms = np.linalg.norm(forward_vectors, axis=1, keepdims=True) | |
| forward_vectors = np.divide(forward_vectors, forward_norms, out=np.zeros_like(forward_vectors), where=forward_norms > 0) | |
| conf_tensor = predictions.get("world_points_conf") | |
| if conf_tensor is None: | |
| conf_tensor = predictions.get("depth_conf") | |
| quality_scores = np.zeros(num_frames, dtype=np.float64) | |
| coverage_scores = np.zeros(num_frames, dtype=np.float64) | |
| for idx in range(num_frames): | |
| if conf_tensor is not None: | |
| conf = conf_tensor[idx].reshape(-1) | |
| if conf.size: | |
| conf = conf[~np.isnan(conf)] | |
| if conf.size: | |
| quality_scores[idx] = float(np.mean(conf)) | |
| high_thresh = np.percentile(conf, 75) | |
| coverage_scores[idx] = float(np.mean(conf >= high_thresh)) | |
| continue | |
| quality_scores[idx] = 0.0 | |
| coverage_scores[idx] = 0.0 | |
| else: | |
| quality_scores[idx] = 1.0 | |
| coverage_scores[idx] = 1.0 | |
| max_cov = coverage_scores.max() | |
| if max_cov > 0: | |
| coverage_scores = coverage_scores / max_cov | |
| else: | |
| coverage_scores = np.ones_like(coverage_scores) | |
| base_scores = quality_scores * (0.5 + 0.5 * coverage_scores) | |
| indices = list(range(num_frames)) | |
| indices.sort(key=lambda idx: base_scores[idx], reverse=True) | |
| bbox_min = positions.min(axis=0) | |
| bbox_max = positions.max(axis=0) | |
| scene_scale = float(np.linalg.norm(bbox_max - bbox_min)) | |
| pos_threshold = max(0.1, 0.1 * scene_scale) | |
| ori_threshold = 15.0 | |
| selected = [] | |
| for idx in indices: | |
| if not selected: | |
| selected.append(idx) | |
| else: | |
| accept = False | |
| min_dist = min(np.linalg.norm(positions[idx] - positions[j]) for j in selected) | |
| max_angle = max( | |
| np.degrees( | |
| np.arccos( | |
| np.clip(np.dot(forward_vectors[idx], forward_vectors[j]), -1.0, 1.0) | |
| ) | |
| ) | |
| for j in selected | |
| ) | |
| if min_dist >= pos_threshold or max_angle >= ori_threshold: | |
| accept = True | |
| elif len(selected) < max(1, top_k // 3): | |
| accept = True | |
| if accept: | |
| selected.append(idx) | |
| if len(selected) >= top_k: | |
| break | |
| if len(selected) < top_k: | |
| for idx in indices: | |
| if idx not in selected: | |
| selected.append(idx) | |
| if len(selected) >= top_k: | |
| break | |
| selected = sorted(selected[:top_k]) | |
| records = [] | |
| for idx in selected: | |
| filename = image_files[idx] if idx < len(image_files) else f"frame_{idx:06d}" | |
| records.append( | |
| { | |
| "index": int(idx), | |
| "filename": filename, | |
| "score": float(base_scores[idx]), | |
| "mean_confidence": float(quality_scores[idx]), | |
| "coverage_ratio": float(coverage_scores[idx]), | |
| } | |
| ) | |
| return records | |
| # ------------------------------------------------------------------------- | |
| # 1) Core model inference | |
| # ------------------------------------------------------------------------- | |
| # triggers ZeroGPU allocation for this call | |
| def run_model(target_dir: str, model: STream3R, mode: str="causal", streaming: bool=False) -> tuple[dict, str | None]: | |
| """ | |
| Run the STream3R model on images in the 'target_dir/images' folder. | |
| Args: | |
| target_dir: Directory containing the images subfolder | |
| model: STream3R model instance | |
| mode: Processing mode ("causal", "window", or "full") | |
| streaming: If True, use StreamSession for sequential processing; if False, use batch processing | |
| Returns: | |
| tuple[dict, str | None]: Predictions dictionary and optional path to the saved session cache when | |
| streaming mode is used. | |
| """ | |
| print(f"Processing images from {target_dir}") | |
| # Device check | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| if not torch.cuda.is_available(): | |
| raise ValueError("CUDA is not available. Check your environment.") | |
| # Move model to device | |
| model = model.to(device) | |
| model.eval() | |
| # Load and preprocess images | |
| image_names = glob.glob(os.path.join(target_dir, "images", "*")) | |
| image_names = sorted(image_names) | |
| print(f"Found {len(image_names)} images") | |
| if len(image_names) == 0: | |
| raise ValueError("No images found. Check your upload.") | |
| images = load_and_preprocess_images(image_names).to(device) | |
| print(f"Preprocessed images shape: {images.shape}") | |
| # Run inference | |
| print(f"Running inference in {'streaming' if streaming else 'batch'} mode...") | |
| dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16 | |
| session_cache_path: Optional[str] = None | |
| with torch.no_grad(): | |
| with torch.amp.autocast(dtype=dtype, device_type=device): | |
| if streaming: | |
| # Use StreamSession for sequential processing | |
| if mode == "full": | |
| print("Warning: Streaming mode does not support 'full' attention mode. Switching to 'causal' mode.") | |
| mode = "causal" | |
| session = StreamSession(model, mode=mode) | |
| kv_cache_path = os.path.join(target_dir, "kv_cache.pt") | |
| if os.path.exists(kv_cache_path): | |
| print(f"Loading existing session cache from {kv_cache_path}") | |
| session.load_cache(kv_cache_path, device=images.device) | |
| existing_predictions = session.get_all_predictions() | |
| existing_frames = 0 | |
| for value in existing_predictions.values(): | |
| if isinstance(value, torch.Tensor) and value.dim() >= 2: | |
| existing_frames = max(existing_frames, value.shape[1]) | |
| total_frames = images.shape[0] | |
| if existing_frames > total_frames: | |
| raise ValueError( | |
| "Session cache contains more frames than available images. Please ensure the images folder " | |
| "matches the saved session state." | |
| ) | |
| if existing_frames == total_frames: | |
| print("No new frames detected; reusing cached predictions.") | |
| else: | |
| for i in range(existing_frames, total_frames): | |
| image = images[i : i + 1] | |
| session.forward_stream(image) | |
| predictions = session.get_all_predictions() | |
| session.save_cache(kv_cache_path) | |
| session_cache_path = kv_cache_path | |
| else: | |
| # Use batch processing (original behavior) | |
| predictions = model(images, mode=mode) | |
| # Convert pose encoding to extrinsic and intrinsic matrices | |
| print("Converting pose encoding to extrinsic and intrinsic matrices...") | |
| extrinsic, intrinsic = pose_encoding_to_extri_intri(predictions["pose_enc"], images.shape[-2:]) | |
| predictions["extrinsic"] = extrinsic | |
| predictions["intrinsic"] = intrinsic | |
| # Convert tensors to numpy | |
| for key in list(predictions.keys()): | |
| value = predictions[key] | |
| if isinstance(value, torch.Tensor): | |
| predictions[key] = value.cpu().numpy().squeeze(0) # remove batch dimension | |
| predictions.pop("pose_enc_list", None) | |
| # Generate world points from depth map | |
| print("Computing world points from depth map...") | |
| depth_map = predictions["depth"] # (S, H, W, 1) | |
| world_points = unproject_depth_map_to_point_map(depth_map, predictions["extrinsic"], predictions["intrinsic"]) | |
| predictions["world_points_from_depth"] = world_points | |
| # Clean up | |
| torch.cuda.empty_cache() | |
| return predictions, session_cache_path | |
| # ------------------------------------------------------------------------- | |
| # 2) Handle uploaded video/images --> produce target_dir + images | |
| # ------------------------------------------------------------------------- | |
| def handle_uploads(input_video, input_images, input_zip=None, session_state=None, current_target_dir: str | None = None): | |
| """ | |
| Create a new 'target_dir' + 'images' subfolder. | |
| - Copies uploaded images | |
| - Optionally extracts images from a ZIP | |
| - Optionally extracts frames from a video (1 fps) | |
| - Optionally loads a previously saved session archive | |
| Returns (target_dir, image_paths, session_loaded). | |
| """ | |
| start_time = time.time() | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| # Create a unique folder name | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f") | |
| session_loaded = False | |
| if session_state: | |
| session_path = session_state.get("name") if isinstance(session_state, dict) and "name" in session_state else getattr(session_state, "name", None) | |
| session_path = session_path or session_state | |
| extract_root = os.path.join("demo_cache", f"session_{timestamp}") | |
| target_dir = extract_session_state(session_path, extract_root) | |
| session_loaded = True | |
| elif current_target_dir and os.path.isdir(current_target_dir): | |
| target_dir = current_target_dir | |
| else: | |
| target_dir = os.path.join("demo_cache", f"input_images_{timestamp}") | |
| if os.path.exists(target_dir): | |
| shutil.rmtree(target_dir) | |
| os.makedirs(target_dir, exist_ok=True) | |
| target_dir_images = os.path.join(target_dir, "images") | |
| os.makedirs(target_dir_images, exist_ok=True) | |
| image_paths: list[str] = [] | |
| # --- Handle images (list) --- | |
| if input_images: | |
| for file_data in input_images: | |
| file_path = file_data["name"] if isinstance(file_data, dict) and "name" in file_data else file_data | |
| copied_path = _copy_with_unique_name(file_path, target_dir_images) | |
| image_paths.append(copied_path) | |
| # --- Handle ZIP (extract images) --- | |
| if input_zip: | |
| zip_path = input_zip["name"] if isinstance(input_zip, dict) and "name" in input_zip else input_zip | |
| extracted = extract_images_from_zip(zip_path, target_dir_images) | |
| image_paths.extend(extracted) | |
| # --- Handle video (extract frames at 1 fps) --- | |
| if input_video: | |
| video_path = input_video["name"] if isinstance(input_video, dict) and "name" in input_video else input_video | |
| vs = cv2.VideoCapture(video_path) | |
| fps = vs.get(cv2.CAP_PROP_FPS) or 30.0 | |
| frame_interval = max(1, int(fps * 1)) # 1 frame/sec | |
| count = 0 | |
| video_frame_num = len(os.listdir(target_dir_images)) | |
| while True: | |
| gotit, frame = vs.read() | |
| if not gotit: | |
| break | |
| count += 1 | |
| if count % frame_interval == 0: | |
| image_path = os.path.join(target_dir_images, f"{video_frame_num:06}.png") | |
| cv2.imwrite(image_path, frame) | |
| image_paths.append(image_path) | |
| video_frame_num += 1 | |
| vs.release() | |
| image_paths = sorted(set(image_paths)) # de-dupe + sort | |
| # Ensure gallery reflects existing files in the images directory | |
| existing_images = sorted(glob.glob(os.path.join(target_dir_images, "*"))) | |
| image_paths = existing_images | |
| end_time = time.time() | |
| print(f"Prepared {len(image_paths)} files in {target_dir_images}; took {end_time - start_time:.3f}s") | |
| return target_dir, image_paths, session_loaded | |
| # ------------------------------------------------------------------------- | |
| # 3) Update gallery on upload | |
| # ------------------------------------------------------------------------- | |
| def update_gallery_on_upload(input_video, input_images, input_zip, session_state, current_target_dir): | |
| """ | |
| Handle any new uploads (video, images, or zip) and render preview. | |
| """ | |
| if not input_video and not input_images and not input_zip and not session_state: | |
| default_updates = [ | |
| gr.update(), | |
| gr.update(value=""), | |
| gr.update(), | |
| gr.update(), | |
| gr.update(), | |
| gr.update(), | |
| gr.update(), | |
| gr.update(), | |
| gr.update(), | |
| gr.update(), | |
| gr.update(), | |
| ] | |
| return (gr.update(), current_target_dir, None, None, None, *default_updates) | |
| target_dir, image_paths, session_loaded = handle_uploads( | |
| input_video, | |
| input_images, | |
| input_zip, | |
| session_state=session_state, | |
| current_target_dir=current_target_dir, | |
| ) | |
| if session_loaded: | |
| message = "Session state loaded. Add new frames and click 'Reconstruct' to continue." | |
| else: | |
| message = "Upload complete. Click 'Reconstruct' to begin 3D processing." | |
| target_dir_images = os.path.join(target_dir, "images") | |
| frame_filter_choices = build_frame_filter_choices(target_dir_images) | |
| frame_value = "All" | |
| frame_update = gr.update(choices=frame_filter_choices, value=frame_value) | |
| streaming_update = gr.update() | |
| mode_update = gr.update() | |
| conf_update = gr.update() | |
| mask_black_update = gr.update() | |
| mask_white_update = gr.update() | |
| show_cam_update = gr.update() | |
| mask_sky_update = gr.update() | |
| prediction_mode_update = gr.update() | |
| reconstruction_update = gr.update() | |
| session_state_reset = gr.update() | |
| localization_result_update = gr.update(value="") | |
| if session_loaded: | |
| settings = load_session_settings(target_dir) | |
| if settings: | |
| if "frame_filter" in settings: | |
| potential_value = settings.get("frame_filter", "All") | |
| if potential_value in frame_filter_choices: | |
| frame_value = potential_value | |
| frame_update = gr.update(choices=frame_filter_choices, value=frame_value) | |
| if "streaming" in settings: | |
| streaming_update = gr.update(value=bool(settings.get("streaming", True))) | |
| if settings.get("mode") in {"causal", "window", "full"}: | |
| mode_update = gr.update(value=settings["mode"]) | |
| if "conf_thres" in settings: | |
| try: | |
| conf_update = gr.update(value=float(settings["conf_thres"])) | |
| except (TypeError, ValueError): | |
| pass | |
| if "mask_black_bg" in settings: | |
| mask_black_update = gr.update(value=bool(settings.get("mask_black_bg", False))) | |
| if "mask_white_bg" in settings: | |
| mask_white_update = gr.update(value=bool(settings.get("mask_white_bg", False))) | |
| if "show_cam" in settings: | |
| show_cam_update = gr.update(value=bool(settings.get("show_cam", True))) | |
| if "mask_sky" in settings: | |
| mask_sky_update = gr.update(value=bool(settings.get("mask_sky", False))) | |
| pred_mode_value = settings.get("prediction_mode") | |
| if pred_mode_value in {"Depthmap and Camera Branch", "Pointmap Branch"}: | |
| prediction_mode_update = gr.update(value=pred_mode_value) | |
| try: | |
| conf_val = settings["conf_thres"] | |
| mode_val = settings["mode"] | |
| pred_mode_val = settings["prediction_mode"] | |
| mask_black_val = bool(settings.get("mask_black_bg", False)) | |
| mask_white_val = bool(settings.get("mask_white_bg", False)) | |
| show_cam_val = bool(settings.get("show_cam", True)) | |
| mask_sky_val = bool(settings.get("mask_sky", False)) | |
| glb_candidate = os.path.join( | |
| target_dir, | |
| f"glbscene_{conf_val}_{sanitize_frame_filter_label(frame_value)}_maskb{mask_black_val}_maskw{mask_white_val}_cam{show_cam_val}_sky{mask_sky_val}_pred{pred_mode_val.replace(' ', '_')}_mode{mode_val}.glb", | |
| ) | |
| if os.path.exists(glb_candidate): | |
| reconstruction_update = glb_candidate | |
| except (KeyError, AttributeError): | |
| pass | |
| session_state_reset = gr.update(value=None) | |
| localization_result_update = gr.update(value="") | |
| return ( | |
| reconstruction_update, | |
| target_dir, | |
| image_paths, | |
| message, | |
| None, | |
| session_state_reset, | |
| localization_result_update, | |
| streaming_update, | |
| mode_update, | |
| conf_update, | |
| frame_update, | |
| mask_black_update, | |
| mask_white_update, | |
| show_cam_update, | |
| mask_sky_update, | |
| prediction_mode_update, | |
| ) | |
| def update_gallery_without_session(input_video, input_images, input_zip, current_target_dir): | |
| return update_gallery_on_upload(input_video, input_images, input_zip, None, current_target_dir) | |
| def _format_vector(vec: np.ndarray) -> str: | |
| return "[" + ", ".join(f"{v:.4f}" for v in vec.tolist()) + "]" | |
| def localize_new_image( | |
| target_dir, | |
| localization_image, | |
| mode, | |
| conf_thres, | |
| mask_black_bg, | |
| mask_white_bg, | |
| show_cam, | |
| mask_sky, | |
| prediction_mode, | |
| ): | |
| if not target_dir or target_dir == "None" or not os.path.isdir(target_dir): | |
| return ("No reconstruction available. Please upload and reconstruct first.", gr.update()) | |
| image_path = _resolve_path(localization_image) | |
| if not image_path or not os.path.exists(image_path): | |
| return ("Please select an image to localize.", gr.update()) | |
| kv_cache_path = os.path.join(target_dir, "kv_cache.pt") | |
| predictions_path = os.path.join(target_dir, "predictions.npz") | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| try: | |
| image_tensor = load_and_preprocess_images([image_path]).to(device) | |
| except Exception as exc: | |
| return (f"Failed to preprocess image: {exc}", gr.update()) | |
| model.eval() | |
| model.to(device) | |
| session = getattr(localize_new_image, "_session", None) | |
| if session is None or session.mode != mode: | |
| session = StreamSession(model, mode=mode) | |
| localize_new_image._session = session | |
| session.clear() | |
| if not os.path.exists(kv_cache_path): | |
| if device == "cuda": | |
| torch.cuda.empty_cache() | |
| return ("Streaming cache not found. Run reconstruction in streaming mode first.", gr.update()) | |
| if not os.path.exists(predictions_path): | |
| if device == "cuda": | |
| torch.cuda.empty_cache() | |
| return ("Predictions not found. Run reconstruction before localizing.", gr.update()) | |
| try: | |
| session.load_cache(kv_cache_path, device=image_tensor.device) | |
| existing_predictions = session.get_all_predictions() | |
| existing_frames = 0 | |
| for value in existing_predictions.values(): | |
| if isinstance(value, torch.Tensor) and value.dim() >= 2: | |
| existing_frames = max(existing_frames, value.shape[1]) | |
| with torch.no_grad(): | |
| session.forward_stream(image_tensor) | |
| localized_predictions = session.get_all_predictions() | |
| except Exception as exc: | |
| session.clear() | |
| if image_tensor.device.type == "cuda": | |
| torch.cuda.empty_cache() | |
| return (f"Localization failed: {exc}", gr.update()) | |
| def _extract_frame(tensor: torch.Tensor, index: int) -> np.ndarray: | |
| if tensor.dim() < 2: | |
| return tensor.detach().cpu().numpy() | |
| frame = tensor[:, index:index + 1] | |
| return frame.detach().cpu().numpy() | |
| new_index = existing_frames | |
| try: | |
| pose_slice = localized_predictions["pose_enc"][:, new_index:new_index + 1] | |
| except KeyError: | |
| return ("pose_enc not available in session predictions.", gr.update()) | |
| H, W = image_tensor.shape[-2:] | |
| extrinsics_t, intrinsics_t = pose_encoding_to_extri_intri( | |
| pose_slice, | |
| image_size_hw=(H, W), | |
| build_intrinsics=True, | |
| ) | |
| extrinsic = extrinsics_t[0, 0].cpu().numpy() | |
| intrinsic = intrinsics_t[0, 0].cpu().numpy() if intrinsics_t is not None else None | |
| rotation_matrix = extrinsic[:, :3] | |
| translation = extrinsic[:, 3] | |
| camera_position = -rotation_matrix.T @ translation | |
| try: | |
| rotation = Rotation.from_matrix(rotation_matrix) | |
| yaw_pitch_roll = rotation.as_euler("zyx", degrees=True) | |
| except ValueError: | |
| yaw_pitch_roll = np.zeros(3) | |
| world_points_tensor = localized_predictions.get("world_points") | |
| if world_points_tensor is None: | |
| world_points_tensor = localized_predictions.get("world_points_from_depth") | |
| world_points = _extract_frame(world_points_tensor, new_index)[0] | |
| world_conf_tensor = localized_predictions.get("world_points_conf") | |
| if world_conf_tensor is None: | |
| world_conf_tensor = localized_predictions.get("depth_conf") | |
| if world_conf_tensor is not None: | |
| confidence = _extract_frame(world_conf_tensor, new_index)[0] | |
| else: | |
| confidence = np.ones(world_points.shape[:-1], dtype=np.float32) | |
| flattened_points = world_points.reshape(-1, 3) | |
| flattened_conf = confidence.reshape(-1) | |
| if conf_thres is None or conf_thres == 0.0: | |
| conf_threshold = 0.0 | |
| else: | |
| conf_threshold = np.percentile(flattened_conf, conf_thres) | |
| visibility_mask = (flattened_conf >= conf_threshold) & (flattened_conf > 1e-5) | |
| visible_points = flattened_points[visibility_mask] | |
| if visible_points.size == 0: | |
| visible_points = flattened_points | |
| bbox_min = visible_points.min(axis=0) if visible_points.size > 0 else np.zeros(3) | |
| bbox_max = visible_points.max(axis=0) if visible_points.size > 0 else np.zeros(3) | |
| summary_lines = [ | |
| f"Localized image: {os.path.basename(image_path)}", | |
| f"Translation (t): {_format_vector(translation)}", | |
| f"Camera position (world): {_format_vector(camera_position)}", | |
| f"Orientation yaw/pitch/roll (deg): {_format_vector(yaw_pitch_roll)}", | |
| f"Visible points (>= percentile {conf_thres:.1f}%): {visible_points.shape[0]}", | |
| f"World-space bounding box min: {_format_vector(bbox_min)}", | |
| f"World-space bounding box max: {_format_vector(bbox_max)}", | |
| ] | |
| if intrinsic is not None: | |
| summary_lines.append( | |
| "Intrinsics matrix:\n" + "\n".join(" " + _format_vector(row) for row in intrinsic) | |
| ) | |
| try: | |
| try: | |
| stored = np.load(predictions_path) | |
| except ValueError: | |
| stored = np.load(predictions_path, allow_pickle=True) | |
| base_predictions = {key: stored[key] for key in stored.files} | |
| stored.close() | |
| except Exception as exc: | |
| return (f"Failed to load stored predictions: {exc}", gr.update()) | |
| localization_glb_path = os.path.join(target_dir, "localization_preview.glb") | |
| world_to_camera = np.eye(4) | |
| world_to_camera[:3, :4] = extrinsic | |
| try: | |
| glbscene = predictions_to_glb( | |
| base_predictions, | |
| conf_thres=conf_thres, | |
| filter_by_frames="All", | |
| mask_black_bg=mask_black_bg, | |
| mask_white_bg=mask_white_bg, | |
| show_cam=False, | |
| mask_sky=mask_sky, | |
| target_dir=target_dir, | |
| prediction_mode=prediction_mode, | |
| extra_cameras=[world_to_camera], | |
| extra_camera_color=(255, 0, 0), | |
| ) | |
| glbscene.export(file_obj=localization_glb_path) | |
| except Exception as exc: | |
| localization_glb_path = None | |
| summary_lines.append(f"Warning: failed to update GLB preview ({exc})") | |
| session.clear() | |
| if image_tensor.device.type == "cuda": | |
| torch.cuda.empty_cache() | |
| return ("\n".join(summary_lines), localization_glb_path if localization_glb_path else gr.update()) | |
| # ------------------------------------------------------------------------- | |
| # 4) Reconstruction: uses the target_dir plus any viz parameters | |
| # ------------------------------------------------------------------------- | |
| def gradio_demo( | |
| target_dir, | |
| conf_thres=3.0, | |
| frame_filter="All", | |
| mask_black_bg=False, | |
| mask_white_bg=False, | |
| show_cam=True, | |
| mask_sky=False, | |
| prediction_mode="Pointmap Regression", | |
| mode="causal", | |
| streaming=False, | |
| ): | |
| """ | |
| Perform reconstruction using the already-created target_dir/images. | |
| """ | |
| if not os.path.isdir(target_dir) or target_dir == "None": | |
| return None, "No valid target directory found. Please upload first.", None, None | |
| start_time = time.time() | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| # Prepare frame_filter dropdown | |
| target_dir_images = os.path.join(target_dir, "images") | |
| frame_filter_choices = build_frame_filter_choices(target_dir_images) | |
| config = load_stream_session_config() | |
| top_k_frames = int(config.get("top_k_frames", 0) or 0) | |
| print("Running run_model...") | |
| with torch.no_grad(): | |
| predictions, session_cache_path = run_model(target_dir, model, mode=mode, streaming=streaming) | |
| # Save predictions | |
| prediction_save_path = os.path.join(target_dir, "predictions.npz") | |
| np.savez(prediction_save_path, **predictions) | |
| selected_frames = select_top_k_frames(predictions, target_dir_images, top_k_frames) | |
| selected_frames_path = os.path.join(target_dir, "selected_frames.json") | |
| if selected_frames: | |
| try: | |
| with open(selected_frames_path, "w", encoding="utf-8") as handle: | |
| json.dump({"top_k": top_k_frames, "frames": selected_frames}, handle, indent=2) | |
| except OSError as exc: | |
| print(f"Failed to write selected frames: {exc}") | |
| elif os.path.exists(selected_frames_path): | |
| try: | |
| os.remove(selected_frames_path) | |
| except OSError: | |
| pass | |
| frame_filter_value = frame_filter if frame_filter is not None else "All" | |
| session_settings = { | |
| "streaming": bool(streaming), | |
| "mode": mode, | |
| "conf_thres": float(conf_thres), | |
| "frame_filter": frame_filter_value, | |
| "mask_black_bg": bool(mask_black_bg), | |
| "mask_white_bg": bool(mask_white_bg), | |
| "show_cam": bool(show_cam), | |
| "mask_sky": bool(mask_sky), | |
| "prediction_mode": prediction_mode, | |
| } | |
| session_settings["top_k_frames"] = top_k_frames | |
| if selected_frames: | |
| session_settings["selected_frames"] = [frame["filename"] for frame in selected_frames] | |
| try: | |
| with open(os.path.join(target_dir, "session_settings.json"), "w", encoding="utf-8") as handle: | |
| json.dump(session_settings, handle, indent=2) | |
| except OSError as exc: | |
| print(f"Failed to write session settings: {exc}") | |
| session_state_file = None | |
| if streaming: | |
| if session_cache_path is None: | |
| session_cache_path = os.path.join(target_dir, "kv_cache.pt") | |
| if os.path.exists(session_cache_path): | |
| session_state_file = package_session_state(target_dir) | |
| # Handle None frame_filter | |
| if frame_filter is None: | |
| frame_filter = "All" | |
| # Build a GLB file name | |
| glbfile = os.path.join( | |
| target_dir, | |
| f"glbscene_{conf_thres}_{sanitize_frame_filter_label(frame_filter)}_maskb{mask_black_bg}_maskw{mask_white_bg}_cam{show_cam}_sky{mask_sky}_pred{prediction_mode.replace(' ', '_')}_mode{mode}.glb", | |
| ) | |
| # Convert predictions to GLB | |
| glbscene = predictions_to_glb( | |
| predictions, | |
| conf_thres=conf_thres, | |
| filter_by_frames=frame_filter, | |
| mask_black_bg=mask_black_bg, | |
| mask_white_bg=mask_white_bg, | |
| show_cam=show_cam, | |
| mask_sky=mask_sky, | |
| target_dir=target_dir, | |
| prediction_mode=prediction_mode, | |
| ) | |
| glbscene.export(file_obj=glbfile) | |
| # Cleanup | |
| del predictions | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| end_time = time.time() | |
| print(f"Total time: {end_time - start_time:.2f} seconds (including IO)") | |
| num_frames = max(0, len(frame_filter_choices) - 1) | |
| log_msg = f"Reconstruction Success ({num_frames} frames). Waiting for visualization." | |
| return ( | |
| glbfile, | |
| log_msg, | |
| gr.Dropdown(choices=frame_filter_choices, value=frame_filter, interactive=True), | |
| session_state_file, | |
| ) | |
| # ------------------------------------------------------------------------- | |
| # 5) Helper functions for UI resets + re-visualization | |
| # ------------------------------------------------------------------------- | |
| def clear_fields(): | |
| """ | |
| Clears the 3D viewer, the stored target_dir, and empties the gallery. | |
| """ | |
| return None | |
| def update_log(): | |
| """ | |
| Display a quick log message while waiting. | |
| """ | |
| return "Loading and Reconstructing..." | |
| def update_visualization( | |
| target_dir, | |
| conf_thres, | |
| frame_filter, | |
| mask_black_bg, | |
| mask_white_bg, | |
| show_cam, | |
| mask_sky, | |
| prediction_mode, | |
| mode_value, | |
| is_example, | |
| ): | |
| """ | |
| Reload saved predictions from npz, create (or reuse) the GLB for new parameters, | |
| and return it for the 3D viewer. If is_example == "True", skip. | |
| """ | |
| # If it's an example click, skip as requested | |
| if is_example == "True": | |
| return None, "No reconstruction available. Please click the Reconstruct button first." | |
| if not target_dir or target_dir == "None" or not os.path.isdir(target_dir): | |
| return None, "No reconstruction available. Please click the Reconstruct button first." | |
| predictions_path = os.path.join(target_dir, "predictions.npz") | |
| if not os.path.exists(predictions_path): | |
| return None, f"No reconstruction available at {predictions_path}. Please run 'Reconstruct' first." | |
| key_list = [ | |
| "pose_enc", | |
| "depth", | |
| "depth_conf", | |
| "world_points", | |
| "world_points_conf", | |
| "images", | |
| "extrinsic", | |
| "intrinsic", | |
| "world_points_from_depth", | |
| ] | |
| loaded = np.load(predictions_path) | |
| predictions = {key: np.array(loaded[key]) for key in key_list} | |
| sanitized_frame = sanitize_frame_filter_label(frame_filter) | |
| glbfile = os.path.join( | |
| target_dir, | |
| f"glbscene_{conf_thres}_{sanitized_frame}_maskb{mask_black_bg}_maskw{mask_white_bg}_cam{show_cam}_sky{mask_sky}_pred{prediction_mode.replace(' ', '_')}_mode{mode_value}.glb", | |
| ) | |
| if not os.path.exists(glbfile): | |
| glbscene = predictions_to_glb( | |
| predictions, | |
| conf_thres=conf_thres, | |
| filter_by_frames=frame_filter, | |
| mask_black_bg=mask_black_bg, | |
| mask_white_bg=mask_white_bg, | |
| show_cam=show_cam, | |
| mask_sky=mask_sky, | |
| target_dir=target_dir, | |
| prediction_mode=prediction_mode, | |
| ) | |
| glbscene.export(file_obj=glbfile) | |
| return glbfile, "Updating Visualization" | |
| # ------------------------------------------------------------------------- | |
| # Example images | |
| # ------------------------------------------------------------------------- | |
| great_wall_video = "examples/videos/great_wall.mp4" | |
| colosseum_video = "examples/videos/Colosseum.mp4" | |
| room_video = "examples/videos/room.mp4" | |
| kitchen_video = "examples/videos/kitchen.mp4" | |
| fern_video = "examples/videos/fern.mp4" | |
| single_cartoon_video = "examples/videos/single_cartoon.mp4" | |
| single_oil_painting_video = "examples/videos/single_oil_painting.mp4" | |
| pyramid_video = "examples/videos/pyramid.mp4" | |
| # ------------------------------------------------------------------------- | |
| # 6) Build Gradio UI | |
| # ------------------------------------------------------------------------- | |
| theme = gr.themes.Ocean() | |
| theme.set( | |
| checkbox_label_background_fill_selected="*button_primary_background_fill", | |
| checkbox_label_text_color_selected="*button_primary_text_color", | |
| ) | |
| with gr.Blocks( | |
| theme=theme, | |
| css=""" | |
| .custom-log * { | |
| font-style: italic; | |
| font-size: 22px !important; | |
| background-image: linear-gradient(120deg, #0ea5e9 0%, #6ee7b7 60%, #34d399 100%); | |
| -webkit-background-clip: text; | |
| background-clip: text; | |
| font-weight: bold !important; | |
| color: transparent !important; | |
| text-align: center !important; | |
| } | |
| .example-log * { | |
| font-style: italic; | |
| font-size: 16px !important; | |
| background-image: linear-gradient(120deg, #0ea5e9 0%, #6ee7b7 60%, #34d399 100%); | |
| -webkit-background-clip: text; | |
| background-clip: text; | |
| color: transparent !important; | |
| } | |
| #my_radio .wrap { | |
| display: flex; | |
| flex-wrap: nowrap; | |
| justify-content: center; | |
| align-items: center; | |
| } | |
| #my_radio .wrap label { | |
| display: flex; | |
| width: 50%; | |
| justify-content: center; | |
| align-items: center; | |
| margin: 0; | |
| padding: 10px 0; | |
| box-sizing: border-box; | |
| } | |
| """, | |
| ) as demo: | |
| # Instead of gr.State, we use a hidden Textbox: | |
| is_example = gr.Textbox(label="is_example", visible=False, value="None") | |
| num_images = gr.Textbox(label="num_images", visible=False, value="None") | |
| example_preview = gr.Image(label="Example Preview", visible=False) | |
| gr.HTML( | |
| """ | |
| <h1>🌅 STream3R: Scalable Sequential 3D Reconstruction with Causal Transformer</h1> | |
| <p> | |
| <a href="https://github.com/NIRVANALAN/STream3R">GitHub Repository</a> | | |
| <a href="https://nirvanalan.github.io/projects/stream3r">Project Page</a> | | |
| <a href="https://arxiv.org/abs/2508.10893">Paper</a> | |
| </p> | |
| <blockquote> | |
| Special thanks to VGG-T for their visualization demo, which this demo is built upon! | |
| </blockquote> | |
| <div style="font-size: 16px; line-height: 1.5;"> | |
| <p>Upload a video or a set of images to create a 3D reconstruction of a scene or object. STream3R takes these images and generates a 3D point cloud, along with estimated camera poses.</p> | |
| <h3>Getting Started:</h3> | |
| <ol> | |
| <li><strong>Upload Your Data:</strong> Use the "Upload Video" or "Upload Images" buttons on the left to provide your input. Videos will be automatically split into individual frames (one frame per second).</li> | |
| <li><strong>Preview:</strong> Your uploaded images will appear in the gallery on the left.</li> | |
| <li><strong>Reconstruct:</strong> Click the "Reconstruct" button to start the 3D reconstruction process.</li> | |
| <li><strong>Visualize:</strong> The 3D reconstruction will appear in the viewer on the right. You can rotate, pan, and zoom to explore the model, and download the GLB file. Note the visualization of 3D points may be slow for a large number of input images.</li> | |
| <li> | |
| <strong>Adjust Visualization (Optional):</strong> | |
| After reconstruction, you can fine-tune the visualization using the options below | |
| <details style="display:inline;"> | |
| <summary style="display:inline;">(<strong>click to expand</strong>):</summary> | |
| <ul> | |
| <li><em>Confidence Threshold:</em> Adjust the filtering of points based on confidence.</li> | |
| <li><em>Show Points from Frame:</em> Select specific frames to display in the point cloud.</li> | |
| <li><em>Show Camera:</em> Toggle the display of estimated camera positions.</li> | |
| <li><em>Filter Sky / Filter Black Background:</em> Remove sky or black-background points.</li> | |
| <li><em>Select a Prediction Mode:</em> Choose between "Depthmap and Camera Branch" or "Pointmap Branch."</li> | |
| </ul> | |
| </details> | |
| </li> | |
| </ol> | |
| <p><strong style="color: #0ea5e9;">Please note:</strong> <span style="color: #0ea5e9; font-weight: bold;">STream3R typically reconstructs a scene in less than 1 second. However, visualizing 3D points may take tens of seconds due to third-party rendering, which are independent of STream3R's processing time. </span></p> | |
| </div> | |
| """ | |
| ) | |
| target_dir_output = gr.Textbox(label="Target Dir", visible=False, value="None") | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| input_video = gr.Video(label="Upload Video", interactive=True) | |
| input_images = gr.File(file_count="multiple", label="Upload Images", interactive=True) | |
| input_zip = gr.File(file_types=[".zip"], label="Upload ZIP of Images", interactive=True) | |
| session_state_input = gr.File(file_types=[".zip"], label="Load Session State", interactive=True) | |
| image_gallery = gr.Gallery( | |
| label="Preview", | |
| columns=4, | |
| height="300px", | |
| show_download_button=True, | |
| object_fit="contain", | |
| preview=True, | |
| ) | |
| with gr.Column(scale=4): | |
| with gr.Column(): | |
| gr.Markdown("**3D Reconstruction (Point Cloud and Camera Poses)**") | |
| log_output = gr.Markdown( | |
| "Please upload a video or images, then click Reconstruct.", elem_classes=["custom-log"] | |
| ) | |
| reconstruction_output = gr.Model3D(height=520, zoom_speed=0.5, pan_speed=0.5) | |
| session_state_output = gr.File(label="Download Session State", interactive=False) | |
| localization_output = gr.Textbox(label="Localization Result", lines=8, interactive=False) | |
| with gr.Row(): | |
| localization_image_input = gr.File( | |
| label="Localize Single Image", | |
| file_types=[".png", ".jpg", ".jpeg", ".bmp", ".webp"], | |
| interactive=True, | |
| ) | |
| localize_button = gr.Button("Localize Image", variant="secondary") | |
| with gr.Row(): | |
| submit_btn = gr.Button("Reconstruct", scale=1, variant="primary") | |
| clear_btn = gr.ClearButton( | |
| [ | |
| input_video, | |
| input_images, | |
| input_zip, | |
| session_state_input, | |
| localization_image_input, | |
| reconstruction_output, | |
| log_output, | |
| target_dir_output, | |
| image_gallery, | |
| session_state_output, | |
| localization_output, | |
| ], | |
| scale=1, | |
| ) | |
| with gr.Row(): | |
| prediction_mode = gr.Radio( | |
| ["Depthmap and Camera Branch", "Pointmap Branch"], | |
| label="Select a Prediction Mode", | |
| value="Depthmap and Camera Branch", | |
| scale=1, | |
| elem_id="my_radio", | |
| ) | |
| with gr.Row(): | |
| streaming = gr.Radio( | |
| [('stream', True), ('batch', False)], | |
| label="Streaming or Batch Mode", | |
| value=True, | |
| scale=1, | |
| ) | |
| with gr.Row(): | |
| mode = gr.Radio( | |
| ["causal", "window", "full"], | |
| label="Select Processing Mode", | |
| value="window", | |
| scale=1, | |
| ) | |
| with gr.Row(): | |
| conf_thres = gr.Slider(minimum=0, maximum=100, value=50, step=0.1, label="Confidence Threshold (%)") | |
| frame_filter = gr.Dropdown(choices=["All"], value="All", label="Show Points from Frame") | |
| with gr.Column(): | |
| show_cam = gr.Checkbox(label="Show Camera", value=True) | |
| mask_sky = gr.Checkbox(label="Filter Sky", value=False) | |
| mask_black_bg = gr.Checkbox(label="Filter Black Background", value=False) | |
| mask_white_bg = gr.Checkbox(label="Filter White Background", value=False) | |
| # ---------------------- Examples section ---------------------- | |
| def build_examples_from_folder(): | |
| examples_root = "examples" | |
| entries = [] | |
| if not os.path.isdir(examples_root): | |
| return entries | |
| candidate_dirs = sorted( | |
| [ | |
| os.path.join(examples_root, d) | |
| for d in os.listdir(examples_root) | |
| if os.path.isdir(os.path.join(examples_root, d)) | |
| ], reverse=True | |
| ) | |
| if not candidate_dirs: | |
| candidate_dirs = [examples_root] | |
| for example_dir in candidate_dirs: | |
| image_files = [] | |
| for pattern in ("*.png", "*.jpg", "*.jpeg", "*.bmp", "*.webp"): | |
| image_files.extend(sorted(glob.glob(os.path.join(example_dir, pattern)))) | |
| if not image_files: | |
| continue | |
| preview_image = image_files[0] | |
| num_images_str = str(len(image_files)) | |
| entries.append( | |
| [ | |
| preview_image, # preview image (for visualization only) | |
| None, # input_video (unused for examples) | |
| num_images_str, | |
| image_files, # input_images | |
| 15.0, # conf_thres | |
| False, # mask_black_bg | |
| False, # mask_white_bg | |
| True, # show_cam | |
| False, # mask_sky | |
| "Depthmap and Camera Branch", # prediction_mode | |
| "True", # is_example | |
| "causal", # mode | |
| ] | |
| ) | |
| return entries[:2] | |
| examples = build_examples_from_folder() | |
| def example_pipeline( | |
| preview_image, | |
| input_video, | |
| num_images_str, | |
| input_images, | |
| conf_thres, | |
| mask_black_bg, | |
| mask_white_bg, | |
| show_cam, | |
| mask_sky, | |
| prediction_mode, | |
| is_example_str, | |
| mode="causal", | |
| ): | |
| """ | |
| 1) Copy example images to new target_dir | |
| 2) Reconstruct | |
| 3) Return model3D + logs + new_dir + updated dropdown + gallery | |
| We do NOT return is_example. It's just an input. | |
| """ | |
| target_dir, image_paths, _ = handle_uploads(input_video, input_images) | |
| # Always use "All" for frame_filter in examples | |
| frame_filter = "All" | |
| glbfile, log_msg, dropdown, session_file = gradio_demo( | |
| target_dir, | |
| conf_thres, | |
| frame_filter, | |
| mask_black_bg, | |
| mask_white_bg, | |
| show_cam, | |
| mask_sky, | |
| prediction_mode, | |
| mode, | |
| False, | |
| ) | |
| return ( | |
| glbfile, | |
| log_msg, | |
| target_dir, | |
| dropdown, | |
| image_paths, | |
| session_file, | |
| gr.update(value=None), | |
| gr.update(value=""), | |
| False, | |
| mode, | |
| conf_thres, | |
| mask_black_bg, | |
| mask_white_bg, | |
| show_cam, | |
| mask_sky, | |
| prediction_mode, | |
| ) | |
| gr.Markdown("Click any row to load an example.", elem_classes=["example-log"]) | |
| gr.Examples( | |
| examples=examples, | |
| inputs=[ | |
| example_preview, | |
| input_video, | |
| num_images, | |
| input_images, | |
| conf_thres, | |
| mask_black_bg, | |
| mask_white_bg, | |
| show_cam, | |
| mask_sky, | |
| prediction_mode, | |
| is_example, | |
| mode, | |
| ], | |
| outputs=[ | |
| reconstruction_output, | |
| log_output, | |
| target_dir_output, | |
| frame_filter, | |
| image_gallery, | |
| session_state_output, | |
| session_state_input, | |
| localization_output, | |
| streaming, | |
| mode, | |
| conf_thres, | |
| mask_black_bg, | |
| mask_white_bg, | |
| show_cam, | |
| mask_sky, | |
| prediction_mode, | |
| ], | |
| fn=example_pipeline, | |
| cache_examples=False, | |
| examples_per_page=50, | |
| ) | |
| # ------------------------------------------------------------------------- | |
| # "Reconstruct" button logic: | |
| # - Clear fields | |
| # - Update log | |
| # - gradio_demo(...) with the existing target_dir | |
| # - Then set is_example = "False" | |
| # ------------------------------------------------------------------------- | |
| submit_btn.click(fn=clear_fields, inputs=[], outputs=[reconstruction_output]).then( | |
| fn=update_log, inputs=[], outputs=[log_output] | |
| ).then( | |
| fn=gradio_demo, | |
| inputs=[ | |
| target_dir_output, | |
| conf_thres, | |
| frame_filter, | |
| mask_black_bg, | |
| mask_white_bg, | |
| show_cam, | |
| mask_sky, | |
| prediction_mode, | |
| mode, | |
| streaming, | |
| ], | |
| outputs=[reconstruction_output, log_output, frame_filter, session_state_output], | |
| ).then( | |
| fn=lambda: "False", inputs=[], outputs=[is_example] # set is_example to "False" | |
| ) | |
| # ------------------------------------------------------------------------- | |
| # Real-time Visualization Updates | |
| # ------------------------------------------------------------------------- | |
| conf_thres.change( | |
| update_visualization, | |
| [ | |
| target_dir_output, | |
| conf_thres, | |
| frame_filter, | |
| mask_black_bg, | |
| mask_white_bg, | |
| show_cam, | |
| mask_sky, | |
| prediction_mode, | |
| mode, | |
| is_example, | |
| ], | |
| [reconstruction_output, log_output], | |
| ) | |
| frame_filter.change( | |
| update_visualization, | |
| [ | |
| target_dir_output, | |
| conf_thres, | |
| frame_filter, | |
| mask_black_bg, | |
| mask_white_bg, | |
| show_cam, | |
| mask_sky, | |
| prediction_mode, | |
| mode, | |
| is_example, | |
| ], | |
| [reconstruction_output, log_output], | |
| ) | |
| mask_black_bg.change( | |
| update_visualization, | |
| [ | |
| target_dir_output, | |
| conf_thres, | |
| frame_filter, | |
| mask_black_bg, | |
| mask_white_bg, | |
| show_cam, | |
| mask_sky, | |
| prediction_mode, | |
| mode, | |
| is_example, | |
| ], | |
| [reconstruction_output, log_output], | |
| ) | |
| mask_white_bg.change( | |
| update_visualization, | |
| [ | |
| target_dir_output, | |
| conf_thres, | |
| frame_filter, | |
| mask_black_bg, | |
| mask_white_bg, | |
| show_cam, | |
| mask_sky, | |
| prediction_mode, | |
| mode, | |
| is_example, | |
| ], | |
| [reconstruction_output, log_output], | |
| ) | |
| show_cam.change( | |
| update_visualization, | |
| [ | |
| target_dir_output, | |
| conf_thres, | |
| frame_filter, | |
| mask_black_bg, | |
| mask_white_bg, | |
| show_cam, | |
| mask_sky, | |
| prediction_mode, | |
| mode, | |
| is_example, | |
| ], | |
| [reconstruction_output, log_output], | |
| ) | |
| mask_sky.change( | |
| update_visualization, | |
| [ | |
| target_dir_output, | |
| conf_thres, | |
| frame_filter, | |
| mask_black_bg, | |
| mask_white_bg, | |
| show_cam, | |
| mask_sky, | |
| prediction_mode, | |
| mode, | |
| is_example, | |
| ], | |
| [reconstruction_output, log_output], | |
| ) | |
| prediction_mode.change( | |
| update_visualization, | |
| [ | |
| target_dir_output, | |
| conf_thres, | |
| frame_filter, | |
| mask_black_bg, | |
| mask_white_bg, | |
| show_cam, | |
| mask_sky, | |
| prediction_mode, | |
| mode, | |
| is_example, | |
| ], | |
| [reconstruction_output, log_output], | |
| ) | |
| localize_button.click( | |
| localize_new_image, | |
| [ | |
| target_dir_output, | |
| localization_image_input, | |
| mode, | |
| conf_thres, | |
| mask_black_bg, | |
| mask_white_bg, | |
| show_cam, | |
| mask_sky, | |
| prediction_mode, | |
| ], | |
| [localization_output, reconstruction_output], | |
| ) | |
| # ------------------------------------------------------------------------- | |
| # Auto-update gallery whenever user uploads or changes their files | |
| # ------------------------------------------------------------------------- | |
| upload_outputs = [ | |
| reconstruction_output, | |
| target_dir_output, | |
| image_gallery, | |
| log_output, | |
| session_state_output, | |
| session_state_input, | |
| localization_output, | |
| streaming, | |
| mode, | |
| conf_thres, | |
| frame_filter, | |
| mask_black_bg, | |
| mask_white_bg, | |
| show_cam, | |
| mask_sky, | |
| prediction_mode, | |
| ] | |
| no_session_inputs = [input_video, input_images, input_zip, target_dir_output] | |
| input_video.change(fn=update_gallery_without_session, inputs=no_session_inputs, outputs=upload_outputs) | |
| input_images.change(fn=update_gallery_without_session, inputs=no_session_inputs, outputs=upload_outputs) | |
| input_zip.change(fn=update_gallery_without_session, inputs=no_session_inputs, outputs=upload_outputs) | |
| session_state_input.change( | |
| fn=update_gallery_on_upload, | |
| inputs=[input_video, input_images, input_zip, session_state_input, target_dir_output], | |
| outputs=upload_outputs, | |
| ) | |
| demo.queue(max_size=20).launch(show_error=True, share=False) | |