Spaces:
Running
on
Zero
Running
on
Zero
File size: 13,036 Bytes
fb34514 b01e03f fb34514 bfecffd 939c549 8f13c28 007bd3f bfecffd f200a98 939c549 6e7e289 e313051 bfecffd 939c549 bfecffd 6e7e289 bfecffd 6e7e289 f200a98 bfecffd 6e7e289 f200a98 bfecffd 6e7e289 939c549 bfecffd 6e7e289 bfecffd 6e7e289 bfecffd 939c549 bfecffd 939c549 6e7e289 3483f07 6e7e289 3483f07 6e7e289 939c549 6e7e289 939c549 3483f07 939c549 6e7e289 939c549 6e7e289 939c549 6e7e289 939c549 6e7e289 939c549 bfecffd 8f13c28 bfecffd ecc3183 939c549 8f13c28 bfecffd ecc3183 bfecffd 340434e 8f13c28 340434e 939c549 340434e 8f13c28 bfecffd 134049c bfecffd 134049c bfecffd d8659cc 8f13c28 d8659cc bfecffd 8f13c28 939c549 bfecffd 8f13c28 bfecffd 8f13c28 bfecffd 8f13c28 bfecffd 8f13c28 bfecffd 8f13c28 939c549 bfecffd e855781 939c549 bfecffd 8f13c28 bfecffd 8f13c28 bfecffd 8f13c28 bfecffd 8f13c28 bfecffd 8f13c28 5210f53 e313051 bfecffd 91c178f 5210f53 e313051 bfecffd e313051 bfecffd e313051 8f13c28 e313051 993262e bfecffd 8f13c28 bfecffd 8f13c28 bfecffd 91c178f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 |
import subprocess
subprocess.run('pip install flash-attn==2.7.4.post1 --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
import os
import sys
import torch
import datetime
import numpy as np
from PIL import Image
import imageio
import shutil
import requests
import base64
import io
import spaces
# --- Part 1: Auto-Setup (Clone Repo & Download Weights) ---
REPO_URL = "https://github.com/Tencent-Hunyuan/HunyuanVideo-1.5.git"
REPO_DIR = os.path.abspath("HunyuanVideo-1.5")
MODEL_DIR = os.path.abspath("ckpts")
# Repositories
HF_MAIN_REPO = "tencent/HunyuanVideo-1.5"
HF_GLYPH_REPO = "multimodalart/glyph-sdxl-v2-byt5-small"
HF_LLM_REPO = "Qwen/Qwen2.5-VL-7B-Instruct"
HF_VISION_REPO = "black-forest-labs/FLUX.1-Redux-dev"
# Configuration
TRANSFORMER_VERSION = "480p_i2v_distilled"
DTYPE = torch.bfloat16
ENABLE_OFFLOADING = False
def setup_environment():
print("=" * 50)
print("Checking Environment & Dependencies...")
# 1. Clone Code Repository
if not os.path.exists(REPO_DIR):
print(f"Cloning repository to {REPO_DIR}...")
subprocess.run(["git", "clone", REPO_URL, REPO_DIR], check=True)
# 2. Add Repo to Python Path
if REPO_DIR not in sys.path:
sys.path.insert(0, REPO_DIR)
# 3. Download Main Weights (Transformer, VAE, Scheduler)
os.makedirs(MODEL_DIR, exist_ok=True)
target_transformer = os.path.join(MODEL_DIR, "transformer", TRANSFORMER_VERSION)
if not os.path.exists(target_transformer):
print(f"Downloading Main Weights from {HF_MAIN_REPO}...")
try:
from huggingface_hub import snapshot_download
allow_patterns = [
f"transformer/{TRANSFORMER_VERSION}/*",
"vae/*",
"scheduler/*",
"tokenizer/*"
]
snapshot_download(
repo_id=HF_MAIN_REPO,
local_dir=MODEL_DIR,
allow_patterns=allow_patterns,
local_dir_use_symlinks=False
)
except Exception as e:
print(f"Error downloading main weights: {e}")
sys.exit(1)
# 4. Download LLM Text Encoder (Qwen)
llm_target = os.path.join(MODEL_DIR, "text_encoder", "llm")
if not os.path.exists(llm_target) or not os.listdir(llm_target):
print(f"Downloading LLM Text Encoder from {HF_LLM_REPO}...")
try:
from huggingface_hub import snapshot_download
snapshot_download(
repo_id=HF_LLM_REPO,
local_dir=llm_target,
local_dir_use_symlinks=False
)
except Exception as e:
print(f"Error downloading LLM: {e}")
# 5. Download Vision Encoder (SigLIP)
vision_target = os.path.join(MODEL_DIR, "vision_encoder", "siglip")
if not os.path.exists(vision_target) or not os.listdir(vision_target):
print(f"Downloading Vision Encoder from {HF_VISION_REPO}...")
try:
from huggingface_hub import snapshot_download
snapshot_download(
repo_id=HF_VISION_REPO,
local_dir=vision_target,
local_dir_use_symlinks=False
)
except Exception as e:
print(f"Error downloading Vision Encoder: {e}")
# 6. Download & Restructure Glyph Weights
glyph_root = os.path.join(MODEL_DIR, "text_encoder", "Glyph-SDXL-v2")
glyph_ckpt_target = os.path.join(glyph_root, "checkpoints", "byt5_model.pt")
if not os.path.exists(glyph_ckpt_target):
print(f"Downloading & Structuring Glyph Weights from {HF_GLYPH_REPO}...")
try:
from huggingface_hub import snapshot_download
glyph_temp = os.path.join(MODEL_DIR, "glyph_temp")
snapshot_download(repo_id=HF_GLYPH_REPO, local_dir=glyph_temp, local_dir_use_symlinks=False)
os.makedirs(os.path.join(glyph_root, "assets"), exist_ok=True)
os.makedirs(os.path.join(glyph_root, "checkpoints"), exist_ok=True)
# Move Assets
src_assets = os.path.join(glyph_temp, "assets")
if os.path.exists(src_assets):
for f in os.listdir(src_assets):
shutil.copy(os.path.join(src_assets, f), os.path.join(glyph_root, "assets", f))
# Move Model
src_bin = os.path.join(glyph_temp, "pytorch_model.bin")
if os.path.exists(src_bin):
shutil.move(src_bin, glyph_ckpt_target)
else:
src_safe = os.path.join(glyph_temp, "model.safetensors")
if os.path.exists(src_safe):
shutil.move(src_safe, glyph_ckpt_target)
shutil.rmtree(glyph_temp, ignore_errors=True)
except Exception as e:
print(f"Error setting up Glyph weights: {e}")
print("Environment Ready.")
print("=" * 50)
setup_environment()
# --- Part 2: Imports & Patching ---
try:
import hyvideo.commons
import hyvideo.pipelines.hunyuan_video_pipeline
from hyvideo.pipelines.hunyuan_video_pipeline import HunyuanVideo_1_5_Pipeline
from hyvideo.commons.infer_state import initialize_infer_state
# Import the specific I2V System Prompt from the repo
from hyvideo.utils.rewrite.i2v_prompt import i2v_rewrite_system_prompt
except ImportError as e:
print(f"CRITICAL ERROR: {e}")
sys.exit(1)
import gradio as gr
def dummy_get_gpu_memory(device=None):
return 80 * 1024 * 1024 * 1024
print("🛠️ Applying ZeroGPU Monkey Patch...")
hyvideo.commons.get_gpu_memory = dummy_get_gpu_memory
hyvideo.pipelines.hunyuan_video_pipeline.get_gpu_memory = dummy_get_gpu_memory
# --- Part 3: Prompt Rewrite Logic (External API) ---
def encode_image_to_base64(pil_image):
buffered = io.BytesIO()
pil_image.save(buffered, format="JPEG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
return f"data:image/jpeg;base64,{img_str}"
def rewrite_prompt_external(user_prompt, pil_image):
"""Calls HF Router API to rewrite prompt using Qwen2.5-VL"""
api_key = os.environ.get("HF_TOKEN")
if not api_key:
print("⚠️ No HF_TOKEN found. Skipping rewrite.")
return user_prompt
print("🧠 Rewriting prompt via API...")
API_URL = "https://router.huggingface.co/v1/chat/completions"
headers = {"Authorization": f"Bearer {api_key}"}
# Combine the official Hunyuan System Prompt with the User Input
# The system prompt string contains a {} placeholder for the user input
full_instruction = i2v_rewrite_system_prompt.format(user_prompt)
base64_img = encode_image_to_base64(pil_image)
payload = {
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": full_instruction
},
{
"type": "image_url",
"image_url": {
"url": base64_img
}
}
]
}
],
"model": "Qwen/Qwen2.5-VL-7B-Instruct",
"max_tokens": 512,
"temperature": 0.7
}
try:
response = requests.post(API_URL, headers=headers, json=payload, timeout=30)
response.raise_for_status()
data = response.json()
rewritten = data["choices"][0]["message"]["content"]
print(f"✅ Rewritten: {rewritten[:50]}...")
return rewritten
except Exception as e:
print(f"❌ Rewrite failed: {e}")
return user_prompt
# --- Part 4: Model Initialization (CPU) ---
class ArgsNamespace:
def __init__(self):
self.sage_blocks_range = "0-53"
self.no_cache_block_id = "0-0"
self.use_sageattn = False
self.enable_torch_compile = False
self.enable_cache = False
self.cache_type = "deepcache"
self.cache_start_step = 11
self.cache_end_step = 45
self.total_steps = 50
self.cache_step_interval = 4
initialize_infer_state(ArgsNamespace())
print(f"⏳ Initializing Pipeline ({TRANSFORMER_VERSION})...")
try:
pipe = HunyuanVideo_1_5_Pipeline.create_pipeline(
pretrained_model_name_or_path=MODEL_DIR,
transformer_version=TRANSFORMER_VERSION,
enable_offloading=ENABLE_OFFLOADING,
enable_group_offloading=ENABLE_OFFLOADING,
transformer_dtype=DTYPE,
device=torch.device('cpu')
)
pipe.to('cuda')
print("✅ Model loaded into CPU RAM.")
except Exception as e:
print(f"❌ Failed to load model: {e}")
sys.exit(1)
def save_video_tensor(video_tensor, path, fps=24):
if isinstance(video_tensor, list): video_tensor = video_tensor[0]
if video_tensor.ndim == 5: video_tensor = video_tensor[0]
vid = (video_tensor * 255).clamp(0, 255).to(torch.uint8)
vid = vid.permute(1, 2, 3, 0).cpu().numpy()
imageio.mimwrite(path, vid, fps=fps)
# --- Part 5: Inference ---
@spaces.GPU(duration=120)
def generate(input_image, prompt, length, steps, shift, seed, guidance, do_rewrite, progress=gr.Progress(track_tqdm=True)):
if pipe is None: raise gr.Error("Pipeline not initialized!")
if input_image is None: raise gr.Error("Reference image required.")
# Process Input Image
if isinstance(input_image, np.ndarray):
pil_image = Image.fromarray(input_image).convert("RGB")
else:
pil_image = input_image.convert("RGB")
# 1. Prompt Rewrite (if enabled)
actual_prompt = prompt
if do_rewrite:
actual_prompt = rewrite_prompt_external(prompt, pil_image)
# 2. Setup Generator
if seed == -1: seed = torch.randint(0, 1000000, (1,)).item()
generator = torch.Generator(device="cpu").manual_seed(int(seed))
print(f"🚀 GPU Inference: {actual_prompt[:30]}... | Seed: {seed}")
try:
pipe.execution_device = torch.device("cuda")
output = pipe(
prompt=actual_prompt,
height=480, width=854, aspect_ratio="16:9",
video_length=int(length),
num_inference_steps=int(steps),
guidance_scale=float(guidance),
flow_shift=float(shift),
reference_image=pil_image,
seed=int(seed),
generator=generator,
output_type="pt",
enable_sr=False,
return_dict=True
)
except Exception as e:
print(f"Error: {e}")
raise gr.Error(f"Inference Failed: {e}")
timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
os.makedirs("outputs", exist_ok=True)
output_path = f"outputs/gen_{timestamp}.mp4"
save_video_tensor(output.videos, output_path)
return output_path, actual_prompt
# --- Part 6: UI ---
css = '''.gradio-container .app { max-width: 900px !important; margin: 0 auto; }
.dark .progress-text{color: white !important}'''
def create_ui():
with gr.Blocks(title="HunyuanVideo 1.5 I2V") as demo:
gr.Markdown(f"# 🎬 HunyuanVideo 1.5 I2V 480p distilled demo")
gr.Markdown(f"This is a demo for HunyuanVideo 1.5 I2v {TRANSFORMER_VERSION}, released together with a collection of 10 other checkpoints (text-to-video, 720p, upscalers). Check out the [HunyuanVideo-1.5 model page](https://huggingface.co/tencent/HunyuanVideo-1.5) for more")
with gr.Row():
with gr.Column():
img = gr.Image(label="Reference", type="pil")
prompt = gr.Textbox(label="Prompt", placeholder="Describe motion...", lines=2)
rewrite_chk = gr.Checkbox(label="Enable Prompt Rewrite (Strongly Recommended)", value=True)
with gr.Accordion("Advanced Options", open=False):
with gr.Row():
steps = gr.Slider(2, 50, value=6, step=1, label="Steps")
guidance = gr.Slider(1.0, 5.0, value=1.0, step=0.1, label="Guidance")
with gr.Row():
shift = gr.Slider(1.0, 20.0, value=5.0, step=0.5, label="Shift")
length = gr.Slider(1, 129, value=61, step=4, label="Length")
seed = gr.Number(value=-1, label="Seed", precision=0, info="-1 is a random seed")
btn = gr.Button("Generate", variant="primary")
with gr.Column():
out = gr.Video(label="Result", autoplay=True)
final_prompt_box = gr.Textbox(label="Actual Prompt Used", interactive=False)
btn.click(
generate,
inputs=[img, prompt, length, steps, shift, seed, guidance, rewrite_chk],
outputs=[out, final_prompt_box]
)
return demo
if __name__ == "__main__":
ui = create_ui()
ui.queue().launch(server_name="0.0.0.0", share=True, css=css) |