| import gradio as gr |
| import time |
|
|
| import sys |
| import subprocess |
| import time |
| from pathlib import Path |
|
|
| import hydra |
| from omegaconf import DictConfig, OmegaConf |
| from omegaconf.omegaconf import open_dict |
|
|
| from utils.print_utils import cyan |
| from utils.ckpt_utils import download_latest_checkpoint, is_run_id |
| from utils.cluster_utils import submit_slurm_job |
| from utils.distributed_utils import is_rank_zero |
| import numpy as np |
| import torch |
| from datasets.video.minecraft_video_dataset import * |
| import torchvision.transforms as transforms |
| import cv2 |
| import subprocess |
| from PIL import Image |
| from datetime import datetime |
|
|
| ACTION_KEYS = [ |
| "inventory", |
| "ESC", |
| "hotbar.1", |
| "hotbar.2", |
| "hotbar.3", |
| "hotbar.4", |
| "hotbar.5", |
| "hotbar.6", |
| "hotbar.7", |
| "hotbar.8", |
| "hotbar.9", |
| "forward", |
| "back", |
| "left", |
| "right", |
| "cameraY", |
| "cameraX", |
| "jump", |
| "sneak", |
| "sprint", |
| "swapHands", |
| "attack", |
| "use", |
| "pickItem", |
| "drop", |
| ] |
|
|
| |
| KEY_TO_ACTION = { |
| "Q": ("forward", 1), |
| "E": ("back", 1), |
| "W": ("cameraY", -1), |
| "S": ("cameraY", 1), |
| "A": ("cameraX", -1), |
| "D": ("cameraX", 1), |
| "U": ("drop", 1), |
| "N": ("noop", 1), |
| "1": ("hotbar.1", 1), |
| } |
|
|
| def parse_input_to_tensor(input_str): |
| """ |
| Convert an input string into a (sequence_length, 25) tensor, where each row is a one-hot representation |
| of the corresponding action key. |
| |
| Args: |
| input_str (str): A string consisting of "WASD" characters (e.g., "WASDWS"). |
| |
| Returns: |
| torch.Tensor: A tensor of shape (sequence_length, 25), where each row is a one-hot encoded action. |
| """ |
| |
| seq_len = len(input_str) |
| |
| |
| action_tensor = torch.zeros((seq_len, 25)) |
|
|
| |
| for i, char in enumerate(input_str): |
| action, value = KEY_TO_ACTION.get(char.upper()) |
| if action and action in ACTION_KEYS: |
| index = ACTION_KEYS.index(action) |
| action_tensor[i, index] = value |
|
|
| return action_tensor |
|
|
| def load_image_as_tensor(image_path: str) -> torch.Tensor: |
| """ |
| Load an image and convert it to a 0-1 normalized tensor. |
| |
| Args: |
| image_path (str): Path to the image file. |
| |
| Returns: |
| torch.Tensor: Image tensor of shape (C, H, W), normalized to [0,1]. |
| """ |
| if isinstance(image_path, str): |
| image = Image.open(image_path).convert("RGB") |
| else: |
| image = image_path |
| transform = transforms.Compose([ |
| transforms.ToTensor(), |
| ]) |
| return transform(image) |
|
|
| def run_local(cfg: DictConfig): |
| |
| from experiments import build_experiment |
|
|
| |
| hydra_cfg = hydra.core.hydra_config.HydraConfig.get() |
| cfg_choice = OmegaConf.to_container(hydra_cfg.runtime.choices) |
|
|
| with open_dict(cfg): |
| if cfg_choice["experiment"] is not None: |
| cfg.experiment._name = cfg_choice["experiment"] |
| if cfg_choice["dataset"] is not None: |
| cfg.dataset._name = cfg_choice["dataset"] |
| if cfg_choice["algorithm"] is not None: |
| cfg.algorithm._name = cfg_choice["algorithm"] |
|
|
| |
| experiment = build_experiment(cfg, None, cfg.checkpoint_path) |
| return experiment.exec_interactive(cfg.experiment.tasks[0]) |
|
|
| memory_frames = [] |
| memory_curr_frame = 0 |
| input_history = "" |
| ICE_PLAINS_IMAGE = "assets/ice_plains.png" |
| DESERT_IMAGE = "assets/desert.png" |
| SAVANNA_IMAGE = "assets/savanna.png" |
| PLAINS_IMAGE = "assets/plans.png" |
| PLACE_IMAGE = "assets/place.png" |
| SUNFLOWERS_IMAGE = "assets/sunflower_plains.png" |
| SUNFLOWERS_RAIN_IMAGE = "assets/rain_sunflower_plains.png" |
|
|
| DEFAULT_IMAGE = ICE_PLAINS_IMAGE |
| device = "cuda:0" |
|
|
| def save_video(frames, path="output.mp4", fps=10): |
| h, w, _ = frames[0].shape |
| out = cv2.VideoWriter(path, cv2.VideoWriter_fourcc(*'XVID'), fps, (w, h)) |
| for frame in frames: |
| out.write(cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)) |
| out.release() |
|
|
| ffmpeg_cmd = [ |
| "ffmpeg", "-y", "-i", path, "-c:v", "libx264", "-crf", "23", "-preset", "medium", path |
| ] |
| subprocess.run(ffmpeg_cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) |
| return path |
|
|
| @hydra.main( |
| version_base=None, |
| config_path="configurations", |
| config_name="config", |
| ) |
| def run(cfg: DictConfig): |
| algo = run_local(cfg) |
| algo.to("cuda:0") |
| |
| actions = torch.zeros((1, 25)) |
| poses = torch.zeros((1, 5)) |
| |
| memory_frames.append(load_image_as_tensor(DEFAULT_IMAGE)) |
|
|
| _ = algo.interactive(memory_frames[0], |
| actions[0], |
| poses[0], |
| memory_curr_frame, |
| device="cuda:0") |
|
|
| def set_denoising_steps(denoising_steps, sampling_timesteps_state): |
| algo.sampling_timesteps = denoising_steps |
| algo.diffusion_model.sampling_timesteps = denoising_steps |
| sampling_timesteps_state = denoising_steps |
| print("set denoising steps to", algo.sampling_timesteps) |
| return sampling_timesteps_state |
|
|
|
|
| def update_image_and_log(keys): |
| actions = parse_input_to_tensor(keys) |
| global input_history |
| global memory_curr_frame |
| for i in range(len(actions)): |
| memory_curr_frame += 1 |
| new_frame = algo.interactive(memory_frames[0], |
| actions[i], |
| None, |
| memory_curr_frame, |
| device="cuda:0") |
|
|
| memory_frames.append(new_frame) |
|
|
| out_video = torch.stack(memory_frames) |
| out_video = out_video.permute(0,2,3,1).numpy() |
| out_video = np.clip(out_video, a_min=0.0, a_max=1.0) |
| out_video = (out_video * 255).astype(np.uint8) |
|
|
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") |
| os.makedirs("outputs_gradio", exist_ok=True) |
| filename = f"outputs_gradio/{timestamp}.mp4" |
| save_video(out_video, filename) |
|
|
| input_history += keys |
| return out_video[-1], filename, input_history |
|
|
| def reset(): |
| global memory_curr_frame |
| global input_history |
| global memory_frames |
|
|
| algo.reset() |
| memory_frames = [] |
| memory_frames.append(load_image_as_tensor(DEFAULT_IMAGE)) |
| memory_curr_frame = 0 |
| input_history = "" |
|
|
| _ = algo.interactive(memory_frames[0], |
| actions[0], |
| poses[0], |
| memory_curr_frame, |
| device="cuda:0") |
| return input_history, DEFAULT_IMAGE |
|
|
| def on_image_click(SELECTED_IMAGE): |
| global DEFAULT_IMAGE |
| DEFAULT_IMAGE = SELECTED_IMAGE |
| reset() |
| return SELECTED_IMAGE |
|
|
| css = """ |
| h1 { |
| text-align: center; |
| display:block; |
| } |
| """ |
|
|
| |
| with gr.Blocks(css=css) as demo: |
| gr.Markdown( |
| """ |
| # WORLDMEM: Long-term Consistent World Generation with Memory |
| |
| <div style="text-align: center;"> |
| <!-- Public Website --> |
| <a style="display:inline-block" href="https://nirvanalan.github.io/projects/GA/"> |
| <img src="https://img.shields.io/badge/public_website-8A2BE2"> |
| </a> |
| |
| <!-- GitHub Stars --> |
| <a style="display:inline-block; margin-left: .5em" href="https://github.com/NIRVANALAN/GaussianAnything"> |
| <img src="https://img.shields.io/github/stars/NIRVANALAN/GaussianAnything?style=social"> |
| </a> |
| |
| <!-- Project Page --> |
| <a style="display:inline-block; margin-left: .5em" href="https://nirvanalan.github.io/projects/GA/"> |
| <img src="https://img.shields.io/badge/project_page-blue"> |
| </a> |
| |
| <!-- arXiv Paper --> |
| <a style="display:inline-block; margin-left: .5em" href="https://arxiv.org/abs/XXXX.XXXXX"> |
| <img src="https://img.shields.io/badge/arXiv-paper-red"> |
| </a> |
| </div> |
| |
| """ |
| ) |
| |
| with gr.Row(variant="panel"): |
| video_display = gr.Video(autoplay=True, loop=True) |
| image_display = gr.Image(value=DEFAULT_IMAGE, interactive=False, label="Last Frame") |
|
|
| with gr.Row(variant="panel"): |
| with gr.Column(scale=2): |
| input_box = gr.Textbox(label="Action Sequence", placeholder="Enter action sequence here...", lines=1, max_lines=1) |
| log_output = gr.Textbox(label="History Log", interactive=False) |
| with gr.Column(scale=1): |
| slider = gr.Slider(minimum=10, maximum=50, value=algo.sampling_timesteps, step=1, label="Denoising Steps") |
| submit_button = gr.Button("Generate") |
| reset_btn = gr.Button("Reset") |
| |
| sampling_timesteps_state = gr.State(algo.sampling_timesteps) |
| |
| example_actions = ["DDDDDDDDEEEEEEEEEESSSAAAAAAAAWWW", "DDDDDDDDDDDDQQQQQQQQQQQQQQQDDDDDDDDDDDD", |
| "DDDDWWWDDDDDDDDDDDDDDDDDDDDSSSAAAAAAAAAAAAAAAAAAAAAAAA", "SSUNNWWEEEEEEEEEAAASSUNNWWEEEEEEEEEAAAAAAAAAAAAAAAAAAAAAA"] |
| |
| def set_action(action): |
| return action |
| |
| gr.Markdown("### Action sequence examples.") |
| with gr.Row(): |
| buttons = [] |
| for action in example_actions[:2]: |
| with gr.Column(scale=len(action)): |
| buttons.append(gr.Button(action)) |
| with gr.Row(): |
| for action in example_actions[2:4]: |
| with gr.Column(scale=len(action)): |
| buttons.append(gr.Button(action)) |
| with gr.Row(): |
| for action in example_actions[4:5]: |
| with gr.Column(scale=len(action)): |
| buttons.append(gr.Button(action)) |
|
|
| for button, action in zip(buttons, example_actions): |
| button.click(set_action, inputs=[gr.State(value=action)], outputs=input_box) |
|
|
|
|
| gr.Markdown("### Click on the images below to reset the sequence and generate from the new image.") |
|
|
| with gr.Row(): |
| image_display_1 = gr.Image(value=SUNFLOWERS_IMAGE, interactive=False, label="Sunflower Plains") |
| image_display_2 = gr.Image(value=DESERT_IMAGE, interactive=False, label="Desert") |
| image_display_3 = gr.Image(value=SAVANNA_IMAGE, interactive=False, label="Savanna") |
| image_display_4 = gr.Image(value=ICE_PLAINS_IMAGE, interactive=False, label="Ice Plains") |
| image_display_5 = gr.Image(value=SUNFLOWERS_RAIN_IMAGE, interactive=False, label="Rainy Sunflower Plains") |
| image_display_6 = gr.Image(value=PLACE_IMAGE, interactive=False, label="Place") |
|
|
| gr.Markdown( |
| """ |
| ## Instructions & Notes: |
| |
| 1. Enter an action sequence in the **"Action Sequence"** text box and click **"Generate"** to begin. |
| 2. You can continue generation by clicking **"Generation"** again and again. Previous sequences are logged in the history panel. |
| 3. Click **"Reset"** to clear the current sequence and start fresh. |
| 4. Action sequences can be composed using the following keys: |
| - W: turn up |
| - S: turn down |
| - A: turn left |
| - D: turn right |
| - Q: move forward |
| - E: move backward |
| - N: no-op (do nothing) |
| - 1: switch to hotbar 1 |
| - U: use item |
| 5. Higher denoising steps produce more detailed results but take longer. **20 steps** is a good balance between quality and speed. |
| 6. If you find this project interesting or useful, please consider giving it a ⭐️ on [GitHub]()! |
| 7. For feedback or suggestions, feel free to open a GitHub issue or contact me directly at **zeqixiao1@gmail.com**. |
| """ |
| ) |
| |
| submit_button.click(update_image_and_log, inputs=[input_box], outputs=[image_display, video_display, log_output]) |
| reset_btn.click(reset, outputs=[log_output, image_display]) |
| image_display_1.select(lambda: on_image_click(SUNFLOWERS_IMAGE), outputs=image_display) |
| image_display_2.select(lambda: on_image_click(DESERT_IMAGE), outputs=image_display) |
| image_display_3.select(lambda: on_image_click(SAVANNA_IMAGE), outputs=image_display) |
| image_display_4.select(lambda: on_image_click(ICE_PLAINS_IMAGE), outputs=image_display) |
| image_display_5.select(lambda: on_image_click(SUNFLOWERS_RAIN_IMAGE), outputs=image_display) |
| image_display_6.select(lambda: on_image_click(PLACE_IMAGE), outputs=image_display) |
|
|
| slider.change(fn=set_denoising_steps, inputs=[slider, sampling_timesteps_state], outputs=sampling_timesteps_state) |
|
|
| |
| demo.launch(share=True) |
| demo.launch(server_name="0.0.0.0", server_port=30066) |
|
|
| if __name__ == "__main__": |
| run() |
|
|