Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| import torch | |
| from diffusers import AutoencoderKLWan, WanImageToVideoPipeline, UniPCMultistepScheduler | |
| from diffusers.utils import export_to_video | |
| from transformers import CLIPVisionModel | |
| import gradio as gr | |
| import tempfile | |
| import spaces | |
| from huggingface_hub import hf_hub_download | |
| import numpy as np | |
| from PIL import Image | |
| import random | |
| import gc | |
| MODEL_ID = "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers" | |
| LORA_REPO_ID = "Kijai/WanVideo_comfy" | |
| LORA_FILENAME = "Wan21_CausVid_14B_T2V_lora_rank32.safetensors" | |
| # Initialize model with error handling | |
| try: | |
| image_encoder = CLIPVisionModel.from_pretrained(MODEL_ID, subfolder="image_encoder", torch_dtype=torch.float16) | |
| vae = AutoencoderKLWan.from_pretrained(MODEL_ID, subfolder="vae", torch_dtype=torch.float16) | |
| pipe = WanImageToVideoPipeline.from_pretrained( | |
| MODEL_ID, vae=vae, image_encoder=image_encoder, torch_dtype=torch.float16 | |
| ) | |
| pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config, flow_shift=8.0) # Increased for smoother transitions | |
| pipe.to("cuda") | |
| causvid_path = hf_hub_download(repo_id=LORA_REPO_ID, filename=LORA_FILENAME) | |
| pipe.load_lora_weights(causvid_path, adapter_name="causvid_lora") | |
| pipe.set_adapters(["causvid_lora"], adapter_weights=[1.0]) # Increased weight for better LoRA effect | |
| pipe.fuse_lora() | |
| except Exception as e: | |
| gr.Error(f"Model initialization failed: {str(e)}") | |
| raise | |
| MOD_VALUE = 32 | |
| DEFAULT_H_SLIDER_VALUE = 256 | |
| DEFAULT_W_SLIDER_VALUE = 448 | |
| NEW_FORMULA_MAX_AREA = 256.0 * 448.0 | |
| SLIDER_MIN_H, SLIDER_MAX_H = 128, 512 | |
| SLIDER_MIN_W, SLIDER_MAX_W = 128, 672 | |
| MAX_SEED = np.iinfo(np.int32).max | |
| FIXED_FPS = 24 | |
| MIN_FRAMES_MODEL = 8 | |
| MAX_FRAMES_MODEL = 241 # Supports 8 seconds at 24 FPS | |
| default_prompt_i2v = "" | |
| default_negative_prompt = "jerky motion, low quality, blurry, pixelated, distorted, unnatural movements, artifacts, static, overexposed, underexposed, grainy, inconsistent frames, watermark, text, logo, low detail, unnatural lighting, deformed objects" | |
| def _calculate_new_dimensions_wan(pil_image, mod_val, calculation_max_area, | |
| min_slider_h, max_slider_h, | |
| min_slider_w, max_slider_w, | |
| default_h, default_w): | |
| if pil_image is None: | |
| return default_h, default_w | |
| try: | |
| orig_w, orig_h = pil_image.size | |
| if orig_w <= 0 or orig_h <= 0: | |
| return default_h, default_w | |
| aspect_ratio = orig_h / orig_w | |
| calc_h = round(np.sqrt(calculation_max_area * aspect_ratio)) | |
| calc_w = round(np.sqrt(calculation_max_area / aspect_ratio)) | |
| calc_h = max(mod_val, (calc_h // mod_val) * mod_val) | |
| calc_w = max(mod_val, (calc_w // mod_val) * mod_val) | |
| new_h = int(np.clip(calc_h, min_slider_h, (max_slider_h // mod_val) * mod_val)) | |
| new_w = int(np.clip(calc_w, min_slider_w, (max_slider_w // mod_val) * mod_val)) | |
| return new_h, new_w | |
| except Exception as e: | |
| gr.Warning(f"Dimension calculation error: {str(e)}") | |
| return default_h, default_w | |
| def handle_image_upload_for_dims_wan(uploaded_pil_image, current_h_val, current_w_val): | |
| try: | |
| new_h, new_w = _calculate_new_dimensions_wan( | |
| uploaded_pil_image, MOD_VALUE, NEW_FORMULA_MAX_AREA, | |
| SLIDER_MIN_H, SLIDER_MAX_H, SLIDER_MIN_W, SLIDER_MAX_W, | |
| DEFAULT_H_SLIDER_VALUE, DEFAULT_W_SLIDER_VALUE | |
| ) | |
| return gr.update(value=new_h), gr.update(value=new_w) | |
| except Exception as e: | |
| gr.Warning(f"Error calculating dimensions: {str(e)}") | |
| return gr.update(value=DEFAULT_H_SLIDER_VALUE), gr.update(value=DEFAULT_W_SLIDER_VALUE) | |
| def get_duration(input_image, prompt, height, width, | |
| negative_prompt, duration_seconds, | |
| guidance_scale, steps, | |
| seed, randomize_seed, | |
| progress): | |
| if steps > 6 or duration_seconds > 4: | |
| return 60 # Adjusted for 10 seconds with up to 6 steps | |
| return 30 | |
| def generate_video(input_image, prompt, height, width, | |
| negative_prompt=default_negative_prompt, duration_seconds=4, | |
| guidance_scale=1.5, steps=6, # Increased steps and guidance for quality | |
| seed=42, randomize_seed=False, | |
| progress=gr.Progress(track_tqdm=True)): | |
| if input_image is None: | |
| gr.Warning("Input image is None. Using default dimensions.") | |
| target_h = DEFAULT_H_SLIDER_VALUE | |
| target_w = DEFAULT_W_SLIDER_VALUE | |
| else: | |
| try: | |
| input_image = Image.frombytes(input_image.mode, input_image.size, input_image.rgb) if hasattr(input_image, 'rgb') else input_image | |
| target_h = max(MOD_VALUE, (int(height) // MOD_VALUE) * MOD_VALUE) | |
| target_w = max(MOD_VALUE, (int(width) // MOD_VALUE) * MOD_VALUE) | |
| input_image = input_image.resize((target_w, target_h)) | |
| except Exception as e: | |
| gr.Warning(f"Image processing error: {str(e)}") | |
| target_h = DEFAULT_H_SLIDER_VALUE | |
| target_w = DEFAULT_W_SLIDER_VALUE | |
| num_frames = np.clip(int(round(duration_seconds * FIXED_FPS)), MIN_FRAMES_MODEL, MAX_FRAMES_MODEL) | |
| if (num_frames - 1) % 8 != 0: | |
| num_frames = round((num_frames - 1) / 8) * 8 + 1 | |
| current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed) | |
| try: | |
| with torch.inference_mode(): | |
| output_frames_list = pipe( | |
| image=input_image, prompt=prompt, negative_prompt=negative_prompt, | |
| height=target_h, width=target_w, num_frames=num_frames, | |
| guidance_scale=float(guidance_scale), num_inference_steps=int(steps), | |
| generator=torch.Generator(device="cuda").manual_seed(current_seed) | |
| ).frames[0] | |
| except Exception as e: | |
| gr.Error(f"Video generation failed: {str(e)}") | |
| return None, current_seed | |
| finally: | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile: | |
| video_path = tmpfile.name | |
| export_to_video(output_frames_list, video_path, fps=FIXED_FPS) | |
| return video_path, current_seed | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# Fast Wan 2.1 I2V (14B) with CausVid LoRA") | |
| gr.Markdown("[CausVid](https://github.com/tianweiy/CausVid) is a distilled version of Wan 2.1, [extracted as LoRA by Kijai](https://huggingface.co/Kijai/WanVideo_comfy/blob/main/Wan21_CausVid_14B_T2V_lora_rank32.safetensors)") | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_image_component = gr.Image(type="pil", label="Input Image (auto-resized to target H/W)") | |
| prompt_input = gr.Textbox(label="Prompt", value=default_prompt_i2v) | |
| duration_seconds_input = gr.Slider(minimum=round(MIN_FRAMES_MODEL/FIXED_FPS,1), maximum=10, step=0.1, value=10, label="Duration (seconds)", info=f"Clamped to {MIN_FRAMES_MODEL}-{MAX_FRAMES_MODEL} frames at {FIXED_FPS}fps") | |
| with gr.Accordion("Advanced Settings", open=False): | |
| negative_prompt_input = gr.Textbox(label="Negative Prompt", value=default_negative_prompt, lines=3) | |
| seed_input = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42, interactive=True) | |
| randomize_seed_checkbox = gr.Checkbox(label="Randomize seed", value=True, interactive=True) | |
| with gr.Row(): | |
| height_input = gr.Slider(minimum=SLIDER_MIN_H, maximum=SLIDER_MAX_H, step=MOD_VALUE, value=DEFAULT_H_SLIDER_VALUE, label=f"Output Height (multiple of {MOD_VALUE})") | |
| width_input = gr.Slider(minimum=SLIDER_MIN_W, maximum=SLIDER_MAX_W, step=MOD_VALUE, value=DEFAULT_W_SLIDER_VALUE, label=f"Output Width (multiple of {MOD_VALUE})") | |
| steps_slider = gr.Slider(minimum=1, maximum=6, step=1, value=4, label="Inference Steps") | |
| guidance_scale_input = gr.Slider(minimum=0.0, maximum=20.0, step=0.5, value=1.5, label="Guidance Scale") | |
| generate_button = gr.Button("Generate Video", variant="primary") | |
| with gr.Column(): | |
| video_output = gr.Video(label="Generated Video", autoplay=True, interactive=False) | |
| input_image_component.upload( | |
| fn=handle_image_upload_for_dims_wan, | |
| inputs=[input_image_component, height_input, width_input], | |
| outputs=[height_input, width_input] | |
| ) | |
| input_image_component.clear( | |
| fn=handle_image_upload_for_dims_wan, | |
| inputs=[input_image_component, height_input, width_input], | |
| outputs=[height_input, width_input] | |
| ) | |
| ui_inputs = [ | |
| input_image_component, prompt_input, height_input, width_input, | |
| negative_prompt_input, duration_seconds_input, | |
| guidance_scale_input, steps_slider, seed_input, randomize_seed_checkbox | |
| ] | |
| generate_button.click(fn=generate_video, inputs=ui_inputs, outputs=[video_output, seed_input]) | |
| if __name__ == "__main__": | |
| demo.queue(max_size=10).launch(share=False, server_name="0.0.0.0", server_port=7860) |