wanflf / app.py
Revrse's picture
Update app.py
ee64f8b verified
import os
# PyTorch 2.8 (temporary hack)
os.system('pip install --upgrade --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu126 "torch<2.9" spaces')
# --- 1. Model Download and Setup (Diffusers Backend) ---
import spaces
import torch
from diffusers import FlowMatchEulerDiscreteScheduler
from diffusers.pipelines.wan.pipeline_wan_i2v import WanImageToVideoPipeline
from diffusers.models.transformers.transformer_wan import WanTransformer3DModel
from diffusers.utils.export_utils import export_to_video
import gradio as gr
import tempfile
import numpy as np
from PIL import Image
import random
import gc
from gradio_client import Client, handle_file # Import for API call
# Import the optimization function from the separate file
from optimization import optimize_pipeline_
# --- Constants and Model Loading ---
MODEL_ID = "Wan-AI/Wan2.2-I2V-A14B-Diffusers"
# --- NEW: Flexible Dimension Constants ---
MAX_DIMENSION = 832
MIN_DIMENSION = 480
DIMENSION_MULTIPLE = 16
SQUARE_SIZE = 480
MAX_SEED = np.iinfo(np.int32).max
FIXED_FPS = 16
MIN_FRAMES_MODEL = 8
MAX_FRAMES_MODEL = 81
MIN_DURATION = round(MIN_FRAMES_MODEL/FIXED_FPS, 1)
MAX_DURATION = round(MAX_FRAMES_MODEL/FIXED_FPS, 1)
default_negative_prompt = "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走,过曝,"
print("Loading models into memory. This may take a few minutes...")
pipe = WanImageToVideoPipeline.from_pretrained(
MODEL_ID,
transformer=WanTransformer3DModel.from_pretrained('cbensimon/Wan2.2-I2V-A14B-bf16-Diffusers',
subfolder='transformer',
torch_dtype=torch.bfloat16,
device_map='cuda',
),
transformer_2=WanTransformer3DModel.from_pretrained('cbensimon/Wan2.2-I2V-A14B-bf16-Diffusers',
subfolder='transformer_2',
torch_dtype=torch.bfloat16,
device_map='cuda',
),
torch_dtype=torch.bfloat16,
)
pipe.scheduler = FlowMatchEulerDiscreteScheduler.from_config(pipe.scheduler.config, shift=8.0)
pipe.to('cuda')
print("Optimizing pipeline...")
for i in range(3):
gc.collect()
torch.cuda.synchronize()
torch.cuda.empty_cache()
optimize_pipeline_(pipe,
image=Image.new('RGB', (MAX_DIMENSION, MIN_DIMENSION)),
prompt='prompt',
height=MIN_DIMENSION,
width=MAX_DIMENSION,
num_frames=MAX_FRAMES_MODEL,
)
print("All models loaded and optimized. Gradio app is ready.")
# --- 2. Image Processing and Application Logic ---
def generate_end_frame(start_img, gen_prompt, progress=gr.Progress(track_tqdm=True)):
"""Calls an external Gradio API to generate an image."""
if start_img is None:
raise gr.Error("Please provide a Start Frame first.")
hf_token = os.getenv("HF_TOKEN")
if not hf_token:
raise gr.Error("HF_TOKEN not found in environment variables. Please set it in your Space secrets.")
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as tmpfile:
start_img.save(tmpfile.name)
tmp_path = tmpfile.name
progress(0.1, desc="Connecting to image generation API...")
client = Client("multimodalart/nano-banana-private")
progress(0.5, desc=f"Generating with prompt: '{gen_prompt}'...")
try:
result = client.predict(
prompt=gen_prompt,
images=[
{"image": handle_file(tmp_path)}
],
manual_token=hf_token,
api_name="/unified_image_generator"
)
finally:
os.remove(tmp_path)
progress(1.0, desc="Done!")
print(result)
return result
def switch_to_upload_tab():
"""Returns a gr.Tabs update to switch to the first tab."""
return gr.Tabs(selected="upload_tab")
def process_image_for_video(image: Image.Image) -> Image.Image:
"""
Resizes an image based on the following rules for video generation:
1. The longest side will be scaled down to MAX_DIMENSION if it's larger.
2. The shortest side will be scaled up to MIN_DIMENSION if it's smaller.
3. The final dimensions will be rounded to the nearest multiple of DIMENSION_MULTIPLE.
4. Square images are resized to a fixed SQUARE_SIZE.
The aspect ratio is preserved as closely as possible.
"""
width, height = image.size
# Rule 4: Handle square images
if width == height:
return image.resize((SQUARE_SIZE, SQUARE_SIZE), Image.Resampling.LANCZOS)
# Determine target dimensions while preserving aspect ratio
aspect_ratio = width / height
new_width, new_height = width, height
# Rule 1: Scale down if too large
if new_width > MAX_DIMENSION or new_height > MAX_DIMENSION:
if aspect_ratio > 1: # Landscape
scale = MAX_DIMENSION / new_width
else: # Portrait
scale = MAX_DIMENSION / new_height
new_width *= scale
new_height *= scale
# Rule 2: Scale up if too small
if new_width < MIN_DIMENSION or new_height < MIN_DIMENSION:
if aspect_ratio > 1: # Landscape
scale = MIN_DIMENSION / new_height
else: # Portrait
scale = MIN_DIMENSION / new_width
new_width *= scale
new_height *= scale
# Rule 3: Round to the nearest multiple of DIMENSION_MULTIPLE
final_width = int(round(new_width / DIMENSION_MULTIPLE) * DIMENSION_MULTIPLE)
final_height = int(round(new_height / DIMENSION_MULTIPLE) * DIMENSION_MULTIPLE)
# Ensure final dimensions are at least the minimum
final_width = max(final_width, MIN_DIMENSION if aspect_ratio < 1 else SQUARE_SIZE)
final_height = max(final_height, MIN_DIMENSION if aspect_ratio > 1 else SQUARE_SIZE)
return image.resize((final_width, final_height), Image.Resampling.LANCZOS)
def resize_and_crop_to_match(target_image, reference_image):
"""Resizes and center-crops the target image to match the reference image's dimensions."""
ref_width, ref_height = reference_image.size
target_width, target_height = target_image.size
scale = max(ref_width / target_width, ref_height / target_height)
new_width, new_height = int(target_width * scale), int(target_height * scale)
resized = target_image.resize((new_width, new_height), Image.Resampling.LANCZOS)
left, top = (new_width - ref_width) // 2, (new_height - ref_height) // 2
return resized.crop((left, top, left + ref_width, top + ref_height))
@spaces.GPU(duration=120)
def generate_video(
start_image_pil,
end_image_pil,
prompt,
negative_prompt=default_negative_prompt,
duration_seconds=2.1,
steps=8,
guidance_scale=1,
guidance_scale_2=1,
seed=42,
randomize_seed=False,
progress=gr.Progress(track_tqdm=True)
):
"""
Generates a video by interpolating between a start and end image, guided by a text prompt,
using the diffusers Wan2.2 pipeline.
"""
if start_image_pil is None or end_image_pil is None:
raise gr.Error("Please upload both a start and an end image.")
progress(0.1, desc="Preprocessing images...")
# Step 1: Process the start image to get our target dimensions based on the new rules.
processed_start_image = process_image_for_video(start_image_pil)
# Step 2: Make the end image match the *exact* dimensions of the processed start image.
processed_end_image = resize_and_crop_to_match(end_image_pil, processed_start_image)
target_height, target_width = processed_start_image.height, processed_start_image.width
# Handle seed and frame count
current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
num_frames = np.clip(int(round(duration_seconds * FIXED_FPS)), MIN_FRAMES_MODEL, MAX_FRAMES_MODEL)
progress(0.2, desc=f"Generating {num_frames} frames at {target_width}x{target_height} (seed: {current_seed})...")
output_frames_list = pipe(
image=processed_start_image,
last_image=processed_end_image,
prompt=prompt,
negative_prompt=negative_prompt,
height=target_height,
width=target_width,
num_frames=num_frames,
guidance_scale=float(guidance_scale),
guidance_scale_2=float(guidance_scale_2),
num_inference_steps=int(steps),
generator=torch.Generator(device="cuda").manual_seed(current_seed),
).frames[0]
progress(0.9, desc="Encoding and saving video...")
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
video_path = tmpfile.name
export_to_video(output_frames_list, video_path, fps=FIXED_FPS)
progress(1.0, desc="Done!")
return video_path, current_seed
# --- 3. Simplified Gradio User Interface (Examples removed) ---
css = '''
.fillable{max-width: 1100px !important}
.section-title {font-size: 20px; margin-bottom: 8px;}
.kv {margin-bottom: 8px;}
.controls {gap: 8px;}
'''
with gr.Blocks(theme=gr.themes.Citrus(), css=css) as app:
# Header
with gr.Row():
gr.Markdown("### Wan 2.2 — First & Last Frame Video (Diffusers)")
gr.Markdown("Compact UI — examples removed.")
# Main layout: left inputs, right preview
with gr.Row():
with gr.Column(scale=6):
gr.Markdown("<div class='section-title'>Inputs</div>", elem_id="inputs_title")
# fixed
start_image = gr.Image(type="pil", label="Start Frame", sources=["upload", "clipboard"], elem_classes=["kv"])
end_image = gr.Image(type="pil", label="End Frame", sources=["upload", "clipboard"], elem_classes=["kv"])
prompt = gr.Textbox(label="Prompt", placeholder="Describe the transition between frames", lines=2, elem_classes=["kv"])
# Quick generate button that creates an end-frame 5s after start
with gr.Row():
generate_5seconds = gr.Button("Generate End Frame (5s later)", elem_classes=["kv"])
generate_button = gr.Button("Generate Video", variant="primary", elem_classes=["kv"])
# Advanced settings collapsed in an accordion to keep UI lean
with gr.Accordion("Advanced Settings (click to open)", open=False):
duration_seconds_input = gr.Slider(minimum=MIN_DURATION, maximum=MAX_DURATION, step=0.1, value=5, label="Video Duration (s)")
negative_prompt_input = gr.Textbox(label="Negative Prompt", value=default_negative_prompt, lines=3)
steps_slider = gr.Slider(minimum=1, maximum=30, step=1, value=8, label="Inference Steps")
guidance_scale_input = gr.Slider(minimum=0.0, maximum=10.0, step=0.5, value=1.0, label="Guidance Scale - high noise")
guidance_scale_2_input = gr.Slider(minimum=0.0, maximum=10.0, step=0.5, value=1.0, label="Guidance Scale - low noise")
with gr.Row():
seed_input = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42)
randomize_seed_checkbox = gr.Checkbox(label="Randomize seed", value=True)
with gr.Column(scale=6):
gr.Markdown("<div class='section-title'>Output</div>", elem_id="output_title")
output_video = gr.Video(label="Generated Video", autoplay=True)
seed_display = gr.Textbox(label="Seed Used (for reproducibility)", interactive=False)
# Hook up events
ui_inputs = [
start_image,
end_image,
prompt,
negative_prompt_input,
duration_seconds_input,
steps_slider,
guidance_scale_input,
guidance_scale_2_input,
seed_input,
randomize_seed_checkbox
]
ui_outputs = [output_video, seed_input]
generate_button.click(
fn=generate_video,
inputs=ui_inputs,
outputs=ui_outputs
)
# Generate an end-frame from the start image, then switch back to input and populate End Frame
generate_5seconds.click(
fn=switch_to_upload_tab,
inputs=None,
outputs=None
).then(
fn=lambda img: generate_end_frame(img, "this image is a still frame from a movie. generate a new frame with what happens on this scene 5 seconds in the future"),
inputs=[start_image],
outputs=[end_image]
)
if __name__ == "__main__":
app.launch(share=True)