Spaces:
Runtime error
Runtime error
| import os | |
| import sys | |
| import time | |
| import threading | |
| from pathlib import Path | |
| import gc | |
| import random | |
| import numpy as np | |
| import torch | |
| import gradio as gr | |
| from PIL import Image | |
| from diffusers import FlowMatchEulerDiscreteScheduler | |
| from diffusers.pipelines.wan.pipeline_wan_i2v import WanImageToVideoPipeline | |
| from diffusers.models.transformers.transformer_wan import WanTransformer3DModel | |
| from diffusers.utils.export_utils import export_to_video | |
| from pathlib import Path | |
| # Reuse the optimization utilities from the reference project. | |
| CURRENT_DIR = Path(__file__).resolve().parent | |
| from optimization_quantized import optimize_pipeline_int8 # noqa: E402 | |
| MAX_DIMENSION = 832 | |
| MIN_DIMENSION = 480 | |
| DIMENSION_MULTIPLE = 16 | |
| SQUARE_SIZE = 480 | |
| MAX_SEED = np.iinfo(np.int32).max | |
| FIXED_FPS = 16 | |
| MIN_FRAMES_MODEL = 8 | |
| MAX_FRAMES_MODEL = 81 | |
| MIN_DURATION = round(MIN_FRAMES_MODEL / FIXED_FPS, 1) | |
| MAX_DURATION = round(MAX_FRAMES_MODEL / FIXED_FPS, 1) | |
| default_negative_prompt = ( | |
| "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰," | |
| "最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部," | |
| "畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走,过曝," | |
| ) | |
| MODEL_ID = "Wan-AI/Wan2.2-I2V-A14B-Diffusers" | |
| PERSISTENT_DIR = Path('/data/.huggingface') | |
| MODEL_LOCAL_DIR = Path("/data/.huggingface/Wan-AI/Wan2.2-I2V-A14B-Diffusers") | |
| #MODEL_LOCAL_DIR.mkdir(parents=True, exist_ok=True) | |
| os.environ["HF_HOME"] = "/data/.huggingface" | |
| OUTPUT_DIR = CURRENT_DIR / "outputs" | |
| OUTPUT_DIR.mkdir(parents=True, exist_ok=True) | |
| def _load_pipeline(): | |
| print("Loading models from local directory or downloading...") | |
| wan_pipe = WanImageToVideoPipeline.from_pretrained( | |
| MODEL_ID, | |
| transformer=WanTransformer3DModel.from_pretrained('cbensimon/Wan2.2-I2V-A14B-bf16-Diffusers', | |
| subfolder='transformer', | |
| torch_dtype=torch.bfloat16, | |
| device_map='balanced', | |
| ), | |
| transformer_2=WanTransformer3DModel.from_pretrained('cbensimon/Wan2.2-I2V-A14B-bf16-Diffusers', | |
| subfolder='transformer_2', | |
| torch_dtype=torch.bfloat16, | |
| device_map='balanced', | |
| ), | |
| torch_dtype=torch.bfloat16, | |
| device_map='balanced', | |
| ) | |
| wan_pipe.scheduler = FlowMatchEulerDiscreteScheduler.from_config( | |
| wan_pipe.scheduler.config, shift=8.0 | |
| ) | |
| return wan_pipe | |
| def _optimise_pipeline(pipe): | |
| print("Optimizing pipeline...") | |
| for _ in range(3): | |
| gc.collect() | |
| torch.cuda.synchronize() | |
| torch.cuda.empty_cache() | |
| optimize_pipeline_int8( | |
| pipe, | |
| image=Image.new("RGB", (832, 480)), | |
| prompt="prompt", | |
| height=480, | |
| width=832, | |
| num_frames=81, | |
| ) | |
| print("All models loaded and optimized. Gradio app is ready.") | |
| pipe = _load_pipeline() | |
| _optimise_pipeline(pipe) | |
| def save_video(frames, path, fps): | |
| import imageio | |
| with imageio.get_writer(str(path), fps=fps, codec="libx264", quality=8) as writer: | |
| for frame in frames: | |
| writer.append_data(np.array(frame)) | |
| def process_image_for_video(image: Image.Image) -> Image.Image: | |
| width, height = image.size | |
| if width == height: | |
| return image.resize((SQUARE_SIZE, SQUARE_SIZE), Image.Resampling.LANCZOS) | |
| aspect_ratio = width / height | |
| new_width, new_height = width, height | |
| if new_width > MAX_DIMENSION or new_height > MAX_DIMENSION: | |
| scale = MAX_DIMENSION / new_width if aspect_ratio > 1 else MAX_DIMENSION / new_height | |
| new_width *= scale | |
| new_height *= scale | |
| if new_width < MIN_DIMENSION or new_height < MIN_DIMENSION: | |
| scale = MIN_DIMENSION / new_height if aspect_ratio > 1 else MIN_DIMENSION / new_width | |
| new_width *= scale | |
| new_height *= scale | |
| final_width = int(round(new_width / DIMENSION_MULTIPLE) * DIMENSION_MULTIPLE) | |
| final_height = int(round(new_height / DIMENSION_MULTIPLE) * DIMENSION_MULTIPLE) | |
| final_width = max(final_width, MIN_DIMENSION if aspect_ratio < 1 else SQUARE_SIZE) | |
| final_height = max(final_height, MIN_DIMENSION if aspect_ratio > 1 else SQUARE_SIZE) | |
| return image.resize((final_width, final_height), Image.Resampling.LANCZOS) | |
| def resize_and_crop_to_match(target_image, reference_image): | |
| ref_width, ref_height = reference_image.size | |
| target_width, target_height = target_image.size | |
| scale = max(ref_width / target_width, ref_height / target_height) | |
| new_width, new_height = int(target_width * scale), int(target_height * scale) | |
| resized = target_image.resize((new_width, new_height), Image.Resampling.LANCZOS) | |
| left, top = (new_width - ref_width) // 2, (new_height - ref_height) // 2 | |
| return resized.crop((left, top, left + ref_width, top + ref_height)) | |
| def preview_uploaded_images(file_paths): | |
| if not file_paths: | |
| return [] | |
| return file_paths | |
| def generate_video_sequence( | |
| frame1, | |
| frame2, | |
| frame3, | |
| frame4, | |
| frame5, | |
| prompt, | |
| negative_prompt=default_negative_prompt, | |
| duration_seconds=2.1, | |
| steps=8, | |
| guidance_scale=1, | |
| guidance_scale_2=1, | |
| seed=42, | |
| randomize_seed=True, | |
| progress=gr.Progress(track_tqdm=True), | |
| ): | |
| image_paths = [path for path in (frame1, frame2, frame3, frame4, frame5) if path] | |
| if len(image_paths) < 2: | |
| raise gr.Error("Please provide at least the first two frames before generating.") | |
| if len(image_paths) > 5: | |
| image_paths = image_paths[:5] | |
| raw_images = [] | |
| for path in image_paths: | |
| if path is None or path == "": | |
| raise gr.Error("Encountered an empty image slot. Please upload images sequentially without gaps.") | |
| try: | |
| with Image.open(path) as img: | |
| raw_images.append(img.convert("RGB")) | |
| except Exception as exc: | |
| raise gr.Error(f"Failed to read image '{path}': {exc}") from exc | |
| progress(0.05, desc="Preprocessing images...") | |
| processed_images = [process_image_for_video(raw_images[0])] | |
| for img in raw_images[1:]: | |
| resized = resize_and_crop_to_match(img, processed_images[-1]) | |
| processed_images.append(resized) | |
| num_segments = len(processed_images) - 1 | |
| frames_per_segment = int(round(duration_seconds * FIXED_FPS)) | |
| frames_per_segment = int(np.clip(frames_per_segment, MIN_FRAMES_MODEL, MAX_FRAMES_MODEL)) | |
| all_frames = [] | |
| seeds_per_segment = [] | |
| base_seed = int(seed) | |
| for idx in range(num_segments): | |
| start_img = processed_images[idx] | |
| end_img = processed_images[idx + 1] | |
| current_seed = ( | |
| random.randint(0, MAX_SEED) | |
| if randomize_seed | |
| else (base_seed + idx) % (MAX_SEED + 1) | |
| ) | |
| seeds_per_segment.append(current_seed) | |
| progress_value = 0.1 + 0.7 * ((idx + 1) / max(1, num_segments)) | |
| progress(progress_value, desc=f"Generating segment {idx + 1}/{num_segments}...") | |
| segment_frames = pipe( | |
| image=start_img, | |
| last_image=end_img, | |
| prompt=prompt, | |
| negative_prompt=negative_prompt, | |
| height=start_img.height, | |
| width=start_img.width, | |
| num_frames=frames_per_segment, | |
| guidance_scale=float(guidance_scale), | |
| guidance_scale_2=float(guidance_scale_2), | |
| num_inference_steps=int(steps), | |
| generator=torch.Generator(device="cuda").manual_seed(current_seed), | |
| ).frames[0] | |
| if idx > 0: | |
| segment_frames = segment_frames[1:] | |
| all_frames.extend(segment_frames) | |
| progress(0.9, desc="Encoding and saving video...") | |
| video_filename = ( | |
| f"wan_multi_{int(time.time())}_{seeds_per_segment[-1]}_" | |
| f"{processed_images[0].width}x{processed_images[0].height}_" | |
| f"{len(all_frames)}.mp4" | |
| ) | |
| video_path = OUTPUT_DIR / video_filename | |
| save_video(all_frames, video_path, fps=FIXED_FPS) | |
| print( | |
| f"[OK] Saved video -> {video_path} " | |
| f"({video_path.stat().st_size / 1024:.1f} KB; {len(all_frames)} frames)" | |
| ) | |
| def _delete_file(path): | |
| try: | |
| os.remove(path) | |
| print(f"[CLEANUP] Deleted cached video: {path}") | |
| except Exception as error: | |
| print(f"[CLEANUP ERROR] {error}") | |
| threading.Timer(30, _delete_file, args=[str(video_path)]).start() | |
| progress(1.0, desc="Done!") | |
| seeds_summary = ", ".join(str(s) for s in seeds_per_segment) | |
| last_seed = seeds_per_segment[-1] if seeds_per_segment else base_seed | |
| return str(video_path), last_seed, seeds_summary | |
| css = """ | |
| .fillable{max-width: 1100px !important} | |
| .dark .progress-text {color: white} | |
| #general_items{margin-top: 2em} | |
| #group_all{overflow:visible} | |
| #group_all .styler{overflow:visible} | |
| #group_tabs .tabitem{padding: 0} | |
| .tab-wrapper{margin-top: -33px;z-index: 999;position: absolute;width: 100%;background-color: var(--block-background-fill);padding: 0;} | |
| """ | |
| with gr.Blocks(theme=gr.themes.Citrus(), css=css) as app: | |
| gr.Markdown("# Wan 2.2 Multi-Frame Video") | |
| gr.Markdown( | |
| "Upload 2-5 keyframes, provide a prompt, and generate smooth transitions between each pair." | |
| ) | |
| with gr.Row(elem_id="general_items"): | |
| with gr.Column(): | |
| with gr.Group(elem_id="group_all"): | |
| with gr.Column(): | |
| frame_slots = [] | |
| max_slots = 5 | |
| for idx in range(max_slots): | |
| slot_label = f"Frame {idx + 1}" + (" (required)" if idx < 2 else " (optional)") | |
| frame = gr.Image( | |
| label=slot_label, | |
| type="filepath", | |
| sources=["upload", "clipboard"], | |
| interactive=True, | |
| ) | |
| frame_slots.append(frame) | |
| image_preview = gr.Gallery( | |
| label="Frame Preview", | |
| columns=5, | |
| height="auto", | |
| show_label=True, | |
| ) | |
| prompt = gr.Textbox(label="Prompt") | |
| with gr.Accordion("Advanced Settings", open=False): | |
| duration_seconds_input = gr.Slider( | |
| minimum=MIN_DURATION, | |
| maximum=MAX_DURATION, | |
| step=0.1, | |
| value=2.1, | |
| label="Duration per segment (seconds)", | |
| ) | |
| negative_prompt_input = gr.Textbox( | |
| label="Negative Prompt", value=default_negative_prompt, lines=3 | |
| ) | |
| steps_slider = gr.Slider( | |
| minimum=1, maximum=30, step=1, value=16, label="Inference Steps" | |
| ) | |
| guidance_scale_input = gr.Slider( | |
| minimum=0.0, | |
| maximum=10.0, | |
| step=0.5, | |
| value=1.0, | |
| label="Guidance Scale - high noise", | |
| ) | |
| guidance_scale_2_input = gr.Slider( | |
| minimum=0.0, | |
| maximum=10.0, | |
| step=0.5, | |
| value=1.0, | |
| label="Guidance Scale - low noise", | |
| ) | |
| with gr.Row(): | |
| seed_input = gr.Slider( | |
| label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42 | |
| ) | |
| randomize_seed_checkbox = gr.Checkbox( | |
| label="Randomize seed", value=True | |
| ) | |
| generate_button = gr.Button("Generate Video", variant="primary") | |
| with gr.Column(): | |
| output_video = gr.Video(label="Generated Video", autoplay=True) | |
| seeds_output = gr.Textbox( | |
| label="Seeds per segment", interactive=False, show_copy_button=True | |
| ) | |
| def update_gallery(*images): | |
| return [img for img in images if img] | |
| for frame_slot in frame_slots: | |
| frame_slot.change( | |
| fn=update_gallery, | |
| inputs=frame_slots, | |
| outputs=image_preview, | |
| ) | |
| ui_inputs = ( | |
| frame_slots | |
| + [ | |
| prompt, | |
| negative_prompt_input, | |
| duration_seconds_input, | |
| steps_slider, | |
| guidance_scale_input, | |
| guidance_scale_2_input, | |
| seed_input, | |
| randomize_seed_checkbox, | |
| ] | |
| ) | |
| ui_outputs = [output_video, seed_input, seeds_output] | |
| generate_button.click(fn=generate_video_sequence, inputs=ui_inputs, outputs=ui_outputs) | |
| if __name__ == "__main__": | |
| app.launch(share=True) | |