multimodalart HF Staff commited on
Commit
8f13c28
·
verified ·
1 Parent(s): d8659cc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +112 -76
app.py CHANGED
@@ -7,7 +7,9 @@ import numpy as np
7
  from PIL import Image
8
  import imageio
9
  import shutil
10
- import spaces
 
 
11
 
12
  # --- Part 1: Auto-Setup (Clone Repo & Download Weights) ---
13
 
@@ -18,29 +20,23 @@ MODEL_DIR = os.path.abspath("ckpts")
18
  # Repositories
19
  HF_MAIN_REPO = "tencent/HunyuanVideo-1.5"
20
  HF_GLYPH_REPO = "multimodalart/glyph-sdxl-v2-byt5-small"
21
- HF_LLM_REPO = "Qwen/Qwen2.5-VL-7B-Instruct"
22
- HF_VISION_REPO = "black-forest-labs/FLUX.1-Redux-dev" # User specified
23
 
24
  # Configuration
25
  TRANSFORMER_VERSION = "480p_i2v_distilled"
26
  DTYPE = torch.bfloat16
27
- # ZeroGPU: Set False so we control offloading manually (CPU -> GPU -> CPU)
28
  ENABLE_OFFLOADING = False
29
 
30
  def setup_environment():
31
  print("=" * 50)
32
  print("Checking Environment & Dependencies...")
33
 
34
- # 1. Clone Code Repository
35
  if not os.path.exists(REPO_DIR):
36
- print(f"Cloning repository to {REPO_DIR}...")
37
  subprocess.run(["git", "clone", REPO_URL, REPO_DIR], check=True)
38
 
39
- # 2. Add Repo to Python Path
40
  if REPO_DIR not in sys.path:
41
  sys.path.insert(0, REPO_DIR)
42
 
43
- # 3. Download Main Weights (Transformer, VAE, Scheduler)
44
  os.makedirs(MODEL_DIR, exist_ok=True)
45
  target_transformer = os.path.join(MODEL_DIR, "transformer", TRANSFORMER_VERSION)
46
 
@@ -52,52 +48,30 @@ def setup_environment():
52
  f"transformer/{TRANSFORMER_VERSION}/*",
53
  "vae/*",
54
  "scheduler/*",
55
- "tokenizer/*"
 
56
  ]
57
- snapshot_download(
58
- repo_id=HF_MAIN_REPO,
59
- local_dir=MODEL_DIR,
60
- allow_patterns=allow_patterns,
61
- local_dir_use_symlinks=False
62
- )
63
  except Exception as e:
64
  print(f"Error downloading main weights: {e}")
65
  sys.exit(1)
66
 
67
- # 4. Download LLM Text Encoder (Qwen)
68
- llm_target = os.path.join(MODEL_DIR, "text_encoder", "llm")
69
- if not os.path.exists(llm_target) or not os.listdir(llm_target):
70
- print(f"Downloading LLM Text Encoder from {HF_LLM_REPO}...")
71
- try:
72
- from huggingface_hub import snapshot_download
73
- snapshot_download(
74
- repo_id=HF_LLM_REPO,
75
- local_dir=llm_target,
76
- local_dir_use_symlinks=False
77
- )
78
- except Exception as e:
79
- print(f"Error downloading LLM: {e}")
80
-
81
- # 5. Download Vision Encoder (SigLIP)
82
  vision_target = os.path.join(MODEL_DIR, "vision_encoder", "siglip")
83
  if not os.path.exists(vision_target) or not os.listdir(vision_target):
84
  print(f"Downloading Vision Encoder from {HF_VISION_REPO}...")
85
  try:
86
  from huggingface_hub import snapshot_download
87
- snapshot_download(
88
- repo_id=HF_VISION_REPO,
89
- local_dir=vision_target,
90
- local_dir_use_symlinks=False
91
- )
92
  except Exception as e:
93
  print(f"Error downloading Vision Encoder: {e}")
94
 
95
- # 6. Download & Restructure Glyph Weights
96
  glyph_root = os.path.join(MODEL_DIR, "text_encoder", "Glyph-SDXL-v2")
97
  glyph_ckpt_target = os.path.join(glyph_root, "checkpoints", "byt5_model.pt")
98
 
99
  if not os.path.exists(glyph_ckpt_target):
100
- print(f"Downloading & Structuring Glyph Weights from {HF_GLYPH_REPO}...")
101
  try:
102
  from huggingface_hub import snapshot_download
103
  glyph_temp = os.path.join(MODEL_DIR, "glyph_temp")
@@ -106,55 +80,112 @@ def setup_environment():
106
  os.makedirs(os.path.join(glyph_root, "assets"), exist_ok=True)
107
  os.makedirs(os.path.join(glyph_root, "checkpoints"), exist_ok=True)
108
 
109
- # Move Assets
110
  src_assets = os.path.join(glyph_temp, "assets")
111
  if os.path.exists(src_assets):
112
  for f in os.listdir(src_assets):
113
  shutil.copy(os.path.join(src_assets, f), os.path.join(glyph_root, "assets", f))
114
 
115
- # Move Model
116
  src_bin = os.path.join(glyph_temp, "pytorch_model.bin")
117
  if os.path.exists(src_bin):
118
  shutil.move(src_bin, glyph_ckpt_target)
119
  else:
120
  src_safe = os.path.join(glyph_temp, "model.safetensors")
121
- if os.path.exists(src_safe):
122
- shutil.move(src_safe, glyph_ckpt_target)
123
 
124
  shutil.rmtree(glyph_temp, ignore_errors=True)
125
-
126
- except Exception as e:
127
- print(f"Error setting up Glyph weights: {e}")
128
 
129
  print("Environment Ready.")
130
  print("=" * 50)
131
 
132
  setup_environment()
133
 
134
- # --- Part 2: Imports & Monkey Patching ---
135
 
136
- # 1. Import Modules explicitly for patching
137
  try:
138
  import hyvideo.commons
139
  import hyvideo.pipelines.hunyuan_video_pipeline
140
  from hyvideo.pipelines.hunyuan_video_pipeline import HunyuanVideo_1_5_Pipeline
141
  from hyvideo.commons.infer_state import initialize_infer_state
 
 
 
142
  except ImportError as e:
143
  print(f"CRITICAL ERROR: {e}")
144
  sys.exit(1)
145
 
146
  import gradio as gr
147
 
148
- # 2. Apply ZeroGPU Monkey Patch
149
- # We must patch the specific modules where get_gpu_memory is imported/used
150
  def dummy_get_gpu_memory(device=None):
151
- return 80 * 1024 * 1024 * 1024 # Spoof 80GB
152
 
153
  print("🛠️ Applying ZeroGPU Monkey Patch...")
154
  hyvideo.commons.get_gpu_memory = dummy_get_gpu_memory
155
  hyvideo.pipelines.hunyuan_video_pipeline.get_gpu_memory = dummy_get_gpu_memory
156
 
157
- # --- Part 3: Model Initialization (CPU) ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
 
159
  class ArgsNamespace:
160
  def __init__(self):
@@ -166,7 +197,6 @@ initialize_infer_state(ArgsNamespace())
166
 
167
  print(f"⏳ Initializing Pipeline ({TRANSFORMER_VERSION})...")
168
  try:
169
- # Load to CPU explicitly
170
  pipe = HunyuanVideo_1_5_Pipeline.create_pipeline(
171
  pretrained_model_name_or_path=MODEL_DIR,
172
  transformer_version=TRANSFORMER_VERSION,
@@ -175,13 +205,11 @@ try:
175
  transformer_dtype=DTYPE,
176
  device=torch.device('cpu')
177
  )
 
178
  print("✅ Model loaded into CPU RAM.")
179
  except Exception as e:
180
  print(f"❌ Failed to load model: {e}")
181
- import traceback
182
- traceback.print_exc()
183
  sys.exit(1)
184
- pipe.to("cuda")
185
 
186
  def save_video_tensor(video_tensor, path, fps=24):
187
  if isinstance(video_tensor, list): video_tensor = video_tensor[0]
@@ -190,68 +218,70 @@ def save_video_tensor(video_tensor, path, fps=24):
190
  vid = vid.permute(1, 2, 3, 0).cpu().numpy()
191
  imageio.mimwrite(path, vid, fps=fps)
192
 
193
- # --- Part 4: Inference ---
194
 
195
  @spaces.GPU(duration=120)
196
- def generate(input_image, prompt, length, steps, shift, seed, guidance):
197
- if pipe is None:
198
- raise gr.Error("Pipeline not initialized!")
199
-
200
- if input_image is None:
201
- raise gr.Error("Reference image required.")
202
 
 
203
  if isinstance(input_image, np.ndarray):
204
- input_image = Image.fromarray(input_image).convert("RGB")
 
 
 
 
 
 
 
205
 
 
206
  if seed == -1: seed = torch.randint(0, 1000000, (1,)).item()
207
  generator = torch.Generator(device="cpu").manual_seed(int(seed))
208
 
209
- print(f"🚀 Moving Pipeline to GPU... (Prompt: {prompt})")
210
 
211
  try:
212
  pipe.execution_device = torch.device("cuda")
213
 
214
  output = pipe(
215
- prompt=prompt,
216
  height=480, width=854, aspect_ratio="16:9",
217
  video_length=int(length),
218
  num_inference_steps=int(steps),
219
  guidance_scale=float(guidance),
220
  flow_shift=float(shift),
221
- reference_image=input_image,
222
  seed=int(seed),
223
  generator=generator,
224
  output_type="pt",
225
  enable_sr=False,
226
  return_dict=True
227
  )
228
-
229
- # 4. Optional: Move back to CPU?
230
- # pipe.to("cpu")
231
-
232
  except Exception as e:
233
- print(f"Generation Error: {e}")
234
- import traceback
235
- traceback.print_exc()
236
  raise gr.Error(f"Inference Failed: {e}")
237
 
238
  timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
239
  os.makedirs("outputs", exist_ok=True)
240
  output_path = f"outputs/gen_{timestamp}.mp4"
241
  save_video_tensor(output.videos, output_path)
242
- return output_path
 
243
 
244
- # --- Part 5: UI ---
245
 
246
  def create_ui():
247
  with gr.Blocks(title="HunyuanVideo 1.5 I2V") as demo:
248
  gr.Markdown(f"### 🎬 HunyuanVideo 1.5 I2V ({TRANSFORMER_VERSION})")
249
- gr.Markdown("Running on ZeroGPU. Weights are pre-loaded on CPU.")
250
 
251
  with gr.Row():
252
  with gr.Column():
253
  img = gr.Image(label="Reference", type="pil", height=250)
254
  prompt = gr.Textbox(label="Prompt", placeholder="Describe motion...", lines=2)
 
 
255
  with gr.Row():
256
  steps = gr.Slider(2, 50, value=6, step=1, label="Steps")
257
  guidance = gr.Slider(1.0, 5.0, value=1.0, step=0.1, label="Guidance")
@@ -263,10 +293,16 @@ def create_ui():
263
 
264
  with gr.Column():
265
  out = gr.Video(label="Result", autoplay=True)
 
266
 
267
- btn.click(generate, inputs=[img, prompt, length, steps, shift, seed, guidance], outputs=[out])
 
 
 
 
268
  return demo
269
 
270
  if __name__ == "__main__":
 
271
  ui = create_ui()
272
  ui.queue().launch(server_name="0.0.0.0", share=True)
 
7
  from PIL import Image
8
  import imageio
9
  import shutil
10
+ import requests
11
+ import base64
12
+ import io
13
 
14
  # --- Part 1: Auto-Setup (Clone Repo & Download Weights) ---
15
 
 
20
  # Repositories
21
  HF_MAIN_REPO = "tencent/HunyuanVideo-1.5"
22
  HF_GLYPH_REPO = "multimodalart/glyph-sdxl-v2-byt5-small"
23
+ HF_VISION_REPO = "black-forest-labs/FLUX.1-Redux-dev"
 
24
 
25
  # Configuration
26
  TRANSFORMER_VERSION = "480p_i2v_distilled"
27
  DTYPE = torch.bfloat16
 
28
  ENABLE_OFFLOADING = False
29
 
30
  def setup_environment():
31
  print("=" * 50)
32
  print("Checking Environment & Dependencies...")
33
 
 
34
  if not os.path.exists(REPO_DIR):
 
35
  subprocess.run(["git", "clone", REPO_URL, REPO_DIR], check=True)
36
 
 
37
  if REPO_DIR not in sys.path:
38
  sys.path.insert(0, REPO_DIR)
39
 
 
40
  os.makedirs(MODEL_DIR, exist_ok=True)
41
  target_transformer = os.path.join(MODEL_DIR, "transformer", TRANSFORMER_VERSION)
42
 
 
48
  f"transformer/{TRANSFORMER_VERSION}/*",
49
  "vae/*",
50
  "scheduler/*",
51
+ "tokenizer/*",
52
+ "text_encoder/*" # Download LLM here too to simplify
53
  ]
54
+ snapshot_download(repo_id=HF_MAIN_REPO, local_dir=MODEL_DIR, allow_patterns=allow_patterns, local_dir_use_symlinks=False)
 
 
 
 
 
55
  except Exception as e:
56
  print(f"Error downloading main weights: {e}")
57
  sys.exit(1)
58
 
59
+ # Vision Encoder
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  vision_target = os.path.join(MODEL_DIR, "vision_encoder", "siglip")
61
  if not os.path.exists(vision_target) or not os.listdir(vision_target):
62
  print(f"Downloading Vision Encoder from {HF_VISION_REPO}...")
63
  try:
64
  from huggingface_hub import snapshot_download
65
+ snapshot_download(repo_id=HF_VISION_REPO, local_dir=vision_target, local_dir_use_symlinks=False)
 
 
 
 
66
  except Exception as e:
67
  print(f"Error downloading Vision Encoder: {e}")
68
 
69
+ # Glyph Weights
70
  glyph_root = os.path.join(MODEL_DIR, "text_encoder", "Glyph-SDXL-v2")
71
  glyph_ckpt_target = os.path.join(glyph_root, "checkpoints", "byt5_model.pt")
72
 
73
  if not os.path.exists(glyph_ckpt_target):
74
+ print(f"Downloading Glyph Weights from {HF_GLYPH_REPO}...")
75
  try:
76
  from huggingface_hub import snapshot_download
77
  glyph_temp = os.path.join(MODEL_DIR, "glyph_temp")
 
80
  os.makedirs(os.path.join(glyph_root, "assets"), exist_ok=True)
81
  os.makedirs(os.path.join(glyph_root, "checkpoints"), exist_ok=True)
82
 
 
83
  src_assets = os.path.join(glyph_temp, "assets")
84
  if os.path.exists(src_assets):
85
  for f in os.listdir(src_assets):
86
  shutil.copy(os.path.join(src_assets, f), os.path.join(glyph_root, "assets", f))
87
 
 
88
  src_bin = os.path.join(glyph_temp, "pytorch_model.bin")
89
  if os.path.exists(src_bin):
90
  shutil.move(src_bin, glyph_ckpt_target)
91
  else:
92
  src_safe = os.path.join(glyph_temp, "model.safetensors")
93
+ if os.path.exists(src_safe): shutil.move(src_safe, glyph_ckpt_target)
 
94
 
95
  shutil.rmtree(glyph_temp, ignore_errors=True)
96
+ except Exception:
97
+ pass
 
98
 
99
  print("Environment Ready.")
100
  print("=" * 50)
101
 
102
  setup_environment()
103
 
104
+ # --- Part 2: Imports & Patching ---
105
 
 
106
  try:
107
  import hyvideo.commons
108
  import hyvideo.pipelines.hunyuan_video_pipeline
109
  from hyvideo.pipelines.hunyuan_video_pipeline import HunyuanVideo_1_5_Pipeline
110
  from hyvideo.commons.infer_state import initialize_infer_state
111
+ # Import the specific I2V System Prompt from the repo
112
+ from hyvideo.utils.rewrite.i2v_prompt import i2v_rewrite_system_prompt
113
+ import spaces
114
  except ImportError as e:
115
  print(f"CRITICAL ERROR: {e}")
116
  sys.exit(1)
117
 
118
  import gradio as gr
119
 
 
 
120
  def dummy_get_gpu_memory(device=None):
121
+ return 80 * 1024 * 1024 * 1024
122
 
123
  print("🛠️ Applying ZeroGPU Monkey Patch...")
124
  hyvideo.commons.get_gpu_memory = dummy_get_gpu_memory
125
  hyvideo.pipelines.hunyuan_video_pipeline.get_gpu_memory = dummy_get_gpu_memory
126
 
127
+ # --- Part 3: Prompt Rewrite Logic (External API) ---
128
+
129
+ def encode_image_to_base64(pil_image):
130
+ buffered = io.BytesIO()
131
+ pil_image.save(buffered, format="JPEG")
132
+ img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
133
+ return f"data:image/jpeg;base64,{img_str}"
134
+
135
+ def rewrite_prompt_external(user_prompt, pil_image):
136
+ """Calls HF Router API to rewrite prompt using Qwen2.5-VL"""
137
+
138
+ api_key = os.environ.get("HF_TOKEN")
139
+ if not api_key:
140
+ print("⚠️ No HF_TOKEN found. Skipping rewrite.")
141
+ return user_prompt
142
+
143
+ print("🧠 Rewriting prompt via API...")
144
+
145
+ API_URL = "https://router.huggingface.co/v1/chat/completions"
146
+ headers = {"Authorization": f"Bearer {api_key}"}
147
+
148
+ # Combine the official Hunyuan System Prompt with the User Input
149
+ # The system prompt string contains a {} placeholder for the user input
150
+ full_instruction = i2v_rewrite_system_prompt.format(user_prompt)
151
+
152
+ base64_img = encode_image_to_base64(pil_image)
153
+
154
+ payload = {
155
+ "messages": [
156
+ {
157
+ "role": "user",
158
+ "content": [
159
+ {
160
+ "type": "text",
161
+ "text": full_instruction
162
+ },
163
+ {
164
+ "type": "image_url",
165
+ "image_url": {
166
+ "url": base64_img
167
+ }
168
+ }
169
+ ]
170
+ }
171
+ ],
172
+ "model": "Qwen/Qwen2.5-VL-7B-Instruct",
173
+ "max_tokens": 512,
174
+ "temperature": 0.7
175
+ }
176
+
177
+ try:
178
+ response = requests.post(API_URL, headers=headers, json=payload, timeout=30)
179
+ response.raise_for_status()
180
+ data = response.json()
181
+ rewritten = data["choices"][0]["message"]["content"]
182
+ print(f"✅ Rewritten: {rewritten[:50]}...")
183
+ return rewritten
184
+ except Exception as e:
185
+ print(f"❌ Rewrite failed: {e}")
186
+ return user_prompt
187
+
188
+ # --- Part 4: Model Initialization (CPU) ---
189
 
190
  class ArgsNamespace:
191
  def __init__(self):
 
197
 
198
  print(f"⏳ Initializing Pipeline ({TRANSFORMER_VERSION})...")
199
  try:
 
200
  pipe = HunyuanVideo_1_5_Pipeline.create_pipeline(
201
  pretrained_model_name_or_path=MODEL_DIR,
202
  transformer_version=TRANSFORMER_VERSION,
 
205
  transformer_dtype=DTYPE,
206
  device=torch.device('cpu')
207
  )
208
+ pipe.to('cuda')
209
  print("✅ Model loaded into CPU RAM.")
210
  except Exception as e:
211
  print(f"❌ Failed to load model: {e}")
 
 
212
  sys.exit(1)
 
213
 
214
  def save_video_tensor(video_tensor, path, fps=24):
215
  if isinstance(video_tensor, list): video_tensor = video_tensor[0]
 
218
  vid = vid.permute(1, 2, 3, 0).cpu().numpy()
219
  imageio.mimwrite(path, vid, fps=fps)
220
 
221
+ # --- Part 5: Inference ---
222
 
223
  @spaces.GPU(duration=120)
224
+ def generate(input_image, prompt, length, steps, shift, seed, guidance, do_rewrite, progress=gr.Progress(track_tqdm=True)):
225
+ if pipe is None: raise gr.Error("Pipeline not initialized!")
226
+ if input_image is None: raise gr.Error("Reference image required.")
 
 
 
227
 
228
+ # Process Input Image
229
  if isinstance(input_image, np.ndarray):
230
+ pil_image = Image.fromarray(input_image).convert("RGB")
231
+ else:
232
+ pil_image = input_image.convert("RGB")
233
+
234
+ # 1. Prompt Rewrite (if enabled)
235
+ actual_prompt = prompt
236
+ if do_rewrite:
237
+ actual_prompt = rewrite_prompt_external(prompt, pil_image)
238
 
239
+ # 2. Setup Generator
240
  if seed == -1: seed = torch.randint(0, 1000000, (1,)).item()
241
  generator = torch.Generator(device="cpu").manual_seed(int(seed))
242
 
243
+ print(f"🚀 GPU Inference: {actual_prompt[:30]}... | Seed: {seed}")
244
 
245
  try:
246
  pipe.execution_device = torch.device("cuda")
247
 
248
  output = pipe(
249
+ prompt=actual_prompt,
250
  height=480, width=854, aspect_ratio="16:9",
251
  video_length=int(length),
252
  num_inference_steps=int(steps),
253
  guidance_scale=float(guidance),
254
  flow_shift=float(shift),
255
+ reference_image=pil_image,
256
  seed=int(seed),
257
  generator=generator,
258
  output_type="pt",
259
  enable_sr=False,
260
  return_dict=True
261
  )
 
 
 
 
262
  except Exception as e:
263
+ print(f"Error: {e}")
 
 
264
  raise gr.Error(f"Inference Failed: {e}")
265
 
266
  timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")
267
  os.makedirs("outputs", exist_ok=True)
268
  output_path = f"outputs/gen_{timestamp}.mp4"
269
  save_video_tensor(output.videos, output_path)
270
+
271
+ return output_path, actual_prompt
272
 
273
+ # --- Part 6: UI ---
274
 
275
  def create_ui():
276
  with gr.Blocks(title="HunyuanVideo 1.5 I2V") as demo:
277
  gr.Markdown(f"### 🎬 HunyuanVideo 1.5 I2V ({TRANSFORMER_VERSION})")
 
278
 
279
  with gr.Row():
280
  with gr.Column():
281
  img = gr.Image(label="Reference", type="pil", height=250)
282
  prompt = gr.Textbox(label="Prompt", placeholder="Describe motion...", lines=2)
283
+ rewrite_chk = gr.Checkbox(label="Enable Prompt Rewrite (Recommended)", value=True)
284
+
285
  with gr.Row():
286
  steps = gr.Slider(2, 50, value=6, step=1, label="Steps")
287
  guidance = gr.Slider(1.0, 5.0, value=1.0, step=0.1, label="Guidance")
 
293
 
294
  with gr.Column():
295
  out = gr.Video(label="Result", autoplay=True)
296
+ final_prompt_box = gr.Textbox(label="Actual Prompt Used", interactive=False)
297
 
298
+ btn.click(
299
+ generate,
300
+ inputs=[img, prompt, length, steps, shift, seed, guidance, rewrite_chk],
301
+ outputs=[out, final_prompt_box]
302
+ )
303
  return demo
304
 
305
  if __name__ == "__main__":
306
+ pre_load_model()
307
  ui = create_ui()
308
  ui.queue().launch(server_name="0.0.0.0", share=True)