Spaces:
Running on Zero
Running on Zero
| """ | |
| Game Editing โ G-Buffer conditioned stylized video generation | |
| Hugging Face Space (ZeroGPU) version. | |
| """ | |
| import os | |
| import sys | |
| import torch | |
| import tempfile | |
| import spaces | |
| import gradio as gr | |
| from PIL import Image | |
| from huggingface_hub import hf_hub_download | |
| # โโ download models on CPU at startup โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| HF_TOKEN = os.environ.get("HF_TOKEN", None) | |
| print("Downloading models...") | |
| CKPT_PATH = hf_hub_download( | |
| repo_id="Brian9999/game-editing", | |
| filename="model.safetensors", | |
| token=HF_TOKEN, | |
| ) | |
| T5_PATH = hf_hub_download( | |
| repo_id="Wan-AI/Wan2.1-T2V-1.3B", | |
| filename="models_t5_umt5-xxl-enc-bf16.pth", | |
| ) | |
| VAE_PATH = hf_hub_download( | |
| repo_id="Wan-AI/Wan2.1-T2V-1.3B", | |
| filename="Wan2.1_VAE.pth", | |
| ) | |
| print("Models downloaded.") | |
| # Patch: diffsynth/transformers indirectly imports torchaudio which fails on ZeroGPU. | |
| # We create a proper mock so importlib.util.find_spec doesn't choke on __spec__=None. | |
| import types, importlib | |
| _mock = types.ModuleType("torchaudio") | |
| _mock.__spec__ = importlib.machinery.ModuleSpec("torchaudio", None) | |
| _mock.__version__ = "0.0.0" | |
| sys.modules["torchaudio"] = _mock | |
| from diffsynth.utils.data import save_video, VideoData | |
| from diffsynth.pipelines.wan_video import WanVideoPipeline, ModelConfig | |
| from diffsynth.models.model_loader import MODEL_CONFIGS | |
| from gbuffer_utils import expand_patch_embedding, inject_gbuffer_unit | |
| # Register the fine-tuned GBuffer DiT model hash (not in PyPI diffsynth). | |
| # in_dim=96 = 16 (base) + 5*16 (gbuffer latents for albedo/depth/metallic/normal/roughness) | |
| _GBUFFER_DIT_CONFIG = { | |
| "model_hash": "a87823c16fa5119ca6ef32cefc0be86d", | |
| "model_name": "wan_video_dit", | |
| "model_class": "diffsynth.models.wan_video_dit.WanModel", | |
| "extra_kwargs": { | |
| "has_image_input": False, | |
| "patch_size": [1, 2, 2], | |
| "in_dim": 96, | |
| "dim": 1536, | |
| "ffn_dim": 8960, | |
| "freq_dim": 256, | |
| "text_dim": 4096, | |
| "out_dim": 16, | |
| "num_heads": 12, | |
| "num_layers": 30, | |
| "eps": 1e-06, | |
| }, | |
| } | |
| if not any(c["model_hash"] == _GBUFFER_DIT_CONFIG["model_hash"] for c in MODEL_CONFIGS): | |
| MODEL_CONFIGS.append(_GBUFFER_DIT_CONFIG) | |
| print("Registered GBuffer DiT model config.") | |
| # โโ constants โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| NEGATIVE_PROMPT = ( | |
| "่ฒ่ฐ่ณไธฝ๏ผ่ฟๆ๏ผ้ๆ๏ผ็ป่ๆจก็ณไธๆธ ๏ผๅญๅน๏ผ้ฃๆ ผ๏ผไฝๅ๏ผ็ปไฝ๏ผ็ป้ข๏ผ้ๆญข๏ผๆดไฝๅ็ฐ๏ผ" | |
| "ๆๅทฎ่ดจ้๏ผไฝ่ดจ้๏ผJPEGๅ็ผฉๆฎ็๏ผไธ้็๏ผๆฎ็ผบ็๏ผๅคไฝ็ๆๆ๏ผ็ปๅพไธๅฅฝ็ๆ้จ๏ผ" | |
| "็ปๅพไธๅฅฝ็่ธ้จ๏ผ็ธๅฝข็๏ผๆฏๅฎน็๏ผๅฝขๆ็ธๅฝข็่ขไฝ๏ผๆๆ่ๅ๏ผ้ๆญขไธๅจ็็ป้ข๏ผ" | |
| "ๆไนฑ็่ๆฏ๏ผไธๆก่ ฟ๏ผ่ๆฏไบบๅพๅค๏ผๅ็่ตฐ" | |
| ) | |
| PRESET_STYLES = { | |
| "(custom)": "", | |
| "bright_sunny": "the scene is bathed in bright, warm golden sunlight under a clear blue sky, creating a serene and radiant atmosphere.", | |
| "snowy_winter": "the scene is set in a frozen, snow-covered environment under cold, pale winter light with falling snowflakes, creating a silent and ethereal winter wonderland atmosphere.", | |
| "sunset_dramatic": "the scene is bathed in dramatic sunset light under a deep orange and crimson sky, with long shadows and volumetric god rays, creating an epic and cinematic atmosphere.", | |
| "cyberpunk_neon": "the scene is illuminated by vibrant pink and blue neon lights with glowing holographic particles and electric haze, creating a futuristic cyberpunk atmosphere.", | |
| "underwater": "the scene is submerged deep underwater under soft, filtered aquamarine light, with shimmering caustics and drifting bubbles, creating a mysterious and tranquil deep-sea atmosphere.", | |
| "autumn_warm": "the scene is filled with vibrant autumn foliage and golden and crimson fallen leaves under warm amber afternoon light, creating a cozy and nostalgic autumn atmosphere.", | |
| "moonlit_night": "the scene is set under a brilliant full moon, bathed in cool silver moonlight with fireflies dancing in the darkness and stars visible overhead, creating a mystical and enchanting nighttime atmosphere.", | |
| } | |
| def read_video_frames(video_path, num_frames, width, height): | |
| frames = VideoData(video_path).raw_data() | |
| frames = [f.resize((width, height), Image.LANCZOS) for f in frames[:num_frames]] | |
| return frames | |
| # โโ Pre-load pipeline on CPU at startup (free, no GPU needed) โโโโโโโโโโโโโ | |
| print("Building pipeline on CPU...") | |
| pipe = WanVideoPipeline.from_pretrained( | |
| torch_dtype=torch.bfloat16, | |
| device="cpu", | |
| model_configs=[ | |
| ModelConfig(CKPT_PATH), | |
| ModelConfig(T5_PATH), | |
| ModelConfig(VAE_PATH), | |
| ], | |
| ) | |
| expand_patch_embedding(pipe, num_gbuffers=5) | |
| inject_gbuffer_unit(pipe) | |
| print("Pipeline ready on CPU.") | |
| # โโ CPU pre-processing (outside @spaces.GPU to save GPU quota) โโโโโโโโโโโโโ | |
| def prepare_inputs( | |
| albedo_video, depth_video, metallic_video, normal_video, roughness_video, | |
| prompt, style_preset, num_frames, height, width, | |
| ): | |
| gbuffer_paths = [albedo_video, depth_video, metallic_video, normal_video, roughness_video] | |
| names = ["Albedo", "Depth", "Metallic", "Normal", "Roughness"] | |
| for i, path in enumerate(gbuffer_paths): | |
| if path is None: | |
| raise gr.Error(f"Missing G-buffer: {names[i]}") | |
| # build prompt | |
| if style_preset and style_preset != "(custom)": | |
| style_text = PRESET_STYLES[style_preset] | |
| final_prompt = f"{prompt.rstrip('.')}; {style_text}" if prompt.strip() else style_text | |
| else: | |
| if not prompt.strip(): | |
| raise gr.Error("Please enter a prompt or select a style preset.") | |
| final_prompt = prompt | |
| num_frames = int(num_frames) | |
| height = int(height) | |
| width = int(width) | |
| gbuffer_videos = [read_video_frames(p, num_frames, width, height) for p in gbuffer_paths] | |
| return final_prompt, gbuffer_videos, num_frames, height, width | |
| # โโ inference (ZeroGPU: GPU allocated only during this call) โโโโโโโโโโโโโโโ | |
| def generate( | |
| albedo_video, depth_video, metallic_video, normal_video, roughness_video, | |
| prompt, style_preset, | |
| num_frames, height, width, seed, cfg_scale, num_inference_steps, | |
| progress=gr.Progress(track_tqdm=True), | |
| ): | |
| # CPU pre-processing (validation, video loading, prompt building) | |
| final_prompt, gbuffer_videos, num_frames, height, width = prepare_inputs( | |
| albedo_video, depth_video, metallic_video, normal_video, roughness_video, | |
| prompt, style_preset, num_frames, height, width, | |
| ) | |
| seed = int(seed) | |
| # Move all model weights from CPU to GPU | |
| pipe.to("cuda") | |
| pipe.device = "cuda" | |
| tiled = True | |
| tile_size = (30, 52) | |
| tile_stride = (15, 26) | |
| pipe.scheduler.set_timesteps(num_inference_steps, denoising_strength=1.0, shift=5.0) | |
| inputs_posi = {"prompt": final_prompt} | |
| inputs_nega = {"negative_prompt": NEGATIVE_PROMPT} | |
| inputs_shared = { | |
| "input_image": None, "end_image": None, | |
| "input_video": None, "denoising_strength": 1.0, | |
| "control_video": None, "reference_image": None, | |
| "vace_video": None, "vace_video_mask": None, "vace_reference_image": None, "vace_scale": 1.0, | |
| "seed": seed, "rand_device": "cpu", | |
| "height": height, "width": width, "num_frames": num_frames, | |
| "cfg_scale": cfg_scale, "cfg_merge": False, | |
| "sigma_shift": 5.0, | |
| "motion_bucket_id": None, "longcat_video": None, | |
| "tiled": tiled, "tile_size": tile_size, "tile_stride": tile_stride, | |
| "sliding_window_size": None, "sliding_window_stride": None, | |
| "input_audio": None, "audio_sample_rate": 16000, | |
| "s2v_pose_video": None, "audio_embeds": None, "s2v_pose_latents": None, "motion_video": None, | |
| "animate_pose_video": None, "animate_face_video": None, "animate_inpaint_video": None, "animate_mask_video": None, | |
| "vap_video": None, | |
| "gbuffer_videos": gbuffer_videos, | |
| } | |
| for unit in pipe.units: | |
| inputs_shared, inputs_posi, inputs_nega = pipe.unit_runner(unit, pipe, inputs_shared, inputs_posi, inputs_nega) | |
| pipe.load_models_to_device(pipe.in_iteration_models) | |
| models = {name: getattr(pipe, name) for name in pipe.in_iteration_models} | |
| from tqdm import tqdm | |
| with torch.no_grad(): | |
| for progress_id, timestep in enumerate(tqdm(pipe.scheduler.timesteps, desc="Generating")): | |
| timestep = timestep.unsqueeze(0).to(dtype=pipe.torch_dtype, device="cuda") | |
| noise_pred_posi = pipe.model_fn(**models, **inputs_shared, **inputs_posi, timestep=timestep) | |
| if cfg_scale != 1.0: | |
| noise_pred_nega = pipe.model_fn(**models, **inputs_shared, **inputs_nega, timestep=timestep) | |
| noise_pred = noise_pred_nega + cfg_scale * (noise_pred_posi - noise_pred_nega) | |
| else: | |
| noise_pred = noise_pred_posi | |
| inputs_shared["latents"] = pipe.scheduler.step( | |
| noise_pred, pipe.scheduler.timesteps[progress_id], inputs_shared["latents"] | |
| ) | |
| pipe.load_models_to_device(["vae"]) | |
| with torch.no_grad(): | |
| video = pipe.vae.decode( | |
| inputs_shared["latents"], device="cuda", | |
| tiled=tiled, tile_size=tile_size, tile_stride=tile_stride, | |
| ) | |
| video = pipe.vae_output_to_video(video) | |
| out_path = os.path.join(tempfile.mkdtemp(), "output.mp4") | |
| save_video(video, out_path, fps=15, quality=5) | |
| return out_path | |
| # โโ UI โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| MODALITIES = ["Albedo", "Depth", "Metallic", "Normal", "Roughness"] | |
| def on_style_change(style): | |
| if style and style != "(custom)": | |
| return PRESET_STYLES[style] | |
| return "" | |
| def update_status(*videos): | |
| """Return a status string showing which G-buffers have been uploaded.""" | |
| uploaded = sum(1 for v in videos if v is not None) | |
| missing = [name for name, v in zip(MODALITIES, videos) if v is None] | |
| if uploaded == 5: | |
| return "**All 5 G-buffers uploaded. Ready to generate!**" | |
| else: | |
| return f"**Uploaded {uploaded}/5 G-buffers.** Missing: {', '.join(missing)}" | |
| with gr.Blocks(title="Game Editing", theme=gr.themes.Soft()) as demo: | |
| gr.Markdown("# Game Editing") | |
| gr.Markdown( | |
| "> **Quick Start:** Click one of the **Examples** below to auto-fill G-buffer videos and style, then hit **Generate**.\n>\n" | |
| "> **Note:** The entire generation process takes approximately **180 seconds**. " | |
| "For better results, increase the *Inference Steps* if you have sufficient GPU quota. " | |
| "In our paper, we use **50 steps**." | |
| ) | |
| scroll_btn = gr.Button("Jump to Examples", variant="secondary", size="sm") | |
| scroll_btn.click(fn=None, js="() => { document.querySelector('#examples-anchor').scrollIntoView({behavior: 'smooth'}); }") | |
| status = gr.Markdown("**Upload 5 G-buffer videos, pick a style, and hit Generate. (0/5 uploaded)**") | |
| # โโ Top row: Input (left) + Output (right), videos aligned โโ | |
| VIDEO_HEIGHT = 400 | |
| with gr.Row(equal_height=True): | |
| # ๅทฆไพง Column | |
| with gr.Column(scale=1): | |
| with gr.Tabs(): | |
| with gr.Tab("Albedo"): | |
| albedo = gr.Video(label="Albedo video", sources=["upload"], height=VIDEO_HEIGHT) | |
| with gr.Tab("Depth"): | |
| depth = gr.Video(label="Depth video", sources=["upload"], height=VIDEO_HEIGHT) | |
| with gr.Tab("Metallic"): | |
| metallic = gr.Video(label="Metallic video", sources=["upload"], height=VIDEO_HEIGHT) | |
| with gr.Tab("Normal"): | |
| normal = gr.Video(label="Normal video", sources=["upload"], height=VIDEO_HEIGHT) | |
| with gr.Tab("Roughness"): | |
| roughness = gr.Video(label="Roughness video", sources=["upload"], height=VIDEO_HEIGHT) | |
| # ไบไปถ็ปๅฎ | |
| all_videos = [albedo, depth, metallic, normal, roughness] | |
| for v in all_videos: | |
| v.change(update_status, inputs=all_videos, outputs=[status]) | |
| v.clear(update_status, inputs=all_videos, outputs=[status]) | |
| # ๅณไพง Column | |
| with gr.Column(scale=1): | |
| # ๐ก ๆ ธๅฟๆๅทง๏ผๅณ่พนไนๅฅไธ Tabs๏ผๅผบๅถ UI ้กถ้จๅฏน้ฝ | |
| with gr.Tabs(): | |
| with gr.Tab("Result"): | |
| output_video = gr.Video(label="Generated Video", height=VIDEO_HEIGHT) | |
| # โโ Bottom: Settings โโ | |
| gr.Markdown("### Settings") | |
| with gr.Row(): | |
| style_preset = gr.Dropdown( | |
| choices=list(PRESET_STYLES.keys()), | |
| value="bright_sunny", | |
| label="Style Preset", | |
| scale=1, | |
| ) | |
| prompt = gr.Textbox( | |
| label="Prompt", | |
| value=PRESET_STYLES["bright_sunny"], | |
| lines=2, | |
| placeholder="Describe the scene and desired style...", | |
| scale=3, | |
| ) | |
| with gr.Row(): | |
| num_frames = gr.Slider(17, 81, value=73, step=4, label="Frames") | |
| seed = gr.Number(value=0, label="Seed", precision=0) | |
| height = gr.Dropdown([480, 720], value=480, label="Height") | |
| width = gr.Dropdown([832, 1280], value=832, label="Width") | |
| cfg_scale = gr.Slider(1.0, 10.0, value=5.0, step=0.5, label="CFG Scale") | |
| num_steps = gr.Slider(10, 100, value=20, step=5, label="Inference Steps") | |
| run_btn = gr.Button("Generate", variant="primary", size="lg") | |
| # โโ Examples โโ | |
| gr.Markdown("### Examples", elem_id="examples-anchor") | |
| gr.Examples( | |
| examples=[ | |
| [ | |
| "examples/clip_0000_albedo.mp4", | |
| "examples/clip_0000_depth.mp4", | |
| "examples/clip_0000_metallic.mp4", | |
| "examples/clip_0000_normal.mp4", | |
| "examples/clip_0000_roughness.mp4", | |
| "bright_sunny", | |
| ], | |
| [ | |
| "examples/clip_0000_albedo.mp4", | |
| "examples/clip_0000_depth.mp4", | |
| "examples/clip_0000_metallic.mp4", | |
| "examples/clip_0000_normal.mp4", | |
| "examples/clip_0000_roughness.mp4", | |
| "cyberpunk_neon", | |
| ], | |
| [ | |
| "examples/clip_0000_albedo.mp4", | |
| "examples/clip_0000_depth.mp4", | |
| "examples/clip_0000_metallic.mp4", | |
| "examples/clip_0000_normal.mp4", | |
| "examples/clip_0000_roughness.mp4", | |
| "snowy_winter", | |
| ], | |
| ], | |
| inputs=[ | |
| albedo, depth, metallic, normal, roughness, | |
| style_preset, | |
| ], | |
| label="Click an example to load G-buffer videos and style preset", | |
| ) | |
| style_preset.change(on_style_change, inputs=[style_preset], outputs=[prompt]) | |
| run_btn.click( | |
| fn=generate, | |
| inputs=[ | |
| albedo, depth, metallic, normal, roughness, | |
| prompt, style_preset, | |
| num_frames, height, width, seed, cfg_scale, num_steps, | |
| ], | |
| outputs=[output_video], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860) | |