Wan2.1 / simple_app.py
Xenobd's picture
Update simple_app.py
80502d4 verified
import gradio as gr
import re
import subprocess
import time
from tqdm import tqdm
from huggingface_hub import snapshot_download
import torch
import os
# Force the device to CPU
device = torch.device("cpu")
# Download model
snapshot_download(
repo_id="Wan-AI/Wan2.1-T2V-1.3B",
local_dir="./Wan2.1-T2V-1.3B"
)
print("✅ Model downloaded successfully.")
def infer(prompt, progress=gr.Progress(track_tqdm=True)):
total_process_steps = 11
irrelevant_steps = 4
relevant_steps = total_process_steps - irrelevant_steps
overall_bar = tqdm(total=relevant_steps, desc="Overall Process", position=1,
ncols=120, dynamic_ncols=False, leave=True)
processed_steps = 0
progress_pattern = re.compile(r"(\d+)%\|.*\| (\d+)/(\d+)")
video_progress_bar = None
sub_bar = None
sub_ticks = 0
sub_tick_total = 1500
video_phase = False
# ✅ Use generate.py directly
command = [
"python", "-u", "generate.py", # <- Make sure generate.py is in the same folder
"--task", "t2v-1.3B",
"--size", "480*480",
"--ckpt_dir", "./Wan2.1-T2V-1.3B",
"--sample_shift", "8",
"--sample_guide_scale", "6",
"--prompt", prompt,
"--t5_cpu",
"--offload_model", "True",
"--save_file", "generated_video.mp4"
]
print("🚀 Starting video generation process...")
process = subprocess.Popen(
command,
stdout=subprocess.PIPE,
stderr=subprocess.STDOUT,
text=True,
bufsize=1
)
stdout = process.stdout
for line in iter(stdout.readline, ''):
stripped_line = line.strip()
print(f"[SUBPROCESS]: {stripped_line}") # Debug print
if not stripped_line:
continue
# Match video generation progress (like tqdm)
progress_match = progress_pattern.search(stripped_line)
if progress_match:
if sub_bar is not None and sub_ticks < sub_tick_total:
sub_bar.update(sub_tick_total - sub_ticks)
sub_bar.close()
overall_bar.update(1)
sub_bar = None
sub_ticks = 0
video_phase = True
current = int(progress_match.group(2))
total = int(progress_match.group(3))
if video_progress_bar is None:
video_progress_bar = tqdm(total=total, desc="Video Generation", position=0,
ncols=120, dynamic_ncols=True, leave=True)
video_progress_bar.update(current - video_progress_bar.n)
if video_progress_bar.n >= video_progress_bar.total:
video_phase = False
overall_bar.update(1)
video_progress_bar.close()
video_progress_bar = None
continue
if "INFO:" in stripped_line:
parts = stripped_line.split("INFO:", 1)
msg = parts[1].strip() if len(parts) > 1 else ""
print(f"[INFO]: {msg}")
if processed_steps < irrelevant_steps:
processed_steps += 1
continue
else:
if sub_bar is not None and sub_ticks < sub_tick_total:
sub_bar.update(sub_tick_total - sub_ticks)
sub_bar.close()
overall_bar.update(1)
sub_bar = None
sub_ticks = 0
sub_bar = tqdm(total=sub_tick_total, desc=msg, position=2,
ncols=120, dynamic_ncols=False, leave=True)
sub_ticks = 0
continue
process.wait()
# Final cleanup
if video_progress_bar is not None:
video_progress_bar.close()
if sub_bar is not None:
sub_bar.close()
overall_bar.close()
# ✅ Output validation
if process.returncode == 0:
if os.path.exists("generated_video.mp4"):
print("✅ Video generation completed successfully.")
return "generated_video.mp4"
else:
print("❌ Video generation finished but output file is missing.")
raise gr.Error("Output video not found after generation.")
else:
print("❌ Subprocess failed.")
raise gr.Error("Video generation failed. Check logs above.")
# ✅ Gradio UI
with gr.Blocks() as demo:
with gr.Column():
gr.Markdown("# Wan 2.1 1.3B - Text to Video")
gr.Markdown("Generate short videos from prompts. Duplicate this space to avoid queue limits.")
prompt = gr.Textbox(label="Enter your prompt")
submit_btn = gr.Button("Generate Video")
video_res = gr.Video(label="Generated Video")
submit_btn.click(
fn=infer,
inputs=[prompt],
outputs=[video_res]
)
demo.queue().launch(show_error=True, show_api=False, ssr_mode=False)