|
|
import os |
|
|
import sys |
|
|
import subprocess |
|
|
import traceback |
|
|
import gc |
|
|
import tempfile |
|
|
import random |
|
|
import time |
|
|
from pathlib import Path |
|
|
|
|
|
os.system("pip install spaces-0.1.0-py3-none-any.whl moviepy==1.0.3 imageio[ffmpeg] librosa soundfile accelerate") |
|
|
os.system("pip install git+https://github.com/tolgacangoz/diffusers.git") |
|
|
|
|
|
import spaces |
|
|
import torch |
|
|
import numpy as np |
|
|
import librosa |
|
|
import soundfile as sf |
|
|
from PIL import Image |
|
|
from moviepy.editor import VideoFileClip, concatenate_videoclips |
|
|
from huggingface_hub import snapshot_download |
|
|
import gradio as gr |
|
|
|
|
|
try: |
|
|
import diffusers |
|
|
from diffusers import AutoencoderKLWan, WanPipeline, WanImageToVideoPipeline, UniPCMultistepScheduler, WanSpeechToVideoPipeline |
|
|
from diffusers.utils import export_to_video |
|
|
except ImportError: |
|
|
pass |
|
|
|
|
|
MODEL_ID_TI2V = "FastVideo/FastWan2.2-TI2V-5B-FullAttn-Diffusers" |
|
|
MODEL_ID_S2V = "tolgacangoz/Wan2.2-S2V-14B-Diffusers" |
|
|
|
|
|
MODELS = { |
|
|
"ti2v_text": None, |
|
|
"ti2v_image": None, |
|
|
"s2v": None |
|
|
} |
|
|
|
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
|
def load_models_at_startup(): |
|
|
global MODELS |
|
|
|
|
|
try: |
|
|
vae = AutoencoderKLWan.from_pretrained(MODEL_ID_TI2V, subfolder="vae", torch_dtype=torch.float32) |
|
|
|
|
|
text_pipe = WanPipeline.from_pretrained(MODEL_ID_TI2V, vae=vae, torch_dtype=torch.bfloat16) |
|
|
text_pipe.scheduler = UniPCMultistepScheduler.from_config(text_pipe.scheduler.config, flow_shift=8.0) |
|
|
|
|
|
try: |
|
|
if DEVICE == "cuda": |
|
|
text_pipe.enable_model_cpu_offload() |
|
|
else: |
|
|
text_pipe.to(DEVICE) |
|
|
except RuntimeError: |
|
|
text_pipe.to("cpu") |
|
|
|
|
|
MODELS["ti2v_text"] = text_pipe |
|
|
|
|
|
image_pipe = WanImageToVideoPipeline.from_pretrained(MODEL_ID_TI2V, vae=vae, torch_dtype=torch.bfloat16) |
|
|
image_pipe.scheduler = UniPCMultistepScheduler.from_config(image_pipe.scheduler.config, flow_shift=8.0) |
|
|
|
|
|
try: |
|
|
if DEVICE == "cuda": |
|
|
image_pipe.enable_model_cpu_offload() |
|
|
else: |
|
|
image_pipe.to(DEVICE) |
|
|
except RuntimeError: |
|
|
image_pipe.to("cpu") |
|
|
|
|
|
MODELS["ti2v_image"] = image_pipe |
|
|
|
|
|
except Exception as e: |
|
|
pass |
|
|
|
|
|
try: |
|
|
s2v_pipe = WanSpeechToVideoPipeline.from_pretrained( |
|
|
MODEL_ID_S2V, |
|
|
torch_dtype=torch.bfloat16 |
|
|
) |
|
|
try: |
|
|
if DEVICE == "cuda": |
|
|
s2v_pipe.enable_model_cpu_offload() |
|
|
else: |
|
|
s2v_pipe.to(DEVICE) |
|
|
except RuntimeError: |
|
|
s2v_pipe.to("cpu") |
|
|
|
|
|
MODELS["s2v"] = s2v_pipe |
|
|
except Exception as e: |
|
|
pass |
|
|
|
|
|
load_models_at_startup() |
|
|
|
|
|
def auto_duration_estimator(mode, input_data, duration_val): |
|
|
base_overhead = 45 |
|
|
if mode == "s2v": |
|
|
audio_path = input_data |
|
|
if audio_path: |
|
|
try: |
|
|
dur = librosa.get_duration(filename=audio_path) |
|
|
return int(base_overhead + (dur * 15)) |
|
|
except: |
|
|
return 120 |
|
|
return 120 |
|
|
else: |
|
|
num_images = len(input_data) if input_data else 0 |
|
|
if num_images > 0: |
|
|
total_seconds = max(duration_val, num_images * 2) |
|
|
else: |
|
|
total_seconds = duration_val |
|
|
return int(base_overhead + (total_seconds * 12)) |
|
|
|
|
|
def fast_stitch_videos(video_paths): |
|
|
if not video_paths: return None |
|
|
if len(video_paths) == 1: return video_paths[0] |
|
|
|
|
|
try: |
|
|
with tempfile.NamedTemporaryFile(mode='w', suffix='.txt', delete=False) as f: |
|
|
for path in video_paths: |
|
|
f.write(f"file '{path}'\n") |
|
|
list_path = f.name |
|
|
|
|
|
with tempfile.NamedTemporaryFile(suffix="_stitched_stream.mp4", delete=False) as tmp: |
|
|
out_path = tmp.name |
|
|
|
|
|
cmd = [ |
|
|
"ffmpeg", "-y", "-f", "concat", "-safe", "0", |
|
|
"-i", list_path, "-c", "copy", out_path |
|
|
] |
|
|
subprocess.run(cmd, check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) |
|
|
os.remove(list_path) |
|
|
return out_path |
|
|
except: |
|
|
return video_paths[-1] |
|
|
|
|
|
@spaces(duration=lambda *args: auto_duration_estimator("ti2v", args[0], args[5])) |
|
|
def generate_ti2v_gpu_stream(input_files, prompt, height, width, negative_prompt, duration_seconds, guidance_scale, steps, seed, randomize_seed, progress=gr.Progress(track_tqdm=True)): |
|
|
global MODELS |
|
|
text_to_video_pipe = MODELS.get("ti2v_text") |
|
|
image_to_video_pipe = MODELS.get("ti2v_image") |
|
|
|
|
|
if not text_to_video_pipe or not image_to_video_pipe: |
|
|
raise gr.Error("Models failed to load at startup.") |
|
|
|
|
|
MOD_VALUE = 32 |
|
|
target_h = max(MOD_VALUE, (int(height) // MOD_VALUE) * MOD_VALUE) |
|
|
target_w = max(MOD_VALUE, (int(width) // MOD_VALUE) * MOD_VALUE) |
|
|
|
|
|
master_seed = random.randint(0, 2**32 - 1) if randomize_seed else int(seed) |
|
|
|
|
|
video_clips_paths = [] |
|
|
pil_images = [] |
|
|
|
|
|
if input_files: |
|
|
files_list = input_files if isinstance(input_files, list) else [input_files] |
|
|
for f in files_list: |
|
|
try: |
|
|
path = f.name if hasattr(f, "name") else f |
|
|
img = Image.open(path).convert("RGB") |
|
|
pil_images.append(img) |
|
|
except: |
|
|
continue |
|
|
|
|
|
SAFE_CHUNK_DURATION = 4.0 |
|
|
FIXED_FPS = 24 |
|
|
|
|
|
last_preview_frame = None |
|
|
|
|
|
if len(pil_images) > 0: |
|
|
seconds_per_image = max(2.0, duration_seconds / len(pil_images)) |
|
|
|
|
|
for i, img in enumerate(pil_images): |
|
|
current_chunk_duration = min(seconds_per_image, SAFE_CHUNK_DURATION) |
|
|
num_frames = int(current_chunk_duration * FIXED_FPS) |
|
|
|
|
|
local_seed = master_seed + i |
|
|
generator = torch.Generator(device=DEVICE).manual_seed(local_seed) |
|
|
resized_image = img.resize((target_w, target_h)) |
|
|
|
|
|
try: |
|
|
with torch.inference_mode(): |
|
|
output_frames = image_to_video_pipe( |
|
|
image=resized_image, |
|
|
prompt=prompt, |
|
|
negative_prompt=negative_prompt, |
|
|
height=target_h, |
|
|
width=target_w, |
|
|
num_frames=num_frames, |
|
|
guidance_scale=float(guidance_scale), |
|
|
num_inference_steps=int(steps), |
|
|
generator=generator |
|
|
).frames[0] |
|
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=f"_img_{i}.mp4", delete=False) as tmp: |
|
|
export_to_video(output_frames, tmp.name, fps=FIXED_FPS) |
|
|
video_clips_paths.append(tmp.name) |
|
|
|
|
|
if len(output_frames) > 0: |
|
|
last_preview_frame = output_frames[-1] |
|
|
|
|
|
current_stitched = fast_stitch_videos(video_clips_paths) |
|
|
yield current_stitched, last_preview_frame, master_seed |
|
|
|
|
|
except Exception: |
|
|
continue |
|
|
else: |
|
|
num_chunks = int(np.ceil(duration_seconds / SAFE_CHUNK_DURATION)) |
|
|
frames_per_chunk = int(SAFE_CHUNK_DURATION * FIXED_FPS) |
|
|
|
|
|
for i in range(num_chunks): |
|
|
chunk_seed = master_seed + (i * 100) |
|
|
generator = torch.Generator(device=DEVICE).manual_seed(chunk_seed) |
|
|
|
|
|
with torch.inference_mode(): |
|
|
output_frames = text_to_video_pipe( |
|
|
prompt=prompt, |
|
|
negative_prompt=negative_prompt, |
|
|
height=target_h, |
|
|
width=target_w, |
|
|
num_frames=frames_per_chunk, |
|
|
guidance_scale=float(guidance_scale), |
|
|
num_inference_steps=int(steps), |
|
|
generator=generator |
|
|
).frames[0] |
|
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=f"_chunk_{i}.mp4", delete=False) as tmp: |
|
|
export_to_video(output_frames, tmp.name, fps=FIXED_FPS) |
|
|
video_clips_paths.append(tmp.name) |
|
|
|
|
|
if len(output_frames) > 0: |
|
|
last_preview_frame = output_frames[-1] |
|
|
|
|
|
current_stitched = fast_stitch_videos(video_clips_paths) |
|
|
yield current_stitched, last_preview_frame, master_seed |
|
|
|
|
|
def merge_audio_video(video_path, audio_path, output_path): |
|
|
cmd = [ |
|
|
"ffmpeg", "-y", |
|
|
"-i", video_path, |
|
|
"-i", audio_path, |
|
|
"-c:v", "copy", |
|
|
"-c:a", "aac", |
|
|
"-map", "0:v:0", "-map", "1:a:0", |
|
|
"-shortest", |
|
|
output_path |
|
|
] |
|
|
subprocess.run(cmd, check=True) |
|
|
return output_path |
|
|
|
|
|
def load_audio_for_model(audio_filepath): |
|
|
try: |
|
|
wav, sr = librosa.load(audio_filepath, sr=16000) |
|
|
return wav, sr |
|
|
except: |
|
|
return None, None |
|
|
|
|
|
@spaces(duration=lambda *args: auto_duration_estimator("s2v", args[1], 0)) |
|
|
def generate_s2v_gpu(image_input, audio_filepath, prompt, seed, randomize_seed): |
|
|
global MODELS |
|
|
pipe = MODELS.get("s2v") |
|
|
if not pipe: |
|
|
raise gr.Error("S2V Model not initialized.") |
|
|
|
|
|
if image_input is None or audio_filepath is None: |
|
|
raise gr.Error("Inputs Missing") |
|
|
|
|
|
audio_values, sample_rate = load_audio_for_model(audio_filepath) |
|
|
if audio_values is None: |
|
|
raise gr.Error("Invalid Audio") |
|
|
|
|
|
init_image = image_input.convert("RGB") |
|
|
w, h = init_image.size |
|
|
w = (w // 16) * 16 |
|
|
h = (h // 16) * 16 |
|
|
init_image = init_image.resize((w, h), Image.LANCZOS) |
|
|
|
|
|
current_seed = random.randint(0, 2**32 - 1) if randomize_seed else int(seed) |
|
|
generator = torch.Generator(device=DEVICE).manual_seed(current_seed) |
|
|
|
|
|
with torch.inference_mode(): |
|
|
out = pipe( |
|
|
image=init_image, |
|
|
audio=audio_values, |
|
|
num_inference_steps=25, |
|
|
guidance_scale=4.0, |
|
|
sampling_rate=sample_rate, |
|
|
prompt=prompt, |
|
|
generator=generator |
|
|
) |
|
|
|
|
|
frames = out.frames[0] |
|
|
|
|
|
with tempfile.NamedTemporaryFile(suffix="_temp_mute.mp4", delete=False) as tmp_vid: |
|
|
temp_mute_path = tmp_vid.name |
|
|
|
|
|
with tempfile.NamedTemporaryFile(suffix="_output_s2v.mp4", delete=False) as tmp_final: |
|
|
final_video_path = tmp_final.name |
|
|
|
|
|
export_to_video(frames, temp_mute_path, fps=30) |
|
|
final_output = merge_audio_video(temp_mute_path, audio_filepath, final_video_path) |
|
|
|
|
|
return final_output, current_seed |
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft()) as demo: |
|
|
gr.Markdown("# Wan 2.2 Unified Streaming Video Platform") |
|
|
|
|
|
with gr.Tabs(): |
|
|
with gr.TabItem("Text & Image to Video (Streaming & Long Duration)"): |
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
ti2v_files = gr.File(label="Input Images", file_count="multiple", type="filepath", file_types=["image"]) |
|
|
ti2v_prompt = gr.Textbox(label="Prompt", value="Cinematic view, realistic lighting, 4k", lines=2) |
|
|
ti2v_duration = gr.Slider(minimum=2, maximum=300, step=1, value=5, label="Total Duration (s)") |
|
|
|
|
|
with gr.Accordion("Advanced", open=False): |
|
|
ti2v_neg = gr.Textbox(label="Negative Prompt", value="low quality, distortion, text, watermark", lines=2) |
|
|
ti2v_seed = gr.Slider(label="Seed", minimum=0, maximum=2**32-1, step=1, value=42) |
|
|
ti2v_rand = gr.Checkbox(label="Random Seed", value=True) |
|
|
with gr.Row(): |
|
|
ti2v_h = gr.Slider(256, 1024, 32, 832, label="Height") |
|
|
ti2v_w = gr.Slider(256, 1024, 32, 832, label="Width") |
|
|
ti2v_steps = gr.Slider(2, 10, 1, 4, label="Steps") |
|
|
ti2v_scale = gr.Slider(1.0, 8.0, 0.1, 5.0, label="CFG") |
|
|
|
|
|
btn_ti2v = gr.Button("Start Streaming Generation", variant="primary") |
|
|
|
|
|
with gr.Column(scale=2): |
|
|
with gr.Row(): |
|
|
out_ti2v = gr.Video(label="Live Video Stream", autoplay=True) |
|
|
out_preview_ti2v = gr.Image(label="Last Frame Preview", interactive=False) |
|
|
out_seed_ti2v = gr.Number(label="Seed Used") |
|
|
|
|
|
btn_ti2v.click( |
|
|
fn=generate_ti2v_gpu_stream, |
|
|
inputs=[ti2v_files, ti2v_prompt, ti2v_h, ti2v_w, ti2v_neg, ti2v_duration, ti2v_scale, ti2v_steps, ti2v_seed, ti2v_rand], |
|
|
outputs=[out_ti2v, out_preview_ti2v, out_seed_ti2v] |
|
|
) |
|
|
|
|
|
with gr.TabItem("Speech to Video (S2V)"): |
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
s2v_img = gr.Image(label="Reference Image", type="pil") |
|
|
s2v_audio = gr.Audio(label="Audio Input", type="filepath") |
|
|
s2v_prompt = gr.Textbox(label="Prompt", value="Realistic movement, talking face") |
|
|
s2v_seed = gr.Slider(label="Seed", minimum=0, maximum=2**32-1, step=1, value=42) |
|
|
s2v_rand = gr.Checkbox(label="Random Seed", value=True) |
|
|
btn_s2v = gr.Button("Generate S2V", variant="primary") |
|
|
|
|
|
with gr.Column(scale=2): |
|
|
out_s2v = gr.Video(label="Result") |
|
|
out_seed_s2v = gr.Number(label="Seed Used") |
|
|
|
|
|
btn_s2v.click(generate_s2v_gpu, [s2v_img, s2v_audio, s2v_prompt, s2v_seed, s2v_rand], [out_s2v, out_seed_s2v]) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.queue().launch() |