Spaces:
Runtime error
Runtime error
| import spaces | |
| import torch | |
| print(f'torch version:{torch.__version__}') | |
| import functools | |
| import gc | |
| import os | |
| import subprocess | |
| import shutil | |
| import sys | |
| import tempfile | |
| import time | |
| from datetime import datetime | |
| from pathlib import Path | |
| import cv2 | |
| import gradio as gr | |
| from huggingface_hub import hf_hub_download | |
| from PIL import Image | |
| sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| from src.misc.image_io import save_interpolated_video | |
| from src.model.model.anysplat import AnySplat | |
| from src.model.ply_export import export_ply | |
| from src.utils.image import process_image | |
| os.environ["ANYSPLAT_PROCESSED"] = f"{os.getcwd()}/proprocess_results" | |
| # 1) Core model inference | |
| def get_reconstructed_scene(outdir, model, device): | |
| # Load Images | |
| image_files = sorted( | |
| [ | |
| os.path.join(outdir, "images", f) | |
| for f in os.listdir(os.path.join(outdir, "images")) | |
| ] | |
| ) | |
| images = [process_image(img_path) for img_path in image_files] | |
| images = torch.stack(images, dim=0).unsqueeze(0).to(device) # [1, K, 3, 448, 448] | |
| b, v, c, h, w = images.shape | |
| assert c == 3, "Images must have 3 channels" | |
| # Run Inference | |
| gaussians, pred_context_pose = model.inference((images + 1) * 0.5) | |
| # Save the results | |
| pred_all_extrinsic = pred_context_pose["extrinsic"] | |
| pred_all_intrinsic = pred_context_pose["intrinsic"] | |
| video, depth_colored = save_interpolated_video( | |
| pred_all_extrinsic, | |
| pred_all_intrinsic, | |
| b, | |
| h, | |
| w, | |
| gaussians, | |
| outdir, | |
| model.decoder, | |
| ) | |
| plyfile = os.path.join(outdir, "gaussians.ply") | |
| export_ply( | |
| gaussians.means[0], | |
| gaussians.scales[0], | |
| gaussians.rotations[0], | |
| gaussians.harmonics[0], | |
| gaussians.opacities[0], | |
| Path(plyfile), | |
| save_sh_dc_only=True, | |
| ) | |
| # Clean up | |
| torch.cuda.empty_cache() | |
| return plyfile, video, depth_colored | |
| # 2) Handle uploaded video/images --> produce target_dir + images | |
| def handle_uploads(input_video, input_images, session_id): | |
| """ | |
| Create a new 'target_dir' + 'images' subfolder, and place user-uploaded | |
| images or extracted frames from video into it. Return (target_dir, image_paths). | |
| """ | |
| start_time = time.time() | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| base_dir = os.path.join(os.environ["ANYSPLAT_PROCESSED"], session_id) | |
| target_dir = base_dir | |
| target_dir_images = os.path.join(target_dir, "images") | |
| # Clean up if somehow that folder already exists | |
| if os.path.exists(target_dir): | |
| shutil.rmtree(target_dir) | |
| os.makedirs(target_dir) | |
| os.makedirs(target_dir_images) | |
| image_paths = [] | |
| # --- Handle images --- | |
| if input_images is not None: | |
| for file_data in input_images: | |
| if isinstance(file_data, dict) and "name" in file_data: | |
| file_path = file_data["name"] | |
| else: | |
| file_path = file_data | |
| dst_path = os.path.join(target_dir_images, os.path.basename(file_path)) | |
| shutil.copy(file_path, dst_path) | |
| image_paths.append(dst_path) | |
| # --- Handle video --- | |
| if input_video is not None: | |
| if isinstance(input_video, dict) and "name" in input_video: | |
| video_path = input_video["name"] | |
| else: | |
| video_path = input_video | |
| vs = cv2.VideoCapture(video_path) | |
| fps = vs.get(cv2.CAP_PROP_FPS) | |
| frame_interval = int(fps * 1) # 1 frame/sec | |
| count = 0 | |
| video_frame_num = 0 | |
| while True: | |
| gotit, frame = vs.read() | |
| if not gotit: | |
| break | |
| count += 1 | |
| if count % frame_interval == 0: | |
| image_path = os.path.join( | |
| target_dir_images, f"{video_frame_num:06}.png" | |
| ) | |
| cv2.imwrite(image_path, frame) | |
| image_paths.append(image_path) | |
| video_frame_num += 1 | |
| # Sort final images for gallery | |
| image_paths = sorted(image_paths) | |
| end_time = time.time() | |
| print( | |
| f"Files copied to {target_dir_images}; took {end_time - start_time:.3f} seconds" | |
| ) | |
| return target_dir, image_paths | |
| # 3) Update gallery on upload | |
| def update_gallery_on_upload(input_video, input_images, session_id): | |
| """ | |
| Whenever user uploads or changes files, immediately handle them | |
| and show in the gallery. Return (target_dir, image_paths). | |
| If nothing is uploaded, returns "None" and empty list. | |
| """ | |
| if not input_video and not input_images: | |
| return None, None, None | |
| target_dir, image_paths = handle_uploads(input_video, input_images, session_id) | |
| return None, target_dir, image_paths | |
| def generate_splat(images_folder, session_id=None): | |
| start_time = time.time() | |
| gc.collect() | |
| torch.cuda.empty_cache() | |
| base_dir = os.path.join(os.environ["ANYSPLAT_PROCESSED"], session_id) | |
| all_files = ( | |
| sorted(os.listdir(images_folder)) | |
| if os.path.isdir(images_folder) | |
| else [] | |
| ) | |
| all_files = [f"{i}: {filename}" for i, filename in enumerate(all_files)] | |
| print("Running run_model...") | |
| with torch.no_grad(): | |
| plyfile, video, depth_colored = get_reconstructed_scene(base_dir, model, device) | |
| end_time = time.time() | |
| print(f"Total time: {end_time - start_time:.2f} seconds (including IO)") | |
| return plyfile, video, depth_colored | |
| def start_session(request: gr.Request): | |
| """ | |
| Initialize a new user session and return the session identifier. | |
| This function is triggered when the Gradio demo loads and creates a unique | |
| session hash that will be used to organize outputs and temporary files | |
| for this specific user session. | |
| Args: | |
| request (gr.Request): Gradio request object containing session information | |
| Returns: | |
| str: Unique session hash identifier | |
| """ | |
| return request.session_hash | |
| if __name__ == "__main__": | |
| share = True | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| # Load model | |
| model = AnySplat.from_pretrained( | |
| "lhjiang/anysplat" | |
| ) | |
| model = model.to(device) | |
| model.eval() | |
| for param in model.parameters(): | |
| param.requires_grad = False | |
| theme = gr.themes.Ocean() | |
| theme.set( | |
| checkbox_label_background_fill_selected="*button_primary_background_fill", | |
| checkbox_label_text_color_selected="*button_primary_text_color", | |
| ) | |
| css = """ | |
| #col-container { | |
| margin: 0 auto; | |
| max-width: 1024px; | |
| } | |
| .custom-log * { | |
| font-style: italic; | |
| font-size: 22px !important; | |
| background-image: linear-gradient(120deg, #0ea5e9 0%, #6ee7b7 60%, #34d399 100%); | |
| -webkit-background-clip: text; | |
| background-clip: text; | |
| font-weight: bold !important; | |
| color: transparent !important; | |
| text-align: center !important; | |
| } | |
| .example-log * { | |
| font-style: italic; | |
| font-size: 16px !important; | |
| background-image: linear-gradient(120deg, #0ea5e9 0%, #6ee7b7 60%, #34d399 100%); | |
| -webkit-background-clip: text; | |
| background-clip: text; | |
| color: transparent !important; | |
| } | |
| #my_radio .wrap { | |
| display: flex; | |
| flex-wrap: nowrap; | |
| justify-content: center; | |
| align-items: center; | |
| } | |
| #my_radio .wrap label { | |
| display: flex; | |
| width: 50%; | |
| justify-content: center; | |
| align-items: center; | |
| margin: 0; | |
| padding: 10px 0; | |
| box-sizing: border-box; | |
| } | |
| """ | |
| with gr.Blocks(css=css, title="AnySplat Demo", theme=theme) as demo: | |
| session_state = gr.State() | |
| demo.load(start_session, outputs=[session_state]) | |
| target_dir_output = gr.Textbox(label="Target Dir", visible=False, value="None") | |
| is_example = gr.Textbox(label="is_example", visible=False, value="None") | |
| num_images = gr.Textbox(label="num_images", visible=False, value="None") | |
| dataset_name = gr.Textbox(label="dataset_name", visible=False, value="None") | |
| scene_name = gr.Textbox(label="scene_name", visible=False, value="None") | |
| image_type = gr.Textbox(label="image_type", visible=False, value="None") | |
| with gr.Column(elem_id="col-container"): | |
| gr.Markdown( | |
| """ # AnySplat – Feed-forward 3D Gaussian Splatting from Unconstrained Views | |
| • Source: [Github](https://github.com/OpenRobotLab/AnySplat) """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_video = gr.Video(label="Upload Video", interactive=True, height=512) | |
| input_images = gr.File( | |
| file_count="multiple", | |
| label="Upload Images", | |
| interactive=True, | |
| visible=False | |
| ) | |
| submit_btn = gr.Button( | |
| "Reconstruct", scale=1, variant="primary" | |
| ) | |
| image_gallery = gr.Gallery( | |
| label="Preview", | |
| columns=4, | |
| height="300px", | |
| show_download_button=True, | |
| object_fit="contain", | |
| preview=True, | |
| ) | |
| with gr.Column(): | |
| with gr.Column(): | |
| reconstruction_output = gr.Model3D( | |
| label="3D Reconstructed Gaussian Splat", | |
| height=512, | |
| zoom_speed=0.5, | |
| pan_speed=0.5, | |
| camera_position=[20, 20, 20], | |
| ) | |
| with gr.Row(): | |
| with gr.Row(): | |
| rgb_video = gr.Video( | |
| label="RGB Video", interactive=False, autoplay=True | |
| ) | |
| depth_video = gr.Video( | |
| label="Depth Video", | |
| interactive=False, | |
| autoplay=True, | |
| ) | |
| # ---------------------- Examples section ---------------------- | |
| # examples = [ | |
| # [None, "examples/video/re10k_1eca36ec55b88fe4.mp4", "re10k", "1eca36ec55b88fe4", "2", "Real", "True",], | |
| # [None, "examples/video/bungeenerf_colosseum.mp4", "bungeenerf", "colosseum", "8", "Synthetic", "True",], | |
| # [None, "examples/video/fox.mp4", "InstantNGP", "fox", "14", "Real", "True",], | |
| # [None, "examples/video/matrixcity_street.mp4", "matrixcity", "street", "32", "Synthetic", "True",], | |
| # [None, "examples/video/vrnerf_apartment.mp4", "vrnerf", "apartment", "32", "Real", "True",], | |
| # [None, "examples/video/vrnerf_kitchen.mp4", "vrnerf", "kitchen", "17", "Real", "True",], | |
| # [None, "examples/video/vrnerf_riverview.mp4", "vrnerf", "riverview", "12", "Real", "True",], | |
| # [None, "examples/video/vrnerf_workshop.mp4", "vrnerf", "workshop", "32", "Real", "True",], | |
| # [None, "examples/video/fillerbuster_ramen.mp4", "fillerbuster", "ramen", "32", "Real", "True",], | |
| # [None, "examples/video/meganerf_rubble.mp4", "meganerf", "rubble", "10", "Real", "True",], | |
| # [None, "examples/video/llff_horns.mp4", "llff", "horns", "12", "Real", "True",], | |
| # [None, "examples/video/llff_fortress.mp4", "llff", "fortress", "7", "Real", "True",], | |
| # [None, "examples/video/dtu_scan_106.mp4", "dtu", "scan_106", "20", "Real", "True",], | |
| # [None, "examples/video/horizongs_hillside_summer.mp4", "horizongs", "hillside_summer", "55", "Synthetic", "True",], | |
| # [None, "examples/video/kitti360.mp4", "kitti360", "kitti360", "64", "Real", "True",], | |
| # ] | |
| # gr.Examples( | |
| # examples=examples, | |
| # inputs=[ | |
| # input_images, | |
| # input_video, | |
| # dataset_name, | |
| # scene_name, | |
| # num_images, | |
| # image_type, | |
| # is_example, | |
| # ], | |
| # outputs=[ | |
| # reconstruction_output, | |
| # rgb_video, | |
| # depth_video, | |
| # target_dir_output, | |
| # image_gallery, | |
| # ], | |
| # fn=example_pipeline, | |
| # cache_examples=False, | |
| # examples_per_page=50, | |
| # ) | |
| submit_btn.click( | |
| fn=generate_splat, | |
| inputs=[target_dir_output,], | |
| outputs=[reconstruction_output, rgb_video, depth_video]) | |
| input_video.change( | |
| fn=update_gallery_on_upload, | |
| inputs=[input_video, input_images, session_state], | |
| outputs=[reconstruction_output, target_dir_output, image_gallery], | |
| ) | |
| input_images.change( | |
| fn=update_gallery_on_upload, | |
| inputs=[input_video, input_images, session_state], | |
| outputs=[reconstruction_output, target_dir_output, image_gallery], | |
| ) | |
| demo.queue().launch(show_error=True, share=True) | |