Spaces:
Paused
Paused
alex commited on
Commit ·
fc134c2
1
Parent(s): f8c25ce
allow loRA
Browse files
app.py
CHANGED
|
@@ -30,6 +30,8 @@ from ltx_pipelines.utils.constants import (
|
|
| 30 |
DEFAULT_FRAME_RATE,
|
| 31 |
DEFAULT_LORA_STRENGTH,
|
| 32 |
)
|
|
|
|
|
|
|
| 33 |
|
| 34 |
|
| 35 |
MAX_SEED = np.iinfo(np.int32).max
|
|
@@ -182,24 +184,48 @@ print("Loading LTX-2 Distilled pipeline...")
|
|
| 182 |
print("=" * 80)
|
| 183 |
|
| 184 |
checkpoint_path = get_hub_or_local_checkpoint(DEFAULT_REPO_ID, DEFAULT_CHECKPOINT_FILENAME)
|
| 185 |
-
distilled_lora_path = get_hub_or_local_checkpoint(DEFAULT_REPO_ID, DEFAULT_DISTILLED_LORA_FILENAME)
|
| 186 |
spatial_upsampler_path = get_hub_or_local_checkpoint(DEFAULT_REPO_ID, DEFAULT_SPATIAL_UPSAMPLER_FILENAME)
|
| 187 |
|
| 188 |
print(f"Initializing pipeline with:")
|
| 189 |
print(f" checkpoint_path={checkpoint_path}")
|
| 190 |
-
print(f" distilled_lora_path={distilled_lora_path}")
|
| 191 |
print(f" spatial_upsampler_path={spatial_upsampler_path}")
|
| 192 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 193 |
|
| 194 |
# Load distilled LoRA as a regular LoRA
|
| 195 |
loras = [
|
|
|
|
| 196 |
LoraPathStrengthAndSDOps(
|
| 197 |
path=distilled_lora_path,
|
| 198 |
strength=DEFAULT_LORA_STRENGTH,
|
| 199 |
sd_ops=LTXV_LORA_COMFY_RENAMING_MAP,
|
| 200 |
-
)
|
|
|
|
|
|
|
|
|
|
| 201 |
]
|
| 202 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 203 |
# Initialize pipeline WITHOUT text encoder (gemma_root=None)
|
| 204 |
# Text encoding will be done by external space
|
| 205 |
pipeline = DistilledPipeline(
|
|
@@ -222,23 +248,6 @@ print("=" * 80)
|
|
| 222 |
print("Pipeline fully loaded and ready!")
|
| 223 |
print("=" * 80)
|
| 224 |
|
| 225 |
-
def get_duration(
|
| 226 |
-
input_image,
|
| 227 |
-
prompt,
|
| 228 |
-
duration,
|
| 229 |
-
enhance_prompt,
|
| 230 |
-
seed,
|
| 231 |
-
randomize_seed,
|
| 232 |
-
height,
|
| 233 |
-
width,
|
| 234 |
-
progress
|
| 235 |
-
):
|
| 236 |
-
if duration <= 5:
|
| 237 |
-
return 80
|
| 238 |
-
elif duration <= 10:
|
| 239 |
-
return 120
|
| 240 |
-
else:
|
| 241 |
-
return 180
|
| 242 |
|
| 243 |
class RadioAnimated(gr.HTML):
|
| 244 |
"""
|
|
@@ -274,41 +283,254 @@ class RadioAnimated(gr.HTML):
|
|
| 274 |
|
| 275 |
js_on_load = r"""
|
| 276 |
(() => {
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
| 305 |
-
|
| 306 |
-
|
| 307 |
-
|
| 308 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 309 |
});
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 310 |
});
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 311 |
})();
|
|
|
|
| 312 |
"""
|
| 313 |
|
| 314 |
super().__init__(
|
|
@@ -318,10 +540,42 @@ class RadioAnimated(gr.HTML):
|
|
| 318 |
**kwargs
|
| 319 |
)
|
| 320 |
|
|
|
|
| 321 |
def generate_video_example(input_image, prompt, duration, progress=gr.Progress(track_tqdm=True)):
|
| 322 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 323 |
|
| 324 |
return output_video
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 325 |
|
| 326 |
@spaces.GPU(duration=get_duration)
|
| 327 |
def generate_video(
|
|
@@ -333,6 +587,7 @@ def generate_video(
|
|
| 333 |
randomize_seed: bool = True,
|
| 334 |
height: int = DEFAULT_1_STAGE_HEIGHT,
|
| 335 |
width: int = DEFAULT_1_STAGE_WIDTH,
|
|
|
|
| 336 |
progress=gr.Progress(track_tqdm=True),
|
| 337 |
):
|
| 338 |
"""
|
|
@@ -346,8 +601,10 @@ def generate_video(
|
|
| 346 |
randomize_seed: If True, a random seed is generated for each run.
|
| 347 |
height: Output video height in pixels.
|
| 348 |
width: Output video width in pixels.
|
|
|
|
| 349 |
progress: Gradio progress tracker.
|
| 350 |
Returns:
|
|
|
|
| 351 |
A tuple of:
|
| 352 |
- output_path: Path to the generated MP4 video file.
|
| 353 |
- seed: The seed used for generation.
|
|
@@ -396,6 +653,20 @@ def generate_video(
|
|
| 396 |
del embeddings, final_prompt, status
|
| 397 |
torch.cuda.empty_cache()
|
| 398 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 399 |
# Run inference - progress automatically tracks tqdm from pipeline
|
| 400 |
pipeline(
|
| 401 |
prompt=prompt,
|
|
@@ -431,7 +702,42 @@ def apply_duration(duration: str):
|
|
| 431 |
duration_s = int(duration[:-1])
|
| 432 |
return duration_s
|
| 433 |
|
|
|
|
| 434 |
css = """
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 435 |
#col-container {
|
| 436 |
margin: 0 auto;
|
| 437 |
max-width: 1600px;
|
|
@@ -570,6 +876,176 @@ css += """
|
|
| 570 |
}
|
| 571 |
"""
|
| 572 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 573 |
|
| 574 |
with gr.Blocks(title="LTX-2 Video Distilled 🎥🔈") as demo:
|
| 575 |
gr.HTML(
|
|
@@ -605,12 +1081,19 @@ with gr.Blocks(title="LTX-2 Video Distilled 🎥🔈") as demo:
|
|
| 605 |
height=512
|
| 606 |
)
|
| 607 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 608 |
prompt = gr.Textbox(
|
| 609 |
label="Prompt",
|
| 610 |
value="Make this image come alive with cinematic motion, smooth animation",
|
| 611 |
lines=3,
|
| 612 |
max_lines=3,
|
| 613 |
-
placeholder="Describe the motion and animation you want..."
|
|
|
|
| 614 |
)
|
| 615 |
|
| 616 |
enhance_prompt = gr.Checkbox(
|
|
@@ -633,10 +1116,9 @@ with gr.Blocks(title="LTX-2 Video Distilled 🎥🔈") as demo:
|
|
| 633 |
|
| 634 |
with gr.Column(elem_id="step-column"):
|
| 635 |
output_video = gr.Video(label="Generated Video", autoplay=True, height=512)
|
| 636 |
-
|
| 637 |
-
with gr.Row():
|
| 638 |
-
|
| 639 |
-
with gr.Column():
|
| 640 |
radioanimated_duration = RadioAnimated(
|
| 641 |
choices=["3s", "5s", "10s", "15s"],
|
| 642 |
value="3s",
|
|
@@ -651,8 +1133,7 @@ with gr.Blocks(title="LTX-2 Video Distilled 🎥🔈") as demo:
|
|
| 651 |
step=0.1,
|
| 652 |
visible=False
|
| 653 |
)
|
| 654 |
-
|
| 655 |
-
with gr.Column():
|
| 656 |
radioanimated_resolution = RadioAnimated(
|
| 657 |
choices=["768x512", "512x512", "512x768"],
|
| 658 |
value=f"{DEFAULT_1_STAGE_WIDTH}x{DEFAULT_1_STAGE_HEIGHT}",
|
|
@@ -661,10 +1142,30 @@ with gr.Blocks(title="LTX-2 Video Distilled 🎥🔈") as demo:
|
|
| 661 |
|
| 662 |
width = gr.Number(label="Width", value=DEFAULT_1_STAGE_WIDTH, precision=0, visible=False)
|
| 663 |
height = gr.Number(label="Height", value=DEFAULT_1_STAGE_HEIGHT, precision=0, visible=False)
|
| 664 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 665 |
|
| 666 |
generate_btn = gr.Button("🤩 Generate Video", variant="primary", elem_classes="button-gradient")
|
| 667 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 668 |
|
| 669 |
radioanimated_duration.change(
|
| 670 |
fn=apply_duration,
|
|
@@ -678,6 +1179,13 @@ with gr.Blocks(title="LTX-2 Video Distilled 🎥🔈") as demo:
|
|
| 678 |
outputs=[width, height],
|
| 679 |
api_visibility="private"
|
| 680 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 681 |
|
| 682 |
generate_btn.click(
|
| 683 |
fn=generate_video,
|
|
@@ -690,6 +1198,7 @@ with gr.Blocks(title="LTX-2 Video Distilled 🎥🔈") as demo:
|
|
| 690 |
randomize_seed,
|
| 691 |
height,
|
| 692 |
width,
|
|
|
|
| 693 |
],
|
| 694 |
outputs=[output_video,seed]
|
| 695 |
)
|
|
@@ -716,7 +1225,7 @@ with gr.Blocks(title="LTX-2 Video Distilled 🎥🔈") as demo:
|
|
| 716 |
|
| 717 |
],
|
| 718 |
fn=generate_video_example,
|
| 719 |
-
inputs=[input_image,
|
| 720 |
outputs = [output_video],
|
| 721 |
label="Example",
|
| 722 |
cache_examples=True,
|
|
|
|
| 30 |
DEFAULT_FRAME_RATE,
|
| 31 |
DEFAULT_LORA_STRENGTH,
|
| 32 |
)
|
| 33 |
+
from ltx_core.loader.single_gpu_model_builder import set_lora_enabled
|
| 34 |
+
|
| 35 |
|
| 36 |
|
| 37 |
MAX_SEED = np.iinfo(np.int32).max
|
|
|
|
| 184 |
print("=" * 80)
|
| 185 |
|
| 186 |
checkpoint_path = get_hub_or_local_checkpoint(DEFAULT_REPO_ID, DEFAULT_CHECKPOINT_FILENAME)
|
|
|
|
| 187 |
spatial_upsampler_path = get_hub_or_local_checkpoint(DEFAULT_REPO_ID, DEFAULT_SPATIAL_UPSAMPLER_FILENAME)
|
| 188 |
|
| 189 |
print(f"Initializing pipeline with:")
|
| 190 |
print(f" checkpoint_path={checkpoint_path}")
|
|
|
|
| 191 |
print(f" spatial_upsampler_path={spatial_upsampler_path}")
|
| 192 |
|
| 193 |
+
distilled_lora_path = get_hub_or_local_checkpoint(
|
| 194 |
+
DEFAULT_REPO_ID,
|
| 195 |
+
DEFAULT_DISTILLED_LORA_FILENAME,
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
dolly_in_lora_path = get_hub_or_local_checkpoint(
|
| 199 |
+
"Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-In",
|
| 200 |
+
"ltx-2-19b-lora-camera-control-dolly-in.safetensors",
|
| 201 |
+
)
|
| 202 |
+
dolly_out_lora_path = get_hub_or_local_checkpoint(
|
| 203 |
+
"Lightricks/LTX-2-19b-LoRA-Camera-Control-Dolly-Out",
|
| 204 |
+
"ltx-2-19b-lora-camera-control-dolly-out.safetensors",
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
|
| 208 |
# Load distilled LoRA as a regular LoRA
|
| 209 |
loras = [
|
| 210 |
+
# --- fused / base behavior ---
|
| 211 |
LoraPathStrengthAndSDOps(
|
| 212 |
path=distilled_lora_path,
|
| 213 |
strength=DEFAULT_LORA_STRENGTH,
|
| 214 |
sd_ops=LTXV_LORA_COMFY_RENAMING_MAP,
|
| 215 |
+
),
|
| 216 |
+
# # --- runtime-toggle camera controls ---#
|
| 217 |
+
LoraPathStrengthAndSDOps(dolly_in_lora_path, DEFAULT_LORA_STRENGTH, LTXV_LORA_COMFY_RENAMING_MAP),
|
| 218 |
+
LoraPathStrengthAndSDOps(dolly_out_lora_path, DEFAULT_LORA_STRENGTH, LTXV_LORA_COMFY_RENAMING_MAP),
|
| 219 |
]
|
| 220 |
|
| 221 |
+
# Runtime-toggle LoRAs (exclude fused distilled at index 0)
|
| 222 |
+
RUNTIME_LORA_CHOICES = [
|
| 223 |
+
("No LoRA", -1),
|
| 224 |
+
("Dolly In", 0),
|
| 225 |
+
("Dolly Out", 1),
|
| 226 |
+
]
|
| 227 |
+
|
| 228 |
+
|
| 229 |
# Initialize pipeline WITHOUT text encoder (gemma_root=None)
|
| 230 |
# Text encoding will be done by external space
|
| 231 |
pipeline = DistilledPipeline(
|
|
|
|
| 248 |
print("Pipeline fully loaded and ready!")
|
| 249 |
print("=" * 80)
|
| 250 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 251 |
|
| 252 |
class RadioAnimated(gr.HTML):
|
| 253 |
"""
|
|
|
|
| 283 |
|
| 284 |
js_on_load = r"""
|
| 285 |
(() => {
|
| 286 |
+
const wrap = element.querySelector('.ra-wrap');
|
| 287 |
+
const inner = element.querySelector('.ra-inner');
|
| 288 |
+
const highlight = element.querySelector('.ra-highlight');
|
| 289 |
+
const inputs = Array.from(element.querySelectorAll('.ra-input'));
|
| 290 |
+
const labels = Array.from(element.querySelectorAll('.ra-label'));
|
| 291 |
+
|
| 292 |
+
if (!inputs.length || !labels.length) return;
|
| 293 |
+
|
| 294 |
+
const choices = inputs.map(i => i.value);
|
| 295 |
+
const PAD = 6; // must match .ra-inner padding and .ra-highlight top/left
|
| 296 |
+
|
| 297 |
+
let currentIdx = 0;
|
| 298 |
+
|
| 299 |
+
function setHighlightByIndex(idx) {
|
| 300 |
+
currentIdx = idx;
|
| 301 |
+
|
| 302 |
+
const lbl = labels[idx];
|
| 303 |
+
if (!lbl) return;
|
| 304 |
+
|
| 305 |
+
const innerRect = inner.getBoundingClientRect();
|
| 306 |
+
const lblRect = lbl.getBoundingClientRect();
|
| 307 |
+
|
| 308 |
+
// width matches the label exactly
|
| 309 |
+
highlight.style.width = `${lblRect.width}px`;
|
| 310 |
+
|
| 311 |
+
// highlight has left: 6px, so subtract PAD to align
|
| 312 |
+
const x = (lblRect.left - innerRect.left - PAD);
|
| 313 |
+
highlight.style.transform = `translateX(${x}px)`;
|
| 314 |
+
}
|
| 315 |
+
|
| 316 |
+
function setCheckedByValue(val, shouldTrigger=false) {
|
| 317 |
+
const idx = Math.max(0, choices.indexOf(val));
|
| 318 |
+
inputs.forEach((inp, i) => { inp.checked = (i === idx); });
|
| 319 |
+
|
| 320 |
+
// Wait a frame in case fonts/layout settle (prevents rare drift)
|
| 321 |
+
requestAnimationFrame(() => setHighlightByIndex(idx));
|
| 322 |
+
|
| 323 |
+
props.value = choices[idx];
|
| 324 |
+
if (shouldTrigger) trigger('change', props.value);
|
| 325 |
+
}
|
| 326 |
+
|
| 327 |
+
// Init
|
| 328 |
+
setCheckedByValue(props.value ?? choices[0], false);
|
| 329 |
+
|
| 330 |
+
// Input handlers
|
| 331 |
+
inputs.forEach((inp) => {
|
| 332 |
+
inp.addEventListener('change', () => setCheckedByValue(inp.value, true));
|
| 333 |
});
|
| 334 |
+
|
| 335 |
+
// Recalc on resize (important in Gradio layouts)
|
| 336 |
+
window.addEventListener('resize', () => setHighlightByIndex(currentIdx));
|
| 337 |
+
})();
|
| 338 |
+
|
| 339 |
+
"""
|
| 340 |
+
|
| 341 |
+
super().__init__(
|
| 342 |
+
value=value,
|
| 343 |
+
html_template=html_template,
|
| 344 |
+
js_on_load=js_on_load,
|
| 345 |
+
**kwargs
|
| 346 |
+
)
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
class PromptBox(gr.HTML):
|
| 350 |
+
"""
|
| 351 |
+
DeepSite-like prompt box (HTML textarea) that behaves like an input component.
|
| 352 |
+
Outputs: the current text value (string)
|
| 353 |
+
"""
|
| 354 |
+
def __init__(self, value="", placeholder="Describe the video with audio you want to generate...", **kwargs):
|
| 355 |
+
uid = uuid.uuid4().hex[:8]
|
| 356 |
+
|
| 357 |
+
html_template = f"""
|
| 358 |
+
<div style="text-align:center; font-weight:600; margin-bottom:6px;">
|
| 359 |
+
Prompt
|
| 360 |
+
</div>
|
| 361 |
+
<div class="ds-prompt" data-ds="{uid}">
|
| 362 |
+
<textarea class="ds-textarea" rows="3"
|
| 363 |
+
placeholder="{placeholder}"></textarea>
|
| 364 |
+
</div>
|
| 365 |
+
"""
|
| 366 |
+
|
| 367 |
+
js_on_load = r"""
|
| 368 |
+
(() => {
|
| 369 |
+
const textarea = element.querySelector(".ds-textarea");
|
| 370 |
+
if (!textarea) return;
|
| 371 |
+
|
| 372 |
+
// Auto-resize (optional, but nice)
|
| 373 |
+
const autosize = () => {
|
| 374 |
+
textarea.style.height = "0px";
|
| 375 |
+
textarea.style.height = Math.min(textarea.scrollHeight, 240) + "px";
|
| 376 |
+
};
|
| 377 |
+
|
| 378 |
+
// Set initial value from props.value
|
| 379 |
+
const setValue = (v, triggerChange=false) => {
|
| 380 |
+
const val = (v ?? "");
|
| 381 |
+
if (textarea.value !== val) textarea.value = val;
|
| 382 |
+
autosize();
|
| 383 |
+
|
| 384 |
+
props.value = textarea.value;
|
| 385 |
+
if (triggerChange) trigger("change", props.value);
|
| 386 |
+
};
|
| 387 |
+
|
| 388 |
+
setValue(props.value, false);
|
| 389 |
+
|
| 390 |
+
// Update Gradio value on input
|
| 391 |
+
textarea.addEventListener("input", () => {
|
| 392 |
+
autosize();
|
| 393 |
+
props.value = textarea.value;
|
| 394 |
+
trigger("change", props.value);
|
| 395 |
});
|
| 396 |
+
|
| 397 |
+
let last = props.value;
|
| 398 |
+
const syncFromProps = () => {
|
| 399 |
+
if (props.value !== last) {
|
| 400 |
+
last = props.value;
|
| 401 |
+
setValue(last, false); // don't re-trigger change loop
|
| 402 |
+
}
|
| 403 |
+
requestAnimationFrame(syncFromProps);
|
| 404 |
+
};
|
| 405 |
+
requestAnimationFrame(syncFromProps);
|
| 406 |
+
})();
|
| 407 |
+
"""
|
| 408 |
+
|
| 409 |
+
super().__init__(
|
| 410 |
+
value=value,
|
| 411 |
+
html_template=html_template,
|
| 412 |
+
js_on_load=js_on_load,
|
| 413 |
+
**kwargs
|
| 414 |
+
)
|
| 415 |
+
|
| 416 |
+
class CameraDropdown(gr.HTML):
|
| 417 |
+
"""
|
| 418 |
+
Custom dropdown (More-style).
|
| 419 |
+
Outputs: selected option string, e.g. "Dolly Left"
|
| 420 |
+
"""
|
| 421 |
+
def __init__(self, choices, value="None", title="Camera LoRA", **kwargs):
|
| 422 |
+
if not choices:
|
| 423 |
+
raise ValueError("CameraDropdown requires choices.")
|
| 424 |
+
|
| 425 |
+
uid = uuid.uuid4().hex[:8]
|
| 426 |
+
safe_choices = [str(c) for c in choices]
|
| 427 |
+
|
| 428 |
+
items_html = "\n".join(
|
| 429 |
+
f"""<button type="button" class="cd-item" data-value="{c}">{c}</button>"""
|
| 430 |
+
for c in safe_choices
|
| 431 |
+
)
|
| 432 |
+
|
| 433 |
+
html_template = f"""
|
| 434 |
+
<div class="cd-wrap" data-cd="{uid}">
|
| 435 |
+
<button type="button" class="cd-trigger" aria-haspopup="menu" aria-expanded="false">
|
| 436 |
+
<span class="cd-trigger-text">More</span>
|
| 437 |
+
<span class="cd-caret">▾</span>
|
| 438 |
+
</button>
|
| 439 |
+
|
| 440 |
+
<div class="cd-menu" role="menu" aria-hidden="true">
|
| 441 |
+
<div class="cd-title">{title}</div>
|
| 442 |
+
<div class="cd-items">
|
| 443 |
+
{items_html}
|
| 444 |
+
</div>
|
| 445 |
+
</div>
|
| 446 |
+
</div>
|
| 447 |
+
"""
|
| 448 |
+
|
| 449 |
+
js_on_load = r"""
|
| 450 |
+
(() => {
|
| 451 |
+
const wrap = element.querySelector(".cd-wrap");
|
| 452 |
+
const trigger = element.querySelector(".cd-trigger");
|
| 453 |
+
const triggerText = element.querySelector(".cd-trigger-text");
|
| 454 |
+
const menu = element.querySelector(".cd-menu");
|
| 455 |
+
const items = Array.from(element.querySelectorAll(".cd-item"));
|
| 456 |
+
|
| 457 |
+
if (!wrap || !trigger || !menu || !items.length) return;
|
| 458 |
+
|
| 459 |
+
function closeMenu() {
|
| 460 |
+
menu.classList.remove("open");
|
| 461 |
+
trigger.setAttribute("aria-expanded", "false");
|
| 462 |
+
menu.setAttribute("aria-hidden", "true");
|
| 463 |
+
}
|
| 464 |
+
|
| 465 |
+
function openMenu() {
|
| 466 |
+
menu.classList.add("open");
|
| 467 |
+
trigger.setAttribute("aria-expanded", "true");
|
| 468 |
+
menu.setAttribute("aria-hidden", "false");
|
| 469 |
+
}
|
| 470 |
+
|
| 471 |
+
function setValue(val, shouldTrigger = false) {
|
| 472 |
+
const v = (val ?? "None");
|
| 473 |
+
props.value = v;
|
| 474 |
+
triggerText.textContent = v;
|
| 475 |
+
|
| 476 |
+
items.forEach(btn => {
|
| 477 |
+
btn.classList.toggle("selected", btn.dataset.value === v);
|
| 478 |
+
});
|
| 479 |
+
|
| 480 |
+
if (shouldTrigger) trigger("change", props.value);
|
| 481 |
+
}
|
| 482 |
+
|
| 483 |
+
// Toggle menu
|
| 484 |
+
trigger.addEventListener("pointerdown", (e) => {
|
| 485 |
+
e.preventDefault(); // prevents focus/blur weirdness
|
| 486 |
+
e.stopPropagation();
|
| 487 |
+
if (menu.classList.contains("open")) closeMenu();
|
| 488 |
+
else openMenu();
|
| 489 |
+
});
|
| 490 |
+
|
| 491 |
+
// Close on outside interaction (use capture so it wins)
|
| 492 |
+
document.addEventListener("pointerdown", (e) => {
|
| 493 |
+
if (!wrap.contains(e.target)) closeMenu();
|
| 494 |
+
}, true);
|
| 495 |
+
|
| 496 |
+
// Close on ESC
|
| 497 |
+
document.addEventListener("keydown", (e) => {
|
| 498 |
+
if (e.key === "Escape") closeMenu();
|
| 499 |
+
});
|
| 500 |
+
|
| 501 |
+
// Close when focus leaves the dropdown (keyboard users)
|
| 502 |
+
wrap.addEventListener("focusout", (e) => {
|
| 503 |
+
// if the newly-focused element isn't inside wrap, close
|
| 504 |
+
if (!wrap.contains(e.relatedTarget)) closeMenu();
|
| 505 |
+
});
|
| 506 |
+
|
| 507 |
+
// Item selection: use pointerdown so it closes immediately
|
| 508 |
+
items.forEach((btn) => {
|
| 509 |
+
btn.addEventListener("pointerdown", (e) => {
|
| 510 |
+
e.preventDefault();
|
| 511 |
+
e.stopPropagation();
|
| 512 |
+
|
| 513 |
+
// close first so it never "sticks" open
|
| 514 |
+
closeMenu();
|
| 515 |
+
setValue(btn.dataset.value, true);
|
| 516 |
+
});
|
| 517 |
+
});
|
| 518 |
+
|
| 519 |
+
// init
|
| 520 |
+
setValue((props.value ?? "None"), false);
|
| 521 |
+
|
| 522 |
+
// sync from Python updates
|
| 523 |
+
let last = props.value;
|
| 524 |
+
const syncFromProps = () => {
|
| 525 |
+
if (props.value !== last) {
|
| 526 |
+
last = props.value;
|
| 527 |
+
setValue(last, false);
|
| 528 |
+
}
|
| 529 |
+
requestAnimationFrame(syncFromProps);
|
| 530 |
+
};
|
| 531 |
+
requestAnimationFrame(syncFromProps);
|
| 532 |
})();
|
| 533 |
+
|
| 534 |
"""
|
| 535 |
|
| 536 |
super().__init__(
|
|
|
|
| 540 |
**kwargs
|
| 541 |
)
|
| 542 |
|
| 543 |
+
|
| 544 |
def generate_video_example(input_image, prompt, duration, progress=gr.Progress(track_tqdm=True)):
|
| 545 |
+
|
| 546 |
+
output_video, seed = generate_video(
|
| 547 |
+
input_image,
|
| 548 |
+
prompt,
|
| 549 |
+
5, # duration seconds
|
| 550 |
+
True, # enhance_prompt
|
| 551 |
+
42, # seed
|
| 552 |
+
True, # randomize_seed
|
| 553 |
+
DEFAULT_1_STAGE_HEIGHT, # height
|
| 554 |
+
DEFAULT_1_STAGE_WIDTH, # width
|
| 555 |
+
"No LoRA",
|
| 556 |
+
progress
|
| 557 |
+
)
|
| 558 |
|
| 559 |
return output_video
|
| 560 |
+
|
| 561 |
+
def get_duration(
|
| 562 |
+
input_image,
|
| 563 |
+
prompt,
|
| 564 |
+
duration,
|
| 565 |
+
enhance_prompt,
|
| 566 |
+
seed,
|
| 567 |
+
randomize_seed,
|
| 568 |
+
height,
|
| 569 |
+
width,
|
| 570 |
+
camera_lora,
|
| 571 |
+
progress
|
| 572 |
+
):
|
| 573 |
+
if duration <= 5:
|
| 574 |
+
return 80
|
| 575 |
+
elif duration <= 10:
|
| 576 |
+
return 120
|
| 577 |
+
else:
|
| 578 |
+
return 180
|
| 579 |
|
| 580 |
@spaces.GPU(duration=get_duration)
|
| 581 |
def generate_video(
|
|
|
|
| 587 |
randomize_seed: bool = True,
|
| 588 |
height: int = DEFAULT_1_STAGE_HEIGHT,
|
| 589 |
width: int = DEFAULT_1_STAGE_WIDTH,
|
| 590 |
+
camera_lora: str = "No LoRA",
|
| 591 |
progress=gr.Progress(track_tqdm=True),
|
| 592 |
):
|
| 593 |
"""
|
|
|
|
| 601 |
randomize_seed: If True, a random seed is generated for each run.
|
| 602 |
height: Output video height in pixels.
|
| 603 |
width: Output video width in pixels.
|
| 604 |
+
camera_lora: Camera motion control LoRA to apply during generation (enables exactly one at runtime).
|
| 605 |
progress: Gradio progress tracker.
|
| 606 |
Returns:
|
| 607 |
+
|
| 608 |
A tuple of:
|
| 609 |
- output_path: Path to the generated MP4 video file.
|
| 610 |
- seed: The seed used for generation.
|
|
|
|
| 653 |
del embeddings, final_prompt, status
|
| 654 |
torch.cuda.empty_cache()
|
| 655 |
|
| 656 |
+
|
| 657 |
+
# Map dropdown name -> adapter index
|
| 658 |
+
name_to_idx = {name: idx for name, idx in RUNTIME_LORA_CHOICES}
|
| 659 |
+
selected_idx = name_to_idx.get(camera_lora, -1)
|
| 660 |
+
|
| 661 |
+
# Disable all runtime adapters first (0..N-1)
|
| 662 |
+
# N here is len(RUNTIME_LORA_CHOICES)-1 because "None" isn't an adapter
|
| 663 |
+
for i in range(len(RUNTIME_LORA_CHOICES) - 1):
|
| 664 |
+
set_lora_enabled(pipeline._transformer, i, False)
|
| 665 |
+
|
| 666 |
+
# Enable selected one (if any)
|
| 667 |
+
if selected_idx >= 0:
|
| 668 |
+
set_lora_enabled(pipeline._transformer, selected_idx, True)
|
| 669 |
+
|
| 670 |
# Run inference - progress automatically tracks tqdm from pipeline
|
| 671 |
pipeline(
|
| 672 |
prompt=prompt,
|
|
|
|
| 702 |
duration_s = int(duration[:-1])
|
| 703 |
return duration_s
|
| 704 |
|
| 705 |
+
|
| 706 |
css = """
|
| 707 |
+
|
| 708 |
+
/* Make the row behave nicely */
|
| 709 |
+
#controls-row {
|
| 710 |
+
display: flex;
|
| 711 |
+
align-items: center;
|
| 712 |
+
gap: 12px;
|
| 713 |
+
flex-wrap: nowrap; /* or wrap if you prefer on small screens */
|
| 714 |
+
}
|
| 715 |
+
|
| 716 |
+
/* Stop these components from stretching */
|
| 717 |
+
#controls-row > * {
|
| 718 |
+
flex: 0 0 auto !important;
|
| 719 |
+
width: auto !important;
|
| 720 |
+
min-width: 0 !important;
|
| 721 |
+
}
|
| 722 |
+
|
| 723 |
+
#controls-row #camera_lora_ui {
|
| 724 |
+
margin-left: auto !important;
|
| 725 |
+
}
|
| 726 |
+
|
| 727 |
+
/* Gradio HTML components often have an inner wrapper div that is width:100% */
|
| 728 |
+
#camera_lora_ui,
|
| 729 |
+
#camera_lora_ui > div {
|
| 730 |
+
width: fit-content !important;
|
| 731 |
+
}
|
| 732 |
+
|
| 733 |
+
/* Same idea for your radio HTML blocks (optional but helps) */
|
| 734 |
+
#radioanimated_duration,
|
| 735 |
+
#radioanimated_duration > div,
|
| 736 |
+
#radioanimated_resolution,
|
| 737 |
+
#radioanimated_resolution > div {
|
| 738 |
+
width: fit-content !important;
|
| 739 |
+
}
|
| 740 |
+
|
| 741 |
#col-container {
|
| 742 |
margin: 0 auto;
|
| 743 |
max-width: 1600px;
|
|
|
|
| 876 |
}
|
| 877 |
"""
|
| 878 |
|
| 879 |
+
css += """
|
| 880 |
+
/* --- prompt box --- */
|
| 881 |
+
.ds-prompt{
|
| 882 |
+
width: 100%;
|
| 883 |
+
max-width: 720px;
|
| 884 |
+
margin-top: 3px;
|
| 885 |
+
}
|
| 886 |
+
|
| 887 |
+
.ds-textarea{
|
| 888 |
+
width: 100%;
|
| 889 |
+
box-sizing: border-box;
|
| 890 |
+
|
| 891 |
+
background: #2b2b2b;
|
| 892 |
+
color: rgba(255,255,255,0.9);
|
| 893 |
+
|
| 894 |
+
border: 1px solid rgba(255,255,255,0.12);
|
| 895 |
+
border-radius: 14px;
|
| 896 |
+
|
| 897 |
+
padding: 14px 16px;
|
| 898 |
+
outline: none;
|
| 899 |
+
|
| 900 |
+
font-family: ui-sans-serif, system-ui, -apple-system, Segoe UI, Roboto, Arial;
|
| 901 |
+
font-size: 15px;
|
| 902 |
+
line-height: 1.35;
|
| 903 |
+
|
| 904 |
+
resize: none;
|
| 905 |
+
height: 94px;
|
| 906 |
+
min-height: 94px;
|
| 907 |
+
max-height: 94px;
|
| 908 |
+
overflow-y: auto;
|
| 909 |
+
}
|
| 910 |
+
|
| 911 |
+
.ds-textarea::placeholder{
|
| 912 |
+
color: rgba(255,255,255,0.55);
|
| 913 |
+
}
|
| 914 |
+
|
| 915 |
+
.ds-textarea:focus{
|
| 916 |
+
border-color: rgba(255,255,255,0.22);
|
| 917 |
+
box-shadow: 0 0 0 3px rgba(255,255,255,0.06);
|
| 918 |
+
}
|
| 919 |
+
"""
|
| 920 |
+
|
| 921 |
+
css += """
|
| 922 |
+
/* ---- camera dropdown ---- */
|
| 923 |
+
|
| 924 |
+
/* 1) Fix overlap: make the Gradio HTML block shrink-to-fit when it contains a CameraDropdown.
|
| 925 |
+
Gradio uses .gr-html for HTML components in most versions; older themes sometimes use .gradio-html.
|
| 926 |
+
This keeps your big header HTML unaffected because it doesn't contain .cd-wrap.
|
| 927 |
+
*/
|
| 928 |
+
|
| 929 |
+
/* 2) Actual dropdown layout */
|
| 930 |
+
.cd-wrap{
|
| 931 |
+
position: relative;
|
| 932 |
+
display: inline-block;
|
| 933 |
+
}
|
| 934 |
+
|
| 935 |
+
/* 3) Match RadioAnimated pill size/feel */
|
| 936 |
+
.cd-trigger{
|
| 937 |
+
margin-top: 2px;
|
| 938 |
+
display: inline-flex;
|
| 939 |
+
align-items: center;
|
| 940 |
+
justify-content: center;
|
| 941 |
+
gap: 10px;
|
| 942 |
+
|
| 943 |
+
border: none;
|
| 944 |
+
|
| 945 |
+
box-sizing: border-box;
|
| 946 |
+
padding: 10px 18px;
|
| 947 |
+
min-height: 52px;
|
| 948 |
+
line-height: 1.2;
|
| 949 |
+
|
| 950 |
+
border-radius: 9999px;
|
| 951 |
+
background: #0b0b0b;
|
| 952 |
+
|
| 953 |
+
font-family: ui-sans-serif, system-ui, -apple-system, Segoe UI, Roboto, Arial;
|
| 954 |
+
font-size: 14px;
|
| 955 |
+
|
| 956 |
+
/* ✅ match .ra-label exactly */
|
| 957 |
+
color: rgba(255,255,255,0.7) !important;
|
| 958 |
+
font-weight: 600 !important;
|
| 959 |
+
|
| 960 |
+
cursor: pointer;
|
| 961 |
+
user-select: none;
|
| 962 |
+
white-space: nowrap;
|
| 963 |
+
}
|
| 964 |
+
|
| 965 |
+
/* Ensure inner spans match too */
|
| 966 |
+
.cd-trigger .cd-trigger-text,
|
| 967 |
+
.cd-trigger .cd-caret{
|
| 968 |
+
color: rgba(255,255,255,0.7) !important;
|
| 969 |
+
}
|
| 970 |
+
|
| 971 |
+
/* keep caret styling */
|
| 972 |
+
.cd-caret{
|
| 973 |
+
opacity: 0.8;
|
| 974 |
+
font-weight: 900;
|
| 975 |
+
}
|
| 976 |
+
|
| 977 |
+
/* 4) Ensure menu overlays neighbors and isn't clipped */
|
| 978 |
+
.cd-menu{
|
| 979 |
+
position: absolute;
|
| 980 |
+
top: calc(100% + 10px);
|
| 981 |
+
left: 0;
|
| 982 |
+
|
| 983 |
+
min-width: 240px;
|
| 984 |
+
background: #2b2b2b;
|
| 985 |
+
border: 1px solid rgba(255,255,255,0.14);
|
| 986 |
+
border-radius: 14px;
|
| 987 |
+
box-shadow: 0 18px 40px rgba(0,0,0,0.35);
|
| 988 |
+
padding: 10px;
|
| 989 |
+
|
| 990 |
+
opacity: 0;
|
| 991 |
+
transform: translateY(-6px);
|
| 992 |
+
pointer-events: none;
|
| 993 |
+
transition: opacity 160ms ease, transform 160ms ease;
|
| 994 |
+
|
| 995 |
+
z-index: 9999; /* was 50 */
|
| 996 |
+
}
|
| 997 |
+
|
| 998 |
+
.cd-menu.open{
|
| 999 |
+
opacity: 1;
|
| 1000 |
+
transform: translateY(0);
|
| 1001 |
+
pointer-events: auto;
|
| 1002 |
+
}
|
| 1003 |
+
|
| 1004 |
+
.cd-title{
|
| 1005 |
+
padding: 6px 8px 10px 8px;
|
| 1006 |
+
font-size: 12px;
|
| 1007 |
+
font-weight: 800;
|
| 1008 |
+
letter-spacing: 0.02em;
|
| 1009 |
+
color: rgba(255,255,255,0.55);
|
| 1010 |
+
text-transform: none;
|
| 1011 |
+
}
|
| 1012 |
+
|
| 1013 |
+
.cd-items{
|
| 1014 |
+
display: flex;
|
| 1015 |
+
flex-direction: column;
|
| 1016 |
+
gap: 6px;
|
| 1017 |
+
}
|
| 1018 |
+
|
| 1019 |
+
.cd-item{
|
| 1020 |
+
width: 100%;
|
| 1021 |
+
text-align: left;
|
| 1022 |
+
border: none;
|
| 1023 |
+
background: rgba(255,255,255,0.06);
|
| 1024 |
+
color: rgba(255,255,255,0.92);
|
| 1025 |
+
padding: 10px 10px;
|
| 1026 |
+
border-radius: 12px;
|
| 1027 |
+
cursor: pointer;
|
| 1028 |
+
font-size: 14px;
|
| 1029 |
+
font-weight: 700;
|
| 1030 |
+
transition: background 120ms ease, transform 80ms ease;
|
| 1031 |
+
}
|
| 1032 |
+
|
| 1033 |
+
.cd-item:hover{
|
| 1034 |
+
background: rgba(255,255,255,0.10);
|
| 1035 |
+
}
|
| 1036 |
+
|
| 1037 |
+
.cd-item:active{
|
| 1038 |
+
transform: translateY(1px);
|
| 1039 |
+
}
|
| 1040 |
+
|
| 1041 |
+
.cd-item.selected{
|
| 1042 |
+
background: rgba(139,255,151,0.22);
|
| 1043 |
+
border: 1px solid rgba(139,255,151,0.35);
|
| 1044 |
+
}
|
| 1045 |
+
|
| 1046 |
+
"""
|
| 1047 |
+
|
| 1048 |
+
|
| 1049 |
|
| 1050 |
with gr.Blocks(title="LTX-2 Video Distilled 🎥🔈") as demo:
|
| 1051 |
gr.HTML(
|
|
|
|
| 1081 |
height=512
|
| 1082 |
)
|
| 1083 |
|
| 1084 |
+
|
| 1085 |
+
prompt_ui = PromptBox(
|
| 1086 |
+
value="Make this image come alive with cinematic motion, smooth animation",
|
| 1087 |
+
elem_id="prompt_ui",
|
| 1088 |
+
)
|
| 1089 |
+
|
| 1090 |
prompt = gr.Textbox(
|
| 1091 |
label="Prompt",
|
| 1092 |
value="Make this image come alive with cinematic motion, smooth animation",
|
| 1093 |
lines=3,
|
| 1094 |
max_lines=3,
|
| 1095 |
+
placeholder="Describe the motion and animation you want...",
|
| 1096 |
+
visible=False
|
| 1097 |
)
|
| 1098 |
|
| 1099 |
enhance_prompt = gr.Checkbox(
|
|
|
|
| 1116 |
|
| 1117 |
with gr.Column(elem_id="step-column"):
|
| 1118 |
output_video = gr.Video(label="Generated Video", autoplay=True, height=512)
|
| 1119 |
+
|
| 1120 |
+
with gr.Row(elem_id="controls-row"):
|
| 1121 |
+
|
|
|
|
| 1122 |
radioanimated_duration = RadioAnimated(
|
| 1123 |
choices=["3s", "5s", "10s", "15s"],
|
| 1124 |
value="3s",
|
|
|
|
| 1133 |
step=0.1,
|
| 1134 |
visible=False
|
| 1135 |
)
|
| 1136 |
+
|
|
|
|
| 1137 |
radioanimated_resolution = RadioAnimated(
|
| 1138 |
choices=["768x512", "512x512", "512x768"],
|
| 1139 |
value=f"{DEFAULT_1_STAGE_WIDTH}x{DEFAULT_1_STAGE_HEIGHT}",
|
|
|
|
| 1142 |
|
| 1143 |
width = gr.Number(label="Width", value=DEFAULT_1_STAGE_WIDTH, precision=0, visible=False)
|
| 1144 |
height = gr.Number(label="Height", value=DEFAULT_1_STAGE_HEIGHT, precision=0, visible=False)
|
| 1145 |
+
|
| 1146 |
+
camera_lora_ui = CameraDropdown(
|
| 1147 |
+
choices=[name for name, _ in RUNTIME_LORA_CHOICES],
|
| 1148 |
+
value="No LoRA",
|
| 1149 |
+
title="Camera LoRA",
|
| 1150 |
+
elem_id="camera_lora_ui",
|
| 1151 |
+
)
|
| 1152 |
+
|
| 1153 |
+
# Hidden real dropdown (backend value)
|
| 1154 |
+
camera_lora = gr.Dropdown(
|
| 1155 |
+
label="Camera Control LoRA",
|
| 1156 |
+
choices=[name for name, _ in RUNTIME_LORA_CHOICES],
|
| 1157 |
+
value="No LoRA",
|
| 1158 |
+
visible=False
|
| 1159 |
+
)
|
| 1160 |
|
| 1161 |
generate_btn = gr.Button("🤩 Generate Video", variant="primary", elem_classes="button-gradient")
|
| 1162 |
|
| 1163 |
+
camera_lora_ui.change(
|
| 1164 |
+
fn=lambda x: x,
|
| 1165 |
+
inputs=camera_lora_ui,
|
| 1166 |
+
outputs=camera_lora,
|
| 1167 |
+
api_visibility="private"
|
| 1168 |
+
)
|
| 1169 |
|
| 1170 |
radioanimated_duration.change(
|
| 1171 |
fn=apply_duration,
|
|
|
|
| 1179 |
outputs=[width, height],
|
| 1180 |
api_visibility="private"
|
| 1181 |
)
|
| 1182 |
+
prompt_ui.change(
|
| 1183 |
+
fn=lambda x: x,
|
| 1184 |
+
inputs=prompt_ui,
|
| 1185 |
+
outputs=prompt,
|
| 1186 |
+
api_visibility="private"
|
| 1187 |
+
)
|
| 1188 |
+
|
| 1189 |
|
| 1190 |
generate_btn.click(
|
| 1191 |
fn=generate_video,
|
|
|
|
| 1198 |
randomize_seed,
|
| 1199 |
height,
|
| 1200 |
width,
|
| 1201 |
+
camera_lora,
|
| 1202 |
],
|
| 1203 |
outputs=[output_video,seed]
|
| 1204 |
)
|
|
|
|
| 1225 |
|
| 1226 |
],
|
| 1227 |
fn=generate_video_example,
|
| 1228 |
+
inputs=[input_image, prompt_ui],
|
| 1229 |
outputs = [output_video],
|
| 1230 |
label="Example",
|
| 1231 |
cache_examples=True,
|
packages/ltx-core/src/ltx_core/loader/fuse_loras.py
CHANGED
|
@@ -3,6 +3,7 @@ import triton
|
|
| 3 |
|
| 4 |
from ltx_core.loader.kernels import fused_add_round_kernel
|
| 5 |
from ltx_core.loader.primitives import LoraStateDictWithStrength, StateDict
|
|
|
|
| 6 |
|
| 7 |
BLOCK_SIZE = 1024
|
| 8 |
|
|
@@ -59,42 +60,59 @@ def _prepare_deltas(
|
|
| 59 |
return deltas[0]
|
| 60 |
return torch.sum(torch.stack(deltas, dim=0), dim=0)
|
| 61 |
|
| 62 |
-
|
| 63 |
def apply_loras(
|
| 64 |
model_sd: StateDict,
|
| 65 |
lora_sd_and_strengths: list[LoraStateDictWithStrength],
|
| 66 |
dtype: torch.dtype,
|
| 67 |
destination_sd: StateDict | None = None,
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
if destination_sd is not None
|
| 71 |
-
sd = destination_sd.sd
|
| 72 |
size = 0
|
| 73 |
device = torch.device("meta")
|
| 74 |
inner_dtypes = set()
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
for key, weight in model_sd.sd.items():
|
| 76 |
if weight is None:
|
| 77 |
continue
|
|
|
|
|
|
|
|
|
|
|
|
|
| 78 |
device = weight.device
|
| 79 |
target_dtype = dtype if dtype is not None else weight.dtype
|
| 80 |
-
deltas_dtype = target_dtype
|
|
|
|
| 81 |
deltas = _prepare_deltas(lora_sd_and_strengths, key, deltas_dtype, device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
if deltas is None:
|
| 83 |
if key in sd:
|
| 84 |
continue
|
| 85 |
-
|
| 86 |
-
elif weight.dtype == torch.float8_e4m3fn:
|
| 87 |
-
if str(device).startswith("cuda"):
|
| 88 |
-
deltas = calculate_weight_float8_(deltas, weight)
|
| 89 |
-
else:
|
| 90 |
-
deltas.add_(weight.to(dtype=deltas.dtype, device=device))
|
| 91 |
-
elif weight.dtype == torch.bfloat16:
|
| 92 |
-
deltas.add_(weight)
|
| 93 |
else:
|
| 94 |
-
|
| 95 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
inner_dtypes.add(target_dtype)
|
| 97 |
-
size +=
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
|
| 4 |
from ltx_core.loader.kernels import fused_add_round_kernel
|
| 5 |
from ltx_core.loader.primitives import LoraStateDictWithStrength, StateDict
|
| 6 |
+
from typing import Iterable
|
| 7 |
|
| 8 |
BLOCK_SIZE = 1024
|
| 9 |
|
|
|
|
| 60 |
return deltas[0]
|
| 61 |
return torch.sum(torch.stack(deltas, dim=0), dim=0)
|
| 62 |
|
|
|
|
| 63 |
def apply_loras(
|
| 64 |
model_sd: StateDict,
|
| 65 |
lora_sd_and_strengths: list[LoraStateDictWithStrength],
|
| 66 |
dtype: torch.dtype,
|
| 67 |
destination_sd: StateDict | None = None,
|
| 68 |
+
return_affected: bool = False,
|
| 69 |
+
) -> StateDict | tuple[StateDict, list[str]]:
|
| 70 |
+
sd = destination_sd.sd if destination_sd is not None else {}
|
|
|
|
| 71 |
size = 0
|
| 72 |
device = torch.device("meta")
|
| 73 |
inner_dtypes = set()
|
| 74 |
+
|
| 75 |
+
affected_weight_keys: list[str] = []
|
| 76 |
+
affected_module_prefixes: set[str] = set()
|
| 77 |
+
|
| 78 |
for key, weight in model_sd.sd.items():
|
| 79 |
if weight is None:
|
| 80 |
continue
|
| 81 |
+
if not key.endswith(".weight"):
|
| 82 |
+
# optional: skip non-weight tensors if your SD has them
|
| 83 |
+
continue
|
| 84 |
+
|
| 85 |
device = weight.device
|
| 86 |
target_dtype = dtype if dtype is not None else weight.dtype
|
| 87 |
+
deltas_dtype = target_dtype # you said ignore fp8 path
|
| 88 |
+
|
| 89 |
deltas = _prepare_deltas(lora_sd_and_strengths, key, deltas_dtype, device)
|
| 90 |
+
|
| 91 |
+
# Record which weights are actually modified by LoRA
|
| 92 |
+
if deltas is not None:
|
| 93 |
+
affected_weight_keys.append(key)
|
| 94 |
+
affected_module_prefixes.add(key[: -len(".weight")])
|
| 95 |
+
|
| 96 |
if deltas is None:
|
| 97 |
if key in sd:
|
| 98 |
continue
|
| 99 |
+
out = weight.clone().to(dtype=target_dtype, device=device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
else:
|
| 101 |
+
# normal add_ path (bf16 etc)
|
| 102 |
+
out = deltas.to(dtype=target_dtype)
|
| 103 |
+
# IMPORTANT: add base weight
|
| 104 |
+
out.add_(weight.to(dtype=out.dtype, device=device))
|
| 105 |
+
|
| 106 |
+
sd[key] = out
|
| 107 |
inner_dtypes.add(target_dtype)
|
| 108 |
+
size += out.nbytes
|
| 109 |
+
|
| 110 |
+
result = destination_sd if destination_sd is not None else StateDict(sd, device, size, inner_dtypes)
|
| 111 |
+
|
| 112 |
+
if return_affected:
|
| 113 |
+
# sorted for stable output
|
| 114 |
+
affected = sorted(affected_module_prefixes)
|
| 115 |
+
return result, affected
|
| 116 |
+
|
| 117 |
+
return result
|
| 118 |
+
|
packages/ltx-core/src/ltx_core/loader/single_gpu_model_builder.py
CHANGED
|
@@ -3,6 +3,7 @@ from dataclasses import dataclass, field, replace
|
|
| 3 |
from typing import Generic
|
| 4 |
|
| 5 |
import torch
|
|
|
|
| 6 |
|
| 7 |
from ltx_core.loader.fuse_loras import apply_loras
|
| 8 |
from ltx_core.loader.module_ops import ModuleOps
|
|
@@ -22,6 +23,109 @@ from ltx_core.model.model_protocol import ModelConfigurator, ModelType
|
|
| 22 |
logger: logging.Logger = logging.getLogger(__name__)
|
| 23 |
|
| 24 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
@dataclass(frozen=True)
|
| 26 |
class SingleGPUModelBuilder(Generic[ModelType], ModelBuilderProtocol[ModelType], LoRAAdaptableProtocol):
|
| 27 |
"""
|
|
@@ -93,9 +197,29 @@ class SingleGPUModelBuilder(Generic[ModelType], ModelBuilderProtocol[ModelType],
|
|
| 93 |
]
|
| 94 |
final_sd = apply_loras(
|
| 95 |
model_sd=model_state_dict,
|
| 96 |
-
lora_sd_and_strengths=lora_sd_and_strengths,
|
| 97 |
dtype=dtype,
|
| 98 |
destination_sd=model_state_dict if isinstance(self.registry, DummyRegistry) else None,
|
| 99 |
)
|
| 100 |
meta_model.load_state_dict(final_sd.sd, strict=False, assign=True)
|
| 101 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
from typing import Generic
|
| 4 |
|
| 5 |
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
|
| 8 |
from ltx_core.loader.fuse_loras import apply_loras
|
| 9 |
from ltx_core.loader.module_ops import ModuleOps
|
|
|
|
| 23 |
logger: logging.Logger = logging.getLogger(__name__)
|
| 24 |
|
| 25 |
|
| 26 |
+
def get_submodule_and_parent(root: nn.Module, path: str):
|
| 27 |
+
"""
|
| 28 |
+
Returns (parent_module, child_name, child_module)
|
| 29 |
+
where child_module is reachable at `path` from root.
|
| 30 |
+
Supports numeric segments for Sequential/ModuleList.
|
| 31 |
+
"""
|
| 32 |
+
parts = path.split(".")
|
| 33 |
+
parent = root
|
| 34 |
+
for p in parts[:-1]:
|
| 35 |
+
if p.isdigit():
|
| 36 |
+
parent = parent[int(p)] # Sequential/ModuleList
|
| 37 |
+
else:
|
| 38 |
+
parent = getattr(parent, p)
|
| 39 |
+
last = parts[-1]
|
| 40 |
+
if last.isdigit():
|
| 41 |
+
child = parent[int(last)]
|
| 42 |
+
else:
|
| 43 |
+
child = getattr(parent, last)
|
| 44 |
+
return parent, last, child
|
| 45 |
+
|
| 46 |
+
def set_submodule(root: nn.Module, path: str, new_module: nn.Module):
|
| 47 |
+
parent, last, _ = get_submodule_and_parent(root, path)
|
| 48 |
+
if last.isdigit():
|
| 49 |
+
parent[int(last)] = new_module
|
| 50 |
+
else:
|
| 51 |
+
setattr(parent, last, new_module)
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class MultiLoraLinear(nn.Module):
|
| 55 |
+
def __init__(self, base: nn.Linear):
|
| 56 |
+
super().__init__()
|
| 57 |
+
self.base = base
|
| 58 |
+
self.adapters: list[tuple[torch.Tensor, torch.Tensor, float]] = []
|
| 59 |
+
self.enabled: list[bool] = []
|
| 60 |
+
|
| 61 |
+
def add_adapter(self, A: torch.Tensor, B: torch.Tensor, scale: float, enabled: bool = True):
|
| 62 |
+
# store as buffers for inference (keeps them off .parameters())
|
| 63 |
+
idx = len(self.adapters)
|
| 64 |
+
self.register_buffer(f"lora_A_{idx}", A, persistent=False)
|
| 65 |
+
self.register_buffer(f"lora_B_{idx}", B, persistent=False)
|
| 66 |
+
self.adapters.append((A, B, float(scale)))
|
| 67 |
+
self.enabled.append(bool(enabled))
|
| 68 |
+
|
| 69 |
+
def set_enabled(self, idx: int, enabled: bool):
|
| 70 |
+
if 0 <= idx < len(self.enabled):
|
| 71 |
+
self.enabled[idx] = enabled
|
| 72 |
+
|
| 73 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 74 |
+
out = self.base(x)
|
| 75 |
+
# add enabled adapters
|
| 76 |
+
for i, on in enumerate(self.enabled):
|
| 77 |
+
if not on:
|
| 78 |
+
continue
|
| 79 |
+
A = getattr(self, f"lora_A_{i}")
|
| 80 |
+
B = getattr(self, f"lora_B_{i}")
|
| 81 |
+
scale = self.adapters[i][2]
|
| 82 |
+
out = out + ((x @ A.t()) @ B.t()) * scale
|
| 83 |
+
return out
|
| 84 |
+
|
| 85 |
+
def set_lora_enabled(model: nn.Module, adapter_idx: int, enabled: bool):
|
| 86 |
+
for m in model.modules():
|
| 87 |
+
if isinstance(m, MultiLoraLinear):
|
| 88 |
+
m.set_enabled(adapter_idx, enabled)
|
| 89 |
+
|
| 90 |
+
def patch_only_affected_linears(
|
| 91 |
+
model: nn.Module,
|
| 92 |
+
lora_sd: dict,
|
| 93 |
+
affected_modules: list[str],
|
| 94 |
+
strength: float,
|
| 95 |
+
adapter_idx: int,
|
| 96 |
+
default_enabled: bool = False,
|
| 97 |
+
):
|
| 98 |
+
for prefix in affected_modules:
|
| 99 |
+
_, _, mod = get_submodule_and_parent(model, prefix)
|
| 100 |
+
|
| 101 |
+
# unwrap / wrap
|
| 102 |
+
if isinstance(mod, MultiLoraLinear):
|
| 103 |
+
wrapped = mod
|
| 104 |
+
else:
|
| 105 |
+
if not isinstance(mod, nn.Linear):
|
| 106 |
+
continue
|
| 107 |
+
wrapped = MultiLoraLinear(mod)
|
| 108 |
+
set_submodule(model, prefix, wrapped)
|
| 109 |
+
|
| 110 |
+
key_a = f"{prefix}.lora_A.weight"
|
| 111 |
+
key_b = f"{prefix}.lora_B.weight"
|
| 112 |
+
if key_a not in lora_sd or key_b not in lora_sd:
|
| 113 |
+
continue
|
| 114 |
+
|
| 115 |
+
base_device = wrapped.base.weight.device
|
| 116 |
+
base_dtype = wrapped.base.weight.dtype
|
| 117 |
+
|
| 118 |
+
A = lora_sd[key_a].to(device=base_device, dtype=base_dtype)
|
| 119 |
+
B = lora_sd[key_b].to(device=base_device, dtype=base_dtype)
|
| 120 |
+
|
| 121 |
+
# parity with your current merge behavior:
|
| 122 |
+
scale = strength
|
| 123 |
+
|
| 124 |
+
# Ensure adapter list indices align across layers
|
| 125 |
+
# If adapters are added sequentially per adapter_idx, this will line up.
|
| 126 |
+
wrapped.add_adapter(A, B, scale=scale, enabled=default_enabled)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
@dataclass(frozen=True)
|
| 130 |
class SingleGPUModelBuilder(Generic[ModelType], ModelBuilderProtocol[ModelType], LoRAAdaptableProtocol):
|
| 131 |
"""
|
|
|
|
| 197 |
]
|
| 198 |
final_sd = apply_loras(
|
| 199 |
model_sd=model_state_dict,
|
| 200 |
+
lora_sd_and_strengths=[lora_sd_and_strengths[0]],
|
| 201 |
dtype=dtype,
|
| 202 |
destination_sd=model_state_dict if isinstance(self.registry, DummyRegistry) else None,
|
| 203 |
)
|
| 204 |
meta_model.load_state_dict(final_sd.sd, strict=False, assign=True)
|
| 205 |
+
model = self._return_model(meta_model, device)
|
| 206 |
+
|
| 207 |
+
_, affected_modules = apply_loras(
|
| 208 |
+
model_sd=model_state_dict,
|
| 209 |
+
lora_sd_and_strengths=lora_sd_and_strengths,
|
| 210 |
+
dtype=dtype,
|
| 211 |
+
destination_sd=None,
|
| 212 |
+
return_affected=True,
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
for runtime_idx, (lora_sd, strength) in enumerate(zip(lora_state_dicts[1:], lora_strengths[1:], strict=True)):
|
| 216 |
+
patch_only_affected_linears(
|
| 217 |
+
model,
|
| 218 |
+
lora_sd.sd,
|
| 219 |
+
affected_modules,
|
| 220 |
+
strength=strength,
|
| 221 |
+
adapter_idx=runtime_idx,
|
| 222 |
+
default_enabled=False, # start off
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
return model
|