Spaces:
Runtime error
Runtime error
| import spaces | |
| import torch | |
| print(f'torch version:{torch.__version__}') | |
| import functools | |
| import gc | |
| import os | |
| import subprocess | |
| import shutil | |
| import sys | |
| import tempfile | |
| import time | |
| from datetime import datetime | |
| from pathlib import Path | |
| import uuid | |
| import cv2 | |
| import gradio as gr | |
| from huggingface_hub import hf_hub_download | |
| from PIL import Image | |
| sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| from src.misc.image_io import save_interpolated_video | |
| from src.model.model.anysplat import AnySplat | |
| from src.model.ply_export import export_ply | |
| from src.utils.image import process_image | |
| import open3d as o3d | |
| os.environ["ANYSPLAT_PROCESSED"] = f"{os.getcwd()}/proprocess_results" | |
| def get_reconstructed_scene(outdir, model, device): | |
| image_files = sorted( | |
| [ | |
| os.path.join(outdir, "images", f) | |
| for f in os.listdir(os.path.join(outdir, "images")) | |
| ] | |
| ) | |
| images = [process_image(img_path) for img_path in image_files] | |
| images = torch.stack(images, dim=0).unsqueeze(0).to(device) # [1, K, 3, 448, 448] | |
| b, v, c, h, w = images.shape | |
| assert c == 3, "Images must have 3 channels" | |
| gaussians, pred_context_pose = model.inference((images + 1) * 0.5) | |
| pred_all_extrinsic = pred_context_pose["extrinsic"] | |
| pred_all_intrinsic = pred_context_pose["intrinsic"] | |
| video, depth_colored = save_interpolated_video( | |
| pred_all_extrinsic, | |
| pred_all_intrinsic, | |
| b, | |
| h, | |
| w, | |
| gaussians, | |
| outdir, | |
| model.decoder, | |
| ) | |
| plyfile = os.path.join(outdir, "gaussians.ply") | |
| glbfile = os.path.join(outdir, "gaussians.glb") | |
| export_ply( | |
| gaussians.means[0], | |
| gaussians.scales[0], | |
| gaussians.rotations[0], | |
| gaussians.harmonics[0], | |
| gaussians.opacities[0], | |
| Path(plyfile), | |
| save_sh_dc_only=True, | |
| ) | |
| import trimesh | |
| import numpy as np | |
| # 1. Load PLY and preserve attributes | |
| mesh = trimesh.load(plyfile, process=False) | |
| # 2. Check or assign vertex colors | |
| if mesh.visual.vertex_colors is None or mesh.visual.vertex_colors.shape[1] < 4: | |
| # Example: assume mesh.metadata['vertex_color'] holds (N×3) array | |
| rgb = np.array(mesh.metadata['vertex_color'], dtype=np.uint8) | |
| alpha = np.full((rgb.shape[0], 1), 255, dtype=np.uint8) | |
| mesh.visual.vertex_colors = np.concatenate([rgb, alpha], axis=1) | |
| # 3. Export GLB | |
| mesh.export(glbfile, file_type='glb') | |
| print("Export complete: scene_colored.glb") | |
| # Clean up | |
| torch.cuda.empty_cache() | |
| return glbfile, video, depth_colored | |
| # 2) Handle uploaded video/images --> produce target_dir + images | |
| def extract_frames(input_video, session_id): | |
| """ | |
| Create a new 'target_dir' + 'images' subfolder, and place user-uploaded | |
| images or extracted frames from video into it. Return (target_dir, image_paths). | |
| """ | |
| start_time = time.time() | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| base_dir = os.path.join(os.environ["ANYSPLAT_PROCESSED"], session_id) | |
| target_dir = base_dir | |
| target_dir_images = os.path.join(target_dir, "images") | |
| # Clean up if somehow that folder already exists | |
| if os.path.exists(target_dir): | |
| shutil.rmtree(target_dir) | |
| os.makedirs(target_dir) | |
| os.makedirs(target_dir_images) | |
| image_paths = [] | |
| if input_video is not None: | |
| if isinstance(input_video, dict) and "name" in input_video: | |
| video_path = input_video["name"] | |
| else: | |
| video_path = input_video | |
| vs = cv2.VideoCapture(video_path) | |
| fps = vs.get(cv2.CAP_PROP_FPS) | |
| frame_interval = int(fps * 1) # 1 frame/sec | |
| count = 0 | |
| video_frame_num = 0 | |
| while True: | |
| gotit, frame = vs.read() | |
| if not gotit: | |
| break | |
| count += 1 | |
| if count % frame_interval == 0: | |
| image_path = os.path.join( | |
| target_dir_images, f"{video_frame_num:06}.png" | |
| ) | |
| cv2.imwrite(image_path, frame) | |
| image_paths.append(image_path) | |
| video_frame_num += 1 | |
| # Sort final images for gallery | |
| image_paths = sorted(image_paths) | |
| end_time = time.time() | |
| print( | |
| f"Files copied to {target_dir_images}; took {end_time - start_time:.3f} seconds" | |
| ) | |
| return target_dir, image_paths | |
| def update_gallery_on_upload(input_video, session_id): | |
| """ | |
| Whenever user uploads or changes files, immediately handle them | |
| and show in the gallery. Return (target_dir, image_paths). | |
| If nothing is uploaded, returns "None" and empty list. | |
| """ | |
| if not input_video and not input_images: | |
| return None, None, None | |
| target_dir, image_paths = extract_frames(input_video, session_id) | |
| return None, target_dir, image_paths | |
| def generate_splats_from_video(video_path, session_id=None): | |
| if session_id is None: | |
| session_id = uuid.uuid4().hex | |
| images_folder, image_paths = extract_frames(video_path, session_id) | |
| plyfile, rgb_vid, depth_vid = generate_splats_from_images(images_folder, session_id) | |
| return plyfile, rgb_vid, depth_vid, image_paths | |
| def generate_splats_from_images(images_folder, session_id=None): | |
| if session_id is None: | |
| session_id = uuid.uuid4().hex | |
| start_time = time.time() | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| base_dir = os.path.join(os.environ["ANYSPLAT_PROCESSED"], session_id) | |
| all_files = ( | |
| sorted(os.listdir(images_folder)) | |
| if os.path.isdir(images_folder) | |
| else [] | |
| ) | |
| all_files = [f"{i}: {filename}" for i, filename in enumerate(all_files)] | |
| print("Running run_model...") | |
| with torch.no_grad(): | |
| plyfile, video, depth_colored = get_reconstructed_scene(base_dir, model, device) | |
| end_time = time.time() | |
| print(f"Total time: {end_time - start_time:.2f} seconds (including IO)") | |
| return plyfile, video, depth_colored | |
| def cleanup(request: gr.Request): | |
| """ | |
| Clean up session-specific directories and temporary files when the user session ends. | |
| This function is triggered when the Gradio demo is unloaded (e.g., when the user | |
| closes the browser tab or navigates away). It removes all temporary files and | |
| directories created during the user's session to free up storage space. | |
| Args: | |
| request (gr.Request): Gradio request object containing session information | |
| """ | |
| sid = request.session_hash | |
| if sid: | |
| d1 = os.path.join(os.environ["ANYSPLAT_PROCESSED"], sid) | |
| shutil.rmtree(d1, ignore_errors=True) | |
| def start_session(request: gr.Request): | |
| """ | |
| Initialize a new user session and return the session identifier. | |
| This function is triggered when the Gradio demo loads and creates a unique | |
| session hash that will be used to organize outputs and temporary files | |
| for this specific user session. | |
| Args: | |
| request (gr.Request): Gradio request object containing session information | |
| Returns: | |
| str: Unique session hash identifier | |
| """ | |
| return request.session_hash | |
| if __name__ == "__main__": | |
| share = True | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Load model | |
| model = AnySplat.from_pretrained( | |
| "lhjiang/anysplat" | |
| ) | |
| model = model.to(device) | |
| model.eval() | |
| for param in model.parameters(): | |
| param.requires_grad = False | |
| theme = gr.themes.Ocean() | |
| theme.set( | |
| checkbox_label_background_fill_selected="*button_primary_background_fill", | |
| checkbox_label_text_color_selected="*button_primary_text_color", | |
| ) | |
| css = """ | |
| #col-container { | |
| margin: 0 auto; | |
| max-width: 1024px; | |
| } | |
| """ | |
| with gr.Blocks(css=css, title="AnySplat Demo", theme=theme) as demo: | |
| session_state = gr.State() | |
| demo.load(start_session, outputs=[session_state]) | |
| target_dir_output = gr.Textbox(label="Target Dir", visible=False, value="None") | |
| is_example = gr.Textbox(label="is_example", visible=False, value="None") | |
| num_images = gr.Textbox(label="num_images", visible=False, value="None") | |
| dataset_name = gr.Textbox(label="dataset_name", visible=False, value="None") | |
| scene_name = gr.Textbox(label="scene_name", visible=False, value="None") | |
| image_type = gr.Textbox(label="image_type", visible=False, value="None") | |
| with gr.Column(elem_id="col-container"): | |
| gr.HTML( | |
| """ | |
| <div style="text-align: center;"> | |
| <p style="font-size:16px; display: inline; margin: 0;"> | |
| <strong>AnySplat</strong> – Feed-forward 3D Gaussian Splatting from Unconstrained Views | |
| </p> | |
| <a href="https://github.com/OpenRobotLab/AnySplat" style="display: inline-block; vertical-align: middle; margin-left: 0.5em;"> | |
| <img src="https://img.shields.io/badge/GitHub-Repo-blue" alt="GitHub Repo"> | |
| </a> | |
| </div> | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_video = gr.Video(label="Upload Video", interactive=True, height=512) | |
| submit_btn = gr.Button( | |
| "Reconstruct", scale=1, variant="primary" | |
| ) | |
| image_gallery = gr.Gallery( | |
| label="Preview", | |
| columns=4, | |
| height="300px", | |
| show_download_button=True, | |
| object_fit="contain", | |
| preview=True, | |
| ) | |
| with gr.Column(): | |
| with gr.Column(): | |
| reconstruction_output = gr.Model3D( | |
| label="3D Reconstructed Gaussian Splat", | |
| height=512, | |
| zoom_speed=0.5, | |
| pan_speed=0.5, | |
| camera_position=[20, 20, 20], | |
| ) | |
| with gr.Row(): | |
| rgb_video = gr.Video( | |
| label="RGB Video", interactive=False, autoplay=True | |
| ) | |
| depth_video = gr.Video( | |
| label="Depth Video", | |
| interactive=False, | |
| autoplay=True, | |
| ) | |
| with gr.Row(): | |
| examples = [ | |
| ["examples/video/re10k_1eca36ec55b88fe4.mp4"], | |
| ["examples/video/bungeenerf_colosseum.mp4"], | |
| ["examples/video/fox.mp4"], | |
| ["examples/video/matrixcity_street.mp4"], | |
| ["examples/video/vrnerf_apartment.mp4"], | |
| # [None, "examples/video/vrnerf_kitchen.mp4", "vrnerf", "kitchen", "17", "Real", "True",], | |
| # [None, "examples/video/vrnerf_riverview.mp4", "vrnerf", "riverview", "12", "Real", "True",], | |
| # [None, "examples/video/vrnerf_workshop.mp4", "vrnerf", "workshop", "32", "Real", "True",], | |
| # [None, "examples/video/fillerbuster_ramen.mp4", "fillerbuster", "ramen", "32", "Real", "True",], | |
| # [None, "examples/video/meganerf_rubble.mp4", "meganerf", "rubble", "10", "Real", "True",], | |
| # [None, "examples/video/llff_horns.mp4", "llff", "horns", "12", "Real", "True",], | |
| # [None, "examples/video/llff_fortress.mp4", "llff", "fortress", "7", "Real", "True",], | |
| # [None, "examples/video/dtu_scan_106.mp4", "dtu", "scan_106", "20", "Real", "True",], | |
| # [None, "examples/video/horizongs_hillside_summer.mp4", "horizongs", "hillside_summer", "55", "Synthetic", "True",], | |
| # [None, "examples/video/kitti360.mp4", "kitti360", "kitti360", "64", "Real", "True",], | |
| ] | |
| gr.Examples( | |
| examples=examples, | |
| inputs=[ | |
| input_video | |
| ], | |
| outputs=[ | |
| reconstruction_output, | |
| rgb_video, | |
| depth_video, | |
| image_gallery | |
| ], | |
| fn=generate_splats_from_video, | |
| cache_examples=True, | |
| ) | |
| submit_btn.click( | |
| fn=generate_splats_from_images, | |
| inputs=[target_dir_output, session_state], | |
| outputs=[reconstruction_output, rgb_video, depth_video]) | |
| input_video.change( | |
| fn=update_gallery_on_upload, | |
| inputs=[input_video, session_state], | |
| outputs=[reconstruction_output, target_dir_output, image_gallery], | |
| ) | |
| demo.unload(cleanup) | |
| demo.queue() | |
| demo.launch(show_error=True, share=True) | |