PusaV1 / app.py
rahul7star's picture
Update app.py
d54a2c6 verified
raw
history blame
1.89 kB
import gradio as gr
import os
import tempfile
from huggingface_hub import snapshot_download
from diffsynth import ModelManager, WanVideoPusaPipeline, save_video
import spaces
# Constants
WAN_SUBFOLDER = "Wan2.1-T2V-14B"
MODEL_REPO_ID = "RaphaelLiu/PusaV1"
MODEL_ZOO_DIR = "./model_zoo"
WAN_MODEL_PATH = os.path.join(MODEL_ZOO_DIR, WAN_SUBFOLDER)
LORA_PATH = os.path.join(MODEL_ZOO_DIR, "PusaV1", "pusa_v1.pt")
# Ensure model is downloaded
def ensure_model_downloaded():
if not os.path.exists(WAN_MODEL_PATH):
print("Downloading Wan2.1-T2V-14B from HuggingFace Hub...")
snapshot_download(
repo_id=MODEL_REPO_ID,
local_dir=MODEL_ZOO_DIR,
repo_type="model",
allow_patterns=[f"{WAN_SUBFOLDER}/**"],
local_dir_use_symlinks=False,
)
print("Model downloaded.")
# Video generation logic
@spaces.GPU
def generate_video(prompt: str):
ensure_model_downloaded()
# Load model
manager = ModelManager(pretrained_model_dir=WAN_MODEL_PATH)
model = manager.load_model()
# Set up pipeline
pipeline = WanVideoPusaPipeline(model=model)
pipeline.set_lora_adapters(LORA_PATH)
# Generate video
result = pipeline(prompt)
# Save video
tmp_dir = tempfile.mkdtemp()
output_path = os.path.join(tmp_dir, "video.mp4")
save_video(result.frames, output_path, fps=8)
return output_path
# Gradio UI
with gr.Blocks() as demo:
gr.Markdown("## πŸŽ₯ Wan2.1-T2V-14B with Pusa LoRA | Text-to-Video Generator")
prompt_input = gr.Textbox(
lines=4,
label="Prompt",
placeholder="Describe your video (e.g. A coral reef full of colorful fish...)"
)
generate_btn = gr.Button("Generate Video")
video_output = gr.Video(label="Output")
generate_btn.click(fn=generate_video, inputs=prompt_input, outputs=video_output)
demo.launch()