victor HF Staff commited on
Commit
3e2d0e9
·
verified ·
1 Parent(s): 5e87129

refactor: defer model loads inside @spaces.GPU (avoid CPU OOM)

Browse files
Files changed (1) hide show
  1. app.py +116 -135
app.py CHANGED
@@ -1,7 +1,7 @@
1
  """Gradio ZeroGPU Space for LongCat-Video-Avatar 1.5 (single-person AI2V).
2
 
3
- Loads the INT8-quantized DiT + DMD2 8-step LoRA + Whisper-Large-v3 audio encoder
4
- and exposes one inference function: reference image + audio + prompt -> mp4.
5
  """
6
 
7
  # IMPORTANT: spaces must be imported before torch (per HF guide).
@@ -27,7 +27,7 @@ WEIGHTS_DIR = Path(os.environ.get("WEIGHTS_DIR", DEFAULT_WEIGHTS))
27
  WEIGHTS_DIR.mkdir(parents=True, exist_ok=True)
28
  BASE_DIR = WEIGHTS_DIR / "LongCat-Video"
29
  AVATAR_DIR = WEIGHTS_DIR / "LongCat-Video-Avatar-1.5"
30
- print(f"[boot] WEIGHTS_DIR={WEIGHTS_DIR}")
31
 
32
  # Make vendored package importable
33
  sys.path.insert(0, str(Path(__file__).parent.resolve()))
@@ -42,15 +42,13 @@ from PIL import Image
42
 
43
  # ---------------------------------------------------------------------------
44
  # 0) Replace xformers.memory_efficient_attention with a PyTorch-SDPA shim.
45
- # xformers wheels for torch 2.12+cu130 aren't available; SDPA is always
46
- # in-tree, fast on Blackwell, and matches the inputs the model passes.
47
  # ---------------------------------------------------------------------------
48
 
49
  def _install_sdpa_shim():
50
- import xformers.ops # the package exists; only its CUDA ext is broken
51
 
52
- # Replace BlockDiagonalMask with a thin record so we don't depend on
53
- # the xformers internal layout for cumulative seq starts.
54
  class _BDShim:
55
  def __init__(self, q_seqlen, kv_seqlen):
56
  self.q_seqlen = list(q_seqlen)
@@ -63,46 +61,40 @@ def _install_sdpa_shim():
63
  xformers.ops.fmha.attn_bias.BlockDiagonalMask = _BDShim
64
 
65
  def _meff(q, k, v, attn_bias=None, op=None, **_):
66
- # xformers convention: q, k, v are [B, M, H, D]
67
  if attn_bias is None:
68
  q_ = q.transpose(1, 2).contiguous()
69
  k_ = k.transpose(1, 2).contiguous()
70
  v_ = v.transpose(1, 2).contiguous()
71
- out = F.scaled_dot_product_attention(q_, k_, v_)
72
- return out.transpose(1, 2)
73
  if isinstance(attn_bias, _BDShim):
74
- # Variable-length cross-attention: batch elements concatenated
75
- # along seq dim. Loop per-element using SDPA.
76
- outs = []
77
- q_off = k_off = 0
78
  for q_len, k_len in zip(attn_bias.q_seqlen, attn_bias.kv_seqlen):
79
  q_b = q[:, q_off:q_off + q_len].transpose(1, 2).contiguous()
80
  k_b = k[:, k_off:k_off + k_len].transpose(1, 2).contiguous()
81
  v_b = v[:, k_off:k_off + k_len].transpose(1, 2).contiguous()
82
- o = F.scaled_dot_product_attention(q_b, k_b, v_b)
83
- outs.append(o.transpose(1, 2))
84
  q_off += q_len
85
  k_off += k_len
86
  return torch.cat(outs, dim=1)
87
  raise NotImplementedError(f"Unsupported attn_bias in SDPA shim: {type(attn_bias)}")
88
 
89
  xformers.ops.memory_efficient_attention = _meff
90
- print("[boot] installed xformers→SDPA shim")
91
 
92
 
93
  _install_sdpa_shim()
94
 
95
 
96
  # ---------------------------------------------------------------------------
97
- # 1) Download weights (one-time per container if /data is persistent)
98
  # ---------------------------------------------------------------------------
99
 
100
  def _ensure_weights():
101
  token = os.environ.get("HF_TOKEN")
102
- # We only need text_encoder / vae / tokenizer from the base LongCat-Video repo.
103
  base_marker = BASE_DIR / "vae" / "config.json"
104
  if not base_marker.exists():
105
- print("[boot] downloading LongCat-Video (vae/text_encoder/tokenizer)…")
106
  snapshot_download(
107
  "meituan-longcat/LongCat-Video",
108
  local_dir=str(BASE_DIR),
@@ -126,7 +118,7 @@ def _ensure_weights():
126
 
127
  avatar_marker = AVATAR_DIR / "base_model_int8" / "config.json"
128
  if not avatar_marker.exists():
129
- print("[boot] downloading LongCat-Video-Avatar-1.5 (INT8 + lora + whisper + vocal_separator)…")
130
  snapshot_download(
131
  "meituan-longcat/LongCat-Video-Avatar-1.5",
132
  local_dir=str(AVATAR_DIR),
@@ -136,28 +128,25 @@ def _ensure_weights():
136
  "lora/*",
137
  "scheduler/*",
138
  "vocal_separator/*",
139
- # Whisper-Large-v3: only the bf16 safetensors + tokenizer/config files
140
  "whisper-large-v3/model.safetensors",
141
  "whisper-large-v3/*.json",
142
  "whisper-large-v3/*.txt",
143
  ],
144
  ignore_patterns=[
145
- # Drop the fp32 sharded copies, flax, TF, and pickled legacy weights
146
  "whisper-large-v3/model.fp32*",
147
  "whisper-large-v3/flax_model*",
148
  "whisper-large-v3/tf_model*",
149
  "whisper-large-v3/pytorch_model*",
150
  ],
151
  )
152
- print("[boot] weights ready.")
153
 
154
 
155
  _ensure_weights()
156
 
157
 
158
  # ---------------------------------------------------------------------------
159
- # 2) Patch DiT config: prefer xformers, disable flash-attn (not buildable on
160
- # ZeroGPU's Blackwell sm_120). Both base and int8 configs share these flags.
161
  # ---------------------------------------------------------------------------
162
 
163
  def _patch_dit_config():
@@ -175,87 +164,87 @@ def _patch_dit_config():
175
  changed = True
176
  if changed:
177
  cfg_path.write_text(json.dumps(cfg, indent=2))
178
- print(f"[boot] patched {cfg_path.name} -> xformers backend")
179
 
180
 
181
  _patch_dit_config()
182
 
183
 
184
  # ---------------------------------------------------------------------------
185
- # 3) Load models on CPU at module level (spaces moves them to GPU on demand)
186
  # ---------------------------------------------------------------------------
187
 
188
- from transformers import AutoTokenizer, UMT5EncoderModel # noqa: E402
 
189
 
190
- from longcat_video.pipeline_longcat_video_avatar import LongCatVideoAvatarPipeline # noqa: E402
191
- from longcat_video.modules.scheduling_flow_match_euler_discrete import ( # noqa: E402
192
- FlowMatchEulerDiscreteScheduler,
193
- )
194
- from longcat_video.modules.autoencoder_kl_wan import AutoencoderKLWan # noqa: E402
195
- from longcat_video.modules.quantization import load_quantized_dit # noqa: E402
196
- from longcat_video.audio_process import ( # noqa: E402
197
- get_audio_encoder,
198
- get_audio_feature_extractor,
199
- )
200
- from longcat_video.audio_process.torch_utils import save_video_ffmpeg # noqa: E402
201
 
202
- CP_SPLIT_HW = [1, 1] # single-GPU, no context-parallel split
 
 
 
 
203
 
204
- print("[boot] loading tokenizer + text_encoder (UMT5-XXL)…")
205
- tokenizer = AutoTokenizer.from_pretrained(str(BASE_DIR), subfolder="tokenizer")
206
- text_encoder = UMT5EncoderModel.from_pretrained(
207
- str(BASE_DIR),
208
- subfolder="text_encoder",
209
- torch_dtype=torch.bfloat16,
210
- low_cpu_mem_usage=True,
211
- )
212
 
213
- print("[boot] loading VAE (Wan)…")
214
- vae = AutoencoderKLWan.from_pretrained(str(BASE_DIR), subfolder="vae", torch_dtype=torch.bfloat16)
215
-
216
- print("[boot] loading scheduler…")
217
- scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(str(AVATAR_DIR), subfolder="scheduler")
218
-
219
- print("[boot] loading INT8 DiT + DMD2 LoRA…")
220
- dit = load_quantized_dit(str(AVATAR_DIR), subfolder="base_model_int8", cp_split_hw=CP_SPLIT_HW)
221
- _lora_path = AVATAR_DIR / "lora" / "dmd_lora.safetensors"
222
- if _lora_path.exists():
223
- dit.load_lora(str(_lora_path), "dmd", multiplier=1.0, lora_network_dim=128, lora_network_alpha=64)
224
- dit.enable_loras(["dmd"])
225
- print("[boot] DMD2 LoRA enabled (8-step distillation)")
226
-
227
- print("[boot] loading Whisper-Large-v3 audio encoder…")
228
- audio_encoder = get_audio_encoder(str(AVATAR_DIR / "whisper-large-v3"), "avatar-v1.5")
229
- audio_feature_extractor = get_audio_feature_extractor(str(AVATAR_DIR / "whisper-large-v3"), "avatar-v1.5")
230
-
231
- print("[boot] loading vocal separator (Kim_Vocal_2)…")
232
- from audio_separator.separator import Separator # noqa: E402
233
-
234
- VOCAL_TMP = Path("/tmp/vocal_out")
235
- VOCAL_TMP.mkdir(parents=True, exist_ok=True)
236
- vocal_separator = Separator(
237
- output_dir=str(VOCAL_TMP / "vocals"),
238
- output_single_stem="vocals",
239
- model_file_dir=str(AVATAR_DIR / "vocal_separator"),
240
- )
241
- vocal_separator.load_model("Kim_Vocal_2.onnx")
242
-
243
- print("[boot] assembling pipeline…")
244
- pipe = LongCatVideoAvatarPipeline(
245
- tokenizer=tokenizer,
246
- text_encoder=text_encoder,
247
- vae=vae,
248
- scheduler=scheduler,
249
- dit=dit,
250
- audio_encoder=audio_encoder,
251
- audio_feature_extractor=audio_feature_extractor,
252
- model_type="avatar-v1.5",
253
- )
254
- print("[boot] ready.")
 
 
 
 
 
 
 
 
255
 
256
 
257
  # ---------------------------------------------------------------------------
258
- # 4) Inference helpers
259
  # ---------------------------------------------------------------------------
260
 
261
  NEGATIVE_PROMPT = (
@@ -269,13 +258,14 @@ NEGATIVE_PROMPT = (
269
 
270
 
271
  def _extract_vocal(src: str) -> str:
272
- """Run the vocal separator; return path to vocals-only wav, or src if it fails."""
 
273
  try:
274
- outputs = vocal_separator.separate(src)
275
  if outputs:
276
- return str((VOCAL_TMP / "vocals" / outputs[0]).resolve())
277
  except Exception as e:
278
- print(f"[vocal] separation failed, using raw audio: {e}")
279
  return src
280
 
281
 
@@ -283,13 +273,12 @@ def _extract_vocal(src: str) -> str:
283
  # 5) GPU-bound inference function
284
  # ---------------------------------------------------------------------------
285
 
286
- @spaces.GPU(duration=300)
287
  def generate(
288
  image_path: str,
289
  audio_path: str,
290
  prompt: str,
291
  resolution: str,
292
- audio_cfg: float,
293
  seed: int,
294
  progress=gr.Progress(track_tqdm=True),
295
  ):
@@ -297,75 +286,73 @@ def generate(
297
  raise gr.Error("Please upload a reference image.")
298
  if not audio_path:
299
  raise gr.Error("Please upload an audio clip.")
300
- if not prompt or not prompt.strip():
301
- prompt = "A person is talking naturally."
302
 
303
- # Move pipeline onto GPU for the duration of this call.
304
- # NB: pipe.to() moves dit / text_encoder / vae but not audio_encoder,
305
- # so we move Whisper explicitly here.
306
- pipe.to("cuda")
307
- audio_encoder.to("cuda")
 
 
 
308
 
309
- width, height = (832, 480) if resolution == "480p" else (1280, 768)
310
  save_fps = 25
311
  audio_stride = 1
312
- num_frames = 93 # one 93-frame segment (~3.7s @ 25fps)
313
-
314
- import librosa # local import to keep boot fast
315
 
316
  # 1) Vocal isolation
317
- progress(0.05, desc="Isolating vocals…")
318
  vocal_path = _extract_vocal(audio_path)
319
 
320
- # 2) Pad audio to required duration
321
  speech, sr = librosa.load(vocal_path, sr=16000)
322
- target_duration = num_frames / save_fps
323
- pad = math.ceil((target_duration - len(speech) / sr) * sr)
324
  if pad > 0:
325
  speech = np.concatenate([speech, np.zeros(pad, dtype=speech.dtype)])
326
 
327
  # 3) Whisper audio embedding
328
- progress(0.15, desc="Encoding audio (Whisper-Large-v3)…")
329
- full_audio_emb = pipe.get_audio_embedding(
330
  speech, fps=save_fps * audio_stride, device="cuda", sample_rate=sr, model_type="avatar-v1.5"
331
  )
332
  if torch.isnan(full_audio_emb).any():
333
- raise gr.Error("Audio embedding contains NaN — try a cleaner audio clip.")
334
 
335
- # 4) Build per-frame windowed audio tensor: [1, T, 5, 5, D]
336
- indices = torch.arange(2 * 2 + 1) - 2 # 5-frame window centered on each latent frame
337
  center = torch.arange(0, audio_stride * num_frames, audio_stride).unsqueeze(1) + indices.unsqueeze(0)
338
  center = torch.clamp(center, min=0, max=full_audio_emb.shape[0] - 1)
339
  audio_emb = full_audio_emb[center][None, ...].to("cuda")
340
 
341
- # 5) Run AI2V generation (8 steps thanks to DMD2 LoRA)
342
  progress(0.30, desc="Generating video (DMD2 8-step)…")
343
  image = Image.open(image_path).convert("RGB")
344
  generator = torch.Generator(device="cuda").manual_seed(int(seed))
345
 
346
- output, _latent = pipe.generate_ai2v(
347
  image=image,
348
- prompt=prompt.strip(),
349
  negative_prompt=NEGATIVE_PROMPT,
350
  resolution=resolution,
351
  num_frames=num_frames,
352
  num_inference_steps=8,
353
  text_guidance_scale=1.0,
354
- audio_guidance_scale=float(audio_cfg),
355
  output_type="both",
356
  generator=generator,
357
  audio_emb=audio_emb,
358
  use_distill=True,
359
  )
360
 
361
- # 6) Save with audio
362
  progress(0.92, desc="Muxing audio + video…")
363
  frames = (output[0] * 255).astype(np.uint8)
364
  out_tensor = torch.from_numpy(frames)
365
  out_base = Path(tempfile.gettempdir()) / f"longcat_{uuid.uuid4().hex[:8]}"
366
  save_video_ffmpeg(out_tensor, str(out_base), audio_path, fps=save_fps, quality=5)
367
  out_path = f"{out_base}.mp4"
368
- print(f"[gen] wrote {out_path}")
369
  return out_path
370
 
371
 
@@ -383,7 +370,6 @@ if (EXAMPLE_DIR / "man.png").exists() and (EXAMPLE_DIR / "man.mp3").exists():
383
  "their mouth. Wearing a vibrant red jacket with gold embroidery, the singer is speaking "
384
  "while smoke swirls around them, creating a dynamic and atmospheric scene.",
385
  "480p",
386
- 4.0,
387
  42,
388
  ])
389
 
@@ -395,6 +381,8 @@ with gr.Blocks(title="LongCat-Video-Avatar 1.5") as demo:
395
  Upload a **reference image** + **audio clip** + a short **text prompt**.
396
  The model generates a ~3.7-second lip-synced video using Meituan's
397
  LongCat-Video-Avatar 1.5 (INT8 DiT + DMD2 8-step distilled).
 
 
398
  """
399
  )
400
  with gr.Row():
@@ -409,13 +397,6 @@ with gr.Blocks(title="LongCat-Video-Avatar 1.5") as demo:
409
  with gr.Row():
410
  resolution = gr.Radio(["480p", "720p"], value="480p", label="Resolution")
411
  seed = gr.Number(value=42, precision=0, label="Seed")
412
- audio_cfg = gr.Slider(
413
- 1.0,
414
- 6.0,
415
- value=4.0,
416
- step=0.5,
417
- label="Audio CFG (higher = stronger lip sync, 3–5 recommended)",
418
- )
419
  go = gr.Button("Generate", variant="primary")
420
  with gr.Column():
421
  video_out = gr.Video(label="Output", autoplay=True)
@@ -423,7 +404,7 @@ with gr.Blocks(title="LongCat-Video-Avatar 1.5") as demo:
423
  if EXAMPLES:
424
  gr.Examples(
425
  examples=EXAMPLES,
426
- inputs=[image_in, audio_in, prompt, resolution, audio_cfg, seed],
427
  outputs=video_out,
428
  fn=generate,
429
  cache_examples=False,
@@ -431,7 +412,7 @@ with gr.Blocks(title="LongCat-Video-Avatar 1.5") as demo:
431
 
432
  go.click(
433
  generate,
434
- inputs=[image_in, audio_in, prompt, resolution, audio_cfg, seed],
435
  outputs=video_out,
436
  )
437
 
 
1
  """Gradio ZeroGPU Space for LongCat-Video-Avatar 1.5 (single-person AI2V).
2
 
3
+ Lazy-loads the INT8 DiT + DMD2 8-step LoRA + Whisper-Large-v3 inside the first
4
+ @spaces.GPU call (CPU RAM on ZeroGPU is too small for 22GB UMT5-XXL + 14GB DiT).
5
  """
6
 
7
  # IMPORTANT: spaces must be imported before torch (per HF guide).
 
27
  WEIGHTS_DIR.mkdir(parents=True, exist_ok=True)
28
  BASE_DIR = WEIGHTS_DIR / "LongCat-Video"
29
  AVATAR_DIR = WEIGHTS_DIR / "LongCat-Video-Avatar-1.5"
30
+ print(f"[boot] WEIGHTS_DIR={WEIGHTS_DIR}", flush=True)
31
 
32
  # Make vendored package importable
33
  sys.path.insert(0, str(Path(__file__).parent.resolve()))
 
42
 
43
  # ---------------------------------------------------------------------------
44
  # 0) Replace xformers.memory_efficient_attention with a PyTorch-SDPA shim.
45
+ # xformers wheels for torch 2.12+cu130 aren't published; SDPA is always
46
+ # in-tree and matches the inputs the model passes.
47
  # ---------------------------------------------------------------------------
48
 
49
  def _install_sdpa_shim():
50
+ import xformers.ops
51
 
 
 
52
  class _BDShim:
53
  def __init__(self, q_seqlen, kv_seqlen):
54
  self.q_seqlen = list(q_seqlen)
 
61
  xformers.ops.fmha.attn_bias.BlockDiagonalMask = _BDShim
62
 
63
  def _meff(q, k, v, attn_bias=None, op=None, **_):
64
+ # xformers convention: [B, M, H, D]; SDPA wants [B, H, M, D].
65
  if attn_bias is None:
66
  q_ = q.transpose(1, 2).contiguous()
67
  k_ = k.transpose(1, 2).contiguous()
68
  v_ = v.transpose(1, 2).contiguous()
69
+ return F.scaled_dot_product_attention(q_, k_, v_).transpose(1, 2)
 
70
  if isinstance(attn_bias, _BDShim):
71
+ outs, q_off, k_off = [], 0, 0
 
 
 
72
  for q_len, k_len in zip(attn_bias.q_seqlen, attn_bias.kv_seqlen):
73
  q_b = q[:, q_off:q_off + q_len].transpose(1, 2).contiguous()
74
  k_b = k[:, k_off:k_off + k_len].transpose(1, 2).contiguous()
75
  v_b = v[:, k_off:k_off + k_len].transpose(1, 2).contiguous()
76
+ outs.append(F.scaled_dot_product_attention(q_b, k_b, v_b).transpose(1, 2))
 
77
  q_off += q_len
78
  k_off += k_len
79
  return torch.cat(outs, dim=1)
80
  raise NotImplementedError(f"Unsupported attn_bias in SDPA shim: {type(attn_bias)}")
81
 
82
  xformers.ops.memory_efficient_attention = _meff
83
+ print("[boot] installed xformers→SDPA shim", flush=True)
84
 
85
 
86
  _install_sdpa_shim()
87
 
88
 
89
  # ---------------------------------------------------------------------------
90
+ # 1) Download weights (one-time per container thanks to /data bucket)
91
  # ---------------------------------------------------------------------------
92
 
93
  def _ensure_weights():
94
  token = os.environ.get("HF_TOKEN")
 
95
  base_marker = BASE_DIR / "vae" / "config.json"
96
  if not base_marker.exists():
97
+ print("[boot] downloading LongCat-Video (vae/text_encoder/tokenizer)…", flush=True)
98
  snapshot_download(
99
  "meituan-longcat/LongCat-Video",
100
  local_dir=str(BASE_DIR),
 
118
 
119
  avatar_marker = AVATAR_DIR / "base_model_int8" / "config.json"
120
  if not avatar_marker.exists():
121
+ print("[boot] downloading LongCat-Video-Avatar-1.5 (INT8 + lora + whisper + vocal_separator)…", flush=True)
122
  snapshot_download(
123
  "meituan-longcat/LongCat-Video-Avatar-1.5",
124
  local_dir=str(AVATAR_DIR),
 
128
  "lora/*",
129
  "scheduler/*",
130
  "vocal_separator/*",
 
131
  "whisper-large-v3/model.safetensors",
132
  "whisper-large-v3/*.json",
133
  "whisper-large-v3/*.txt",
134
  ],
135
  ignore_patterns=[
 
136
  "whisper-large-v3/model.fp32*",
137
  "whisper-large-v3/flax_model*",
138
  "whisper-large-v3/tf_model*",
139
  "whisper-large-v3/pytorch_model*",
140
  ],
141
  )
142
+ print("[boot] weights ready.", flush=True)
143
 
144
 
145
  _ensure_weights()
146
 
147
 
148
  # ---------------------------------------------------------------------------
149
+ # 2) Patch DiT config: prefer xformers (now our SDPA shim) over flash-attn.
 
150
  # ---------------------------------------------------------------------------
151
 
152
  def _patch_dit_config():
 
164
  changed = True
165
  if changed:
166
  cfg_path.write_text(json.dumps(cfg, indent=2))
167
+ print(f"[boot] patched {cfg_path.name} -> xformers/SDPA backend", flush=True)
168
 
169
 
170
  _patch_dit_config()
171
 
172
 
173
  # ---------------------------------------------------------------------------
174
+ # 3) Lazy pipeline cache. Built inside the first @spaces.GPU call.
175
  # ---------------------------------------------------------------------------
176
 
177
+ _PIPE = None
178
+ _VOCAL = None
179
 
 
 
 
 
 
 
 
 
 
 
 
180
 
181
+ def _build_pipeline():
182
+ """Construct the whole pipeline directly on the GPU (avoids CPU RAM cap)."""
183
+ global _PIPE
184
+ print("[load] building pipeline (first call may take ~60s)…", flush=True)
185
+ t0 = time.time()
186
 
187
+ from transformers import AutoTokenizer, UMT5EncoderModel
188
+ from longcat_video.pipeline_longcat_video_avatar import LongCatVideoAvatarPipeline
189
+ from longcat_video.modules.scheduling_flow_match_euler_discrete import FlowMatchEulerDiscreteScheduler
190
+ from longcat_video.modules.autoencoder_kl_wan import AutoencoderKLWan
191
+ from longcat_video.modules.quantization import load_quantized_dit
192
+ from longcat_video.audio_process import get_audio_encoder, get_audio_feature_extractor
 
 
193
 
194
+ cp_split_hw = [1, 1]
195
+ dtype = torch.bfloat16
196
+
197
+ tokenizer = AutoTokenizer.from_pretrained(str(BASE_DIR), subfolder="tokenizer")
198
+ text_encoder = UMT5EncoderModel.from_pretrained(
199
+ str(BASE_DIR), subfolder="text_encoder", torch_dtype=dtype, device_map="cuda"
200
+ )
201
+ vae = AutoencoderKLWan.from_pretrained(
202
+ str(BASE_DIR), subfolder="vae", torch_dtype=dtype
203
+ ).to("cuda")
204
+ scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(str(AVATAR_DIR), subfolder="scheduler")
205
+
206
+ dit = load_quantized_dit(str(AVATAR_DIR), subfolder="base_model_int8", cp_split_hw=cp_split_hw).to("cuda")
207
+ lora_path = AVATAR_DIR / "lora" / "dmd_lora.safetensors"
208
+ if lora_path.exists():
209
+ dit.load_lora(str(lora_path), "dmd", multiplier=1.0, lora_network_dim=128, lora_network_alpha=64)
210
+ dit.enable_loras(["dmd"])
211
+ print("[load] DMD2 8-step LoRA enabled", flush=True)
212
+
213
+ audio_encoder = get_audio_encoder(str(AVATAR_DIR / "whisper-large-v3"), "avatar-v1.5").to("cuda", dtype=dtype)
214
+ audio_feature_extractor = get_audio_feature_extractor(str(AVATAR_DIR / "whisper-large-v3"), "avatar-v1.5")
215
+
216
+ _PIPE = LongCatVideoAvatarPipeline(
217
+ tokenizer=tokenizer,
218
+ text_encoder=text_encoder,
219
+ vae=vae,
220
+ scheduler=scheduler,
221
+ dit=dit,
222
+ audio_encoder=audio_encoder,
223
+ audio_feature_extractor=audio_feature_extractor,
224
+ model_type="avatar-v1.5",
225
+ )
226
+ _PIPE.device = "cuda"
227
+ print(f"[load] pipeline ready in {time.time() - t0:.1f}s", flush=True)
228
+
229
+
230
+ def _build_vocal_separator():
231
+ """Vocal separator is ONNX (CPU); build lazily as well so /data hits late."""
232
+ global _VOCAL
233
+ from audio_separator.separator import Separator
234
+
235
+ vocal_tmp = Path("/tmp/vocal_out")
236
+ (vocal_tmp / "vocals").mkdir(parents=True, exist_ok=True)
237
+ _VOCAL = Separator(
238
+ output_dir=str(vocal_tmp / "vocals"),
239
+ output_single_stem="vocals",
240
+ model_file_dir=str(AVATAR_DIR / "vocal_separator"),
241
+ )
242
+ _VOCAL.load_model("Kim_Vocal_2.onnx")
243
+ print("[load] vocal separator ready", flush=True)
244
 
245
 
246
  # ---------------------------------------------------------------------------
247
+ # 4) Inference helper
248
  # ---------------------------------------------------------------------------
249
 
250
  NEGATIVE_PROMPT = (
 
258
 
259
 
260
  def _extract_vocal(src: str) -> str:
261
+ if _VOCAL is None:
262
+ return src
263
  try:
264
+ outputs = _VOCAL.separate(src)
265
  if outputs:
266
+ return str((Path("/tmp/vocal_out") / "vocals" / outputs[0]).resolve())
267
  except Exception as e:
268
+ print(f"[vocal] separation failed, using raw audio: {e}", flush=True)
269
  return src
270
 
271
 
 
273
  # 5) GPU-bound inference function
274
  # ---------------------------------------------------------------------------
275
 
276
+ @spaces.GPU(duration=420)
277
  def generate(
278
  image_path: str,
279
  audio_path: str,
280
  prompt: str,
281
  resolution: str,
 
282
  seed: int,
283
  progress=gr.Progress(track_tqdm=True),
284
  ):
 
286
  raise gr.Error("Please upload a reference image.")
287
  if not audio_path:
288
  raise gr.Error("Please upload an audio clip.")
289
+ prompt = (prompt or "A person is talking naturally.").strip()
 
290
 
291
+ progress(0.02, desc="Warming up models (one-time on cold start)…")
292
+ if _PIPE is None:
293
+ _build_pipeline()
294
+ if _VOCAL is None:
295
+ _build_vocal_separator()
296
+
297
+ from longcat_video.audio_process.torch_utils import save_video_ffmpeg
298
+ import librosa
299
 
 
300
  save_fps = 25
301
  audio_stride = 1
302
+ num_frames = 93
 
 
303
 
304
  # 1) Vocal isolation
305
+ progress(0.10, desc="Isolating vocals…")
306
  vocal_path = _extract_vocal(audio_path)
307
 
308
+ # 2) Pad audio to target duration
309
  speech, sr = librosa.load(vocal_path, sr=16000)
310
+ pad = math.ceil((num_frames / save_fps - len(speech) / sr) * sr)
 
311
  if pad > 0:
312
  speech = np.concatenate([speech, np.zeros(pad, dtype=speech.dtype)])
313
 
314
  # 3) Whisper audio embedding
315
+ progress(0.20, desc="Encoding audio (Whisper-Large-v3)…")
316
+ full_audio_emb = _PIPE.get_audio_embedding(
317
  speech, fps=save_fps * audio_stride, device="cuda", sample_rate=sr, model_type="avatar-v1.5"
318
  )
319
  if torch.isnan(full_audio_emb).any():
320
+ raise gr.Error("Audio embedding contains NaN — try a different audio clip.")
321
 
322
+ # 4) Build windowed audio tensor: [1, T, 5, 5, D]
323
+ indices = torch.arange(2 * 2 + 1) - 2
324
  center = torch.arange(0, audio_stride * num_frames, audio_stride).unsqueeze(1) + indices.unsqueeze(0)
325
  center = torch.clamp(center, min=0, max=full_audio_emb.shape[0] - 1)
326
  audio_emb = full_audio_emb[center][None, ...].to("cuda")
327
 
328
+ # 5) Generate (8-step distilled, both CFG=1.0 1 forward per step)
329
  progress(0.30, desc="Generating video (DMD2 8-step)…")
330
  image = Image.open(image_path).convert("RGB")
331
  generator = torch.Generator(device="cuda").manual_seed(int(seed))
332
 
333
+ output, _ = _PIPE.generate_ai2v(
334
  image=image,
335
+ prompt=prompt,
336
  negative_prompt=NEGATIVE_PROMPT,
337
  resolution=resolution,
338
  num_frames=num_frames,
339
  num_inference_steps=8,
340
  text_guidance_scale=1.0,
341
+ audio_guidance_scale=1.0,
342
  output_type="both",
343
  generator=generator,
344
  audio_emb=audio_emb,
345
  use_distill=True,
346
  )
347
 
348
+ # 6) Mux + save
349
  progress(0.92, desc="Muxing audio + video…")
350
  frames = (output[0] * 255).astype(np.uint8)
351
  out_tensor = torch.from_numpy(frames)
352
  out_base = Path(tempfile.gettempdir()) / f"longcat_{uuid.uuid4().hex[:8]}"
353
  save_video_ffmpeg(out_tensor, str(out_base), audio_path, fps=save_fps, quality=5)
354
  out_path = f"{out_base}.mp4"
355
+ print(f"[gen] wrote {out_path}", flush=True)
356
  return out_path
357
 
358
 
 
370
  "their mouth. Wearing a vibrant red jacket with gold embroidery, the singer is speaking "
371
  "while smoke swirls around them, creating a dynamic and atmospheric scene.",
372
  "480p",
 
373
  42,
374
  ])
375
 
 
381
  Upload a **reference image** + **audio clip** + a short **text prompt**.
382
  The model generates a ~3.7-second lip-synced video using Meituan's
383
  LongCat-Video-Avatar 1.5 (INT8 DiT + DMD2 8-step distilled).
384
+
385
+ *First call is slow: ~60s to warm up models on GPU.*
386
  """
387
  )
388
  with gr.Row():
 
397
  with gr.Row():
398
  resolution = gr.Radio(["480p", "720p"], value="480p", label="Resolution")
399
  seed = gr.Number(value=42, precision=0, label="Seed")
 
 
 
 
 
 
 
400
  go = gr.Button("Generate", variant="primary")
401
  with gr.Column():
402
  video_out = gr.Video(label="Output", autoplay=True)
 
404
  if EXAMPLES:
405
  gr.Examples(
406
  examples=EXAMPLES,
407
+ inputs=[image_in, audio_in, prompt, resolution, seed],
408
  outputs=video_out,
409
  fn=generate,
410
  cache_examples=False,
 
412
 
413
  go.click(
414
  generate,
415
+ inputs=[image_in, audio_in, prompt, resolution, seed],
416
  outputs=video_out,
417
  )
418