yyy / app.py
sharul20001's picture
Update app.py
ea404ab verified
import os
import gradio as gr
import torch
from diffusers import DiffusionPipeline
# =====================================================
# Pilih model utama (Wan 2.2) - fallback ke damo-vilab
# =====================================================
WAN_MODEL = "BAAI/Wan2.2-videogen" # ⚠ ganti nama sesuai yg VALID di HF Hub
FALLBACK_MODEL = "damo-vilab/text-to-video-ms-1.7b"
# Ambil token kalau model private/gated
hf_token = os.environ.get("HF_TOKEN")
def load_model(model_id):
try:
pipe = DiffusionPipeline.from_pretrained(
model_id,
torch_dtype=torch.float16,
variant="fp16",
use_auth_token=hf_token
).to("cuda")
print(f"βœ… Loaded model: {model_id}")
return pipe
except Exception as e:
print(f"⚠ Error loading {model_id}, fallback ke {FALLBACK_MODEL}. Error: {e}")
pipe = DiffusionPipeline.from_pretrained(
FALLBACK_MODEL,
torch_dtype=torch.float16,
variant="fp16"
).to("cuda")
return pipe
# Load model saat startup
pipe = load_model(WAN_MODEL)
# =====================================================
# Fungsi generate video
# =====================================================
def generate_video(prompt, num_frames=16, fps=8, seed=42, progress=gr.Progress(track_tqdm=True)):
generator = torch.manual_seed(seed)
progress(0, desc="πŸš€ Mulai generate video...")
output = pipe(
prompt=prompt,
num_frames=num_frames,
generator=generator
)
progress(0.7, desc="πŸ“Έ Menyusun frame jadi video...")
video_frames = output.frames[0]
out_path = "output.mp4"
# Simpan video ke file .mp4
pipe.save_pretrained_video(video_frames, out_path, fps=fps)
progress(1, desc="βœ… Selesai!")
return out_path, out_path # untuk Video preview + Download file
# =====================================================
# Gradio UI
# =====================================================
with gr.Blocks() as demo:
gr.Markdown("## 🎬 WAN 2.2 Video Generator (Hugging Face Space)")
with gr.Row():
prompt_inp = gr.Textbox(
label="Prompt",
placeholder="Masukkan deskripsi video...",
value="Seekor naga robot terbang melintasi kota futuristik"
)
with gr.Row():
frame_slider = gr.Slider(8, 64, step=8, value=16, label="Jumlah Frame")
fps_slider = gr.Slider(4, 16, step=1, value=8, label="FPS (kecepatan video)")
btn = gr.Button("πŸš€ Generate Video")
with gr.Row():
video_out = gr.Video(label="Hasil")
download_link = gr.File(label="Unduh Video", type="file")
btn.click(
fn=generate_video,
inputs=[prompt_inp, frame_slider, fps_slider],
outputs=[video_out, download_link]
)
if __name__ == "__main__":
demo.queue(max_size=5, concurrency_count=1).launch()