wan_more_images / app.py
kylin0421
Add out dir
712ead5
import os
import sys
import time
import threading
from pathlib import Path
import gc
import random
import numpy as np
import torch
import gradio as gr
from PIL import Image
from diffusers import FlowMatchEulerDiscreteScheduler
from diffusers.pipelines.wan.pipeline_wan_i2v import WanImageToVideoPipeline
from diffusers.models.transformers.transformer_wan import WanTransformer3DModel
from diffusers.utils.export_utils import export_to_video
from pathlib import Path
# Reuse the optimization utilities from the reference project.
CURRENT_DIR = Path(__file__).resolve().parent
from optimization_quantized import optimize_pipeline_int8 # noqa: E402
MAX_DIMENSION = 832
MIN_DIMENSION = 480
DIMENSION_MULTIPLE = 16
SQUARE_SIZE = 480
MAX_SEED = np.iinfo(np.int32).max
FIXED_FPS = 16
MIN_FRAMES_MODEL = 8
MAX_FRAMES_MODEL = 81
MIN_DURATION = round(MIN_FRAMES_MODEL / FIXED_FPS, 1)
MAX_DURATION = round(MAX_FRAMES_MODEL / FIXED_FPS, 1)
default_negative_prompt = (
"色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,"
"最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,"
"畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走,过曝,"
)
MODEL_ID = "Wan-AI/Wan2.2-I2V-A14B-Diffusers"
PERSISTENT_DIR = Path('/data/.huggingface')
MODEL_LOCAL_DIR = Path("/data/.huggingface/Wan-AI/Wan2.2-I2V-A14B-Diffusers")
#MODEL_LOCAL_DIR.mkdir(parents=True, exist_ok=True)
os.environ["HF_HOME"] = "/data/.huggingface"
OUTPUT_DIR = CURRENT_DIR / "outputs"
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
def _load_pipeline():
print("Loading models from local directory or downloading...")
wan_pipe = WanImageToVideoPipeline.from_pretrained(
MODEL_ID,
transformer=WanTransformer3DModel.from_pretrained('cbensimon/Wan2.2-I2V-A14B-bf16-Diffusers',
subfolder='transformer',
torch_dtype=torch.bfloat16,
device_map='balanced',
),
transformer_2=WanTransformer3DModel.from_pretrained('cbensimon/Wan2.2-I2V-A14B-bf16-Diffusers',
subfolder='transformer_2',
torch_dtype=torch.bfloat16,
device_map='balanced',
),
torch_dtype=torch.bfloat16,
device_map='balanced',
)
wan_pipe.scheduler = FlowMatchEulerDiscreteScheduler.from_config(
wan_pipe.scheduler.config, shift=8.0
)
return wan_pipe
def _optimise_pipeline(pipe):
print("Optimizing pipeline...")
for _ in range(3):
gc.collect()
torch.cuda.synchronize()
torch.cuda.empty_cache()
optimize_pipeline_int8(
pipe,
image=Image.new("RGB", (832, 480)),
prompt="prompt",
height=480,
width=832,
num_frames=81,
)
print("All models loaded and optimized. Gradio app is ready.")
pipe = _load_pipeline()
_optimise_pipeline(pipe)
def save_video(frames, path, fps):
import imageio
with imageio.get_writer(str(path), fps=fps, codec="libx264", quality=8) as writer:
for frame in frames:
writer.append_data(np.array(frame))
def process_image_for_video(image: Image.Image) -> Image.Image:
width, height = image.size
if width == height:
return image.resize((SQUARE_SIZE, SQUARE_SIZE), Image.Resampling.LANCZOS)
aspect_ratio = width / height
new_width, new_height = width, height
if new_width > MAX_DIMENSION or new_height > MAX_DIMENSION:
scale = MAX_DIMENSION / new_width if aspect_ratio > 1 else MAX_DIMENSION / new_height
new_width *= scale
new_height *= scale
if new_width < MIN_DIMENSION or new_height < MIN_DIMENSION:
scale = MIN_DIMENSION / new_height if aspect_ratio > 1 else MIN_DIMENSION / new_width
new_width *= scale
new_height *= scale
final_width = int(round(new_width / DIMENSION_MULTIPLE) * DIMENSION_MULTIPLE)
final_height = int(round(new_height / DIMENSION_MULTIPLE) * DIMENSION_MULTIPLE)
final_width = max(final_width, MIN_DIMENSION if aspect_ratio < 1 else SQUARE_SIZE)
final_height = max(final_height, MIN_DIMENSION if aspect_ratio > 1 else SQUARE_SIZE)
return image.resize((final_width, final_height), Image.Resampling.LANCZOS)
def resize_and_crop_to_match(target_image, reference_image):
ref_width, ref_height = reference_image.size
target_width, target_height = target_image.size
scale = max(ref_width / target_width, ref_height / target_height)
new_width, new_height = int(target_width * scale), int(target_height * scale)
resized = target_image.resize((new_width, new_height), Image.Resampling.LANCZOS)
left, top = (new_width - ref_width) // 2, (new_height - ref_height) // 2
return resized.crop((left, top, left + ref_width, top + ref_height))
def preview_uploaded_images(file_paths):
if not file_paths:
return []
return file_paths
def generate_video_sequence(
frame1,
frame2,
frame3,
frame4,
frame5,
prompt,
negative_prompt=default_negative_prompt,
duration_seconds=2.1,
steps=8,
guidance_scale=1,
guidance_scale_2=1,
seed=42,
randomize_seed=True,
progress=gr.Progress(track_tqdm=True),
):
image_paths = [path for path in (frame1, frame2, frame3, frame4, frame5) if path]
if len(image_paths) < 2:
raise gr.Error("Please provide at least the first two frames before generating.")
if len(image_paths) > 5:
image_paths = image_paths[:5]
raw_images = []
for path in image_paths:
if path is None or path == "":
raise gr.Error("Encountered an empty image slot. Please upload images sequentially without gaps.")
try:
with Image.open(path) as img:
raw_images.append(img.convert("RGB"))
except Exception as exc:
raise gr.Error(f"Failed to read image '{path}': {exc}") from exc
progress(0.05, desc="Preprocessing images...")
processed_images = [process_image_for_video(raw_images[0])]
for img in raw_images[1:]:
resized = resize_and_crop_to_match(img, processed_images[-1])
processed_images.append(resized)
num_segments = len(processed_images) - 1
frames_per_segment = int(round(duration_seconds * FIXED_FPS))
frames_per_segment = int(np.clip(frames_per_segment, MIN_FRAMES_MODEL, MAX_FRAMES_MODEL))
all_frames = []
seeds_per_segment = []
base_seed = int(seed)
for idx in range(num_segments):
start_img = processed_images[idx]
end_img = processed_images[idx + 1]
current_seed = (
random.randint(0, MAX_SEED)
if randomize_seed
else (base_seed + idx) % (MAX_SEED + 1)
)
seeds_per_segment.append(current_seed)
progress_value = 0.1 + 0.7 * ((idx + 1) / max(1, num_segments))
progress(progress_value, desc=f"Generating segment {idx + 1}/{num_segments}...")
segment_frames = pipe(
image=start_img,
last_image=end_img,
prompt=prompt,
negative_prompt=negative_prompt,
height=start_img.height,
width=start_img.width,
num_frames=frames_per_segment,
guidance_scale=float(guidance_scale),
guidance_scale_2=float(guidance_scale_2),
num_inference_steps=int(steps),
generator=torch.Generator(device="cuda").manual_seed(current_seed),
).frames[0]
if idx > 0:
segment_frames = segment_frames[1:]
all_frames.extend(segment_frames)
progress(0.9, desc="Encoding and saving video...")
video_filename = (
f"wan_multi_{int(time.time())}_{seeds_per_segment[-1]}_"
f"{processed_images[0].width}x{processed_images[0].height}_"
f"{len(all_frames)}.mp4"
)
video_path = OUTPUT_DIR / video_filename
save_video(all_frames, video_path, fps=FIXED_FPS)
print(
f"[OK] Saved video -> {video_path} "
f"({video_path.stat().st_size / 1024:.1f} KB; {len(all_frames)} frames)"
)
def _delete_file(path):
try:
os.remove(path)
print(f"[CLEANUP] Deleted cached video: {path}")
except Exception as error:
print(f"[CLEANUP ERROR] {error}")
threading.Timer(30, _delete_file, args=[str(video_path)]).start()
progress(1.0, desc="Done!")
seeds_summary = ", ".join(str(s) for s in seeds_per_segment)
last_seed = seeds_per_segment[-1] if seeds_per_segment else base_seed
return str(video_path), last_seed, seeds_summary
css = """
.fillable{max-width: 1100px !important}
.dark .progress-text {color: white}
#general_items{margin-top: 2em}
#group_all{overflow:visible}
#group_all .styler{overflow:visible}
#group_tabs .tabitem{padding: 0}
.tab-wrapper{margin-top: -33px;z-index: 999;position: absolute;width: 100%;background-color: var(--block-background-fill);padding: 0;}
"""
with gr.Blocks(theme=gr.themes.Citrus(), css=css) as app:
gr.Markdown("# Wan 2.2 Multi-Frame Video")
gr.Markdown(
"Upload 2-5 keyframes, provide a prompt, and generate smooth transitions between each pair."
)
with gr.Row(elem_id="general_items"):
with gr.Column():
with gr.Group(elem_id="group_all"):
with gr.Column():
frame_slots = []
max_slots = 5
for idx in range(max_slots):
slot_label = f"Frame {idx + 1}" + (" (required)" if idx < 2 else " (optional)")
frame = gr.Image(
label=slot_label,
type="filepath",
sources=["upload", "clipboard"],
interactive=True,
)
frame_slots.append(frame)
image_preview = gr.Gallery(
label="Frame Preview",
columns=5,
height="auto",
show_label=True,
)
prompt = gr.Textbox(label="Prompt")
with gr.Accordion("Advanced Settings", open=False):
duration_seconds_input = gr.Slider(
minimum=MIN_DURATION,
maximum=MAX_DURATION,
step=0.1,
value=2.1,
label="Duration per segment (seconds)",
)
negative_prompt_input = gr.Textbox(
label="Negative Prompt", value=default_negative_prompt, lines=3
)
steps_slider = gr.Slider(
minimum=1, maximum=30, step=1, value=16, label="Inference Steps"
)
guidance_scale_input = gr.Slider(
minimum=0.0,
maximum=10.0,
step=0.5,
value=1.0,
label="Guidance Scale - high noise",
)
guidance_scale_2_input = gr.Slider(
minimum=0.0,
maximum=10.0,
step=0.5,
value=1.0,
label="Guidance Scale - low noise",
)
with gr.Row():
seed_input = gr.Slider(
label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42
)
randomize_seed_checkbox = gr.Checkbox(
label="Randomize seed", value=True
)
generate_button = gr.Button("Generate Video", variant="primary")
with gr.Column():
output_video = gr.Video(label="Generated Video", autoplay=True)
seeds_output = gr.Textbox(
label="Seeds per segment", interactive=False, show_copy_button=True
)
def update_gallery(*images):
return [img for img in images if img]
for frame_slot in frame_slots:
frame_slot.change(
fn=update_gallery,
inputs=frame_slots,
outputs=image_preview,
)
ui_inputs = (
frame_slots
+ [
prompt,
negative_prompt_input,
duration_seconds_input,
steps_slider,
guidance_scale_input,
guidance_scale_2_input,
seed_input,
randomize_seed_checkbox,
]
)
ui_outputs = [output_video, seed_input, seeds_output]
generate_button.click(fn=generate_video_sequence, inputs=ui_inputs, outputs=ui_outputs)
if __name__ == "__main__":
app.launch(share=True)