import tempfile import time import gradio as gr import spaces import torch from diffusers import AutoencoderKLWan, HeliosDMDScheduler, HeliosPyramidPipeline from diffusers.utils import export_to_video, load_image, load_video # --------------------------------------------------------------------------- # Pre-load model # --------------------------------------------------------------------------- MODEL_ID = "BestWishYsh/Helios-Distilled" vae = AutoencoderKLWan.from_pretrained(MODEL_ID, subfolder="vae", torch_dtype=torch.float32) scheduler = HeliosDMDScheduler.from_pretrained(MODEL_ID, subfolder="scheduler") pipe = HeliosPyramidPipeline.from_pretrained( MODEL_ID, vae=vae, scheduler=scheduler, torch_dtype=torch.bfloat16, is_distilled=True ) pipe.to("cuda") try: pipe.transformer.set_attention_backend("_flash_3_hub") except Exception: pipe.transformer.set_attention_backend("flash_hub") # ----------------------------- JIT ----------------------------- # pipe.transformer.compile(mode="max-autotune-no-cudagraphs", dynamic=False) # pipe.transformer.compile_repeated_blocks(fullgraph=True) # ----------------------------- AoTI ----------------------------- # def make_exported_contiguous(exported): # for key, val in exported.constants.items(): # if not val.is_contiguous(): # exported.constants[key] = val.contiguous() # return exported # @spaces.GPU(duration=1500) # maximum duration allowed during startup # def compile(): # pipe("prompt", width=640, height=384, pyramid_num_inference_steps_list=[1, 1, 1]) # with spaces.aoti_capture(pipe.transformer) as call_low: # pipe("prompt", width=160, height=96) # exported_low = torch.export.export(pipe.transformer, args=call_low.args, kwargs=call_low.kwargs) # exported_low = make_exported_contiguous(exported_low) # compiled_low = spaces.aoti_compile(exported_low) # with spaces.aoti_capture(pipe.transformer) as call_mid: # pipe("prompt", width=320, height=192) # exported_mid = torch.export.export(pipe.transformer, args=call_mid.args, kwargs=call_mid.kwargs) # exported_mid = make_exported_contiguous(exported_mid) # compiled_mid = spaces.aoti_compile(exported_mid) # with spaces.aoti_capture(pipe.transformer) as call_high: # pipe("prompt", width=640, height=384) # exported_high = torch.export.export(pipe.transformer, args=call_high.args, kwargs=call_high.kwargs) # exported_high = make_exported_contiguous(exported_high) # compiled_high = spaces.aoti_compile(exported_high) # # push_to_hub(compiled_low, "BestWishYsh/HeliosBench-Weights", "transformer_low.pt") # # push_to_hub(compiled_mid, "BestWishYsh/HeliosBench-Weights", "transformer_mid.pt") # # push_to_hub(compiled_high, "BestWishYsh/HeliosBench-Weights", "transformer_high.pt") # compiled_mid.weights = compiled_low.weights # compiled_high.weights = compiled_low.weights # return compiled_low, compiled_mid, compiled_high # compiled_low, compiled_mid, compiled_high = compile() # def combined(*args, **kwargs): # hidden_states = kwargs['hidden_states'] # if hidden_states.shape[-1] == 20: # return compiled_low(*args, **kwargs) # elif hidden_states.shape[-1] == 40: # return compiled_mid(*args, **kwargs) # else: # return compiled_high(*args, **kwargs) # spaces.aoti_apply(combined, pipe.transformer) # --------------------------------------------------------------------------- # Generation # --------------------------------------------------------------------------- @spaces.GPU(duration=120) def generate_video( mode: str, prompt: str, image_input, video_input, height: int, width: int, num_frames: int, num_inference_steps: int, seed: int, is_amplify_first_chunk: bool, progress=gr.Progress(track_tqdm=True), ): if not prompt: raise gr.Error("Please provide a prompt.") generator = torch.Generator(device="cuda").manual_seed(int(seed)) kwargs = { "prompt": prompt, "height": int(height), "width": int(width), "num_frames": int(num_frames), "guidance_scale": 1.0, "generator": generator, "output_type": "np", "pyramid_num_inference_steps_list": [ int(num_inference_steps), int(num_inference_steps), int(num_inference_steps), ], "is_amplify_first_chunk": is_amplify_first_chunk, } if mode == "Image-to-Video" and image_input is not None: img = load_image(image_input).resize((int(width), int(height))) kwargs["image"] = img elif mode == "Video-to-Video" and video_input is not None: kwargs["video"] = load_video(video_input) t0 = time.time() output = pipe(**kwargs).frames[0] elapsed = time.time() - t0 tmp = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) export_to_video(output, tmp.name, fps=24) info = f"Generated in {elapsed:.1f}s · {num_frames} frames · {height}×{width}" return tmp.name, info # --------------------------------------------------------------------------- # UI Setup # --------------------------------------------------------------------------- def update_conditional_visibility(mode): if mode == "Image-to-Video": return gr.update(visible=True), gr.update(visible=False) elif mode == "Video-to-Video": return gr.update(visible=False), gr.update(visible=True) else: return gr.update(visible=False), gr.update(visible=False) CSS = """ #header { text-align: center; margin-bottom: 1.5em; } #header h1 { font-size: 2.2em; margin-bottom: 0.2em; } .logo { max-height: 100px; margin: 0 auto 10px auto; display: block; } .link-buttons { display: flex; justify-content: center; gap: 15px; margin-top: 10px; } .link-buttons a { background-color: #2b3137; color: #ffffff !important; padding: 8px 20px; border-radius: 6px; text-decoration: none; font-weight: 600; font-size: 1em; transition: all 0.2s ease-in-out; box-shadow: 0 2px 4px rgba(0,0,0,0.1); } .link-buttons a:hover { background-color: #4a535c; transform: translateY(-1px); } .contain { max-width: 1350px; margin: 0 auto !important; } """ with gr.Blocks(title="Helios Video Generation") as demo: gr.HTML( """
If you like our project, please give us a star ⭐ on GitHub for the latest update.