genvidtest / app.py
smartdigitalnetworks's picture
Update app.py
5d0bf06 verified
raw
history blame
17.2 kB
import spaces
import torch
from diffusers.pipelines.wan.pipeline_wan_i2v import WanImageToVideoPipeline
from diffusers.models.transformers.transformer_wan import WanTransformer3DModel
from diffusers.utils.export_utils import export_to_video
import gradio as gr
import tempfile
import numpy as np
from PIL import Image
import random
import gc
from torchao.quantization import quantize_
from torchao.quantization import Float8DynamicActivationFloat8WeightConfig
from torchao.quantization import Int8WeightOnlyConfig
import aoti
MODEL_ID = "Wan-AI/Wan2.2-I2V-A14B-Diffusers"
MAX_DIM = 832
MIN_DIM = 480
SQUARE_DIM = 640
MULTIPLE_OF = 16
MAX_SEED = np.iinfo(np.int32).max
FIXED_FPS = 16
MIN_FRAMES_MODEL = 8
MAX_FRAMES_MODEL = 80
MIN_DURATION = round(MIN_FRAMES_MODEL/FIXED_FPS,1)
MAX_DURATION = round(MAX_FRAMES_MODEL/FIXED_FPS,1)
device = "cuda"
# -------------------- NSFW 检测模型加载 --------------------
try:
print("Loading NSFW detector...")
from transformers import AutoProcessor, AutoModelForImageClassification
nsfw_processor = AutoProcessor.from_pretrained("Falconsai/nsfw_image_detection")
nsfw_model = AutoModelForImageClassification.from_pretrained(
"Falconsai/nsfw_image_detection"
).to(device)
print("NSFW detector loaded successfully.")
except Exception as e:
print(f"Failed to load NSFW detector: {e}")
nsfw_model = None
nsfw_processor = None
# -----------------------------------------------------------
class GenerationError(Exception):
"""Custom exception for generation errors"""
pass
def detect_image_nsfw(image: Image.Image, threshold: float = 0.5) -> bool:
"""Returns True if image is NSFW"""
inputs = nsfw_processor(images=image, return_tensors="pt").to(device)
with torch.no_grad():
outputs = nsfw_model(**inputs)
probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
nsfw_score = probs[0][1].item() # label 1 = NSFW
return nsfw_score > threshold
def detect_frames_nsfw(frames, sample_n=5, threshold=0.5):
"""
直接检测 PIL Image 列表
:param frames: 生成的 PIL 图片列表
:param sample_n: 采样图片数量(从视频中均匀抽取几帧)
:param threshold: 判定阈值
"""
if nsfw_model is None or nsfw_processor is None:
return False
num_frames = len(frames)
# 均匀采样,例如 49 帧里抽 5 帧,避免每一帧都检导致变慢
sample_indices = np.linspace(0, num_frames - 1, sample_n, dtype=int)
for idx in sample_indices:
image = frames[idx]
# 1. 核心修正:Wan2.2 的帧是 Tensor [C, H, W],且在 0-1 之间
if torch.is_tensor(image):
# 转为 float32 并搬到 CPU
img = image.detach().cpu().float()
# 置换维度:从 [C, H, W] -> [H, W, C]
if img.ndim == 3 and img.shape[0] == 3:
img = img.permute(1, 2, 0)
# 缩放数值:从 [0, 1] -> [0, 255]
if img.max() <= 1.05:
img = (img * 255).clamp(0, 255)
# 转为 Numpy 整数
image = Image.fromarray(img.numpy().astype(np.uint8))
# 2. 如果是 Numpy 数组
if isinstance(image, np.ndarray):
# 确保数值范围是 0-255
if image.max() <= 1.05:
image = (image * 255).astype(np.uint8)
# 将 Numpy 转回 PIL
image = Image.fromarray(image)
# 2. 统一转为 RGB PIL 格式
image = image.convert("RGB")
# 直接调用你之前写的 detect_nsfw
if detect_image_nsfw(image, threshold):
return True
return False
pipe = WanImageToVideoPipeline.from_pretrained(MODEL_ID,
transformer=WanTransformer3DModel.from_pretrained('cbensimon/Wan2.2-I2V-A14B-bf16-Diffusers',
subfolder='transformer',
torch_dtype=torch.bfloat16,
device_map='cuda',
),
transformer_2=WanTransformer3DModel.from_pretrained('cbensimon/Wan2.2-I2V-A14B-bf16-Diffusers',
subfolder='transformer_2',
torch_dtype=torch.bfloat16,
device_map='cuda',
),
torch_dtype=torch.bfloat16,
).to('cuda')
pipe.load_lora_weights(
"Kijai/WanVideo_comfy",
weight_name="Lightx2v/lightx2v_I2V_14B_480p_cfg_step_distill_rank128_bf16.safetensors",
adapter_name="lightx2v"
)
kwargs_lora = {}
kwargs_lora["load_into_transformer_2"] = True
pipe.load_lora_weights(
"Kijai/WanVideo_comfy",
weight_name="Lightx2v/lightx2v_I2V_14B_480p_cfg_step_distill_rank128_bf16.safetensors",
adapter_name="lightx2v_2", **kwargs_lora
)
pipe.set_adapters(["lightx2v", "lightx2v_2"], adapter_weights=[1., 1.])
pipe.fuse_lora(adapter_names=["lightx2v"], lora_scale=3., components=["transformer"])
pipe.fuse_lora(adapter_names=["lightx2v_2"], lora_scale=1., components=["transformer_2"])
pipe.unload_lora_weights()
quantize_(pipe.text_encoder, Int8WeightOnlyConfig())
quantize_(pipe.transformer, Float8DynamicActivationFloat8WeightConfig())
quantize_(pipe.transformer_2, Float8DynamicActivationFloat8WeightConfig())
aoti.aoti_blocks_load(pipe.transformer, 'zerogpu-aoti/Wan2', variant='fp8da')
aoti.aoti_blocks_load(pipe.transformer_2, 'zerogpu-aoti/Wan2', variant='fp8da')
default_prompt_i2v = "make this image come alive, cinematic motion, smooth animation"
default_negative_prompt = "色调艳丽, 过曝, 静态, 细节模糊不清, 字幕, 风格, 作品, 画作, 画面, 静止, 整体发灰, 最差质量, 低质量, JPEG压缩残留, 丑陋的, 残缺的, 多余的手指, 画得不好的手部, 画得不好的脸部, 畸形的, 毁容的, 形态畸形的肢体, 手指融合, 静止不动的画面, 杂乱的背景, 三条腿, 背景人很多, 倒着走"
def resize_image(image: Image.Image) -> Image.Image:
"""
Resizes an image to fit within the model's constraints, preserving aspect ratio as much as possible.
"""
width, height = image.size
# Handle square case
if width == height:
return image.resize((SQUARE_DIM, SQUARE_DIM), Image.LANCZOS)
aspect_ratio = width / height
MAX_ASPECT_RATIO = MAX_DIM / MIN_DIM
MIN_ASPECT_RATIO = MIN_DIM / MAX_DIM
image_to_resize = image
if aspect_ratio > MAX_ASPECT_RATIO:
# Very wide image -> crop width to fit 832x480 aspect ratio
target_w, target_h = MAX_DIM, MIN_DIM
crop_width = int(round(height * MAX_ASPECT_RATIO))
left = (width - crop_width) // 2
image_to_resize = image.crop((left, 0, left + crop_width, height))
elif aspect_ratio < MIN_ASPECT_RATIO:
# Very tall image -> crop height to fit 480x832 aspect ratio
target_w, target_h = MIN_DIM, MAX_DIM
crop_height = int(round(width / MIN_ASPECT_RATIO))
top = (height - crop_height) // 2
image_to_resize = image.crop((0, top, width, top + crop_height))
else:
if width > height: # Landscape
target_w = MAX_DIM
target_h = int(round(target_w / aspect_ratio))
else: # Portrait
target_h = MAX_DIM
target_w = int(round(target_h * aspect_ratio))
final_w = round(target_w / MULTIPLE_OF) * MULTIPLE_OF
final_h = round(target_h / MULTIPLE_OF) * MULTIPLE_OF
final_w = max(MIN_DIM, min(MAX_DIM, final_w))
final_h = max(MIN_DIM, min(MAX_DIM, final_h))
return image_to_resize.resize((final_w, final_h), Image.LANCZOS)
def get_num_frames(duration_seconds: float):
return 1 + int(np.clip(
int(round(duration_seconds * FIXED_FPS)),
MIN_FRAMES_MODEL,
MAX_FRAMES_MODEL,
))
def get_duration(
input_image,
prompt,
steps,
negative_prompt,
duration_seconds,
guidance_scale,
guidance_scale_2,
seed,
randomize_seed,
):
# 首先检查 input_image 是否为 None
if input_image is None:
# 返回一个较小的预估时间,因为实际上会提前在 _generate_video 中报错
return 10
BASE_FRAMES_HEIGHT_WIDTH = 81 * 832 * 624
BASE_STEP_DURATION = 15
width, height = resize_image(input_image).size
frames = get_num_frames(duration_seconds)
factor = frames * width * height / BASE_FRAMES_HEIGHT_WIDTH
step_duration = BASE_STEP_DURATION * factor ** 1.5
return 10 + int(steps) * step_duration
progress=gr.Progress()
@spaces.GPU(duration=get_duration)
def _generate_video(
input_image,
prompt,
steps = 4,
negative_prompt=default_negative_prompt,
duration_seconds = MAX_DURATION,
guidance_scale = 1,
guidance_scale_2 = 1,
seed = 42,
randomize_seed = False,
):
progress(0,desc="Starting")
def callback_fn(pipe, step, timestep, callback_kwargs):
print(f"[Step {step}] Timestep: {timestep}")
progress_value = (step+1.0)/steps
progress(progress_value, desc=f"Video generating, {step + 1}/{steps} steps")
return callback_kwargs
try:
if input_image is None:
raise gr.Error("Please upload an input image.")
# NSFW 检测
if nsfw_model and nsfw_processor:
if detect_image_nsfw(input_image):
msg = "The input contains NSFW content and cannot be used. Please modify the prompt and try again."
raise Exception(msg)
num_frames = get_num_frames(duration_seconds)
current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
resized_image = resize_image(input_image)
output_frames_list = pipe(
image=resized_image,
prompt=prompt,
negative_prompt=negative_prompt,
height=resized_image.height,
width=resized_image.width,
num_frames=num_frames,
guidance_scale=float(guidance_scale),
guidance_scale_2=float(guidance_scale_2),
num_inference_steps=int(steps),
generator=torch.Generator(device="cuda").manual_seed(current_seed),
callback_on_step_end=callback_fn,
).frames[0]
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
video_path = tmpfile.name
# NSFW 检测
if nsfw_model and nsfw_processor:
num_total_frames = len(output_frames_list)
sample_count = max(1, int(num_total_frames // (FIXED_FPS / 2)))
if detect_frames_nsfw(output_frames_list,sample_count,0.6):
msg = "Generated video contains NSFW content and cannot be displayed. Please modify the input image or prompt and try again."
raise Exception(msg)
export_to_video(output_frames_list, video_path, fps=FIXED_FPS)
progress(1, desc="Complete")
info = {
"status": "success"
}
return info, video_path, current_seed
except GenerationError as e:
error_info = {
"error": str(e),
"status": "failed",
}
return error_info, None, None
except Exception as e:
error_info = {
"error": str(e),
"status": "failed",
}
return error_info, None, None
finally:
# Cleanup
torch.cuda.empty_cache()
gc.collect()
def generate_video(
input_image,
prompt,
steps = 4,
negative_prompt=default_negative_prompt,
duration_seconds = MAX_DURATION,
guidance_scale = 1,
guidance_scale_2 = 1,
seed = 42,
randomize_seed = False,
):
# 调用 GPU 函数
info, video_path, current_seed = _generate_video(input_image,prompt,steps,negative_prompt,duration_seconds,guidance_scale,guidance_scale_2,seed,randomize_seed)
# 如果出错,抛出异常
if info["status"] == "failed":
raise gr.Error(info["error"])
# 返回图片
return video_path, current_seed
css="""
#col-container {
margin: 0 auto;
max-width: 1200px;
}
"""
title = "# AI Video Maker"
description = "Bring your imagination to motion with AI Video Maker. Simply upload a photo and provide a prompt to generate smooth, high-quality animations for your content creation. "
note = "*Note: This demo has a daily usage cap. If you have reached the limit or need faster rendering, please visit [AI Image to Video](https://www.imgtovideo.ai/) to continue animating without interruptions.*"
with gr.Blocks(css=css).queue() as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown(title)
gr.Markdown(description)
gr.Markdown(note)
with gr.Row():
with gr.Column():
gr.Markdown("### Input")
input_image_component = gr.Image(type="pil", label="Input Image")
prompt_input = gr.Textbox(label="Prompt", value=default_prompt_i2v, lines=3)
duration_seconds_input = gr.Slider(minimum=MIN_DURATION, maximum=MAX_DURATION, step=0.1, value=3.5, label="Duration (seconds)", info=f"Clamped to model's {MIN_FRAMES_MODEL}-{MAX_FRAMES_MODEL} frames at {FIXED_FPS}fps.")
with gr.Accordion("Advanced Settings", open=False):
negative_prompt_input = gr.Textbox(label="Negative Prompt", value=default_negative_prompt, lines=3)
seed_input = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42, interactive=True)
randomize_seed_checkbox = gr.Checkbox(label="Randomize seed", value=True, interactive=True)
steps_slider = gr.Slider(minimum=1, maximum=30, step=1, value=6, label="Inference Steps")
guidance_scale_input = gr.Slider(minimum=0.0, maximum=10.0, step=0.5, value=1, label="Guidance Scale - high noise stage")
guidance_scale_2_input = gr.Slider(minimum=0.0, maximum=10.0, step=0.5, value=1, label="Guidance Scale 2 - low noise stage")
generate_button = gr.Button("Generate Video", variant="primary")
with gr.Column():
gr.Markdown("### Output")
video_output = gr.Video(label="Generated Video", autoplay=True, interactive=False)
ui_inputs = [
input_image_component, prompt_input, steps_slider,
negative_prompt_input, duration_seconds_input,
guidance_scale_input, guidance_scale_2_input, seed_input, randomize_seed_checkbox
]
generate_button.click(fn=generate_video, inputs=ui_inputs, outputs=[video_output, seed_input])
gr.Examples(
examples=[
[
"wan_i2v_input.JPG",
"POV selfie video, white cat with sunglasses standing on surfboard, relaxed smile, tropical beach behind (clear water, green hills, blue sky with clouds). Surfboard tips, cat falls into ocean, camera plunges underwater with bubbles and sunlight beams. Brief underwater view of cat's face, then cat resurfaces, still filming selfie, playful summer vacation mood.",
4,
],
[
"wan22_input_2.jpg",
"A sleek lunar vehicle glides into view from left to right, kicking up moon dust as astronauts in white spacesuits hop aboard with characteristic lunar bouncing movements. In the distant background, a VTOL craft descends straight down and lands silently on the surface. Throughout the entire scene, ethereal aurora borealis ribbons dance across the star-filled sky, casting shimmering curtains of green, blue, and purple light that bathe the lunar landscape in an otherworldly, magical glow.",
4,
],
[
"kill_bill.jpeg",
"Uma Thurman's character, Beatrix Kiddo, holds her razor-sharp katana blade steady in the cinematic lighting. Suddenly, the polished steel begins to soften and distort, like heated metal starting to lose its structural integrity. The blade's perfect edge slowly warps and droops, molten steel beginning to flow downward in silvery rivulets while maintaining its metallic sheen. The transformation starts subtly at first - a slight bend in the blade - then accelerates as the metal becomes increasingly fluid. The camera holds steady on her face as her piercing eyes gradually narrow, not with lethal focus, but with confusion and growing alarm as she watches her weapon dissolve before her eyes. Her breathing quickens slightly as she witnesses this impossible transformation. The melting intensifies, the katana's perfect form becoming increasingly abstract, dripping like liquid mercury from her grip. Molten droplets fall to the ground with soft metallic impacts. Her expression shifts from calm readiness to bewilderment and concern as her legendary instrument of vengeance literally liquefies in her hands, leaving her defenseless and disoriented.",
6,
],
],
inputs=[input_image_component, prompt_input, steps_slider],
outputs=[video_output, seed_input],
fn=generate_video,
cache_examples=True,
cache_mode="lazy"
)
if __name__ == "__main__":
demo.launch()