multimodalart's picture
Update app.py
e313051 verified
raw
history blame
12.6 kB
import os
import sys
import subprocess
import torch
import datetime
import numpy as np
from PIL import Image
import imageio
import shutil
import requests
import base64
import io
import spaces
# --- Part 1: Auto-Setup (Clone Repo & Download Weights) ---
REPO_URL = "https://github.com/Tencent-Hunyuan/HunyuanVideo-1.5.git"
REPO_DIR = os.path.abspath("HunyuanVideo-1.5")
MODEL_DIR = os.path.abspath("ckpts")
# Repositories
HF_MAIN_REPO = "tencent/HunyuanVideo-1.5"
HF_GLYPH_REPO = "multimodalart/glyph-sdxl-v2-byt5-small"
HF_LLM_REPO = "Qwen/Qwen2.5-VL-7B-Instruct"
HF_VISION_REPO = "black-forest-labs/FLUX.1-Redux-dev"
# Configuration
TRANSFORMER_VERSION = "480p_i2v_distilled"
DTYPE = torch.bfloat16
ENABLE_OFFLOADING = False
def setup_environment():
print("=" * 50)
print("Checking Environment & Dependencies...")
# 1. Clone Code Repository
if not os.path.exists(REPO_DIR):
print(f"Cloning repository to {REPO_DIR}...")
subprocess.run(["git", "clone", REPO_URL, REPO_DIR], check=True)
# 2. Add Repo to Python Path
if REPO_DIR not in sys.path:
sys.path.insert(0, REPO_DIR)
# 3. Download Main Weights (Transformer, VAE, Scheduler)
os.makedirs(MODEL_DIR, exist_ok=True)
target_transformer = os.path.join(MODEL_DIR, "transformer", TRANSFORMER_VERSION)
if not os.path.exists(target_transformer):
print(f"Downloading Main Weights from {HF_MAIN_REPO}...")
try:
from huggingface_hub import snapshot_download
allow_patterns = [
f"transformer/{TRANSFORMER_VERSION}/*",
"vae/*",
"scheduler/*",
"tokenizer/*"
]
snapshot_download(
repo_id=HF_MAIN_REPO,
local_dir=MODEL_DIR,
allow_patterns=allow_patterns,
local_dir_use_symlinks=False
)
except Exception as e:
print(f"Error downloading main weights: {e}")
sys.exit(1)
# 4. Download LLM Text Encoder (Qwen)
llm_target = os.path.join(MODEL_DIR, "text_encoder", "llm")
if not os.path.exists(llm_target) or not os.listdir(llm_target):
print(f"Downloading LLM Text Encoder from {HF_LLM_REPO}...")
try:
from huggingface_hub import snapshot_download
snapshot_download(
repo_id=HF_LLM_REPO,
local_dir=llm_target,
local_dir_use_symlinks=False
)
except Exception as e:
print(f"Error downloading LLM: {e}")
# 5. Download Vision Encoder (SigLIP)
vision_target = os.path.join(MODEL_DIR, "vision_encoder", "siglip")
if not os.path.exists(vision_target) or not os.listdir(vision_target):
print(f"Downloading Vision Encoder from {HF_VISION_REPO}...")
try:
from huggingface_hub import snapshot_download
snapshot_download(
repo_id=HF_VISION_REPO,
local_dir=vision_target,
local_dir_use_symlinks=False
)
except Exception as e:
print(f"Error downloading Vision Encoder: {e}")
# 6. Download & Restructure Glyph Weights
glyph_root = os.path.join(MODEL_DIR, "text_encoder", "Glyph-SDXL-v2")
glyph_ckpt_target = os.path.join(glyph_root, "checkpoints", "byt5_model.pt")
if not os.path.exists(glyph_ckpt_target):
print(f"Downloading & Structuring Glyph Weights from {HF_GLYPH_REPO}...")
try:
from huggingface_hub import snapshot_download
glyph_temp = os.path.join(MODEL_DIR, "glyph_temp")
snapshot_download(repo_id=HF_GLYPH_REPO, local_dir=glyph_temp, local_dir_use_symlinks=False)
os.makedirs(os.path.join(glyph_root, "assets"), exist_ok=True)
os.makedirs(os.path.join(glyph_root, "checkpoints"), exist_ok=True)
# Move Assets
src_assets = os.path.join(glyph_temp, "assets")
if os.path.exists(src_assets):
for f in os.listdir(src_assets):
shutil.copy(os.path.join(src_assets, f), os.path.join(glyph_root, "assets", f))
# Move Model
src_bin = os.path.join(glyph_temp, "pytorch_model.bin")
if os.path.exists(src_bin):
shutil.move(src_bin, glyph_ckpt_target)
else:
src_safe = os.path.join(glyph_temp, "model.safetensors")
if os.path.exists(src_safe):
shutil.move(src_safe, glyph_ckpt_target)
shutil.rmtree(glyph_temp, ignore_errors=True)
except Exception as e:
print(f"Error setting up Glyph weights: {e}")
print("Environment Ready.")
print("=" * 50)
setup_environment()
# --- Part 2: Imports & Patching ---
try:
import hyvideo.commons
import hyvideo.pipelines.hunyuan_video_pipeline
from hyvideo.pipelines.hunyuan_video_pipeline import HunyuanVideo_1_5_Pipeline
from hyvideo.commons.infer_state import initialize_infer_state
# Import the specific I2V System Prompt from the repo
from hyvideo.utils.rewrite.i2v_prompt import i2v_rewrite_system_prompt
except ImportError as e:
print(f"CRITICAL ERROR: {e}")
sys.exit(1)
import gradio as gr
def dummy_get_gpu_memory(device=None):
return 80 * 1024 * 1024 * 1024
print("🛠️ Applying ZeroGPU Monkey Patch...")
hyvideo.commons.get_gpu_memory = dummy_get_gpu_memory
hyvideo.pipelines.hunyuan_video_pipeline.get_gpu_memory = dummy_get_gpu_memory
# --- Part 3: Prompt Rewrite Logic (External API) ---
def encode_image_to_base64(pil_image):
buffered = io.BytesIO()
pil_image.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
return f"data:image/jpeg;base64,{img_str}"
def rewrite_prompt_external(user_prompt, pil_image):
"""Calls HF Router API to rewrite prompt using Qwen2.5-VL"""
api_key = os.environ.get("HF_TOKEN")
if not api_key:
print("⚠️ No HF_TOKEN found. Skipping rewrite.")
return user_prompt
print("🧠 Rewriting prompt via API...")
API_URL = "https://router.huggingface.co/v1/chat/completions"
headers = {"Authorization": f"Bearer {api_key}"}
# Combine the official Hunyuan System Prompt with the User Input
# The system prompt string contains a {} placeholder for the user input
full_instruction = i2v_rewrite_system_prompt.format(user_prompt)
base64_img = encode_image_to_base64(pil_image)
payload = {
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": full_instruction
},
{
"type": "image_url",
"image_url": {
"url": base64_img
}
}
]
}
],
"model": "Qwen/Qwen2.5-VL-7B-Instruct",
"max_tokens": 512,
"temperature": 0.7
}
try:
response = requests.post(API_URL, headers=headers, json=payload, timeout=30)
response.raise_for_status()
data = response.json()
rewritten = data["choices"][0]["message"]["content"]
print(f"✅ Rewritten: {rewritten[:50]}...")
return rewritten
except Exception as e:
print(f"❌ Rewrite failed: {e}")
return user_prompt
# --- Part 4: Model Initialization (CPU) ---
class ArgsNamespace:
def __init__(self):
self.use_sageattn = False
self.sage_blocks_range = "0-53"
self.enable_torch_compile = False
initialize_infer_state(ArgsNamespace())
print(f"⏳ Initializing Pipeline ({TRANSFORMER_VERSION})...")
try:
pipe = HunyuanVideo_1_5_Pipeline.create_pipeline(
pretrained_model_name_or_path=MODEL_DIR,
transformer_version=TRANSFORMER_VERSION,
enable_offloading=ENABLE_OFFLOADING,
enable_group_offloading=ENABLE_OFFLOADING,
transformer_dtype=DTYPE,
device=torch.device('cpu')
)
pipe.to('cuda')
print("✅ Model loaded into CPU RAM.")
except Exception as e:
print(f"❌ Failed to load model: {e}")
sys.exit(1)
def save_video_tensor(video_tensor, path, fps=24):
if isinstance(video_tensor, list): video_tensor = video_tensor[0]
if video_tensor.ndim == 5: video_tensor = video_tensor[0]
vid = (video_tensor * 255).clamp(0, 255).to(torch.uint8)
vid = vid.permute(1, 2, 3, 0).cpu().numpy()
imageio.mimwrite(path, vid, fps=fps)
# --- Part 5: Inference ---
@spaces.GPU(duration=120)
def generate(input_image, prompt, length, steps, shift, seed, guidance, do_rewrite, progress=gr.Progress(track_tqdm=True)):
if pipe is None: raise gr.Error("Pipeline not initialized!")
if input_image is None: raise gr.Error("Reference image required.")
# Process Input Image
if isinstance(input_image, np.ndarray):
pil_image = Image.fromarray(input_image).convert("RGB")
else:
pil_image = input_image.convert("RGB")
# 1. Prompt Rewrite (if enabled)
actual_prompt = prompt
if do_rewrite:
actual_prompt = rewrite_prompt_external(prompt, pil_image)
# 2. Setup Generator
if seed == -1: seed = torch.randint(0, 1000000, (1,)).item()
generator = torch.Generator(device="cpu").manual_seed(int(seed))
print(f"🚀 GPU Inference: {actual_prompt[:30]}... | Seed: {seed}")
try:
pipe.execution_device = torch.device("cuda")
output = pipe(
prompt=actual_prompt,
height=480, width=854, aspect_ratio="16:9",
video_length=int(length),
num_inference_steps=int(steps),
guidance_scale=float(guidance),
flow_shift=float(shift),
reference_image=pil_image,
seed=int(seed),
generator=generator,
output_type="pt",
enable_sr=False,
return_dict=True
)
except Exception as e:
print(f"Error: {e}")
raise gr.Error(f"Inference Failed: {e}")
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
os.makedirs("outputs", exist_ok=True)
output_path = f"outputs/gen_{timestamp}.mp4"
save_video_tensor(output.videos, output_path)
return output_path, actual_prompt
# --- Part 6: UI ---
css = '''#col-container { max-width: 900px; margin: 0 auto; }
.dark .progress-text{color: white !important}'''
def create_ui():
with gr.Blocks(title="HunyuanVideo 1.5 I2V", css=css) as demo:
gr.Markdown(f"#🎬 HunyuanVideo 1.5 I2V 480p distilled demo")
gr.Markdown(f"This is a demo for HunyuanVideo 1.5 I2v {TRANSFORMER_VERSION}, released together with a collection of 10 other checkpoints (text-to-video, 720p, upscalers). Check out the [HunyuanVideo-1.5 model page](https://huggingface.co/tencent/HunyuanVideo-1.5) for more")
with gr.Row():
with gr.Column():
img = gr.Image(label="Reference", type="pil")
prompt = gr.Textbox(label="Prompt", placeholder="Describe motion...", lines=2)
rewrite_chk = gr.Checkbox(label="Enable Prompt Rewrite (Strongly Recommended)", value=True)
with gr.Accordion("Advanced Options", open=False):
with gr.Row():
steps = gr.Slider(2, 50, value=6, step=1, label="Steps")
guidance = gr.Slider(1.0, 5.0, value=1.0, step=0.1, label="Guidance")
with gr.Row():
shift = gr.Slider(1.0, 20.0, value=5.0, step=0.5, label="Shift")
length = gr.Slider(1, 129, value=61, step=4, label="Length")
seed = gr.Number(value=-1, label="Seed", precision=0, info="-1 is a random seed")
btn = gr.Button("Generate", variant="primary")
with gr.Column():
out = gr.Video(label="Result", autoplay=True)
final_prompt_box = gr.Textbox(label="Actual Prompt Used", interactive=False)
btn.click(
generate,
inputs=[img, prompt, length, steps, shift, seed, guidance, rewrite_chk],
outputs=[out, final_prompt_box]
)
return demo
if __name__ == "__main__":
ui = create_ui()
ui.queue().launch(server_name="0.0.0.0", share=True)