| | import spaces |
| | import torch |
| |
|
| | print(f'torch version:{torch.__version__}') |
| |
|
| | import functools |
| | import gc |
| | import os |
| |
|
| | os.environ['TORCH_CUDA_ARCH_LIST'] = '9.0' |
| |
|
| | import subprocess |
| |
|
| | def sh(cmd): subprocess.check_call(cmd, shell=True) |
| | |
| | sh("pip install pytorch3d-0.7.9-cp310-cp310-linux_x86_64.whl") |
| |
|
| | import shutil |
| | import sys |
| | import tempfile |
| | import time |
| | from datetime import datetime |
| | from pathlib import Path |
| | import uuid |
| |
|
| | import cv2 |
| | import gradio as gr |
| |
|
| | from huggingface_hub import hf_hub_download |
| | from PIL import Image |
| |
|
| | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
| |
|
| | from src.misc.image_io import save_interpolated_video |
| | from src.model.model.anysplat import AnySplat |
| | from src.model.ply_export import export_ply |
| | from src.utils.image import process_image |
| |
|
| | os.environ["ANYSPLAT_PROCESSED"] = f"{os.getcwd()}/proprocess_results" |
| |
|
| | from plyfile import PlyData |
| | import numpy as np |
| | import argparse |
| | from io import BytesIO |
| |
|
| |
|
| | def process_ply_to_splat(ply_file_path): |
| | plydata = PlyData.read(ply_file_path) |
| | vert = plydata["vertex"] |
| | sorted_indices = np.argsort( |
| | -np.exp(vert["scale_0"] + vert["scale_1"] + vert["scale_2"]) |
| | / (1 + np.exp(-vert["opacity"])) |
| | ) |
| | buffer = BytesIO() |
| | for idx in sorted_indices: |
| | v = plydata["vertex"][idx] |
| | position = np.array([v["x"], v["y"], v["z"]], dtype=np.float32) |
| | scales = np.exp( |
| | np.array( |
| | [v["scale_0"], v["scale_1"], v["scale_2"]], |
| | dtype=np.float32, |
| | ) |
| | ) |
| | rot = np.array( |
| | [v["rot_0"], v["rot_1"], v["rot_2"], v["rot_3"]], |
| | dtype=np.float32, |
| | ) |
| | SH_C0 = 0.28209479177387814 |
| | color = np.array( |
| | [ |
| | 0.5 + SH_C0 * v["f_dc_0"], |
| | 0.5 + SH_C0 * v["f_dc_1"], |
| | 0.5 + SH_C0 * v["f_dc_2"], |
| | 1 / (1 + np.exp(-v["opacity"])), |
| | ] |
| | ) |
| | buffer.write(position.tobytes()) |
| | buffer.write(scales.tobytes()) |
| | buffer.write((color * 255).clip(0, 255).astype(np.uint8).tobytes()) |
| | buffer.write( |
| | ((rot / np.linalg.norm(rot)) * 128 + 128) |
| | .clip(0, 255) |
| | .astype(np.uint8) |
| | .tobytes() |
| | ) |
| |
|
| | return buffer.getvalue() |
| |
|
| | def save_splat_file(splat_data, output_path): |
| | with open(output_path, "wb") as f: |
| | f.write(splat_data) |
| |
|
| | def get_reconstructed_scene(outdir, image_files, model, device): |
| |
|
| | images = [process_image(img_path) for img_path in image_files] |
| | images = torch.stack(images, dim=0).unsqueeze(0).to(device) |
| | b, v, c, h, w = images.shape |
| |
|
| | assert c == 3, "Images must have 3 channels" |
| |
|
| | gaussians, pred_context_pose = model.inference((images + 1) * 0.5) |
| |
|
| | pred_all_extrinsic = pred_context_pose["extrinsic"] |
| | pred_all_intrinsic = pred_context_pose["intrinsic"] |
| | video, depth_colored = save_interpolated_video( |
| | pred_all_extrinsic, |
| | pred_all_intrinsic, |
| | b, |
| | h, |
| | w, |
| | gaussians, |
| | outdir, |
| | model.decoder, |
| | ) |
| | plyfile = os.path.join(outdir, "gaussians.ply") |
| | |
| |
|
| | export_ply( |
| | gaussians.means[0], |
| | gaussians.scales[0], |
| | gaussians.rotations[0], |
| | gaussians.harmonics[0], |
| | gaussians.opacities[0], |
| | Path(plyfile), |
| | save_sh_dc_only=True, |
| | ) |
| | |
| | |
| |
|
| | |
| | torch.cuda.empty_cache() |
| | return plyfile, video, depth_colored |
| |
|
| | def extract_images(input_images, session_id): |
| |
|
| | start_time = time.time() |
| | gc.collect() |
| | torch.cuda.empty_cache() |
| |
|
| | base_dir = os.path.join(os.environ["ANYSPLAT_PROCESSED"], session_id) |
| | target_dir = base_dir |
| | target_dir_images = os.path.join(target_dir, "images") |
| |
|
| | if os.path.exists(target_dir): |
| | shutil.rmtree(target_dir) |
| | |
| | os.makedirs(target_dir) |
| | os.makedirs(target_dir_images) |
| |
|
| | image_paths = [] |
| |
|
| | if input_images is not None: |
| | for file_data in input_images: |
| | if isinstance(file_data, dict) and "name" in file_data: |
| | file_path = file_data["name"] |
| | else: |
| | file_path = file_data |
| | dst_path = os.path.join(target_dir_images, os.path.basename(file_path)) |
| | shutil.copy(file_path, dst_path) |
| | image_paths.append(dst_path) |
| |
|
| | end_time = time.time() |
| | print( |
| | f"Files copied to {target_dir_images}; took {end_time - start_time:.3f} seconds" |
| | ) |
| | return target_dir, image_paths |
| | |
| |
|
| | def extract_frames(input_video, session_id): |
| |
|
| | start_time = time.time() |
| | gc.collect() |
| | torch.cuda.empty_cache() |
| |
|
| | base_dir = os.path.join(os.environ["ANYSPLAT_PROCESSED"], session_id) |
| | target_dir = base_dir |
| | target_dir_images = os.path.join(target_dir, "images") |
| |
|
| | if os.path.exists(target_dir): |
| | shutil.rmtree(target_dir) |
| | |
| | os.makedirs(target_dir) |
| | os.makedirs(target_dir_images) |
| |
|
| | image_paths = [] |
| |
|
| | if input_video is not None: |
| | if isinstance(input_video, dict) and "name" in input_video: |
| | video_path = input_video["name"] |
| | else: |
| | video_path = input_video |
| |
|
| | vs = cv2.VideoCapture(video_path) |
| | fps = vs.get(cv2.CAP_PROP_FPS) |
| | frame_interval = int(fps * 1) |
| |
|
| | count = 0 |
| | video_frame_num = 0 |
| | while True: |
| | gotit, frame = vs.read() |
| | if not gotit: |
| | break |
| | count += 1 |
| | if count % frame_interval == 0: |
| | image_path = os.path.join( |
| | target_dir_images, f"{video_frame_num:06}.png" |
| | ) |
| | cv2.imwrite(image_path, frame) |
| | image_paths.append(image_path) |
| | video_frame_num += 1 |
| |
|
| | |
| | image_paths = sorted(image_paths) |
| |
|
| | end_time = time.time() |
| | print( |
| | f"Files copied to {target_dir_images}; took {end_time - start_time:.3f} seconds" |
| | ) |
| | return target_dir, image_paths |
| |
|
| |
|
| | def update_gallery_on_video_upload(input_video, session_id): |
| | if not input_video: |
| | return None, None, None |
| | |
| | target_dir, image_paths = extract_frames(input_video, session_id) |
| | return None, target_dir, image_paths |
| |
|
| | def update_gallery_on_images_upload(input_images, session_id): |
| |
|
| | if not input_images: |
| | return None, None, None |
| | |
| | target_dir, image_paths = extract_images(input_images, session_id) |
| | return None, target_dir, image_paths |
| |
|
| | @spaces.GPU() |
| | def generate_splats_from_video(video_path, session_id=None): |
| | """ |
| | Perform Gaussian Splatting from Unconstrained Views a Given Video, using a Feed-forward model. |
| | |
| | Args: |
| | video_path (str): Path to the input video file on disk. |
| | Returns: |
| | plyfile: Path to the reconstructed 3D object from the given video. |
| | rgb_vid: Path the the interpolated rgb video, increasing the frame rate using guassian splatting and interpolation of frames. |
| | depth_vid: Path the the interpolated depth video, increasing the frame rate using guassian splatting and interpolation of frames. |
| | image_paths: A list of paths from extracted frame from the video that is used for training Gaussian Splatting. |
| | """ |
| | |
| | if session_id is None: |
| | session_id = uuid.uuid4().hex |
| |
|
| | images_folder, image_paths = extract_frames(video_path, session_id) |
| | plyfile, rgb_vid, depth_vid = generate_splats_from_images(image_paths, session_id) |
| |
|
| | return plyfile, rgb_vid, depth_vid, image_paths |
| | |
| | @spaces.GPU() |
| | def generate_splats_from_images(image_paths, session_id=None): |
| | """ |
| | Perform Gaussian Splatting from Unconstrained Views a Given Images , using a Feed-forward model. |
| | |
| | Args: |
| | image_paths (str): Path to the input image files on disk. |
| | Returns: |
| | plyfile: Path to the reconstructed 3D object from the given image files. |
| | rgb_vid: Path the the interpolated rgb video, increasing the frame rate using guassian splatting and interpolation of frames. |
| | depth_vid: Path the the interpolated depth video, increasing the frame rate using guassian splatting and interpolation of frames. |
| | """ |
| | processed_image_paths = [] |
| |
|
| | for file_data in image_paths: |
| | if isinstance(file_data, tuple): |
| | file_path, _ = file_data |
| | processed_image_paths.append(file_path) |
| | else: |
| | processed_image_paths.append(file_data) |
| |
|
| | image_paths = processed_image_paths |
| | print(image_paths) |
| |
|
| | if len(image_paths) == 1: |
| | image_paths.append(image_paths[0]) |
| |
|
| | if session_id is None: |
| | session_id = uuid.uuid4().hex |
| | |
| | start_time = time.time() |
| | gc.collect() |
| | torch.cuda.empty_cache() |
| |
|
| | base_dir = os.path.join(os.environ["ANYSPLAT_PROCESSED"], session_id) |
| |
|
| | print("Running run_model...") |
| | with torch.no_grad(): |
| | plyfile, rgb_vid, depth_vid = get_reconstructed_scene(base_dir, image_paths, model, device) |
| |
|
| | end_time = time.time() |
| | print(f"Total time: {end_time - start_time:.2f} seconds (including IO)") |
| |
|
| | return plyfile, rgb_vid, depth_vid |
| |
|
| | def cleanup(request: gr.Request): |
| |
|
| | sid = request.session_hash |
| | if sid: |
| | d1 = os.path.join(os.environ["ANYSPLAT_PROCESSED"], sid) |
| | shutil.rmtree(d1, ignore_errors=True) |
| | |
| | def start_session(request: gr.Request): |
| |
|
| | return request.session_hash |
| |
|
| |
|
| | if __name__ == "__main__": |
| | share = True |
| | device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| | |
| | |
| | model = AnySplat.from_pretrained( |
| | "lhjiang/anysplat" |
| | ) |
| | model = model.to(device) |
| | model.eval() |
| | for param in model.parameters(): |
| | param.requires_grad = False |
| |
|
| | css = """ |
| | #col-container { |
| | margin: 0 auto; |
| | max-width: 1024px; |
| | } |
| | """ |
| | with gr.Blocks(css=css, title="AnySplat Demo") as demo: |
| | session_state = gr.State() |
| | demo.load(start_session, outputs=[session_state]) |
| | |
| | target_dir_output = gr.Textbox(label="Target Dir", visible=False, value="None") |
| | is_example = gr.Textbox(label="is_example", visible=False, value="None") |
| | num_images = gr.Textbox(label="num_images", visible=False, value="None") |
| | dataset_name = gr.Textbox(label="dataset_name", visible=False, value="None") |
| | scene_name = gr.Textbox(label="scene_name", visible=False, value="None") |
| | image_type = gr.Textbox(label="image_type", visible=False, value="None") |
| |
|
| | with gr.Column(elem_id="col-container"): |
| |
|
| | gr.HTML( |
| | """ |
| | <div style="text-align: center;"> |
| | <p style="font-size:16px; display: inline; margin: 0;"> |
| | <strong>AnySplat</strong> – Feed-forward 3D Gaussian Splatting from Unconstrained Views |
| | </p> |
| | <a href="https://github.com/OpenRobotLab/AnySplat" style="display: inline-block; vertical-align: middle; margin-left: 0.5em;"> |
| | <img src="https://img.shields.io/badge/GitHub-Repo-blue" alt="GitHub Repo"> |
| | </a> |
| | </div> |
| | <div style="text-align: center;"> |
| | <strong>HF Space by:</strong> |
| | <a href="https://twitter.com/alexandernasa/" style="display: inline-block; vertical-align: middle; margin-left: 0.5em;"> |
| | <img src="https://img.shields.io/twitter/url/https/twitter.com/cloudposse.svg?style=social&label=Follow Me" alt="GitHub Repo"> |
| | </a> |
| | </div> |
| | """ |
| | ) |
| | |
| | with gr.Row(): |
| | with gr.Column(): |
| |
|
| | with gr.Tab("Video"): |
| | input_video = gr.Video(label="Upload Video", sources=["upload"], interactive=True, height=512) |
| | with gr.Tab("Images"): |
| | input_images = gr.File(file_count="multiple", label="Upload Files", height=512) |
| | |
| | submit_btn = gr.Button( |
| | "🖌️ Generate Gaussian Splat", scale=1, variant="primary" |
| | ) |
| | |
| | image_gallery = gr.Gallery( |
| | label="Preview", |
| | columns=4, |
| | height="300px", |
| | show_download_button=True, |
| | object_fit="contain", |
| | preview=True, |
| | ) |
| | |
| | with gr.Column(): |
| | with gr.Column(): |
| | gr.HTML( |
| | """ |
| | <p style="opacity: 0.6; font-style: italic;"> |
| | This might take a few seconds to load the 3D model |
| | </p> |
| | """ |
| | ) |
| | reconstruction_output = gr.Model3D( |
| | label="Ply Gaussian Model", |
| | height=512, |
| | zoom_speed=0.5, |
| | pan_speed=0.5, |
| | |
| | ) |
| | |
| | with gr.Row(): |
| | rgb_video = gr.Video( |
| | label="RGB Video", interactive=False, autoplay=True |
| | ) |
| | depth_video = gr.Video( |
| | label="Depth Video", |
| | interactive=False, |
| | autoplay=True, |
| | ) |
| | with gr.Row(): |
| | examples = [ |
| | ["examples/video/re10k_1eca36ec55b88fe4.mp4"], |
| | ["examples/video/spann3r.mp4"], |
| | ["examples/video/bungeenerf_colosseum.mp4"], |
| | ["examples/video/fox.mp4"], |
| | ["examples/video/vrnerf_apartment.mp4"], |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | ] |
| | |
| | gr.Examples( |
| | examples=examples, |
| | inputs=[ |
| | input_video |
| | ], |
| | outputs=[ |
| | reconstruction_output, |
| | rgb_video, |
| | depth_video, |
| | image_gallery |
| | ], |
| | fn=generate_splats_from_video, |
| | cache_examples=True, |
| | ) |
| | |
| | |
| | submit_btn.click( |
| | fn=generate_splats_from_images, |
| | inputs=[image_gallery, session_state], |
| | outputs=[reconstruction_output, rgb_video, depth_video]) |
| |
|
| | input_video.upload( |
| | fn=update_gallery_on_video_upload, |
| | inputs=[input_video, session_state], |
| | outputs=[reconstruction_output, target_dir_output, image_gallery], |
| | show_api=False |
| | ) |
| |
|
| | input_images.upload( |
| | fn=update_gallery_on_images_upload, |
| | inputs=[input_images, session_state], |
| | outputs=[reconstruction_output, target_dir_output, image_gallery], |
| | show_api=False |
| | ) |
| |
|
| | demo.unload(cleanup) |
| | demo.queue() |
| | demo.launch(show_error=True, share=True, mcp_server=True) |