linoyts's picture
linoyts HF Staff
Update app.py
2d9b782 verified
import sys
from pathlib import Path
# Add packages to Python path
current_dir = Path(__file__).parent
sys.path.insert(0, str(current_dir / "packages" / "ltx-pipelines" / "src"))
sys.path.insert(0, str(current_dir / "packages" / "ltx-core" / "src"))
import numpy as np
import random
import spaces
import gradio as gr
from gradio_client import Client, handle_file
import torch
from pathlib import Path
from typing import Optional
from huggingface_hub import hf_hub_download
from ltx_pipelines.keyframe_interpolation import KeyframeInterpolationPipeline
from ltx_core.tiling import TilingConfig
from ltx_pipelines.constants import (
DEFAULT_SEED,
DEFAULT_HEIGHT,
DEFAULT_WIDTH,
DEFAULT_NUM_FRAMES,
DEFAULT_FRAME_RATE,
DEFAULT_NUM_INFERENCE_STEPS,
DEFAULT_CFG_GUIDANCE_SCALE,
DEFAULT_LORA_STRENGTH,
)
MAX_SEED = np.iinfo(np.int32).max
# Custom negative prompt
DEFAULT_NEGATIVE_PROMPT = "shaky, glitchy, low quality, worst quality, deformed, distorted, disfigured, motion smear, motion artifacts, fused fingers, bad anatomy, weird hand, ugly, transition, static"
# Default prompt for keyframe interpolation
DEFAULT_PROMPT = "Smooth cinematic transition between keyframes with natural motion and consistent lighting"
# HuggingFace Hub defaults
DEFAULT_REPO_ID = "Lightricks/LTX-2"
DEFAULT_CHECKPOINT_FILENAME = "ltx-2-19b-dev-fp8.safetensors"
DEFAULT_DISTILLED_LORA_FILENAME = "ltx-2-19b-distilled-lora-384.safetensors"
DEFAULT_SPATIAL_UPSAMPLER_FILENAME = "ltx-2-spatial-upscaler-x2-1.0.safetensors"
# Text encoder space URL
TEXT_ENCODER_SPACE = "linoyts/gemma-text-encoder"
# Image edit space URL
IMAGE_EDIT_SPACE = "linoyts/Qwen-Image-Edit-2509-Fast"
def get_hub_or_local_checkpoint(repo_id: Optional[str] = None, filename: Optional[str] = None):
"""Download from HuggingFace Hub or use local checkpoint."""
if repo_id is None and filename is None:
raise ValueError("Please supply at least one of `repo_id` or `filename`")
if repo_id is not None:
if filename is None:
raise ValueError("If repo_id is specified, filename must also be specified.")
print(f"Downloading {filename} from {repo_id}...")
ckpt_path = hf_hub_download(repo_id=repo_id, filename=filename)
print(f"Downloaded to {ckpt_path}")
else:
ckpt_path = filename
return ckpt_path
# Initialize pipeline at startup
print("=" * 80)
print("Loading LTX-2 Keyframe Interpolation pipeline...")
print("=" * 80)
checkpoint_path = get_hub_or_local_checkpoint(DEFAULT_REPO_ID, DEFAULT_CHECKPOINT_FILENAME)
distilled_lora_path = get_hub_or_local_checkpoint(DEFAULT_REPO_ID, DEFAULT_DISTILLED_LORA_FILENAME)
spatial_upsampler_path = get_hub_or_local_checkpoint(DEFAULT_REPO_ID, DEFAULT_SPATIAL_UPSAMPLER_FILENAME)
print(f"Initializing pipeline with:")
print(f" checkpoint_path={checkpoint_path}")
print(f" distilled_lora_path={distilled_lora_path}")
print(f" spatial_upsampler_path={spatial_upsampler_path}")
print(f" text_encoder_space={TEXT_ENCODER_SPACE}")
# Initialize pipeline WITHOUT text encoder (gemma_root=None)
# Text encoding will be done by external space
pipeline = KeyframeInterpolationPipeline(
checkpoint_path=checkpoint_path,
distilled_lora_path=distilled_lora_path,
distilled_lora_strength=DEFAULT_LORA_STRENGTH,
spatial_upsampler_path=spatial_upsampler_path,
gemma_root=None,
loras=[],
fp8transformer=False,
)
# Initialize text encoder client
print(f"Connecting to text encoder space: {TEXT_ENCODER_SPACE}")
try:
text_encoder_client = Client(TEXT_ENCODER_SPACE)
print("✓ Text encoder client connected!")
except Exception as e:
print(f"⚠ Warning: Could not connect to text encoder space: {e}")
text_encoder_client = None
# Initialize image edit client
print(f"Connecting to image edit space: {IMAGE_EDIT_SPACE}")
try:
image_edit_client = Client(IMAGE_EDIT_SPACE)
print("✓ Image edit client connected!")
except Exception as e:
print(f"⚠ Warning: Could not connect to image edit space: {e}")
image_edit_client = None
def generate_end_frame(start_frame, edit_prompt: str):
"""Generate an end frame from the start frame using Qwen Image Edit."""
try:
if start_frame is None:
raise gr.Error("Please provide a start frame first")
if image_edit_client is None:
raise gr.Error(
f"Image edit client not connected. Please ensure the image edit space "
f"({IMAGE_EDIT_SPACE}) is running and accessible."
)
# Save start frame temporarily if needed
output_dir = Path("outputs")
output_dir.mkdir(exist_ok=True)
temp_path = output_dir / f"temp_start_for_edit.jpg"
if hasattr(start_frame, 'save'):
start_frame.save(temp_path)
image_input = handle_file(str(temp_path))
else:
image_input = handle_file(str(start_frame))
# Call Qwen Image Edit
result, _= image_edit_client.predict(
images=[{"image":image_input}],
prompt=edit_prompt,
api_name="/infer"
)
return result[0]['image']
except Exception as e:
import traceback
error_msg = f"Error generating end frame: {str(e)}\n{traceback.format_exc()}"
print(error_msg)
raise gr.Error(error_msg)
@spaces.GPU(duration=300)
def generate_video(
start_frame,
prompt: str,
end_frame_upload=None,
end_frame_generated=None,
strength_start: float = 1.,
strength_end: float = 1.,
duration: float = 5,
enhance_prompt: bool = True,
negative_prompt: str = DEFAULT_NEGATIVE_PROMPT,
seed: int = 42,
randomize_seed: bool = True,
num_inference_steps: int = 20,
cfg_guidance_scale: float = DEFAULT_CFG_GUIDANCE_SCALE,
height: int = DEFAULT_HEIGHT,
width: int = DEFAULT_WIDTH,
progress=gr.Progress(track_tqdm=True)
):
"""Generate a video with keyframe interpolation between start and end frames."""
try:
# Randomize seed if checkbox is enabled
current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
# Calculate num_frames from duration (using fixed 24 fps)
frame_rate = 24.0
num_frames = int(duration * frame_rate) + 1 # +1 to ensure we meet the duration
# Create output directory if it doesn't exist
output_dir = Path("outputs")
output_dir.mkdir(exist_ok=True)
output_path = output_dir / f"keyframe_video_{current_seed}.mp4"
# Handle keyframe inputs - build list of (image_path, frame_idx, strength)
images = []
temp_paths = []
# Determine which end frame to use (uploaded or generated)
end_frame = end_frame_generated if end_frame_generated is not None else end_frame_upload
if start_frame is None and end_frame is None:
raise ValueError("Please provide at least one keyframe (start or end frame)")
# Save start frame (frame index 0) if provided
if start_frame is not None:
temp_start_path = output_dir / f"temp_start_{current_seed}.jpg"
if hasattr(start_frame, 'save'):
start_frame.save(temp_start_path)
else:
temp_start_path = Path(start_frame)
temp_paths.append(temp_start_path)
images.append((str(temp_start_path), 0, strength_start))
# Save end frame (last frame index) if provided
if end_frame is not None:
temp_end_path = output_dir / f"temp_end_{current_seed}.jpg"
if hasattr(end_frame, 'save'):
end_frame.save(temp_end_path)
else:
temp_end_path = Path(end_frame)
temp_paths.append(temp_end_path)
images.append((str(temp_end_path), num_frames - 1, strength_end))
# Get embeddings from text encoder space
print(f"Encoding prompt: {prompt}")
if text_encoder_client is None:
raise RuntimeError(
f"Text encoder client not connected. Please ensure the text encoder space "
f"({TEXT_ENCODER_SPACE}) is running and accessible."
)
try:
# Use first available frame for prompt enhancement
first_frame_path = temp_paths[0] if temp_paths else None
image_input = handle_file(str(first_frame_path)) if first_frame_path else None
result = text_encoder_client.predict(
prompt=prompt,
enhance_prompt=enhance_prompt,
input_image=image_input,
seed=current_seed,
negative_prompt=negative_prompt,
api_name="/encode_prompt"
)
embedding_path = result[0] # Path to .pt file
print(f"Embeddings received from: {embedding_path}")
# Load embeddings
embeddings = torch.load(embedding_path)
video_context_positive = embeddings['video_context']
audio_context_positive = embeddings['audio_context']
# Get the final prompt that was used (enhanced or original)
final_prompt = embeddings.get('prompt', prompt)
# Load negative contexts if available
video_context_negative = embeddings.get('video_context_negative', None)
audio_context_negative = embeddings.get('audio_context_negative', None)
print("✓ Embeddings loaded successfully")
if video_context_negative is not None:
print(" ✓ Negative prompt embeddings also loaded")
except Exception as e:
raise RuntimeError(
f"Failed to get embeddings from text encoder space: {e}\n"
f"Please ensure {TEXT_ENCODER_SPACE} is running properly."
)
# Run inference - progress automatically tracks tqdm from pipeline
pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
output_path=str(output_path),
seed=current_seed,
height=height,
width=width,
num_frames=num_frames,
frame_rate=frame_rate,
num_inference_steps=num_inference_steps,
cfg_guidance_scale=cfg_guidance_scale,
images=images,
tiling_config=TilingConfig.default(),
video_context_positive=video_context_positive,
audio_context_positive=audio_context_positive,
video_context_negative=video_context_negative,
audio_context_negative=audio_context_negative,
)
return str(output_path), final_prompt, current_seed
except Exception as e:
import traceback
error_msg = f"Error: {str(e)}\n{traceback.format_exc()}"
print(error_msg)
return None, f"Error: {str(e)}", current_seed
# Create Gradio interface
with gr.Blocks(title="LTX-2 Keyframe Interpolation 🎥🔈") as demo:
gr.Markdown("# LTX-2 First-Last Frame 🎥🔈")
gr.Markdown("Generate video& audio with smooth transitions between keyframes with Lightricks LTX-2. Read more: [[model]](https://huggingface.co/Lightricks/LTX-2), [[code]](https://github.com/Lightricks/LTX-2)")
with gr.Row(elem_id="general_items"):
with gr.Column():
with gr.Group(elem_id="group_all"):
with gr.Row():
start_frame = gr.Image(
label="Start Frame (Optional)",
type="pil",
)
with gr.Tabs():
with gr.Tab("Upload"):
end_frame_upload = gr.Image(
label="End Frame",
type="pil",
)
with gr.Tab("Generate"):
end_frame_generated = gr.Image(
label="Generated End Frame",
type="pil",
)
# gr.Markdown("Generate an end frame with Qwen Edit")
edit_prompt = gr.Textbox(
label="Edit Prompt for end frame",
info ="Generate end frame with Qwen Edit",
placeholder="Describe the transformation (e.g., '5 seconds later, sunset lighting')",
lines=2,
value="5 seconds in the future"
)
generate_end_btn = gr.Button("Generate End Frame", variant="secondary")
prompt = gr.Textbox(
label="Prompt",
info="Describe the motion/transition between frames",
value=DEFAULT_PROMPT,
lines=3,
placeholder="Describe the animation style and motion..."
)
generate_btn = gr.Button("Generate Video", variant="primary")
with gr.Accordion("Advanced Settings", open=False):
with gr.Row():
duration = gr.Slider(
label="Duration (seconds)",
minimum=1.0,
maximum=10.0,
value=5.0,
step=0.1
)
enhance_prompt = gr.Checkbox(
label="Enhance Prompt",
value=True
)
with gr.Row():
strength_start = gr.Slider(
label="strength - start frame conditioning",
minimum=0.0,
maximum=1.0,
value=1.0,
step=0.05
)
strength_end = gr.Slider(
label="strength - end frame conditioning",
minimum=0.0,
maximum=1.0,
value=0.9,
step=0.05
)
negative_prompt = gr.Textbox(
label="Negative Prompt",
value=DEFAULT_NEGATIVE_PROMPT,
lines=2
)
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
value=DEFAULT_SEED,
step=1
)
randomize_seed = gr.Checkbox(
label="Randomize Seed",
value=True
)
num_inference_steps = gr.Slider(
label="Inference Steps",
minimum=1,
maximum=DEFAULT_NUM_INFERENCE_STEPS,
value=20,
step=1
)
cfg_guidance_scale = gr.Slider(
label="CFG Guidance Scale",
minimum=1.0,
maximum=10.0,
value=DEFAULT_CFG_GUIDANCE_SCALE,
step=0.1
)
with gr.Row():
width = gr.Number(
label="Width",
value=DEFAULT_WIDTH,
precision=0
)
height = gr.Number(
label="Height",
value=DEFAULT_HEIGHT,
precision=0
)
with gr.Column():
output_video = gr.Video(label="Generated Video", autoplay=True)
final_prompt_output = gr.Textbox(
label="Final Prompt Used",
lines=5,
info="This is the prompt that was used for generation (enhanced if enabled)"
)
# Wire up generate end frame button
generate_end_btn.click(
fn=generate_end_frame,
inputs=[start_frame, edit_prompt],
outputs=[end_frame_generated]
)
# Wire up generate video button
generate_btn.click(
fn=generate_video,
inputs=[
start_frame,
prompt,
end_frame_upload,
end_frame_generated,
strength_start,
strength_end,
duration,
enhance_prompt,
negative_prompt,
seed,
randomize_seed,
num_inference_steps,
cfg_guidance_scale,
height,
width,
],
outputs=[output_video, final_prompt_output, seed]
)
gr.Examples(
examples=[
["disaster_girl.jpg", "Starting frame is a close-up of a young girl with a mischievous smirk, a house engulfed in flames behind her with firefighters working in the background. The girl glances at the camera and says with faux innocence, 'Everyone thinks I did it, but honestly—' she steps aside and gestures downward as the camera pans down and pushes forward, '—talk to him.' The camera reveals a grumpy-faced cat walking slowly and deliberately toward the lens, the burning house and fire truck now behind it. The cat stops, stares directly into the camera with an unapologetic, stone-cold expression, and lets out a single dismissive 'meow.' End frame holds on the cat's grumpy face, flames reflecting in its eyes.", "image-127.webp"],
["wednesday.jpg", "Wednesday says 'im so not in the mood', Cookie monster enters the frame and hugs her, she rolls her eyes", "image-128.webp"],
],
inputs=[start_frame, prompt, end_frame_upload],
outputs=[output_video, final_prompt_output, seed],
fn=generate_video,
cache_examples=True,
cache_mode="lazy"
)
css = '''
.fillable{max-width: 1100px !important}
.dark .progress-text {color: white}
#general_items{margin-top: 2em}
#group_all{overflow:visible}
#group_all .styler{overflow:visible}
#group_tabs .tabitem{padding: 0}
.tab-wrapper{margin-top: 0px;z-index: 999;position: absolute;width: 100%;background-color: var(--block-background-fill);padding: 0;}
#component-9-button{width: 50%;justify-content: center}
#component-11-button{width: 50%;justify-content: center}
#or_item{text-align: center; padding-top: 1em; padding-bottom: 1em; font-size: 1.1em;margin-left: .5em;margin-right: .5em;width: calc(100% - 1em)}
#fivesec{margin-top: 5em;margin-left: .5em;margin-right: .5em;width: calc(100% - 1em)}
'''
if __name__ == "__main__":
demo.launch(theme=gr.themes.Citrus(), css=css, share=True)