multimodalart's picture
Update app.py
f200a98 verified
raw
history blame
7.1 kB
import os
import sys
import subprocess
import torch
import datetime
import numpy as np
from PIL import Image
import imageio
import spaces
# --- Part 1: Auto-Setup (Clone Repo & Download Weights) ---
REPO_URL = "https://github.com/Tencent-Hunyuan/HunyuanVideo-1.5.git"
REPO_DIR = os.path.abspath("HunyuanVideo-1.5")
# Use Absolute Path to ensure the loader finds the folder
MODEL_DIR = os.path.abspath("ckpts")
HF_REPO_ID = "tencent/HunyuanVideo"
# Configuration
TRANSFORMER_VERSION = "480p_i2v_distilled"
DTYPE = torch.bfloat16
ENABLE_OFFLOADING = True
def setup_environment():
"""Clones the repo and downloads weights if they don't exist."""
print("=" * 50)
print("Checking Environment & Dependencies...")
# 1. Clone Repository
if not os.path.exists(REPO_DIR):
print(f"Cloning repository to {REPO_DIR}...")
subprocess.run(["git", "clone", REPO_URL, REPO_DIR], check=True)
else:
print(f"Repository exists at {REPO_DIR}")
# 2. Add Repo to Python Path
if REPO_DIR not in sys.path:
sys.path.insert(0, REPO_DIR)
# 3. Download Weights
# Check if key folders exist to verify download
transformer_path = os.path.join(MODEL_DIR, "transformer", TRANSFORMER_VERSION)
if not os.path.exists(transformer_path):
print(f"Downloading weights to {MODEL_DIR}...")
try:
from huggingface_hub import snapshot_download
allow_patterns = [
f"transformer/{TRANSFORMER_VERSION}/*",
"vae/*",
"text_encoder/*",
"vision_encoder/*",
"scheduler/*",
"tokenizer/*"
]
snapshot_download(
repo_id=HF_REPO_ID,
local_dir=MODEL_DIR,
allow_patterns=allow_patterns
)
print("Download complete.")
except Exception as e:
print(f"Error downloading weights: {e}")
sys.exit(1)
else:
print(f"Weights found in {MODEL_DIR}")
print("Environment Ready.")
print("=" * 50)
# Run setup immediately
setup_environment()
# --- Part 2: Imports from Cloned Repo ---
# Set Env Vars for HyVideo
if 'PYTORCH_CUDA_ALLOC_CONF' not in os.environ:
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
# Even for single GPU, HyVideo code expects these env vars to be set
os.environ['RANK'] = '0'
os.environ['WORLD_SIZE'] = '1'
try:
from hyvideo.pipelines.hunyuan_video_pipeline import HunyuanVideo_1_5_Pipeline
from hyvideo.commons.infer_state import initialize_infer_state
except ImportError as e:
print(f"CRITICAL ERROR: Could not import hyvideo modules. {e}")
sys.exit(1)
import gradio as gr
# --- Part 3: Model Initialization (Pre-Load) ---
# Mock args for inference configuration (required by internal logic)
class ArgsNamespace:
def __init__(self):
self.use_sageattn = False
self.sage_blocks_range = "0-53"
self.enable_torch_compile = False
# Initialize internal state mock
initialize_infer_state(ArgsNamespace())
# Global Pipeline Variable
pipe = None
def pre_load_model():
"""Loads the model into memory/GPU before UI launch."""
global pipe
# Double check path exists
if not os.path.isdir(MODEL_DIR):
print(f"❌ Error: Model directory not found at {MODEL_DIR}")
sys.exit(1)
print(f"⏳ Initializing Pipeline ({TRANSFORMER_VERSION}) from {MODEL_DIR}...")
try:
pipe = HunyuanVideo_1_5_Pipeline.create_pipeline(
pretrained_model_name_or_path=MODEL_DIR,
transformer_version=TRANSFORMER_VERSION,
enable_offloading=ENABLE_OFFLOADING,
enable_group_offloading=ENABLE_OFFLOADING,
transformer_dtype=DTYPE,
)
print("✅ Model loaded successfully!")
except Exception as e:
print(f"❌ Failed to load model: {e}")
import traceback
traceback.print_exc()
sys.exit(1)
def save_video_tensor(video_tensor, path, fps=24):
if isinstance(video_tensor, list): video_tensor = video_tensor[0]
if video_tensor.ndim == 5: video_tensor = video_tensor[0]
vid = (video_tensor * 255).clamp(0, 255).to(torch.uint8)
vid = vid.permute(1, 2, 3, 0).cpu().numpy()
imageio.mimwrite(path, vid, fps=fps)
@spaces.GPU(duration=120)
def generate(input_image, prompt, length, steps, shift, seed, guidance):
if pipe is None:
raise gr.Error("Pipeline not initialized!")
if input_image is None:
raise gr.Error("Reference image required.")
if isinstance(input_image, np.ndarray):
input_image = Image.fromarray(input_image).convert("RGB")
if seed == -1: seed = torch.randint(0, 1000000, (1,)).item()
generator = torch.Generator(device="cpu").manual_seed(int(seed))
print(f"Generating: {prompt} | Seed: {seed}")
try:
output = pipe(
prompt=prompt,
height=480, width=854, aspect_ratio="16:9",
video_length=int(length),
num_inference_steps=int(steps),
guidance_scale=float(guidance),
flow_shift=float(shift),
reference_image=input_image,
seed=int(seed),
generator=generator,
output_type="pt",
enable_sr=False,
return_dict=True
)
except Exception as e:
raise gr.Error(f"Inference Failed: {e}")
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
os.makedirs("outputs", exist_ok=True)
output_path = f"outputs/gen_{timestamp}.mp4"
save_video_tensor(output.videos, output_path)
return output_path
# --- Part 4: UI Definition & Launch ---
def create_ui():
with gr.Blocks(title="HunyuanVideo 1.5 I2V") as demo:
gr.Markdown(f"### 🎬 HunyuanVideo 1.5 I2V ({TRANSFORMER_VERSION})")
with gr.Row():
with gr.Column():
img = gr.Image(label="Reference", type="pil", height=250)
prompt = gr.Textbox(label="Prompt", placeholder="Describe motion...", lines=2)
with gr.Row():
steps = gr.Slider(2, 20, value=6, step=1, label="Steps")
guidance = gr.Slider(1.0, 5.0, value=1.0, step=0.1, label="Guidance")
with gr.Row():
shift = gr.Slider(1.0, 20.0, value=5.0, step=0.5, label="Shift")
length = gr.Slider(1, 129, value=61, step=4, label="Length")
seed = gr.Number(value=-1, label="Seed", precision=0)
btn = gr.Button("Generate", variant="primary")
with gr.Column():
out = gr.Video(label="Result", autoplay=True)
btn.click(generate, inputs=[img, prompt, length, steps, shift, seed, guidance], outputs=[out])
return demo
if __name__ == "__main__":
# 1. Execute the pre-load BEFORE the UI launches
pre_load_model()
# 2. Launch UI
ui = create_ui()
ui.queue().launch(server_name="0.0.0.0", share=True)