# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import json import os import cv2 import torch import numpy as np from scipy.spatial.transform import Rotation import gradio as gr import shutil from datetime import datetime import glob import gc import time import zipfile import functools from typing import Any, Dict, Optional from stream3r.models.stream3r import STream3R from stream3r.stream_session import StreamSession from stream3r.models.components.utils.load_fn import load_and_preprocess_images from stream3r.models.components.utils.pose_enc import pose_encoding_to_extri_intri from stream3r.models.components.utils.geometry import unproject_depth_map_to_point_map from stream3r.utils.visual_utils import predictions_to_glb import warnings warnings.filterwarnings("ignore", message="Can't initialize NVML") import spaces device = "cuda" model = STream3R.from_pretrained("yslan/STream3R") def handle_zip(zip_file): outdir = "uploads" os.makedirs(outdir, exist_ok=True) with zipfile.ZipFile(zip_file.name, "r") as zf: zf.extractall(outdir) return f"Extracted {len(os.listdir(outdir))} files" # --- add near your imports/helpers --- ALLOWED_IMG_EXT = {".png", ".jpg", ".jpeg", ".bmp", ".webp"} def _is_within_dir(base_dir: str, path: str) -> bool: # Prevent zip-slip: ensure extracted path stays inside base_dir base_dir = os.path.abspath(base_dir) path = os.path.abspath(path) return os.path.commonpath([base_dir]) == os.path.commonpath([base_dir, path]) def extract_images_from_zip(zip_path: str, outdir: str) -> list[str]: """ Extracts only image files from a zip into outdir. Returns list of extracted file paths. """ os.makedirs(outdir, exist_ok=True) extracted = [] with zipfile.ZipFile(zip_path, "r") as zf: for member in zf.infolist(): # Skip directories and non-image files name = member.filename if name.endswith("/"): continue ext = os.path.splitext(name)[1].lower() if ext not in ALLOWED_IMG_EXT: continue # Construct final path safely base_name = os.path.basename(name) name_root, name_ext = os.path.splitext(base_name) dest_path = os.path.join(outdir, base_name) counter = 1 while os.path.exists(dest_path): dest_path = os.path.join(outdir, f"{name_root}_{counter}{name_ext}") counter += 1 # Zip-slip guard (in case filename has ../ etc.) if not _is_within_dir(outdir, dest_path): continue with zf.open(member) as src, open(dest_path, "wb") as dst: shutil.copyfileobj(src, dst) extracted.append(dest_path) return extracted def extract_session_state(zip_path: str, extract_root: str) -> str: """Extract a previously saved session archive into *extract_root*. Returns the directory that contains the restored session data. """ if os.path.exists(extract_root): shutil.rmtree(extract_root) os.makedirs(extract_root, exist_ok=True) with zipfile.ZipFile(zip_path, "r") as zf: zf.extractall(extract_root) entries = [os.path.join(extract_root, entry) for entry in os.listdir(extract_root)] dirs = [entry for entry in entries if os.path.isdir(entry)] files = [entry for entry in entries if os.path.isfile(entry)] if len(dirs) == 1 and not files: return dirs[0] return extract_root def package_session_state(target_dir: str) -> str: """Create a zip archive containing the entire session directory.""" if not os.path.isdir(target_dir): raise ValueError(f"Target directory does not exist: {target_dir}") os.makedirs("demo_cache", exist_ok=True) archive_name = f"{os.path.basename(os.path.normpath(target_dir))}_session.zip" archive_path = os.path.join("demo_cache", archive_name) if os.path.exists(archive_path): os.remove(archive_path) with zipfile.ZipFile(archive_path, "w", compression=zipfile.ZIP_DEFLATED) as zf: for root, _, files in os.walk(target_dir): for fname in files: file_path = os.path.join(root, fname) if os.path.abspath(file_path) == os.path.abspath(archive_path): continue arcname = os.path.join(os.path.basename(target_dir), os.path.relpath(file_path, target_dir)) zf.write(file_path, arcname) return archive_path def _copy_with_unique_name(src_path: str, dst_dir: str) -> str: """Copy *src_path* into *dst_dir*, avoiding filename collisions.""" base_name = os.path.basename(src_path) name, ext = os.path.splitext(base_name) candidate = base_name counter = 1 dest_path = os.path.join(dst_dir, candidate) while os.path.exists(dest_path): candidate = f"{name}_{counter}{ext}" dest_path = os.path.join(dst_dir, candidate) counter += 1 shutil.copy(src_path, dest_path) return dest_path def _resolve_path(file_data) -> Optional[str]: if file_data is None: return None if isinstance(file_data, dict) and "name" in file_data: return file_data["name"] if hasattr(file_data, "name"): return getattr(file_data, "name") return str(file_data) STREAM_SESSION_CONFIG_PATH = os.path.join(os.path.dirname(__file__), "configs", "stream_session.json") @functools.lru_cache(maxsize=1) def load_stream_session_config() -> Dict[str, Any]: try: with open(STREAM_SESSION_CONFIG_PATH, "r", encoding="utf-8") as handle: data = json.load(handle) if isinstance(data, dict): return data except (OSError, json.JSONDecodeError): pass return {} def load_session_settings(target_dir: str) -> Dict[str, Any]: settings_path = os.path.join(target_dir, "session_settings.json") if not os.path.exists(settings_path): return {} try: with open(settings_path, "r", encoding="utf-8") as handle: data = json.load(handle) if isinstance(data, dict): return data except (json.JSONDecodeError, OSError) as exc: print(f"Failed to load session settings from {settings_path}: {exc}") return {} def build_frame_filter_choices(target_dir_images: str) -> list[str]: if not os.path.isdir(target_dir_images): return ["All"] files = sorted(os.listdir(target_dir_images)) return ["All"] + [f"{idx}: {name}" for idx, name in enumerate(files)] def sanitize_frame_filter_label(label: Optional[str]) -> str: if not label: return "All" return label.replace('.', '_').replace(':', '').replace(' ', '_') def select_top_k_frames(predictions: Dict[str, np.ndarray], images_dir: str, top_k: int) -> list[Dict[str, Any]]: if top_k <= 0: return [] if not os.path.isdir(images_dir): return [] image_files = sorted( [fname for fname in os.listdir(images_dir) if not fname.startswith('.')] ) extrinsics = predictions.get("extrinsic") if extrinsics is None: return [] num_frames = extrinsics.shape[0] if num_frames == 0: return [] top_k = min(top_k, num_frames) def _camera_position(extr: np.ndarray) -> np.ndarray: R = extr[:, :3] t = extr[:, 3] return (-R.T @ t).astype(np.float64) positions = np.array([_camera_position(extrinsics[i]) for i in range(num_frames)]) forward_vectors = np.array([extrinsics[i][2, :3] for i in range(num_frames)]) forward_norms = np.linalg.norm(forward_vectors, axis=1, keepdims=True) forward_vectors = np.divide(forward_vectors, forward_norms, out=np.zeros_like(forward_vectors), where=forward_norms > 0) conf_tensor = predictions.get("world_points_conf") if conf_tensor is None: conf_tensor = predictions.get("depth_conf") quality_scores = np.zeros(num_frames, dtype=np.float64) coverage_scores = np.zeros(num_frames, dtype=np.float64) for idx in range(num_frames): if conf_tensor is not None: conf = conf_tensor[idx].reshape(-1) if conf.size: conf = conf[~np.isnan(conf)] if conf.size: quality_scores[idx] = float(np.mean(conf)) high_thresh = np.percentile(conf, 75) coverage_scores[idx] = float(np.mean(conf >= high_thresh)) continue quality_scores[idx] = 0.0 coverage_scores[idx] = 0.0 else: quality_scores[idx] = 1.0 coverage_scores[idx] = 1.0 max_cov = coverage_scores.max() if max_cov > 0: coverage_scores = coverage_scores / max_cov else: coverage_scores = np.ones_like(coverage_scores) base_scores = quality_scores * (0.5 + 0.5 * coverage_scores) indices = list(range(num_frames)) indices.sort(key=lambda idx: base_scores[idx], reverse=True) bbox_min = positions.min(axis=0) bbox_max = positions.max(axis=0) scene_scale = float(np.linalg.norm(bbox_max - bbox_min)) pos_threshold = max(0.1, 0.1 * scene_scale) ori_threshold = 15.0 selected = [] for idx in indices: if not selected: selected.append(idx) else: accept = False min_dist = min(np.linalg.norm(positions[idx] - positions[j]) for j in selected) max_angle = max( np.degrees( np.arccos( np.clip(np.dot(forward_vectors[idx], forward_vectors[j]), -1.0, 1.0) ) ) for j in selected ) if min_dist >= pos_threshold or max_angle >= ori_threshold: accept = True elif len(selected) < max(1, top_k // 3): accept = True if accept: selected.append(idx) if len(selected) >= top_k: break if len(selected) < top_k: for idx in indices: if idx not in selected: selected.append(idx) if len(selected) >= top_k: break selected = sorted(selected[:top_k]) records = [] for idx in selected: filename = image_files[idx] if idx < len(image_files) else f"frame_{idx:06d}" records.append( { "index": int(idx), "filename": filename, "score": float(base_scores[idx]), "mean_confidence": float(quality_scores[idx]), "coverage_ratio": float(coverage_scores[idx]), } ) return records # ------------------------------------------------------------------------- # 1) Core model inference # ------------------------------------------------------------------------- @spaces.GPU(duration=180) # triggers ZeroGPU allocation for this call def run_model(target_dir: str, model: STream3R, mode: str="causal", streaming: bool=False) -> tuple[dict, str | None]: """ Run the STream3R model on images in the 'target_dir/images' folder. Args: target_dir: Directory containing the images subfolder model: STream3R model instance mode: Processing mode ("causal", "window", or "full") streaming: If True, use StreamSession for sequential processing; if False, use batch processing Returns: tuple[dict, str | None]: Predictions dictionary and optional path to the saved session cache when streaming mode is used. """ print(f"Processing images from {target_dir}") # Device check device = "cuda" if torch.cuda.is_available() else "cpu" if not torch.cuda.is_available(): raise ValueError("CUDA is not available. Check your environment.") # Move model to device model = model.to(device) model.eval() # Load and preprocess images image_names = glob.glob(os.path.join(target_dir, "images", "*")) image_names = sorted(image_names) print(f"Found {len(image_names)} images") if len(image_names) == 0: raise ValueError("No images found. Check your upload.") images = load_and_preprocess_images(image_names).to(device) print(f"Preprocessed images shape: {images.shape}") # Run inference print(f"Running inference in {'streaming' if streaming else 'batch'} mode...") dtype = torch.bfloat16 if torch.cuda.get_device_capability()[0] >= 8 else torch.float16 session_cache_path: Optional[str] = None with torch.no_grad(): with torch.amp.autocast(dtype=dtype, device_type=device): if streaming: # Use StreamSession for sequential processing if mode == "full": print("Warning: Streaming mode does not support 'full' attention mode. Switching to 'causal' mode.") mode = "causal" session = StreamSession(model, mode=mode) kv_cache_path = os.path.join(target_dir, "kv_cache.pt") if os.path.exists(kv_cache_path): print(f"Loading existing session cache from {kv_cache_path}") session.load_cache(kv_cache_path, device=images.device) existing_predictions = session.get_all_predictions() existing_frames = 0 for value in existing_predictions.values(): if isinstance(value, torch.Tensor) and value.dim() >= 2: existing_frames = max(existing_frames, value.shape[1]) total_frames = images.shape[0] if existing_frames > total_frames: raise ValueError( "Session cache contains more frames than available images. Please ensure the images folder " "matches the saved session state." ) if existing_frames == total_frames: print("No new frames detected; reusing cached predictions.") else: for i in range(existing_frames, total_frames): image = images[i : i + 1] session.forward_stream(image) predictions = session.get_all_predictions() session.save_cache(kv_cache_path) session_cache_path = kv_cache_path else: # Use batch processing (original behavior) predictions = model(images, mode=mode) # Convert pose encoding to extrinsic and intrinsic matrices print("Converting pose encoding to extrinsic and intrinsic matrices...") extrinsic, intrinsic = pose_encoding_to_extri_intri(predictions["pose_enc"], images.shape[-2:]) predictions["extrinsic"] = extrinsic predictions["intrinsic"] = intrinsic # Convert tensors to numpy for key in list(predictions.keys()): value = predictions[key] if isinstance(value, torch.Tensor): predictions[key] = value.cpu().numpy().squeeze(0) # remove batch dimension predictions.pop("pose_enc_list", None) # Generate world points from depth map print("Computing world points from depth map...") depth_map = predictions["depth"] # (S, H, W, 1) world_points = unproject_depth_map_to_point_map(depth_map, predictions["extrinsic"], predictions["intrinsic"]) predictions["world_points_from_depth"] = world_points # Clean up torch.cuda.empty_cache() return predictions, session_cache_path # ------------------------------------------------------------------------- # 2) Handle uploaded video/images --> produce target_dir + images # ------------------------------------------------------------------------- def handle_uploads(input_video, input_images, input_zip=None, session_state=None, current_target_dir: str | None = None): """ Create a new 'target_dir' + 'images' subfolder. - Copies uploaded images - Optionally extracts images from a ZIP - Optionally extracts frames from a video (1 fps) - Optionally loads a previously saved session archive Returns (target_dir, image_paths, session_loaded). """ start_time = time.time() gc.collect() torch.cuda.empty_cache() # Create a unique folder name timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f") session_loaded = False if session_state: session_path = session_state.get("name") if isinstance(session_state, dict) and "name" in session_state else getattr(session_state, "name", None) session_path = session_path or session_state extract_root = os.path.join("demo_cache", f"session_{timestamp}") target_dir = extract_session_state(session_path, extract_root) session_loaded = True elif current_target_dir and os.path.isdir(current_target_dir): target_dir = current_target_dir else: target_dir = os.path.join("demo_cache", f"input_images_{timestamp}") if os.path.exists(target_dir): shutil.rmtree(target_dir) os.makedirs(target_dir, exist_ok=True) target_dir_images = os.path.join(target_dir, "images") os.makedirs(target_dir_images, exist_ok=True) image_paths: list[str] = [] # --- Handle images (list) --- if input_images: for file_data in input_images: file_path = file_data["name"] if isinstance(file_data, dict) and "name" in file_data else file_data copied_path = _copy_with_unique_name(file_path, target_dir_images) image_paths.append(copied_path) # --- Handle ZIP (extract images) --- if input_zip: zip_path = input_zip["name"] if isinstance(input_zip, dict) and "name" in input_zip else input_zip extracted = extract_images_from_zip(zip_path, target_dir_images) image_paths.extend(extracted) # --- Handle video (extract frames at 1 fps) --- if input_video: video_path = input_video["name"] if isinstance(input_video, dict) and "name" in input_video else input_video vs = cv2.VideoCapture(video_path) fps = vs.get(cv2.CAP_PROP_FPS) or 30.0 frame_interval = max(1, int(fps * 1)) # 1 frame/sec count = 0 video_frame_num = len(os.listdir(target_dir_images)) while True: gotit, frame = vs.read() if not gotit: break count += 1 if count % frame_interval == 0: image_path = os.path.join(target_dir_images, f"{video_frame_num:06}.png") cv2.imwrite(image_path, frame) image_paths.append(image_path) video_frame_num += 1 vs.release() image_paths = sorted(set(image_paths)) # de-dupe + sort # Ensure gallery reflects existing files in the images directory existing_images = sorted(glob.glob(os.path.join(target_dir_images, "*"))) image_paths = existing_images end_time = time.time() print(f"Prepared {len(image_paths)} files in {target_dir_images}; took {end_time - start_time:.3f}s") return target_dir, image_paths, session_loaded # ------------------------------------------------------------------------- # 3) Update gallery on upload # ------------------------------------------------------------------------- def update_gallery_on_upload(input_video, input_images, input_zip, session_state, current_target_dir): """ Handle any new uploads (video, images, or zip) and render preview. """ if not input_video and not input_images and not input_zip and not session_state: default_updates = [ gr.update(), gr.update(value=""), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), gr.update(), ] return (gr.update(), current_target_dir, None, None, None, *default_updates) target_dir, image_paths, session_loaded = handle_uploads( input_video, input_images, input_zip, session_state=session_state, current_target_dir=current_target_dir, ) if session_loaded: message = "Session state loaded. Add new frames and click 'Reconstruct' to continue." else: message = "Upload complete. Click 'Reconstruct' to begin 3D processing." target_dir_images = os.path.join(target_dir, "images") frame_filter_choices = build_frame_filter_choices(target_dir_images) frame_value = "All" frame_update = gr.update(choices=frame_filter_choices, value=frame_value) streaming_update = gr.update() mode_update = gr.update() conf_update = gr.update() mask_black_update = gr.update() mask_white_update = gr.update() show_cam_update = gr.update() mask_sky_update = gr.update() prediction_mode_update = gr.update() reconstruction_update = gr.update() session_state_reset = gr.update() localization_result_update = gr.update(value="") if session_loaded: settings = load_session_settings(target_dir) if settings: if "frame_filter" in settings: potential_value = settings.get("frame_filter", "All") if potential_value in frame_filter_choices: frame_value = potential_value frame_update = gr.update(choices=frame_filter_choices, value=frame_value) if "streaming" in settings: streaming_update = gr.update(value=bool(settings.get("streaming", True))) if settings.get("mode") in {"causal", "window", "full"}: mode_update = gr.update(value=settings["mode"]) if "conf_thres" in settings: try: conf_update = gr.update(value=float(settings["conf_thres"])) except (TypeError, ValueError): pass if "mask_black_bg" in settings: mask_black_update = gr.update(value=bool(settings.get("mask_black_bg", False))) if "mask_white_bg" in settings: mask_white_update = gr.update(value=bool(settings.get("mask_white_bg", False))) if "show_cam" in settings: show_cam_update = gr.update(value=bool(settings.get("show_cam", True))) if "mask_sky" in settings: mask_sky_update = gr.update(value=bool(settings.get("mask_sky", False))) pred_mode_value = settings.get("prediction_mode") if pred_mode_value in {"Depthmap and Camera Branch", "Pointmap Branch"}: prediction_mode_update = gr.update(value=pred_mode_value) try: conf_val = settings["conf_thres"] mode_val = settings["mode"] pred_mode_val = settings["prediction_mode"] mask_black_val = bool(settings.get("mask_black_bg", False)) mask_white_val = bool(settings.get("mask_white_bg", False)) show_cam_val = bool(settings.get("show_cam", True)) mask_sky_val = bool(settings.get("mask_sky", False)) glb_candidate = os.path.join( target_dir, f"glbscene_{conf_val}_{sanitize_frame_filter_label(frame_value)}_maskb{mask_black_val}_maskw{mask_white_val}_cam{show_cam_val}_sky{mask_sky_val}_pred{pred_mode_val.replace(' ', '_')}_mode{mode_val}.glb", ) if os.path.exists(glb_candidate): reconstruction_update = glb_candidate except (KeyError, AttributeError): pass session_state_reset = gr.update(value=None) localization_result_update = gr.update(value="") return ( reconstruction_update, target_dir, image_paths, message, None, session_state_reset, localization_result_update, streaming_update, mode_update, conf_update, frame_update, mask_black_update, mask_white_update, show_cam_update, mask_sky_update, prediction_mode_update, ) def update_gallery_without_session(input_video, input_images, input_zip, current_target_dir): return update_gallery_on_upload(input_video, input_images, input_zip, None, current_target_dir) def _format_vector(vec: np.ndarray) -> str: return "[" + ", ".join(f"{v:.4f}" for v in vec.tolist()) + "]" def localize_new_image( target_dir, localization_image, mode, conf_thres, mask_black_bg, mask_white_bg, show_cam, mask_sky, prediction_mode, ): if not target_dir or target_dir == "None" or not os.path.isdir(target_dir): return ("No reconstruction available. Please upload and reconstruct first.", gr.update()) image_path = _resolve_path(localization_image) if not image_path or not os.path.exists(image_path): return ("Please select an image to localize.", gr.update()) kv_cache_path = os.path.join(target_dir, "kv_cache.pt") predictions_path = os.path.join(target_dir, "predictions.npz") device = "cuda" if torch.cuda.is_available() else "cpu" try: image_tensor = load_and_preprocess_images([image_path]).to(device) except Exception as exc: return (f"Failed to preprocess image: {exc}", gr.update()) model.eval() model.to(device) session = getattr(localize_new_image, "_session", None) if session is None or session.mode != mode: session = StreamSession(model, mode=mode) localize_new_image._session = session session.clear() if not os.path.exists(kv_cache_path): if device == "cuda": torch.cuda.empty_cache() return ("Streaming cache not found. Run reconstruction in streaming mode first.", gr.update()) if not os.path.exists(predictions_path): if device == "cuda": torch.cuda.empty_cache() return ("Predictions not found. Run reconstruction before localizing.", gr.update()) try: session.load_cache(kv_cache_path, device=image_tensor.device) existing_predictions = session.get_all_predictions() existing_frames = 0 for value in existing_predictions.values(): if isinstance(value, torch.Tensor) and value.dim() >= 2: existing_frames = max(existing_frames, value.shape[1]) with torch.no_grad(): session.forward_stream(image_tensor) localized_predictions = session.get_all_predictions() except Exception as exc: session.clear() if image_tensor.device.type == "cuda": torch.cuda.empty_cache() return (f"Localization failed: {exc}", gr.update()) def _extract_frame(tensor: torch.Tensor, index: int) -> np.ndarray: if tensor.dim() < 2: return tensor.detach().cpu().numpy() frame = tensor[:, index:index + 1] return frame.detach().cpu().numpy() new_index = existing_frames try: pose_slice = localized_predictions["pose_enc"][:, new_index:new_index + 1] except KeyError: return ("pose_enc not available in session predictions.", gr.update()) H, W = image_tensor.shape[-2:] extrinsics_t, intrinsics_t = pose_encoding_to_extri_intri( pose_slice, image_size_hw=(H, W), build_intrinsics=True, ) extrinsic = extrinsics_t[0, 0].cpu().numpy() intrinsic = intrinsics_t[0, 0].cpu().numpy() if intrinsics_t is not None else None rotation_matrix = extrinsic[:, :3] translation = extrinsic[:, 3] camera_position = -rotation_matrix.T @ translation try: rotation = Rotation.from_matrix(rotation_matrix) yaw_pitch_roll = rotation.as_euler("zyx", degrees=True) except ValueError: yaw_pitch_roll = np.zeros(3) world_points_tensor = localized_predictions.get("world_points") if world_points_tensor is None: world_points_tensor = localized_predictions.get("world_points_from_depth") world_points = _extract_frame(world_points_tensor, new_index)[0] world_conf_tensor = localized_predictions.get("world_points_conf") if world_conf_tensor is None: world_conf_tensor = localized_predictions.get("depth_conf") if world_conf_tensor is not None: confidence = _extract_frame(world_conf_tensor, new_index)[0] else: confidence = np.ones(world_points.shape[:-1], dtype=np.float32) flattened_points = world_points.reshape(-1, 3) flattened_conf = confidence.reshape(-1) if conf_thres is None or conf_thres == 0.0: conf_threshold = 0.0 else: conf_threshold = np.percentile(flattened_conf, conf_thres) visibility_mask = (flattened_conf >= conf_threshold) & (flattened_conf > 1e-5) visible_points = flattened_points[visibility_mask] if visible_points.size == 0: visible_points = flattened_points bbox_min = visible_points.min(axis=0) if visible_points.size > 0 else np.zeros(3) bbox_max = visible_points.max(axis=0) if visible_points.size > 0 else np.zeros(3) summary_lines = [ f"Localized image: {os.path.basename(image_path)}", f"Translation (t): {_format_vector(translation)}", f"Camera position (world): {_format_vector(camera_position)}", f"Orientation yaw/pitch/roll (deg): {_format_vector(yaw_pitch_roll)}", f"Visible points (>= percentile {conf_thres:.1f}%): {visible_points.shape[0]}", f"World-space bounding box min: {_format_vector(bbox_min)}", f"World-space bounding box max: {_format_vector(bbox_max)}", ] if intrinsic is not None: summary_lines.append( "Intrinsics matrix:\n" + "\n".join(" " + _format_vector(row) for row in intrinsic) ) try: try: stored = np.load(predictions_path) except ValueError: stored = np.load(predictions_path, allow_pickle=True) base_predictions = {key: stored[key] for key in stored.files} stored.close() except Exception as exc: return (f"Failed to load stored predictions: {exc}", gr.update()) localization_glb_path = os.path.join(target_dir, "localization_preview.glb") world_to_camera = np.eye(4) world_to_camera[:3, :4] = extrinsic try: glbscene = predictions_to_glb( base_predictions, conf_thres=conf_thres, filter_by_frames="All", mask_black_bg=mask_black_bg, mask_white_bg=mask_white_bg, show_cam=False, mask_sky=mask_sky, target_dir=target_dir, prediction_mode=prediction_mode, extra_cameras=[world_to_camera], extra_camera_color=(255, 0, 0), ) glbscene.export(file_obj=localization_glb_path) except Exception as exc: localization_glb_path = None summary_lines.append(f"Warning: failed to update GLB preview ({exc})") session.clear() if image_tensor.device.type == "cuda": torch.cuda.empty_cache() return ("\n".join(summary_lines), localization_glb_path if localization_glb_path else gr.update()) # ------------------------------------------------------------------------- # 4) Reconstruction: uses the target_dir plus any viz parameters # ------------------------------------------------------------------------- def gradio_demo( target_dir, conf_thres=3.0, frame_filter="All", mask_black_bg=False, mask_white_bg=False, show_cam=True, mask_sky=False, prediction_mode="Pointmap Regression", mode="causal", streaming=False, ): """ Perform reconstruction using the already-created target_dir/images. """ if not os.path.isdir(target_dir) or target_dir == "None": return None, "No valid target directory found. Please upload first.", None, None start_time = time.time() gc.collect() torch.cuda.empty_cache() # Prepare frame_filter dropdown target_dir_images = os.path.join(target_dir, "images") frame_filter_choices = build_frame_filter_choices(target_dir_images) config = load_stream_session_config() top_k_frames = int(config.get("top_k_frames", 0) or 0) print("Running run_model...") with torch.no_grad(): predictions, session_cache_path = run_model(target_dir, model, mode=mode, streaming=streaming) # Save predictions prediction_save_path = os.path.join(target_dir, "predictions.npz") np.savez(prediction_save_path, **predictions) selected_frames = select_top_k_frames(predictions, target_dir_images, top_k_frames) selected_frames_path = os.path.join(target_dir, "selected_frames.json") if selected_frames: try: with open(selected_frames_path, "w", encoding="utf-8") as handle: json.dump({"top_k": top_k_frames, "frames": selected_frames}, handle, indent=2) except OSError as exc: print(f"Failed to write selected frames: {exc}") elif os.path.exists(selected_frames_path): try: os.remove(selected_frames_path) except OSError: pass frame_filter_value = frame_filter if frame_filter is not None else "All" session_settings = { "streaming": bool(streaming), "mode": mode, "conf_thres": float(conf_thres), "frame_filter": frame_filter_value, "mask_black_bg": bool(mask_black_bg), "mask_white_bg": bool(mask_white_bg), "show_cam": bool(show_cam), "mask_sky": bool(mask_sky), "prediction_mode": prediction_mode, } session_settings["top_k_frames"] = top_k_frames if selected_frames: session_settings["selected_frames"] = [frame["filename"] for frame in selected_frames] try: with open(os.path.join(target_dir, "session_settings.json"), "w", encoding="utf-8") as handle: json.dump(session_settings, handle, indent=2) except OSError as exc: print(f"Failed to write session settings: {exc}") session_state_file = None if streaming: if session_cache_path is None: session_cache_path = os.path.join(target_dir, "kv_cache.pt") if os.path.exists(session_cache_path): session_state_file = package_session_state(target_dir) # Handle None frame_filter if frame_filter is None: frame_filter = "All" # Build a GLB file name glbfile = os.path.join( target_dir, f"glbscene_{conf_thres}_{sanitize_frame_filter_label(frame_filter)}_maskb{mask_black_bg}_maskw{mask_white_bg}_cam{show_cam}_sky{mask_sky}_pred{prediction_mode.replace(' ', '_')}_mode{mode}.glb", ) # Convert predictions to GLB glbscene = predictions_to_glb( predictions, conf_thres=conf_thres, filter_by_frames=frame_filter, mask_black_bg=mask_black_bg, mask_white_bg=mask_white_bg, show_cam=show_cam, mask_sky=mask_sky, target_dir=target_dir, prediction_mode=prediction_mode, ) glbscene.export(file_obj=glbfile) # Cleanup del predictions gc.collect() torch.cuda.empty_cache() end_time = time.time() print(f"Total time: {end_time - start_time:.2f} seconds (including IO)") num_frames = max(0, len(frame_filter_choices) - 1) log_msg = f"Reconstruction Success ({num_frames} frames). Waiting for visualization." return ( glbfile, log_msg, gr.Dropdown(choices=frame_filter_choices, value=frame_filter, interactive=True), session_state_file, ) # ------------------------------------------------------------------------- # 5) Helper functions for UI resets + re-visualization # ------------------------------------------------------------------------- def clear_fields(): """ Clears the 3D viewer, the stored target_dir, and empties the gallery. """ return None def update_log(): """ Display a quick log message while waiting. """ return "Loading and Reconstructing..." def update_visualization( target_dir, conf_thres, frame_filter, mask_black_bg, mask_white_bg, show_cam, mask_sky, prediction_mode, mode_value, is_example, ): """ Reload saved predictions from npz, create (or reuse) the GLB for new parameters, and return it for the 3D viewer. If is_example == "True", skip. """ # If it's an example click, skip as requested if is_example == "True": return None, "No reconstruction available. Please click the Reconstruct button first." if not target_dir or target_dir == "None" or not os.path.isdir(target_dir): return None, "No reconstruction available. Please click the Reconstruct button first." predictions_path = os.path.join(target_dir, "predictions.npz") if not os.path.exists(predictions_path): return None, f"No reconstruction available at {predictions_path}. Please run 'Reconstruct' first." key_list = [ "pose_enc", "depth", "depth_conf", "world_points", "world_points_conf", "images", "extrinsic", "intrinsic", "world_points_from_depth", ] loaded = np.load(predictions_path) predictions = {key: np.array(loaded[key]) for key in key_list} sanitized_frame = sanitize_frame_filter_label(frame_filter) glbfile = os.path.join( target_dir, f"glbscene_{conf_thres}_{sanitized_frame}_maskb{mask_black_bg}_maskw{mask_white_bg}_cam{show_cam}_sky{mask_sky}_pred{prediction_mode.replace(' ', '_')}_mode{mode_value}.glb", ) if not os.path.exists(glbfile): glbscene = predictions_to_glb( predictions, conf_thres=conf_thres, filter_by_frames=frame_filter, mask_black_bg=mask_black_bg, mask_white_bg=mask_white_bg, show_cam=show_cam, mask_sky=mask_sky, target_dir=target_dir, prediction_mode=prediction_mode, ) glbscene.export(file_obj=glbfile) return glbfile, "Updating Visualization" # ------------------------------------------------------------------------- # Example images # ------------------------------------------------------------------------- great_wall_video = "examples/videos/great_wall.mp4" colosseum_video = "examples/videos/Colosseum.mp4" room_video = "examples/videos/room.mp4" kitchen_video = "examples/videos/kitchen.mp4" fern_video = "examples/videos/fern.mp4" single_cartoon_video = "examples/videos/single_cartoon.mp4" single_oil_painting_video = "examples/videos/single_oil_painting.mp4" pyramid_video = "examples/videos/pyramid.mp4" # ------------------------------------------------------------------------- # 6) Build Gradio UI # ------------------------------------------------------------------------- theme = gr.themes.Ocean() theme.set( checkbox_label_background_fill_selected="*button_primary_background_fill", checkbox_label_text_color_selected="*button_primary_text_color", ) with gr.Blocks( theme=theme, css=""" .custom-log * { font-style: italic; font-size: 22px !important; background-image: linear-gradient(120deg, #0ea5e9 0%, #6ee7b7 60%, #34d399 100%); -webkit-background-clip: text; background-clip: text; font-weight: bold !important; color: transparent !important; text-align: center !important; } .example-log * { font-style: italic; font-size: 16px !important; background-image: linear-gradient(120deg, #0ea5e9 0%, #6ee7b7 60%, #34d399 100%); -webkit-background-clip: text; background-clip: text; color: transparent !important; } #my_radio .wrap { display: flex; flex-wrap: nowrap; justify-content: center; align-items: center; } #my_radio .wrap label { display: flex; width: 50%; justify-content: center; align-items: center; margin: 0; padding: 10px 0; box-sizing: border-box; } """, ) as demo: # Instead of gr.State, we use a hidden Textbox: is_example = gr.Textbox(label="is_example", visible=False, value="None") num_images = gr.Textbox(label="num_images", visible=False, value="None") example_preview = gr.Image(label="Example Preview", visible=False) gr.HTML( """
GitHub Repository | Project Page | Paper
Special thanks to VGG-T for their visualization demo, which this demo is built upon!
Upload a video or a set of images to create a 3D reconstruction of a scene or object. STream3R takes these images and generates a 3D point cloud, along with estimated camera poses.
Please note: STream3R typically reconstructs a scene in less than 1 second. However, visualizing 3D points may take tens of seconds due to third-party rendering, which are independent of STream3R's processing time.