linoyts HF Staff commited on
Commit
b68e12e
Β·
verified Β·
1 Parent(s): e5ef795

Revert to shipped scene-emb .pt; pre-warm pipeline blocks for ZeroGPU

Browse files

Two changes:

1. Drop Gemma from inference. Re-download the shipped `comfyui_models_loras_ltxv_ltx2_ltx-2.3-22b-ic-lora-hdr-scene-emb.pt` from diffusers-internal-dev/LTX-HDR-LoRA and hand it directly to HDRICLoraPipeline. Removes the prompt textbox and the per-call ~20s Gemma 12B load+encode. Matches the HDR IC-LoRA's training-time scene embedding.

2. Pre-warm the pipeline at module load. Build the fp8-cast LoRA-fused transformer once (shared between stage_1 and stage_2), plus ImageConditioner encoder, VideoUpsampler (encoder + upsampler), and VideoDecoder. Replace the pipeline's blocks with cached wrappers that reuse the built models without the gpu_model() meta-device free on exit. Avoids re-reading the 22B checkpoint + LoRA fusion + fp8 cast on every @spaces.GPU call.

Tradeoffs: startup takes longer (one build of every component), but each subsequent generation skips the ~30-60s rebuild.

Files changed (1) hide show
  1. app.py +95 -72
app.py CHANGED
@@ -42,13 +42,8 @@ if _tv.returncode == 0:
42
  )
43
 
44
  # ─────────────────────────────────────────────────────────────────────────────
45
- # ltx-core / ltx-pipelines source
46
- #
47
- # The HDRICLoraPipeline and its supporting modules (ltx_core.hdr,
48
- # ltx_pipelines.utils.blocks, load_video_conditioning_hdr, apply_hdr_decode_postprocess,
49
- # save_exr_tensor, encode_exr_sequence_to_mp4) are NOT on the public main
50
- # branch at the pinned commit used by the outpaint app. We install from the
51
- # local ltx-2-internal checkout so the HDR code path actually exists.
52
  # ─────────────────────────────────────────────────────────────────────────────
53
  LTX_INTERNAL = Path(os.environ.get(
54
  "LTX_INTERNAL_PATH",
@@ -72,6 +67,8 @@ import logging
72
  import random
73
  import tempfile
74
  import zipfile
 
 
75
 
76
  import torch
77
  torch._dynamo.config.suppress_errors = True
@@ -80,12 +77,12 @@ torch._dynamo.config.disable = True
80
  import spaces
81
  import gradio as gr
82
  import numpy as np
83
- from huggingface_hub import hf_hub_download, snapshot_download
84
 
 
85
  from ltx_core.model.video_vae import TilingConfig
86
  from ltx_core.quantization import QuantizationPolicy
87
  from ltx_pipelines.hdr_ic_lora import HDRICLoraPipeline, _make_tiling_config
88
- from ltx_pipelines.utils.blocks import PromptEncoder
89
  from ltx_pipelines.utils.media_io import (
90
  encode_exr_sequence_to_mp4,
91
  get_videostream_metadata,
@@ -93,7 +90,7 @@ from ltx_pipelines.utils.media_io import (
93
  )
94
  from ltx_pipelines.utils.types import OffloadMode
95
 
96
- # xformers attention patch (same as the outpaint app).
97
  from ltx_core.model.transformer import attention as _attn_mod
98
  print(f"[ATTN] Before patch: memory_efficient_attention={_attn_mod.memory_efficient_attention}")
99
  try:
@@ -111,7 +108,7 @@ logging.getLogger().setLevel(logging.INFO)
111
  # ─────────────────────────────────────────────────────────────────────────────
112
  MAX_SEED = np.iinfo(np.int32).max
113
 
114
- # Frames must satisfy (n-1) % 8 == 0. Aspect-ratio canvas sizes (divisible by 32).
115
  RESOLUTIONS = {
116
  "low": {"16:9": (768, 512), "9:16": (512, 768), "1:1": (768, 768),
117
  "4:3": (768, 576), "3:4": (576, 768), "21:9": (768, 384)},
@@ -122,80 +119,120 @@ RESOLUTIONS = {
122
  LTX_MODEL_REPO = "Lightricks/LTX-2.3"
123
  DISTILLED_CHECKPOINT = "ltx-2.3-22b-distilled-1.1.safetensors"
124
  SPATIAL_UPSCALER = "ltx-2.3-spatial-upscaler-x2-1.1.safetensors"
125
- GEMMA_REPO = "google/gemma-3-12b-it-qat-q4_0-unquantized"
126
 
127
  HDR_LORA_REPO = "diffusers-internal-dev/LTX-HDR-LoRA"
128
  HDR_LORA_FILENAME = "comfyui_models_loras_ltxv_ltx2_ltx-2.3-22b-ic-lora-hdr-0.9 (4).safetensors"
 
129
 
130
  print("=" * 80)
131
- print("Downloading LTX-2.3 distilled + spatial upsampler + Gemma + HDR IC-LoRA...")
132
  print("=" * 80)
133
 
134
  checkpoint_path = hf_hub_download(repo_id=LTX_MODEL_REPO, filename=DISTILLED_CHECKPOINT)
135
  spatial_upsampler_path = hf_hub_download(repo_id=LTX_MODEL_REPO, filename=SPATIAL_UPSCALER)
136
  hdr_lora_path = hf_hub_download(repo_id=HDR_LORA_REPO, filename=HDR_LORA_FILENAME)
137
- gemma_root = snapshot_download(repo_id=GEMMA_REPO)
138
 
139
  print(f"Checkpoint: {checkpoint_path}")
140
  print(f"Spatial upsampler: {spatial_upsampler_path}")
141
  print(f"HDR IC-LoRA: {hdr_lora_path}")
142
- print(f"Gemma root: {gemma_root}")
143
 
144
 
145
  # ─────────────────────────────────────────────────────────────────────────────
146
- # Text encoding: on-the-fly Gemma -> (video_context, audio_context) for each
147
- # prompt. HDRICLoraPipeline expects a `.pt` path at __init__, so we bootstrap
148
- # one from an empty prompt, then overwrite `pipeline.text_embeddings` in
149
- # memory each generate call.
150
  # ─────────────────────────────────────────────────────────────────────────────
151
  _DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
152
  _DTYPE = torch.bfloat16
153
 
154
- prompt_encoder = PromptEncoder(
155
- checkpoint_path=checkpoint_path,
156
- gemma_root=gemma_root,
157
- dtype=_DTYPE,
158
- device=_DEVICE,
159
- )
160
-
161
-
162
- def encode_prompt_to_contexts(prompt: str) -> tuple[torch.Tensor, torch.Tensor]:
163
- """Run Gemma + embeddings processor to produce (video_context, audio_context).
164
-
165
- HDRICLoraPipeline only consumes video_context; audio_context is stored for
166
- shape-compat with the `.pt` interface but ignored during HDR generation.
167
- MUST be called from inside a @spaces.GPU context on ZeroGPU.
168
- """
169
- (out,) = prompt_encoder([prompt])
170
- v = out.video_encoding
171
- a = out.audio_encoding if out.audio_encoding is not None else torch.zeros(0, device=v.device, dtype=v.dtype)
172
- return v, a
173
-
174
-
175
- # HDRICLoraPipeline.__init__ requires a .pt it can torch.load, but it only
176
- # stores the tensors β€” __call__ reads `self.text_embeddings` which we overwrite
177
- # on every generate run. So write a placeholder .pt at module-load (CPU, no
178
- # Gemma run β€” Gemma can only touch GPU inside a @spaces.GPU function on ZeroGPU).
179
- _bootstrap_emb_path = Path(tempfile.gettempdir()) / "ltx_hdr_bootstrap_emb.pt"
180
- _placeholder = torch.zeros(1, 1, 4096, dtype=_DTYPE)
181
- torch.save({"video_context": _placeholder, "audio_context": _placeholder}, _bootstrap_emb_path)
182
-
183
-
184
- # ─────────────────────────────────────────────────────────────────────────────
185
- # Initialize pipeline
186
- # ─────────────────────────────────────────────────────────────────────────────
187
- # HDRICLoraPipeline is video-only (no audio path). HDR transform (LogC3) and
188
- # reference_downscale_factor are auto-detected from the LoRA metadata.
189
  pipeline = HDRICLoraPipeline(
190
  distilled_checkpoint_path=checkpoint_path,
191
  spatial_upsampler_path=spatial_upsampler_path,
192
  hdr_lora=hdr_lora_path,
193
- text_embeddings_path=str(_bootstrap_emb_path),
194
  quantization=QuantizationPolicy.fp8_cast(),
195
  offload_mode=OffloadMode.NONE,
196
  )
197
  print(f"HDRICLoraPipeline ready. HDR transform: {pipeline.hdr_transform}, "
198
  f"ref_downscale={pipeline.reference_downscale_factor}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
199
  print("=" * 80)
200
 
201
 
@@ -241,7 +278,6 @@ def on_video_upload(video):
241
  @torch.inference_mode()
242
  def generate_video(
243
  input_video,
244
- prompt: str,
245
  duration: float,
246
  frame_rate: float,
247
  target_aspect: str,
@@ -270,13 +306,7 @@ def generate_video(
270
  print(f"[HDR] {target_h}x{target_w}, frames={num_frames}, fps={frame_rate}, "
271
  f"seed={current_seed}, aspect={target_aspect}, hq_hdr={high_quality_hdr}")
272
 
273
- # Encode prompt -> (video_context, audio_context) and swap into the
274
- # pipeline. Gemma is loaded, used, and freed inside prompt_encoder.
275
- print(f"[HDR] Encoding prompt: {prompt!r}")
276
- video_context, audio_context = encode_prompt_to_contexts(prompt or "")
277
- pipeline.text_embeddings = (video_context, audio_context)
278
-
279
- # Tiling config: smaller spatial tile on lower-VRAM targets
280
  tiling_config = _make_tiling_config(spatial_tile=768 if not high_res else 1280)
281
 
282
  hdr_video = pipeline(
@@ -341,20 +371,13 @@ with gr.Blocks(title="LTX 2.3 HDR", css=css, theme=theme) as demo:
341
  gr.Markdown("""
342
  # LTX 2.3 HDR ✨
343
  Video-to-video HDR via LTX-2.3 + [HDR IC-LoRA](https://huggingface.co/diffusers-internal-dev/LTX-HDR-LoRA).
344
- Output is linear HDR (LogC3 inverse decoded β€” auto-detected from LoRA metadata). The preview mp4 is a fixed-EV sRGB tonemap; the EXR zip contains the full linear float frames for grading.
345
  """)
346
 
347
  with gr.Row():
348
  with gr.Column(scale=1):
349
  input_video = gr.Video(label="Source Video")
350
 
351
- prompt = gr.Textbox(
352
- label="Prompt",
353
- info="Describes the scene being regenerated in HDR. Encoded through Gemma on each run.",
354
- lines=2,
355
- placeholder="a cinematic sunset over mountains, high dynamic range, bright sky, deep shadows",
356
- )
357
-
358
  with gr.Row():
359
  target_aspect = gr.Dropdown(
360
  label="Aspect Ratio",
@@ -398,7 +421,7 @@ Output is linear HDR (LogC3 inverse decoded β€” auto-detected from LoRA metadata
398
  generate_btn.click(
399
  fn=generate_video,
400
  inputs=[
401
- input_video, prompt, duration, frame_rate, target_aspect, high_res,
402
  seed, randomize_seed, high_quality_hdr, export_exr,
403
  ],
404
  outputs=[output_video, output_exr, seed],
 
42
  )
43
 
44
  # ─────────────────────────────────────────────────────────────────────────────
45
+ # ltx-core / ltx-pipelines source (bundled β€” the HDR code path is not on
46
+ # public Lightricks/LTX-2 main).
 
 
 
 
 
47
  # ─────────────────────────────────────────────────────────────────────────────
48
  LTX_INTERNAL = Path(os.environ.get(
49
  "LTX_INTERNAL_PATH",
 
67
  import random
68
  import tempfile
69
  import zipfile
70
+ from collections.abc import Iterator
71
+ from contextlib import contextmanager
72
 
73
  import torch
74
  torch._dynamo.config.suppress_errors = True
 
77
  import spaces
78
  import gradio as gr
79
  import numpy as np
80
+ from huggingface_hub import hf_hub_download
81
 
82
+ from ltx_core.model.upsampler import upsample_video
83
  from ltx_core.model.video_vae import TilingConfig
84
  from ltx_core.quantization import QuantizationPolicy
85
  from ltx_pipelines.hdr_ic_lora import HDRICLoraPipeline, _make_tiling_config
 
86
  from ltx_pipelines.utils.media_io import (
87
  encode_exr_sequence_to_mp4,
88
  get_videostream_metadata,
 
90
  )
91
  from ltx_pipelines.utils.types import OffloadMode
92
 
93
+ # xformers attention patch
94
  from ltx_core.model.transformer import attention as _attn_mod
95
  print(f"[ATTN] Before patch: memory_efficient_attention={_attn_mod.memory_efficient_attention}")
96
  try:
 
108
  # ─────────────────────────────────────────────────────────────────────────────
109
  MAX_SEED = np.iinfo(np.int32).max
110
 
111
+ # Canvas sizes divisible by 32; kept conservative for A10G 24 GB.
112
  RESOLUTIONS = {
113
  "low": {"16:9": (768, 512), "9:16": (512, 768), "1:1": (768, 768),
114
  "4:3": (768, 576), "3:4": (576, 768), "21:9": (768, 384)},
 
119
  LTX_MODEL_REPO = "Lightricks/LTX-2.3"
120
  DISTILLED_CHECKPOINT = "ltx-2.3-22b-distilled-1.1.safetensors"
121
  SPATIAL_UPSCALER = "ltx-2.3-spatial-upscaler-x2-1.1.safetensors"
 
122
 
123
  HDR_LORA_REPO = "diffusers-internal-dev/LTX-HDR-LoRA"
124
  HDR_LORA_FILENAME = "comfyui_models_loras_ltxv_ltx2_ltx-2.3-22b-ic-lora-hdr-0.9 (4).safetensors"
125
+ HDR_SCENE_EMB_FILENAME = "comfyui_models_loras_ltxv_ltx2_ltx-2.3-22b-ic-lora-hdr-scene-emb.pt"
126
 
127
  print("=" * 80)
128
+ print("Downloading LTX-2.3 distilled + spatial upsampler + HDR IC-LoRA + scene emb...")
129
  print("=" * 80)
130
 
131
  checkpoint_path = hf_hub_download(repo_id=LTX_MODEL_REPO, filename=DISTILLED_CHECKPOINT)
132
  spatial_upsampler_path = hf_hub_download(repo_id=LTX_MODEL_REPO, filename=SPATIAL_UPSCALER)
133
  hdr_lora_path = hf_hub_download(repo_id=HDR_LORA_REPO, filename=HDR_LORA_FILENAME)
134
+ hdr_scene_emb_path = hf_hub_download(repo_id=HDR_LORA_REPO, filename=HDR_SCENE_EMB_FILENAME)
135
 
136
  print(f"Checkpoint: {checkpoint_path}")
137
  print(f"Spatial upsampler: {spatial_upsampler_path}")
138
  print(f"HDR IC-LoRA: {hdr_lora_path}")
139
+ print(f"HDR scene emb: {hdr_scene_emb_path}")
140
 
141
 
142
  # ─────────────────────────────────────────────────────────────────────────────
143
+ # Initialize pipeline β€” text conditioning comes from the shipped scene-emb
144
+ # .pt (no Gemma at inference time).
 
 
145
  # ─────────────────────────────────────────────────────────────────────────────
146
  _DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
147
  _DTYPE = torch.bfloat16
148
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  pipeline = HDRICLoraPipeline(
150
  distilled_checkpoint_path=checkpoint_path,
151
  spatial_upsampler_path=spatial_upsampler_path,
152
  hdr_lora=hdr_lora_path,
153
+ text_embeddings_path=hdr_scene_emb_path,
154
  quantization=QuantizationPolicy.fp8_cast(),
155
  offload_mode=OffloadMode.NONE,
156
  )
157
  print(f"HDRICLoraPipeline ready. HDR transform: {pipeline.hdr_transform}, "
158
  f"ref_downscale={pipeline.reference_downscale_factor}")
159
+
160
+
161
+ # ─────────────────────────────────────────────────────────────────────────────
162
+ # Pre-warm for ZeroGPU: build each component once at module load (ZeroGPU
163
+ # tensor-packing captures the weights), then replace the pipeline's blocks
164
+ # with tiny wrappers that reuse the cached models and skip gpu_model's
165
+ # meta-device freeing. Avoids re-reading the 22B checkpoint + re-fusing
166
+ # the LoRA + re-fp8-casting on every @spaces.GPU invocation.
167
+ # ─────────────────────────────────────────────────────────────────────────────
168
+ print("Pre-warming models (one-shot build)...")
169
+
170
+ _cached_image_encoder = pipeline.image_conditioner._build_encoder()
171
+ _cached_transformer = pipeline.stage_1._build_transformer()
172
+ _cached_upsampler_encoder = (
173
+ pipeline.upsampler._encoder_builder
174
+ .build(device=_DEVICE, dtype=_DTYPE).to(_DEVICE).eval()
175
+ )
176
+ _cached_upsampler = (
177
+ pipeline.upsampler._upsampler_builder
178
+ .build(device=_DEVICE, dtype=_DTYPE).to(_DEVICE).eval()
179
+ )
180
+ _cached_video_decoder = (
181
+ pipeline.video_decoder._decoder_builder
182
+ .build(device=_DEVICE, dtype=_DTYPE).to(_DEVICE).eval()
183
+ )
184
+
185
+
186
+ @contextmanager
187
+ def _yield_cached(model):
188
+ """Drop-in for gpu_model that does NOT move params to meta on exit."""
189
+ yield model
190
+
191
+
192
+ # Patch the transformer context manager on both stages to yield the cached
193
+ # transformer without freeing. stage_1 uses _transformer_ctx inside __call__;
194
+ # stage_2 uses model_context() -> _transformer_ctx.
195
+ def _cached_stage_ctx(**_kwargs):
196
+ return _yield_cached(_cached_transformer)
197
+
198
+
199
+ pipeline.stage_1._transformer_ctx = _cached_stage_ctx
200
+ pipeline.stage_2._transformer_ctx = _cached_stage_ctx
201
+
202
+
203
+ class _CachedImageConditioner:
204
+ def __call__(self, fn):
205
+ return fn(_cached_image_encoder)
206
+
207
+
208
+ class _CachedVideoUpsampler:
209
+ def __call__(self, latent):
210
+ return upsample_video(
211
+ latent=latent,
212
+ video_encoder=_cached_upsampler_encoder,
213
+ upsampler=_cached_upsampler,
214
+ )
215
+
216
+
217
+ class _CachedVideoDecoder:
218
+ def __call__(
219
+ self,
220
+ latent: torch.Tensor,
221
+ tiling_config=None,
222
+ generator=None,
223
+ *,
224
+ output_dtype: torch.dtype = torch.uint8,
225
+ ) -> Iterator[torch.Tensor]:
226
+ return _cached_video_decoder.decode_video(
227
+ latent, tiling_config, generator, output_dtype=output_dtype,
228
+ )
229
+
230
+
231
+ pipeline.image_conditioner = _CachedImageConditioner()
232
+ pipeline.upsampler = _CachedVideoUpsampler()
233
+ pipeline.video_decoder = _CachedVideoDecoder()
234
+
235
+ print("Pre-warm complete.")
236
  print("=" * 80)
237
 
238
 
 
278
  @torch.inference_mode()
279
  def generate_video(
280
  input_video,
 
281
  duration: float,
282
  frame_rate: float,
283
  target_aspect: str,
 
306
  print(f"[HDR] {target_h}x{target_w}, frames={num_frames}, fps={frame_rate}, "
307
  f"seed={current_seed}, aspect={target_aspect}, hq_hdr={high_quality_hdr}")
308
 
309
+ # Smaller spatial tile on non-high-res to keep VAE decode within A10G budget.
 
 
 
 
 
 
310
  tiling_config = _make_tiling_config(spatial_tile=768 if not high_res else 1280)
311
 
312
  hdr_video = pipeline(
 
371
  gr.Markdown("""
372
  # LTX 2.3 HDR ✨
373
  Video-to-video HDR via LTX-2.3 + [HDR IC-LoRA](https://huggingface.co/diffusers-internal-dev/LTX-HDR-LoRA).
374
+ Text conditioning uses the shipped pre-computed scene embedding (`scene-emb.pt`) β€” no prompt input, no per-call Gemma cost. Output is linear HDR (LogC3 inverse decoded, auto-detected from LoRA metadata). The preview mp4 is a fixed-EV sRGB tonemap; the EXR zip contains the full linear float frames for grading.
375
  """)
376
 
377
  with gr.Row():
378
  with gr.Column(scale=1):
379
  input_video = gr.Video(label="Source Video")
380
 
 
 
 
 
 
 
 
381
  with gr.Row():
382
  target_aspect = gr.Dropdown(
383
  label="Aspect Ratio",
 
421
  generate_btn.click(
422
  fn=generate_video,
423
  inputs=[
424
+ input_video, duration, frame_rate, target_aspect, high_res,
425
  seed, randomize_seed, high_quality_hdr, export_exr,
426
  ],
427
  outputs=[output_video, output_exr, seed],