|
|
import os |
|
|
import sys |
|
|
import subprocess |
|
|
import torch |
|
|
import datetime |
|
|
import numpy as np |
|
|
from PIL import Image |
|
|
import imageio |
|
|
import shutil |
|
|
import requests |
|
|
import base64 |
|
|
import io |
|
|
import spaces |
|
|
|
|
|
|
|
|
|
|
|
REPO_URL = "https://github.com/Tencent-Hunyuan/HunyuanVideo-1.5.git" |
|
|
REPO_DIR = os.path.abspath("HunyuanVideo-1.5") |
|
|
MODEL_DIR = os.path.abspath("ckpts") |
|
|
|
|
|
|
|
|
HF_MAIN_REPO = "tencent/HunyuanVideo-1.5" |
|
|
HF_GLYPH_REPO = "multimodalart/glyph-sdxl-v2-byt5-small" |
|
|
HF_LLM_REPO = "Qwen/Qwen2.5-VL-7B-Instruct" |
|
|
HF_VISION_REPO = "black-forest-labs/FLUX.1-Redux-dev" |
|
|
|
|
|
|
|
|
TRANSFORMER_VERSION = "480p_i2v_distilled" |
|
|
DTYPE = torch.bfloat16 |
|
|
ENABLE_OFFLOADING = False |
|
|
|
|
|
def setup_environment(): |
|
|
print("=" * 50) |
|
|
print("Checking Environment & Dependencies...") |
|
|
|
|
|
|
|
|
if not os.path.exists(REPO_DIR): |
|
|
print(f"Cloning repository to {REPO_DIR}...") |
|
|
subprocess.run(["git", "clone", REPO_URL, REPO_DIR], check=True) |
|
|
|
|
|
|
|
|
if REPO_DIR not in sys.path: |
|
|
sys.path.insert(0, REPO_DIR) |
|
|
|
|
|
|
|
|
os.makedirs(MODEL_DIR, exist_ok=True) |
|
|
target_transformer = os.path.join(MODEL_DIR, "transformer", TRANSFORMER_VERSION) |
|
|
|
|
|
if not os.path.exists(target_transformer): |
|
|
print(f"Downloading Main Weights from {HF_MAIN_REPO}...") |
|
|
try: |
|
|
from huggingface_hub import snapshot_download |
|
|
allow_patterns = [ |
|
|
f"transformer/{TRANSFORMER_VERSION}/*", |
|
|
"vae/*", |
|
|
"scheduler/*", |
|
|
"tokenizer/*" |
|
|
] |
|
|
snapshot_download( |
|
|
repo_id=HF_MAIN_REPO, |
|
|
local_dir=MODEL_DIR, |
|
|
allow_patterns=allow_patterns, |
|
|
local_dir_use_symlinks=False |
|
|
) |
|
|
except Exception as e: |
|
|
print(f"Error downloading main weights: {e}") |
|
|
sys.exit(1) |
|
|
|
|
|
|
|
|
llm_target = os.path.join(MODEL_DIR, "text_encoder", "llm") |
|
|
if not os.path.exists(llm_target) or not os.listdir(llm_target): |
|
|
print(f"Downloading LLM Text Encoder from {HF_LLM_REPO}...") |
|
|
try: |
|
|
from huggingface_hub import snapshot_download |
|
|
snapshot_download( |
|
|
repo_id=HF_LLM_REPO, |
|
|
local_dir=llm_target, |
|
|
local_dir_use_symlinks=False |
|
|
) |
|
|
except Exception as e: |
|
|
print(f"Error downloading LLM: {e}") |
|
|
|
|
|
|
|
|
vision_target = os.path.join(MODEL_DIR, "vision_encoder", "siglip") |
|
|
if not os.path.exists(vision_target) or not os.listdir(vision_target): |
|
|
print(f"Downloading Vision Encoder from {HF_VISION_REPO}...") |
|
|
try: |
|
|
from huggingface_hub import snapshot_download |
|
|
snapshot_download( |
|
|
repo_id=HF_VISION_REPO, |
|
|
local_dir=vision_target, |
|
|
local_dir_use_symlinks=False |
|
|
) |
|
|
except Exception as e: |
|
|
print(f"Error downloading Vision Encoder: {e}") |
|
|
|
|
|
|
|
|
glyph_root = os.path.join(MODEL_DIR, "text_encoder", "Glyph-SDXL-v2") |
|
|
glyph_ckpt_target = os.path.join(glyph_root, "checkpoints", "byt5_model.pt") |
|
|
|
|
|
if not os.path.exists(glyph_ckpt_target): |
|
|
print(f"Downloading & Structuring Glyph Weights from {HF_GLYPH_REPO}...") |
|
|
try: |
|
|
from huggingface_hub import snapshot_download |
|
|
glyph_temp = os.path.join(MODEL_DIR, "glyph_temp") |
|
|
snapshot_download(repo_id=HF_GLYPH_REPO, local_dir=glyph_temp, local_dir_use_symlinks=False) |
|
|
|
|
|
os.makedirs(os.path.join(glyph_root, "assets"), exist_ok=True) |
|
|
os.makedirs(os.path.join(glyph_root, "checkpoints"), exist_ok=True) |
|
|
|
|
|
|
|
|
src_assets = os.path.join(glyph_temp, "assets") |
|
|
if os.path.exists(src_assets): |
|
|
for f in os.listdir(src_assets): |
|
|
shutil.copy(os.path.join(src_assets, f), os.path.join(glyph_root, "assets", f)) |
|
|
|
|
|
|
|
|
src_bin = os.path.join(glyph_temp, "pytorch_model.bin") |
|
|
if os.path.exists(src_bin): |
|
|
shutil.move(src_bin, glyph_ckpt_target) |
|
|
else: |
|
|
src_safe = os.path.join(glyph_temp, "model.safetensors") |
|
|
if os.path.exists(src_safe): |
|
|
shutil.move(src_safe, glyph_ckpt_target) |
|
|
|
|
|
shutil.rmtree(glyph_temp, ignore_errors=True) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error setting up Glyph weights: {e}") |
|
|
|
|
|
print("Environment Ready.") |
|
|
print("=" * 50) |
|
|
|
|
|
setup_environment() |
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
import hyvideo.commons |
|
|
import hyvideo.pipelines.hunyuan_video_pipeline |
|
|
from hyvideo.pipelines.hunyuan_video_pipeline import HunyuanVideo_1_5_Pipeline |
|
|
from hyvideo.commons.infer_state import initialize_infer_state |
|
|
|
|
|
from hyvideo.utils.rewrite.i2v_prompt import i2v_rewrite_system_prompt |
|
|
except ImportError as e: |
|
|
print(f"CRITICAL ERROR: {e}") |
|
|
sys.exit(1) |
|
|
|
|
|
import gradio as gr |
|
|
|
|
|
def dummy_get_gpu_memory(device=None): |
|
|
return 80 * 1024 * 1024 * 1024 |
|
|
|
|
|
print("🛠️ Applying ZeroGPU Monkey Patch...") |
|
|
hyvideo.commons.get_gpu_memory = dummy_get_gpu_memory |
|
|
hyvideo.pipelines.hunyuan_video_pipeline.get_gpu_memory = dummy_get_gpu_memory |
|
|
|
|
|
|
|
|
|
|
|
def encode_image_to_base64(pil_image): |
|
|
buffered = io.BytesIO() |
|
|
pil_image.save(buffered, format="JPEG") |
|
|
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") |
|
|
return f"data:image/jpeg;base64,{img_str}" |
|
|
|
|
|
def rewrite_prompt_external(user_prompt, pil_image): |
|
|
"""Calls HF Router API to rewrite prompt using Qwen2.5-VL""" |
|
|
|
|
|
api_key = os.environ.get("HF_TOKEN") |
|
|
if not api_key: |
|
|
print("⚠️ No HF_TOKEN found. Skipping rewrite.") |
|
|
return user_prompt |
|
|
|
|
|
print("🧠 Rewriting prompt via API...") |
|
|
|
|
|
API_URL = "https://router.huggingface.co/v1/chat/completions" |
|
|
headers = {"Authorization": f"Bearer {api_key}"} |
|
|
|
|
|
|
|
|
|
|
|
full_instruction = i2v_rewrite_system_prompt.format(user_prompt) |
|
|
|
|
|
base64_img = encode_image_to_base64(pil_image) |
|
|
|
|
|
payload = { |
|
|
"messages": [ |
|
|
{ |
|
|
"role": "user", |
|
|
"content": [ |
|
|
{ |
|
|
"type": "text", |
|
|
"text": full_instruction |
|
|
}, |
|
|
{ |
|
|
"type": "image_url", |
|
|
"image_url": { |
|
|
"url": base64_img |
|
|
} |
|
|
} |
|
|
] |
|
|
} |
|
|
], |
|
|
"model": "Qwen/Qwen2.5-VL-7B-Instruct", |
|
|
"max_tokens": 512, |
|
|
"temperature": 0.7 |
|
|
} |
|
|
|
|
|
try: |
|
|
response = requests.post(API_URL, headers=headers, json=payload, timeout=30) |
|
|
response.raise_for_status() |
|
|
data = response.json() |
|
|
rewritten = data["choices"][0]["message"]["content"] |
|
|
print(f"✅ Rewritten: {rewritten[:50]}...") |
|
|
return rewritten |
|
|
except Exception as e: |
|
|
print(f"❌ Rewrite failed: {e}") |
|
|
return user_prompt |
|
|
|
|
|
|
|
|
|
|
|
class ArgsNamespace: |
|
|
def __init__(self): |
|
|
self.use_sageattn = False |
|
|
self.sage_blocks_range = "0-53" |
|
|
self.enable_torch_compile = False |
|
|
|
|
|
initialize_infer_state(ArgsNamespace()) |
|
|
|
|
|
print(f"⏳ Initializing Pipeline ({TRANSFORMER_VERSION})...") |
|
|
try: |
|
|
pipe = HunyuanVideo_1_5_Pipeline.create_pipeline( |
|
|
pretrained_model_name_or_path=MODEL_DIR, |
|
|
transformer_version=TRANSFORMER_VERSION, |
|
|
enable_offloading=ENABLE_OFFLOADING, |
|
|
enable_group_offloading=ENABLE_OFFLOADING, |
|
|
transformer_dtype=DTYPE, |
|
|
device=torch.device('cpu') |
|
|
) |
|
|
pipe.to('cuda') |
|
|
print("✅ Model loaded into CPU RAM.") |
|
|
except Exception as e: |
|
|
print(f"❌ Failed to load model: {e}") |
|
|
sys.exit(1) |
|
|
|
|
|
def save_video_tensor(video_tensor, path, fps=24): |
|
|
if isinstance(video_tensor, list): video_tensor = video_tensor[0] |
|
|
if video_tensor.ndim == 5: video_tensor = video_tensor[0] |
|
|
vid = (video_tensor * 255).clamp(0, 255).to(torch.uint8) |
|
|
vid = vid.permute(1, 2, 3, 0).cpu().numpy() |
|
|
imageio.mimwrite(path, vid, fps=fps) |
|
|
|
|
|
|
|
|
|
|
|
@spaces.GPU(duration=120) |
|
|
def generate(input_image, prompt, length, steps, shift, seed, guidance, do_rewrite, progress=gr.Progress(track_tqdm=True)): |
|
|
if pipe is None: raise gr.Error("Pipeline not initialized!") |
|
|
if input_image is None: raise gr.Error("Reference image required.") |
|
|
|
|
|
|
|
|
if isinstance(input_image, np.ndarray): |
|
|
pil_image = Image.fromarray(input_image).convert("RGB") |
|
|
else: |
|
|
pil_image = input_image.convert("RGB") |
|
|
|
|
|
|
|
|
actual_prompt = prompt |
|
|
if do_rewrite: |
|
|
actual_prompt = rewrite_prompt_external(prompt, pil_image) |
|
|
|
|
|
|
|
|
if seed == -1: seed = torch.randint(0, 1000000, (1,)).item() |
|
|
generator = torch.Generator(device="cpu").manual_seed(int(seed)) |
|
|
|
|
|
print(f"🚀 GPU Inference: {actual_prompt[:30]}... | Seed: {seed}") |
|
|
|
|
|
try: |
|
|
pipe.execution_device = torch.device("cuda") |
|
|
|
|
|
output = pipe( |
|
|
prompt=actual_prompt, |
|
|
height=480, width=854, aspect_ratio="16:9", |
|
|
video_length=int(length), |
|
|
num_inference_steps=int(steps), |
|
|
guidance_scale=float(guidance), |
|
|
flow_shift=float(shift), |
|
|
reference_image=pil_image, |
|
|
seed=int(seed), |
|
|
generator=generator, |
|
|
output_type="pt", |
|
|
enable_sr=False, |
|
|
return_dict=True |
|
|
) |
|
|
except Exception as e: |
|
|
print(f"Error: {e}") |
|
|
raise gr.Error(f"Inference Failed: {e}") |
|
|
|
|
|
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") |
|
|
os.makedirs("outputs", exist_ok=True) |
|
|
output_path = f"outputs/gen_{timestamp}.mp4" |
|
|
save_video_tensor(output.videos, output_path) |
|
|
|
|
|
return output_path, actual_prompt |
|
|
|
|
|
|
|
|
css = '''#col-container { max-width: 900px; margin: 0 auto; } |
|
|
.dark .progress-text{color: white !important}''' |
|
|
def create_ui(): |
|
|
with gr.Blocks(title="HunyuanVideo 1.5 I2V", css=css) as demo: |
|
|
gr.Markdown(f"#🎬 HunyuanVideo 1.5 I2V 480p distilled demo") |
|
|
gr.Markdown(f"This is a demo for HunyuanVideo 1.5 I2v {TRANSFORMER_VERSION}, released together with a collection of 10 other checkpoints (text-to-video, 720p, upscalers). Check out the [HunyuanVideo-1.5 model page](https://huggingface.co/tencent/HunyuanVideo-1.5) for more") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
img = gr.Image(label="Reference", type="pil") |
|
|
prompt = gr.Textbox(label="Prompt", placeholder="Describe motion...", lines=2) |
|
|
rewrite_chk = gr.Checkbox(label="Enable Prompt Rewrite (Strongly Recommended)", value=True) |
|
|
|
|
|
with gr.Accordion("Advanced Options", open=False): |
|
|
with gr.Row(): |
|
|
steps = gr.Slider(2, 50, value=6, step=1, label="Steps") |
|
|
guidance = gr.Slider(1.0, 5.0, value=1.0, step=0.1, label="Guidance") |
|
|
with gr.Row(): |
|
|
shift = gr.Slider(1.0, 20.0, value=5.0, step=0.5, label="Shift") |
|
|
length = gr.Slider(1, 129, value=61, step=4, label="Length") |
|
|
seed = gr.Number(value=-1, label="Seed", precision=0, info="-1 is a random seed") |
|
|
btn = gr.Button("Generate", variant="primary") |
|
|
|
|
|
with gr.Column(): |
|
|
out = gr.Video(label="Result", autoplay=True) |
|
|
final_prompt_box = gr.Textbox(label="Actual Prompt Used", interactive=False) |
|
|
|
|
|
btn.click( |
|
|
generate, |
|
|
inputs=[img, prompt, length, steps, shift, seed, guidance, rewrite_chk], |
|
|
outputs=[out, final_prompt_box] |
|
|
) |
|
|
return demo |
|
|
|
|
|
if __name__ == "__main__": |
|
|
ui = create_ui() |
|
|
ui.queue().launch(server_name="0.0.0.0", share=True) |