| import spaces |
| import torch |
|
|
| print(f'torch version:{torch.__version__}') |
|
|
| import functools |
| import gc |
| import os |
|
|
| os.environ['TORCH_CUDA_ARCH_LIST'] = '9.0' |
|
|
| import subprocess |
|
|
| def sh(cmd): subprocess.check_call(cmd, shell=True) |
| |
| sh("pip install pytorch3d-0.7.9-cp310-cp310-linux_x86_64.whl") |
|
|
| import shutil |
| import sys |
| import tempfile |
| import time |
| from datetime import datetime |
| from pathlib import Path |
| import uuid |
|
|
| import cv2 |
| import gradio as gr |
|
|
| from huggingface_hub import hf_hub_download |
| from PIL import Image |
|
|
| sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
|
|
| from src.misc.image_io import save_interpolated_video |
| from src.model.model.anysplat import AnySplat |
| from src.model.ply_export import export_ply |
| from src.utils.image import process_image |
|
|
| os.environ["ANYSPLAT_PROCESSED"] = f"{os.getcwd()}/proprocess_results" |
|
|
| from plyfile import PlyData |
| import numpy as np |
| import argparse |
| from io import BytesIO |
|
|
|
|
| def process_ply_to_splat(ply_file_path): |
| plydata = PlyData.read(ply_file_path) |
| vert = plydata["vertex"] |
| sorted_indices = np.argsort( |
| -np.exp(vert["scale_0"] + vert["scale_1"] + vert["scale_2"]) |
| / (1 + np.exp(-vert["opacity"])) |
| ) |
| buffer = BytesIO() |
| for idx in sorted_indices: |
| v = plydata["vertex"][idx] |
| position = np.array([v["x"], v["y"], v["z"]], dtype=np.float32) |
| scales = np.exp( |
| np.array( |
| [v["scale_0"], v["scale_1"], v["scale_2"]], |
| dtype=np.float32, |
| ) |
| ) |
| rot = np.array( |
| [v["rot_0"], v["rot_1"], v["rot_2"], v["rot_3"]], |
| dtype=np.float32, |
| ) |
| SH_C0 = 0.28209479177387814 |
| color = np.array( |
| [ |
| 0.5 + SH_C0 * v["f_dc_0"], |
| 0.5 + SH_C0 * v["f_dc_1"], |
| 0.5 + SH_C0 * v["f_dc_2"], |
| 1 / (1 + np.exp(-v["opacity"])), |
| ] |
| ) |
| buffer.write(position.tobytes()) |
| buffer.write(scales.tobytes()) |
| buffer.write((color * 255).clip(0, 255).astype(np.uint8).tobytes()) |
| buffer.write( |
| ((rot / np.linalg.norm(rot)) * 128 + 128) |
| .clip(0, 255) |
| .astype(np.uint8) |
| .tobytes() |
| ) |
|
|
| return buffer.getvalue() |
|
|
| def save_splat_file(splat_data, output_path): |
| with open(output_path, "wb") as f: |
| f.write(splat_data) |
|
|
| def get_reconstructed_scene(outdir, image_files, model, device): |
|
|
| images = [process_image(img_path) for img_path in image_files] |
| images = torch.stack(images, dim=0).unsqueeze(0).to(device) |
| b, v, c, h, w = images.shape |
|
|
| assert c == 3, "Images must have 3 channels" |
|
|
| gaussians, pred_context_pose = model.inference((images + 1) * 0.5) |
|
|
| pred_all_extrinsic = pred_context_pose["extrinsic"] |
| pred_all_intrinsic = pred_context_pose["intrinsic"] |
| video, depth_colored = save_interpolated_video( |
| pred_all_extrinsic, |
| pred_all_intrinsic, |
| b, |
| h, |
| w, |
| gaussians, |
| outdir, |
| model.decoder, |
| ) |
| plyfile = os.path.join(outdir, "gaussians.ply") |
| |
|
|
| export_ply( |
| gaussians.means[0], |
| gaussians.scales[0], |
| gaussians.rotations[0], |
| gaussians.harmonics[0], |
| gaussians.opacities[0], |
| Path(plyfile), |
| save_sh_dc_only=True, |
| ) |
| |
| |
|
|
| |
| torch.cuda.empty_cache() |
| return plyfile, video, depth_colored |
|
|
| def extract_images(input_images, session_id): |
|
|
| start_time = time.time() |
| gc.collect() |
| torch.cuda.empty_cache() |
|
|
| base_dir = os.path.join(os.environ["ANYSPLAT_PROCESSED"], session_id) |
| target_dir = base_dir |
| target_dir_images = os.path.join(target_dir, "images") |
|
|
| if os.path.exists(target_dir): |
| shutil.rmtree(target_dir) |
| |
| os.makedirs(target_dir) |
| os.makedirs(target_dir_images) |
|
|
| image_paths = [] |
|
|
| if input_images is not None: |
| for file_data in input_images: |
| if isinstance(file_data, dict) and "name" in file_data: |
| file_path = file_data["name"] |
| else: |
| file_path = file_data |
| dst_path = os.path.join(target_dir_images, os.path.basename(file_path)) |
| shutil.copy(file_path, dst_path) |
| image_paths.append(dst_path) |
|
|
| end_time = time.time() |
| print( |
| f"Files copied to {target_dir_images}; took {end_time - start_time:.3f} seconds" |
| ) |
| return target_dir, image_paths |
| |
|
|
| def extract_frames(input_video, session_id): |
|
|
| start_time = time.time() |
| gc.collect() |
| torch.cuda.empty_cache() |
|
|
| base_dir = os.path.join(os.environ["ANYSPLAT_PROCESSED"], session_id) |
| target_dir = base_dir |
| target_dir_images = os.path.join(target_dir, "images") |
|
|
| if os.path.exists(target_dir): |
| shutil.rmtree(target_dir) |
| |
| os.makedirs(target_dir) |
| os.makedirs(target_dir_images) |
|
|
| image_paths = [] |
|
|
| if input_video is not None: |
| if isinstance(input_video, dict) and "name" in input_video: |
| video_path = input_video["name"] |
| else: |
| video_path = input_video |
|
|
| vs = cv2.VideoCapture(video_path) |
| fps = vs.get(cv2.CAP_PROP_FPS) |
| frame_interval = int(fps * 1) |
|
|
| count = 0 |
| video_frame_num = 0 |
| while True: |
| gotit, frame = vs.read() |
| if not gotit: |
| break |
| count += 1 |
| if count % frame_interval == 0: |
| image_path = os.path.join( |
| target_dir_images, f"{video_frame_num:06}.png" |
| ) |
| cv2.imwrite(image_path, frame) |
| image_paths.append(image_path) |
| video_frame_num += 1 |
|
|
| |
| image_paths = sorted(image_paths) |
|
|
| end_time = time.time() |
| print( |
| f"Files copied to {target_dir_images}; took {end_time - start_time:.3f} seconds" |
| ) |
| return target_dir, image_paths |
|
|
|
|
| def update_gallery_on_video_upload(input_video, session_id): |
| if not input_video: |
| return None, None, None |
| |
| target_dir, image_paths = extract_frames(input_video, session_id) |
| return None, target_dir, image_paths |
|
|
| def update_gallery_on_images_upload(input_images, session_id): |
|
|
| if not input_images: |
| return None, None, None |
| |
| target_dir, image_paths = extract_images(input_images, session_id) |
| return None, target_dir, image_paths |
|
|
| @spaces.GPU() |
| def generate_splats_from_video(video_path, session_id=None): |
| """ |
| Perform Gaussian Splatting from Unconstrained Views a Given Video, using a Feed-forward model. |
| |
| Args: |
| video_path (str): Path to the input video file on disk. |
| Returns: |
| plyfile: Path to the reconstructed 3D object from the given video. |
| rgb_vid: Path the the interpolated rgb video, increasing the frame rate using guassian splatting and interpolation of frames. |
| depth_vid: Path the the interpolated depth video, increasing the frame rate using guassian splatting and interpolation of frames. |
| image_paths: A list of paths from extracted frame from the video that is used for training Gaussian Splatting. |
| """ |
| |
| if session_id is None: |
| session_id = uuid.uuid4().hex |
|
|
| images_folder, image_paths = extract_frames(video_path, session_id) |
| plyfile, rgb_vid, depth_vid = generate_splats_from_images(image_paths, session_id) |
|
|
| return plyfile, rgb_vid, depth_vid, image_paths |
| |
| @spaces.GPU() |
| def generate_splats_from_images(image_paths, session_id=None): |
| """ |
| Perform Gaussian Splatting from Unconstrained Views a Given Images , using a Feed-forward model. |
| |
| Args: |
| image_paths (str): Path to the input image files on disk. |
| Returns: |
| plyfile: Path to the reconstructed 3D object from the given image files. |
| rgb_vid: Path the the interpolated rgb video, increasing the frame rate using guassian splatting and interpolation of frames. |
| depth_vid: Path the the interpolated depth video, increasing the frame rate using guassian splatting and interpolation of frames. |
| """ |
| processed_image_paths = [] |
|
|
| for file_data in image_paths: |
| if isinstance(file_data, tuple): |
| file_path, _ = file_data |
| processed_image_paths.append(file_path) |
| else: |
| processed_image_paths.append(file_data) |
|
|
| image_paths = processed_image_paths |
| print(image_paths) |
|
|
| if len(image_paths) == 1: |
| image_paths.append(image_paths[0]) |
|
|
| if session_id is None: |
| session_id = uuid.uuid4().hex |
| |
| start_time = time.time() |
| gc.collect() |
| torch.cuda.empty_cache() |
|
|
| base_dir = os.path.join(os.environ["ANYSPLAT_PROCESSED"], session_id) |
|
|
| print("Running run_model...") |
| with torch.no_grad(): |
| plyfile, rgb_vid, depth_vid = get_reconstructed_scene(base_dir, image_paths, model, device) |
|
|
| end_time = time.time() |
| print(f"Total time: {end_time - start_time:.2f} seconds (including IO)") |
|
|
| return plyfile, rgb_vid, depth_vid |
|
|
| def cleanup(request: gr.Request): |
|
|
| sid = request.session_hash |
| if sid: |
| d1 = os.path.join(os.environ["ANYSPLAT_PROCESSED"], sid) |
| shutil.rmtree(d1, ignore_errors=True) |
| |
| def start_session(request: gr.Request): |
|
|
| return request.session_hash |
|
|
|
|
| if __name__ == "__main__": |
| share = True |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| |
| |
| model = AnySplat.from_pretrained( |
| "lhjiang/anysplat" |
| ) |
| model = model.to(device) |
| model.eval() |
| for param in model.parameters(): |
| param.requires_grad = False |
|
|
| css = """ |
| #col-container { |
| margin: 0 auto; |
| max-width: 1024px; |
| } |
| """ |
| with gr.Blocks(css=css, title="AnySplat Demo") as demo: |
| session_state = gr.State() |
| demo.load(start_session, outputs=[session_state]) |
| |
| target_dir_output = gr.Textbox(label="Target Dir", visible=False, value="None") |
| is_example = gr.Textbox(label="is_example", visible=False, value="None") |
| num_images = gr.Textbox(label="num_images", visible=False, value="None") |
| dataset_name = gr.Textbox(label="dataset_name", visible=False, value="None") |
| scene_name = gr.Textbox(label="scene_name", visible=False, value="None") |
| image_type = gr.Textbox(label="image_type", visible=False, value="None") |
|
|
| with gr.Column(elem_id="col-container"): |
|
|
| gr.HTML( |
| """ |
| <div style="text-align: center;"> |
| <p style="font-size:16px; display: inline; margin: 0;"> |
| <strong>AnySplat</strong> – Feed-forward 3D Gaussian Splatting from Unconstrained Views |
| </p> |
| <a href="https://github.com/OpenRobotLab/AnySplat" style="display: inline-block; vertical-align: middle; margin-left: 0.5em;"> |
| <img src="https://img.shields.io/badge/GitHub-Repo-blue" alt="GitHub Repo"> |
| </a> |
| </div> |
| <div style="text-align: center;"> |
| <strong>HF Space by:</strong> |
| <a href="https://twitter.com/alexandernasa/" style="display: inline-block; vertical-align: middle; margin-left: 0.5em;"> |
| <img src="https://img.shields.io/twitter/url/https/twitter.com/cloudposse.svg?style=social&label=Follow Me" alt="GitHub Repo"> |
| </a> |
| </div> |
| """ |
| ) |
| |
| with gr.Row(): |
| with gr.Column(): |
|
|
| with gr.Tab("Video"): |
| input_video = gr.Video(label="Upload Video", sources=["upload"], interactive=True, height=512) |
| with gr.Tab("Images"): |
| input_images = gr.File(file_count="multiple", label="Upload Files", height=512) |
| |
| submit_btn = gr.Button( |
| "🖌️ Generate Gaussian Splat", scale=1, variant="primary" |
| ) |
| |
| image_gallery = gr.Gallery( |
| label="Preview", |
| columns=4, |
| height="300px", |
| show_download_button=True, |
| object_fit="contain", |
| preview=True, |
| ) |
| |
| with gr.Column(): |
| with gr.Column(): |
| gr.HTML( |
| """ |
| <p style="opacity: 0.6; font-style: italic;"> |
| This might take a few seconds to load the 3D model |
| </p> |
| """ |
| ) |
| reconstruction_output = gr.Model3D( |
| label="Ply Gaussian Model", |
| height=512, |
| zoom_speed=0.5, |
| pan_speed=0.5, |
| |
| ) |
| |
| with gr.Row(): |
| rgb_video = gr.Video( |
| label="RGB Video", interactive=False, autoplay=True |
| ) |
| depth_video = gr.Video( |
| label="Depth Video", |
| interactive=False, |
| autoplay=True, |
| ) |
| with gr.Row(): |
| examples = [ |
| ["examples/video/re10k_1eca36ec55b88fe4.mp4"], |
| ["examples/video/spann3r.mp4"], |
| ["examples/video/bungeenerf_colosseum.mp4"], |
| ["examples/video/fox.mp4"], |
| ["examples/video/vrnerf_apartment.mp4"], |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| ] |
| |
| gr.Examples( |
| examples=examples, |
| inputs=[ |
| input_video |
| ], |
| outputs=[ |
| reconstruction_output, |
| rgb_video, |
| depth_video, |
| image_gallery |
| ], |
| fn=generate_splats_from_video, |
| cache_examples=True, |
| ) |
| |
| |
| submit_btn.click( |
| fn=generate_splats_from_images, |
| inputs=[image_gallery, session_state], |
| outputs=[reconstruction_output, rgb_video, depth_video]) |
|
|
| input_video.upload( |
| fn=update_gallery_on_video_upload, |
| inputs=[input_video, session_state], |
| outputs=[reconstruction_output, target_dir_output, image_gallery], |
| show_api=False |
| ) |
|
|
| input_images.upload( |
| fn=update_gallery_on_images_upload, |
| inputs=[input_images, session_state], |
| outputs=[reconstruction_output, target_dir_output, image_gallery], |
| show_api=False |
| ) |
|
|
| demo.unload(cleanup) |
| demo.queue() |
| demo.launch(show_error=True, share=True, mcp_server=True) |