|
|
from typing import List, Literal |
|
|
from pathlib import Path |
|
|
from functools import partial |
|
|
import spaces |
|
|
import gradio as gr |
|
|
import numpy as np |
|
|
import torch |
|
|
from torchvision.datasets.utils import download_and_extract_archive |
|
|
from einops import repeat |
|
|
from omegaconf import OmegaConf |
|
|
from modeling.pipeline import VMemPipeline |
|
|
from diffusers.utils import export_to_video |
|
|
from navigation import Navigator |
|
|
from utils import tensor_to_pil, get_default_intrinsics, load_img_and_K, transform_img_and_K |
|
|
import os |
|
|
import shutil |
|
|
|
|
|
|
|
|
CONFIG_PATH = "configs/inference/inference.yaml" |
|
|
CONFIG = OmegaConf.load(CONFIG_PATH) |
|
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
MODEL = VMemPipeline(CONFIG, DEVICE) |
|
|
NAVIGATORS = [] |
|
|
|
|
|
|
|
|
NAVIGATION_FPS = 13 |
|
|
WIDTH = 576 |
|
|
HEIGHT = 576 |
|
|
|
|
|
|
|
|
def clear_visualization_directory(): |
|
|
""" |
|
|
Clear all contents from the visualization directory to prevent users from seeing |
|
|
each other's generated images. |
|
|
""" |
|
|
viz_dir = "./visualization" |
|
|
try: |
|
|
if os.path.exists(viz_dir): |
|
|
shutil.rmtree(viz_dir) |
|
|
os.makedirs(viz_dir, exist_ok=True) |
|
|
print(f"Cleared visualization directory: {viz_dir}") |
|
|
except Exception as e: |
|
|
print(f"Warning: Could not clear visualization directory: {e}") |
|
|
|
|
|
|
|
|
IMAGE_PATHS = ['test_samples/oxford.jpg', |
|
|
'test_samples/open_door.jpg', |
|
|
'test_samples/living_room.jpg', |
|
|
'test_samples/arc_de_tromphe.jpeg', |
|
|
'test_samples/changi.jpg', |
|
|
'test_samples/jesus.jpg',] |
|
|
|
|
|
|
|
|
if not IMAGE_PATHS: |
|
|
def create_placeholder_images(num_samples=5, height=HEIGHT, width=WIDTH): |
|
|
"""Create placeholder images for the demo""" |
|
|
images = [] |
|
|
for i in range(num_samples): |
|
|
img = np.zeros((height, width, 3), dtype=np.uint8) |
|
|
for h in range(height): |
|
|
for w in range(width): |
|
|
img[h, w, 0] = int(255 * h / height) |
|
|
img[h, w, 1] = int(255 * w / width) |
|
|
img[h, w, 2] = int(255 * (i+1) / num_samples) |
|
|
images.append(img) |
|
|
return images |
|
|
|
|
|
|
|
|
def create_placeholder_video_and_poses(num_samples=5, num_frames=1, height=HEIGHT, width=WIDTH): |
|
|
"""Create placeholder videos and poses for the demo""" |
|
|
videos = [] |
|
|
poses = [] |
|
|
|
|
|
for i in range(num_samples): |
|
|
|
|
|
frames = [] |
|
|
for j in range(num_frames): |
|
|
|
|
|
img = np.zeros((height, width, 3), dtype=np.uint8) |
|
|
for h in range(height): |
|
|
for w in range(width): |
|
|
img[h, w, 0] = int(255 * h / height) |
|
|
img[h, w, 1] = int(255 * w / width) |
|
|
img[h, w, 2] = int(255 * (i+1) / num_samples) |
|
|
|
|
|
|
|
|
frame = torch.from_numpy(img.transpose(2, 0, 1)).float() / 255.0 |
|
|
frames.append(frame) |
|
|
|
|
|
video = torch.stack(frames) |
|
|
videos.append(video) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
poses.append(torch.eye(4).unsqueeze(0).repeat(num_frames, 1, 1)) |
|
|
|
|
|
return videos, poses |
|
|
|
|
|
first_frame_list = create_placeholder_images(num_samples=5) |
|
|
video_list, poses_list = create_placeholder_video_and_poses(num_samples=5) |
|
|
|
|
|
|
|
|
def load_image_for_navigation(image_path): |
|
|
"""Load image from path and prepare for navigation""" |
|
|
|
|
|
image, _ = load_img_and_K(image_path, None, K=None, device=DEVICE) |
|
|
|
|
|
|
|
|
config = OmegaConf.load(CONFIG_PATH) |
|
|
image, _ = transform_img_and_K(image, (config.model.height, config.model.width), mode="crop", K=None) |
|
|
|
|
|
|
|
|
video = image |
|
|
pose = torch.eye(4).unsqueeze(0) |
|
|
|
|
|
return { |
|
|
"image": tensor_to_pil(image), |
|
|
"video": video, |
|
|
"pose": pose |
|
|
} |
|
|
|
|
|
|
|
|
class CustomProgressBar: |
|
|
def __init__(self, pbar): |
|
|
self.pbar = pbar |
|
|
|
|
|
def set_postfix(self, **kwargs): |
|
|
pass |
|
|
|
|
|
def __getattr__(self, attr): |
|
|
return getattr(self.pbar, attr) |
|
|
|
|
|
def get_duration_navigate_video(video: torch.Tensor, |
|
|
poses: torch.Tensor, |
|
|
x_angle: float, |
|
|
y_angle: float, |
|
|
distance: float |
|
|
): |
|
|
|
|
|
base_duration = 15 |
|
|
|
|
|
|
|
|
if abs(x_angle) > 20 or abs(y_angle) > 30: |
|
|
base_duration += 10 |
|
|
|
|
|
if distance > 100: |
|
|
base_duration += 10 |
|
|
|
|
|
|
|
|
base_duration += min(10, len(video)) |
|
|
|
|
|
return base_duration |
|
|
|
|
|
@spaces.GPU(duration=get_duration_navigate_video) |
|
|
@torch.autocast("cuda") |
|
|
@torch.no_grad() |
|
|
def navigate_video( |
|
|
video: torch.Tensor, |
|
|
poses: torch.Tensor, |
|
|
x_angle: float, |
|
|
y_angle: float, |
|
|
distance: float, |
|
|
): |
|
|
""" |
|
|
Generate new video frames by navigating in the 3D scene. |
|
|
This function uses the Navigator class from navigation.py to handle movement: |
|
|
- y_angle parameter controls left/right turning (turn_left/turn_right methods) |
|
|
- distance parameter controls forward movement (move_forward method) |
|
|
- x_angle parameter controls vertical angle (not directly implemented in Navigator) |
|
|
|
|
|
Each Navigator instance is stored based on the video session to maintain state. |
|
|
""" |
|
|
try: |
|
|
|
|
|
initial_frame = tensor_to_pil(video[0]) |
|
|
|
|
|
|
|
|
if len(NAVIGATORS) == 0: |
|
|
|
|
|
NAVIGATORS.append(Navigator(MODEL, step_size=0.1, num_interpolation_frames=4)) |
|
|
|
|
|
|
|
|
initial_pose = poses[0].cpu().numpy().reshape(4, 4) |
|
|
|
|
|
|
|
|
initial_K = np.array(get_default_intrinsics()[0]) |
|
|
|
|
|
|
|
|
NAVIGATORS[0].initialize(initial_frame, initial_pose, initial_K) |
|
|
|
|
|
navigator = NAVIGATORS[0] |
|
|
|
|
|
|
|
|
new_frames = [] |
|
|
|
|
|
|
|
|
|
|
|
if abs(x_angle) > 0: |
|
|
|
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
if abs(y_angle) > 0: |
|
|
|
|
|
if y_angle > 0: |
|
|
new_frames = navigator.turn_left(abs(y_angle//2)) |
|
|
else: |
|
|
new_frames = navigator.turn_right(abs(y_angle//2)) |
|
|
|
|
|
elif distance > 0: |
|
|
|
|
|
steps = max(1, int(distance / 10)) |
|
|
new_frames = navigator.move_forward(steps) |
|
|
elif distance < 0: |
|
|
|
|
|
steps = max(1, int(abs(distance) / 10)) |
|
|
new_frames = navigator.move_backward(steps) |
|
|
|
|
|
if not new_frames: |
|
|
|
|
|
return video, poses, tensor_to_pil(video[-1]), export_to_video([tensor_to_pil(video[i]) for i in range(len(video))], fps=NAVIGATION_FPS), [(tensor_to_pil(video[i]), f"t={i}") for i in range(len(video))] |
|
|
|
|
|
|
|
|
new_frame_tensors = [] |
|
|
for frame in new_frames: |
|
|
|
|
|
frame_np = np.array(frame) / 255.0 |
|
|
|
|
|
frame_tensor = torch.from_numpy(frame_np.transpose(2, 0, 1)).float() * 2.0 - 1.0 |
|
|
new_frame_tensors.append(frame_tensor) |
|
|
|
|
|
new_frames_tensor = torch.stack(new_frame_tensors) |
|
|
|
|
|
|
|
|
current_pose = navigator.current_pose |
|
|
new_poses = torch.from_numpy(current_pose).float().unsqueeze(0).repeat(len(new_frames), 1, 1) |
|
|
|
|
|
|
|
|
new_poses = new_poses.view(len(new_frames), 4, 4) |
|
|
|
|
|
|
|
|
updated_video = torch.cat([video.cpu(), new_frames_tensor], dim=0) |
|
|
updated_poses = torch.cat([poses.cpu(), new_poses], dim=0) |
|
|
|
|
|
|
|
|
all_images = [(tensor_to_pil(updated_video[i]), f"t={i}") for i in range(len(updated_video))] |
|
|
updated_video_pil = [tensor_to_pil(updated_video[i]) for i in range(len(updated_video))] |
|
|
|
|
|
return ( |
|
|
updated_video, |
|
|
updated_poses, |
|
|
tensor_to_pil(updated_video[-1]), |
|
|
export_to_video(updated_video_pil, fps=NAVIGATION_FPS), |
|
|
all_images, |
|
|
) |
|
|
except Exception as e: |
|
|
print(f"Error in navigate_video: {e}") |
|
|
gr.Warning(f"Navigation error: {e}") |
|
|
|
|
|
current_frame = tensor_to_pil(video[-1]) if len(video) > 0 else None |
|
|
all_frames = [(tensor_to_pil(video[i]), f"t={i}") for i in range(len(video))] |
|
|
video_frames = [tensor_to_pil(video[i]) for i in range(len(video))] |
|
|
video_output = export_to_video(video_frames, fps=NAVIGATION_FPS) if video_frames else None |
|
|
return video, poses, current_frame, video_output, all_frames |
|
|
|
|
|
|
|
|
def undo_navigation( |
|
|
video: torch.Tensor, |
|
|
poses: torch.Tensor, |
|
|
): |
|
|
""" |
|
|
Undo the last navigation step by removing the last set of frames. |
|
|
Uses the Navigator's undo method which in turn uses the pipeline's undo_latest_move |
|
|
to properly handle surfels and state management. |
|
|
""" |
|
|
if len(NAVIGATORS) > 0: |
|
|
navigator = NAVIGATORS[0] |
|
|
|
|
|
|
|
|
success = navigator.undo() |
|
|
|
|
|
if success: |
|
|
|
|
|
|
|
|
updated_video = video[:len(navigator.frames)] |
|
|
updated_poses = poses[:len(navigator.frames)] |
|
|
|
|
|
|
|
|
all_images = [(tensor_to_pil(updated_video[i]), f"t={i}") for i in range(len(updated_video))] |
|
|
|
|
|
return ( |
|
|
updated_video, |
|
|
updated_poses, |
|
|
tensor_to_pil(updated_video[-1]), |
|
|
export_to_video([tensor_to_pil(updated_video[i]) for i in range(len(updated_video))], fps=NAVIGATION_FPS), |
|
|
all_images, |
|
|
) |
|
|
else: |
|
|
gr.Warning("You have no moves left to undo!") |
|
|
else: |
|
|
gr.Warning("No navigation session available!") |
|
|
|
|
|
|
|
|
all_images = [(tensor_to_pil(video[i]), f"t={i}") for i in range(len(video))] |
|
|
|
|
|
return ( |
|
|
video, |
|
|
poses, |
|
|
tensor_to_pil(video[-1]), |
|
|
export_to_video([tensor_to_pil(video[i]) for i in range(len(video))], fps=NAVIGATION_FPS), |
|
|
all_images, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def render_demonstrate( |
|
|
s: Literal["Selection", "Generation"], |
|
|
idx: int, |
|
|
demonstrate_stage: gr.State, |
|
|
demonstrate_selected_index: gr.State, |
|
|
demonstrate_current_video: gr.State, |
|
|
demonstrate_current_poses: gr.State |
|
|
): |
|
|
gr.Markdown( |
|
|
""" |
|
|
## Single Image → Consistent Scene Navigation |
|
|
> #### _Select an image and navigate through the scene by controlling camera movements._ |
|
|
""", |
|
|
elem_classes=["task-title"] |
|
|
) |
|
|
match s: |
|
|
case "Selection": |
|
|
with gr.Group(): |
|
|
|
|
|
with gr.Group(elem_classes=["gradio-box"]): |
|
|
gr.Markdown("### Upload Your Own Image") |
|
|
gr.Markdown("_Upload an image to navigate through its 3D scene_") |
|
|
with gr.Row(): |
|
|
with gr.Column(scale=3): |
|
|
upload_image = gr.Image( |
|
|
label="Upload an image", |
|
|
type="filepath", |
|
|
height=300, |
|
|
elem_id="upload-image" |
|
|
) |
|
|
with gr.Column(scale=1): |
|
|
gr.Markdown("#### Instructions:") |
|
|
gr.Markdown("1. Upload a clear, high-quality image") |
|
|
gr.Markdown("2. Images with distinct visual features work best") |
|
|
gr.Markdown("3. Landscape or architectural scenes are ideal") |
|
|
upload_btn = gr.Button("Start Navigation", variant="primary", size="lg") |
|
|
|
|
|
def process_uploaded_image(image_path): |
|
|
if image_path is None: |
|
|
gr.Warning("Please upload an image first") |
|
|
return "Selection", None, None, None |
|
|
try: |
|
|
|
|
|
clear_visualization_directory() |
|
|
|
|
|
|
|
|
result = load_image_for_navigation(image_path) |
|
|
|
|
|
|
|
|
global NAVIGATORS |
|
|
NAVIGATORS = [] |
|
|
|
|
|
return ( |
|
|
"Generation", |
|
|
None, |
|
|
result["video"], |
|
|
result["pose"], |
|
|
) |
|
|
except Exception as e: |
|
|
print(f"Error in process_uploaded_image: {e}") |
|
|
gr.Warning(f"Error processing uploaded image: {e}") |
|
|
return "Selection", None, None, None |
|
|
|
|
|
upload_btn.click( |
|
|
fn=process_uploaded_image, |
|
|
inputs=[upload_image], |
|
|
outputs=[demonstrate_stage, demonstrate_selected_index, demonstrate_current_video, demonstrate_current_poses] |
|
|
) |
|
|
|
|
|
gr.Markdown("### Or Choose From Our Examples") |
|
|
|
|
|
image_captions = { |
|
|
|
|
|
'test_samples/oxford.jpg': 'Oxford University', |
|
|
'test_samples/open_door.jpg': 'Bedroom Interior', |
|
|
'test_samples/living_room.jpg': 'Living Room', |
|
|
'test_samples/arc_de_tromphe.jpeg': 'Arc de Triomphe', |
|
|
'test_samples/jesus.jpg': 'Jesus College', |
|
|
'test_samples/changi.jpg': 'Changi Airport', |
|
|
} |
|
|
|
|
|
|
|
|
gallery_images = [] |
|
|
for img_path in IMAGE_PATHS: |
|
|
try: |
|
|
|
|
|
caption = image_captions.get(img_path, os.path.basename(img_path)) |
|
|
gallery_images.append((img_path, caption)) |
|
|
except Exception as e: |
|
|
print(f"Error loading image {img_path}: {e}") |
|
|
|
|
|
|
|
|
demonstrate_image_gallery = gr.Gallery( |
|
|
value=gallery_images, |
|
|
label="Select an Image to Start Navigation", |
|
|
columns=len(gallery_images), |
|
|
height=400, |
|
|
allow_preview=True, |
|
|
preview=False, |
|
|
elem_id="navigation-gallery" |
|
|
) |
|
|
|
|
|
gr.Markdown("_Click on an image to begin navigation_") |
|
|
|
|
|
def start_navigation(evt: gr.SelectData): |
|
|
try: |
|
|
|
|
|
clear_visualization_directory() |
|
|
|
|
|
|
|
|
selected_path = IMAGE_PATHS[evt.index] |
|
|
|
|
|
|
|
|
result = load_image_for_navigation(selected_path) |
|
|
|
|
|
|
|
|
global NAVIGATORS |
|
|
NAVIGATORS = [] |
|
|
|
|
|
return ( |
|
|
"Generation", |
|
|
evt.index, |
|
|
result["video"], |
|
|
result["pose"], |
|
|
) |
|
|
except Exception as e: |
|
|
print(f"Error in start_navigation: {e}") |
|
|
gr.Warning(f"Error starting navigation: {e}") |
|
|
return "Selection", None, None, None |
|
|
|
|
|
demonstrate_image_gallery.select( |
|
|
fn=start_navigation, |
|
|
inputs=None, |
|
|
outputs=[demonstrate_stage, demonstrate_selected_index, demonstrate_current_video, demonstrate_current_poses] |
|
|
) |
|
|
|
|
|
case "Generation": |
|
|
with gr.Row(): |
|
|
with gr.Column(scale=3): |
|
|
with gr.Row(): |
|
|
demonstrate_current_view = gr.Image( |
|
|
label="Current View", |
|
|
width=256, |
|
|
height=256, |
|
|
) |
|
|
demonstrate_video = gr.Video( |
|
|
label="Generated Video", |
|
|
width=256, |
|
|
height=256, |
|
|
autoplay=True, |
|
|
loop=True, |
|
|
show_share_button=True, |
|
|
show_download_button=True, |
|
|
) |
|
|
|
|
|
demonstrate_generated_gallery = gr.Gallery( |
|
|
value=[], |
|
|
label="Generated Frames", |
|
|
columns=[6], |
|
|
) |
|
|
|
|
|
|
|
|
if idx is not None: |
|
|
try: |
|
|
selected_path = IMAGE_PATHS[idx] |
|
|
result = load_image_for_navigation(selected_path) |
|
|
demonstrate_current_view.value = result["image"] |
|
|
except Exception as e: |
|
|
print(f"Error initializing current view: {e}") |
|
|
|
|
|
with gr.Column(): |
|
|
gr.Markdown("### Navigation Controls ↓") |
|
|
with gr.Accordion("Instructions", open=False): |
|
|
gr.Markdown(""" |
|
|
- **The model will predict the next few frames based on your camera movements. Repeat the process to continue navigating through the scene.** |
|
|
- **Use the navigation controls to move forward/backward and turn left/right.** |
|
|
- **At the end of your navigation, you can save your camera path for later use.** |
|
|
|
|
|
""") |
|
|
|
|
|
with gr.Group(): |
|
|
gr.Markdown("_**Select a direction to move:**_") |
|
|
|
|
|
with gr.Row(elem_id="basic-controls"): |
|
|
gr.Button( |
|
|
"↰20°\nVeer", |
|
|
size="sm", |
|
|
min_width=0, |
|
|
variant="primary", |
|
|
).click( |
|
|
fn=partial( |
|
|
navigate_video, |
|
|
x_angle=0, |
|
|
y_angle=20, |
|
|
distance=0, |
|
|
), |
|
|
inputs=[ |
|
|
demonstrate_current_video, |
|
|
demonstrate_current_poses, |
|
|
], |
|
|
outputs=[ |
|
|
demonstrate_current_video, |
|
|
demonstrate_current_poses, |
|
|
demonstrate_current_view, |
|
|
demonstrate_video, |
|
|
demonstrate_generated_gallery, |
|
|
], |
|
|
) |
|
|
|
|
|
gr.Button( |
|
|
"↖10°\nTurn", |
|
|
size="sm", |
|
|
min_width=0, |
|
|
variant="primary", |
|
|
).click( |
|
|
fn=partial( |
|
|
navigate_video, |
|
|
x_angle=0, |
|
|
y_angle=10, |
|
|
distance=0, |
|
|
), |
|
|
inputs=[ |
|
|
demonstrate_current_video, |
|
|
demonstrate_current_poses, |
|
|
], |
|
|
outputs=[ |
|
|
demonstrate_current_video, |
|
|
demonstrate_current_poses, |
|
|
demonstrate_current_view, |
|
|
demonstrate_video, |
|
|
demonstrate_generated_gallery, |
|
|
], |
|
|
) |
|
|
|
|
|
gr.Button( |
|
|
"↗10°\nTurn", |
|
|
size="sm", |
|
|
min_width=0, |
|
|
variant="primary", |
|
|
).click( |
|
|
fn=partial( |
|
|
navigate_video, |
|
|
x_angle=0, |
|
|
y_angle=-10, |
|
|
distance=0, |
|
|
), |
|
|
inputs=[ |
|
|
demonstrate_current_video, |
|
|
demonstrate_current_poses, |
|
|
], |
|
|
outputs=[ |
|
|
demonstrate_current_video, |
|
|
demonstrate_current_poses, |
|
|
demonstrate_current_view, |
|
|
demonstrate_video, |
|
|
demonstrate_generated_gallery, |
|
|
], |
|
|
) |
|
|
gr.Button( |
|
|
"↱\n20° Veer", |
|
|
size="sm", |
|
|
min_width=0, |
|
|
variant="primary", |
|
|
).click( |
|
|
fn=partial( |
|
|
navigate_video, |
|
|
x_angle=0, |
|
|
y_angle=-20, |
|
|
distance=0, |
|
|
), |
|
|
inputs=[ |
|
|
demonstrate_current_video, |
|
|
demonstrate_current_poses, |
|
|
], |
|
|
outputs=[ |
|
|
demonstrate_current_video, |
|
|
demonstrate_current_poses, |
|
|
demonstrate_current_view, |
|
|
demonstrate_video, |
|
|
demonstrate_generated_gallery, |
|
|
], |
|
|
) |
|
|
|
|
|
|
|
|
with gr.Row(elem_id="forward-backward-controls"): |
|
|
gr.Button( |
|
|
"↓\nBackward", |
|
|
size="sm", |
|
|
min_width=0, |
|
|
variant="secondary", |
|
|
).click( |
|
|
fn=partial( |
|
|
navigate_video, |
|
|
x_angle=0, |
|
|
y_angle=0, |
|
|
distance=-10, |
|
|
), |
|
|
inputs=[ |
|
|
demonstrate_current_video, |
|
|
demonstrate_current_poses, |
|
|
], |
|
|
outputs=[ |
|
|
demonstrate_current_video, |
|
|
demonstrate_current_poses, |
|
|
demonstrate_current_view, |
|
|
demonstrate_video, |
|
|
demonstrate_generated_gallery, |
|
|
], |
|
|
) |
|
|
|
|
|
gr.Button( |
|
|
"↑\nForward", |
|
|
size="sm", |
|
|
min_width=0, |
|
|
variant="secondary", |
|
|
).click( |
|
|
fn=partial( |
|
|
navigate_video, |
|
|
x_angle=0, |
|
|
y_angle=0, |
|
|
distance=10, |
|
|
), |
|
|
inputs=[ |
|
|
demonstrate_current_video, |
|
|
demonstrate_current_poses, |
|
|
], |
|
|
outputs=[ |
|
|
demonstrate_current_video, |
|
|
demonstrate_current_poses, |
|
|
demonstrate_current_view, |
|
|
demonstrate_video, |
|
|
demonstrate_generated_gallery, |
|
|
], |
|
|
) |
|
|
gr.Markdown("---") |
|
|
with gr.Group(): |
|
|
gr.Markdown("_**Navigation controls:**_") |
|
|
with gr.Row(): |
|
|
gr.Button("Undo Last Move", variant="huggingface").click( |
|
|
fn=undo_navigation, |
|
|
inputs=[demonstrate_current_video, demonstrate_current_poses], |
|
|
outputs=[ |
|
|
demonstrate_current_video, |
|
|
demonstrate_current_poses, |
|
|
demonstrate_current_view, |
|
|
demonstrate_video, |
|
|
demonstrate_generated_gallery, |
|
|
], |
|
|
) |
|
|
|
|
|
|
|
|
def save_camera_poses(video, poses): |
|
|
if len(NAVIGATORS) > 0: |
|
|
navigator = NAVIGATORS[0] |
|
|
|
|
|
os.makedirs("./visualization", exist_ok=True) |
|
|
save_path = f"./visualization/transforms_{len(navigator.frames)}_frames.json" |
|
|
navigator.save_camera_poses(save_path) |
|
|
return gr.Info(f"Camera poses saved to {save_path}") |
|
|
return gr.Warning("No navigation instance found") |
|
|
|
|
|
gr.Button("Save Camera", variant="huggingface").click( |
|
|
fn=save_camera_poses, |
|
|
inputs=[demonstrate_current_video, demonstrate_current_poses], |
|
|
outputs=[] |
|
|
) |
|
|
|
|
|
|
|
|
def reset_navigation(): |
|
|
|
|
|
global NAVIGATORS |
|
|
NAVIGATORS = [] |
|
|
return "Selection", None, None, None |
|
|
|
|
|
gr.Button("Choose New Image", variant="secondary").click( |
|
|
fn=reset_navigation, |
|
|
inputs=[], |
|
|
outputs=[demonstrate_stage, demonstrate_selected_index, demonstrate_current_video, demonstrate_current_poses] |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Base(primary_hue="blue")) as demo: |
|
|
gr.HTML( |
|
|
""" |
|
|
<style> |
|
|
[data-tab-id="task-1"], [data-tab-id="task-2"], [data-tab-id="task-3"] { |
|
|
font-size: 16px !important; |
|
|
font-weight: bold; |
|
|
} |
|
|
#page-title h1 { |
|
|
color: #002147 !important; |
|
|
} |
|
|
.task-title h2 { |
|
|
color: #004080 !important; |
|
|
} |
|
|
.header-button-row { |
|
|
gap: 4px !important; |
|
|
} |
|
|
.header-button-row div { |
|
|
width: 131.0px !important; |
|
|
} |
|
|
.header-button-column { |
|
|
width: 131.0px !important; |
|
|
gap: 5px !important; |
|
|
} |
|
|
.header-button a { |
|
|
border: 1px solid #002147; |
|
|
} |
|
|
.header-button .button-icon { |
|
|
margin-right: 8px; |
|
|
} |
|
|
.demo-button-column .gap { |
|
|
gap: 5px !important; |
|
|
} |
|
|
#basic-controls { |
|
|
column-gap: 0px; |
|
|
} |
|
|
#basic-controls-tab { |
|
|
padding: 0px; |
|
|
} |
|
|
#advanced-controls-tab { |
|
|
padding: 0px; |
|
|
} |
|
|
#forward-backward-controls { |
|
|
column-gap: 0px; |
|
|
justify-content: center; |
|
|
margin-top: 8px; |
|
|
} |
|
|
#selected-demo-button { |
|
|
color: #004080; |
|
|
text-decoration: underline; |
|
|
} |
|
|
.demo-button { |
|
|
text-align: left !important; |
|
|
display: block !important; |
|
|
} |
|
|
#navigation-gallery { |
|
|
margin-bottom: 15px; |
|
|
} |
|
|
#navigation-gallery .gallery-item { |
|
|
cursor: pointer; |
|
|
border-radius: 6px; |
|
|
transition: transform 0.2s, box-shadow 0.2s; |
|
|
} |
|
|
#navigation-gallery .gallery-item:hover { |
|
|
transform: scale(1.02); |
|
|
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); |
|
|
} |
|
|
#navigation-gallery .gallery-item.selected { |
|
|
border: 3px solid #002147; |
|
|
} |
|
|
/* Upload image styling */ |
|
|
#upload-image { |
|
|
border-radius: 8px; |
|
|
border: 2px dashed #002147; |
|
|
padding: 10px; |
|
|
transition: all 0.3s ease; |
|
|
} |
|
|
#upload-image:hover { |
|
|
border-color: #002147; |
|
|
box-shadow: 0 4px 12px rgba(0, 0, 0, 0.1); |
|
|
} |
|
|
/* Box styling */ |
|
|
.gradio-box { |
|
|
border-radius: 10px; |
|
|
margin-bottom: 20px; |
|
|
padding: 15px; |
|
|
background-color: #002147; |
|
|
border: 1px solid #002147; |
|
|
} |
|
|
/* Start Navigation button styling */ |
|
|
button[data-testid="Start Navigation"] { |
|
|
background-color: #004080 !important; |
|
|
border-color: #004080 !important; |
|
|
color: white !important; |
|
|
} |
|
|
button[data-testid="Start Navigation"]:hover { |
|
|
background-color: #002147 !important; |
|
|
border-color: #002147 !important; |
|
|
} |
|
|
/* Override Gradio's primary button color */ |
|
|
.gradio-button.primary { |
|
|
background-color: #004080 !important; |
|
|
border-color: #004080 !important; |
|
|
color: white !important; |
|
|
} |
|
|
.gradio-button.primary:hover { |
|
|
background-color: #002147 !important; |
|
|
border-color: #002147 !important; |
|
|
} |
|
|
</style> |
|
|
""" |
|
|
) |
|
|
|
|
|
demo_idx = gr.State(value=3) |
|
|
|
|
|
with gr.Sidebar(): |
|
|
gr.Image("assets/title_logo.png", width=60, height=60, show_label=False, show_download_button=False, container=False, interactive=False, show_fullscreen_button=False) |
|
|
gr.Markdown("# Consistent Interactive Video Scene Generation with Surfel-Indexed View Memory", elem_id="page-title") |
|
|
gr.Markdown( |
|
|
"### Interactive Demo for [_VMem_](http://arxiv.org/abs/2506.18903) that enables interactive consistent video scene generation." |
|
|
) |
|
|
gr.Markdown("---") |
|
|
gr.Markdown("#### Links ↓") |
|
|
with gr.Row(elem_classes=["header-button-row"]): |
|
|
with gr.Column(elem_classes=["header-button-column"], min_width=0): |
|
|
gr.Button( |
|
|
value="Website", |
|
|
link="https://v-mem.github.io/", |
|
|
icon="https://simpleicons.org/icons/googlechrome.svg", |
|
|
elem_classes=["header-button"], |
|
|
size="md", |
|
|
min_width=0, |
|
|
) |
|
|
gr.Button( |
|
|
value="Paper", |
|
|
link="http://arxiv.org/abs/2506.18903", |
|
|
icon="https://simpleicons.org/icons/arxiv.svg", |
|
|
elem_classes=["header-button"], |
|
|
size="md", |
|
|
min_width=0, |
|
|
) |
|
|
with gr.Column(elem_classes=["header-button-column"], min_width=0): |
|
|
gr.Button( |
|
|
value="Code", |
|
|
link="https://github.com/runjiali-rl/vmem", |
|
|
icon="https://simpleicons.org/icons/github.svg", |
|
|
elem_classes=["header-button"], |
|
|
size="md", |
|
|
min_width=0, |
|
|
) |
|
|
gr.Button( |
|
|
value="Weights", |
|
|
link="https://huggingface.co/liguang0115/vmem", |
|
|
icon="https://simpleicons.org/icons/huggingface.svg", |
|
|
elem_classes=["header-button"], |
|
|
size="md", |
|
|
min_width=0, |
|
|
) |
|
|
gr.Markdown("---") |
|
|
gr.Markdown("This demo interface is adapted from the History-Guided Video Diffusion demo template. We thank the authors for their work.") |
|
|
|
|
|
|
|
|
|
|
|
demonstrate_stage = gr.State(value="Selection") |
|
|
demonstrate_selected_index = gr.State(value=None) |
|
|
demonstrate_current_video = gr.State(value=None) |
|
|
demonstrate_current_poses = gr.State(value=None) |
|
|
|
|
|
@gr.render(inputs=[demo_idx, demonstrate_stage, demonstrate_selected_index]) |
|
|
def render_demo( |
|
|
_demo_idx, _demonstrate_stage, _demonstrate_selected_index |
|
|
): |
|
|
match _demo_idx: |
|
|
case 3: |
|
|
render_demonstrate(_demonstrate_stage, _demonstrate_selected_index, demonstrate_stage, demonstrate_selected_index, demonstrate_current_video, demonstrate_current_poses) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch(debug=False, |
|
|
share=True, |
|
|
max_threads=1, |
|
|
show_error=False, |
|
|
) |
|
|
|