Spaces:
Paused
Paused
Update app.py
Browse files
app.py
CHANGED
|
@@ -7,15 +7,15 @@ sys.path.insert(0, str(current_dir / "packages" / "ltx-pipelines" / "src"))
|
|
| 7 |
sys.path.insert(0, str(current_dir / "packages" / "ltx-core" / "src"))
|
| 8 |
|
| 9 |
import spaces
|
|
|
|
|
|
|
| 10 |
import gradio as gr
|
| 11 |
-
from gradio_client import Client, handle_file
|
| 12 |
import numpy as np
|
| 13 |
import random
|
| 14 |
import torch
|
| 15 |
from typing import Optional
|
| 16 |
from pathlib import Path
|
| 17 |
-
from huggingface_hub import hf_hub_download
|
| 18 |
-
from gradio_client import Client
|
| 19 |
from ltx_pipelines.distilled import DistilledPipeline
|
| 20 |
from ltx_core.model.video_vae import TilingConfig
|
| 21 |
from ltx_core.loader.primitives import LoraPathStrengthAndSDOps
|
|
@@ -29,7 +29,165 @@ from ltx_pipelines.utils.constants import (
|
|
| 29 |
DEFAULT_LORA_STRENGTH,
|
| 30 |
)
|
| 31 |
|
|
|
|
| 32 |
MAX_SEED = np.iinfo(np.int32).max
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
# Default prompt from docstring example
|
| 34 |
DEFAULT_PROMPT = "An astronaut hatches from a fragile egg on the surface of the Moon, the shell cracking and peeling apart in gentle low-gravity motion. Fine lunar dust lifts and drifts outward with each movement, floating in slow arcs before settling back onto the ground. The astronaut pushes free in a deliberate, weightless motion, small fragments of the egg tumbling and spinning through the air. In the background, the deep darkness of space subtly shifts as stars glide with the camera's movement, emphasizing vast depth and scale. The camera performs a smooth, cinematic slow push-in, with natural parallax between the foreground dust, the astronaut, and the distant starfield. Ultra-realistic detail, physically accurate low-gravity motion, cinematic lighting, and a breath-taking, movie-like shot."
|
| 35 |
|
|
@@ -86,6 +244,7 @@ loras = [
|
|
| 86 |
# Initialize pipeline WITHOUT text encoder (gemma_root=None)
|
| 87 |
# Text encoding will be done by external space
|
| 88 |
pipeline = DistilledPipeline(
|
|
|
|
| 89 |
checkpoint_path=checkpoint_path,
|
| 90 |
spatial_upsampler_path=spatial_upsampler_path,
|
| 91 |
gemma_root=None, # No text encoder in this space
|
|
@@ -93,23 +252,18 @@ pipeline = DistilledPipeline(
|
|
| 93 |
fp8transformer=False,
|
| 94 |
local_files_only=False,
|
| 95 |
)
|
|
|
|
| 96 |
pipeline._video_encoder = pipeline.model_ledger.video_encoder()
|
| 97 |
pipeline._transformer = pipeline.model_ledger.transformer()
|
|
|
|
|
|
|
| 98 |
|
| 99 |
-
# Initialize text encoder client
|
| 100 |
-
print(f"Connecting to text encoder space: {TEXT_ENCODER_SPACE}")
|
| 101 |
-
try:
|
| 102 |
-
text_encoder_client = Client(TEXT_ENCODER_SPACE)
|
| 103 |
-
print("✓ Text encoder client connected!")
|
| 104 |
-
except Exception as e:
|
| 105 |
-
print(f"⚠ Warning: Could not connect to text encoder space: {e}")
|
| 106 |
-
text_encoder_client = None
|
| 107 |
|
| 108 |
print("=" * 80)
|
| 109 |
print("Pipeline fully loaded and ready!")
|
| 110 |
print("=" * 80)
|
| 111 |
|
| 112 |
-
@spaces.GPU(duration=
|
| 113 |
def generate_video(
|
| 114 |
input_image,
|
| 115 |
prompt: str,
|
|
@@ -118,9 +272,10 @@ def generate_video(
|
|
| 118 |
seed: int = 42,
|
| 119 |
randomize_seed: bool = True,
|
| 120 |
height: int = DEFAULT_1_STAGE_HEIGHT,
|
| 121 |
-
width: int = DEFAULT_1_STAGE_WIDTH,
|
| 122 |
progress=gr.Progress(track_tqdm=True)
|
| 123 |
):
|
|
|
|
| 124 |
"""Generate a video based on the given parameters."""
|
| 125 |
try:
|
| 126 |
# Randomize seed if checkbox is enabled
|
|
@@ -153,39 +308,25 @@ def generate_video(
|
|
| 153 |
# Get embeddings from text encoder space
|
| 154 |
print(f"Encoding prompt: {prompt}")
|
| 155 |
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
)
|
| 176 |
-
embedding_path = result[0] # Path to .pt file
|
| 177 |
-
print(f"Embeddings received from: {embedding_path}")
|
| 178 |
-
|
| 179 |
-
# Load embeddings
|
| 180 |
-
embeddings = torch.load(embedding_path)
|
| 181 |
-
video_context = embeddings['video_context']
|
| 182 |
-
audio_context = embeddings['audio_context']
|
| 183 |
-
print("✓ Embeddings loaded successfully")
|
| 184 |
-
except Exception as e:
|
| 185 |
-
raise RuntimeError(
|
| 186 |
-
f"Failed to get embeddings from text encoder space: {e}\n"
|
| 187 |
-
f"Please ensure {TEXT_ENCODER_SPACE} is running properly."
|
| 188 |
-
)
|
| 189 |
|
| 190 |
# Run inference - progress automatically tracks tqdm from pipeline
|
| 191 |
pipeline(
|
|
@@ -321,4 +462,4 @@ css = '''
|
|
| 321 |
.gradio-container .contain{max-width: 1200px !important; margin: 0 auto !important}
|
| 322 |
'''
|
| 323 |
if __name__ == "__main__":
|
| 324 |
-
demo.launch(theme=gr.themes.Citrus(), css=css)
|
|
|
|
| 7 |
sys.path.insert(0, str(current_dir / "packages" / "ltx-core" / "src"))
|
| 8 |
|
| 9 |
import spaces
|
| 10 |
+
import flash_attn_interface
|
| 11 |
+
import time
|
| 12 |
import gradio as gr
|
|
|
|
| 13 |
import numpy as np
|
| 14 |
import random
|
| 15 |
import torch
|
| 16 |
from typing import Optional
|
| 17 |
from pathlib import Path
|
| 18 |
+
from huggingface_hub import hf_hub_download, snapshot_download
|
|
|
|
| 19 |
from ltx_pipelines.distilled import DistilledPipeline
|
| 20 |
from ltx_core.model.video_vae import TilingConfig
|
| 21 |
from ltx_core.loader.primitives import LoraPathStrengthAndSDOps
|
|
|
|
| 29 |
DEFAULT_LORA_STRENGTH,
|
| 30 |
)
|
| 31 |
|
| 32 |
+
|
| 33 |
MAX_SEED = np.iinfo(np.int32).max
|
| 34 |
+
# Import from public LTX-2 package
|
| 35 |
+
# Install with: pip install git+https://github.com/Lightricks/LTX-2.git
|
| 36 |
+
from ltx_pipelines.utils import ModelLedger
|
| 37 |
+
from ltx_pipelines.utils.helpers import generate_enhanced_prompt
|
| 38 |
+
|
| 39 |
+
# HuggingFace Hub defaults
|
| 40 |
+
DEFAULT_REPO_ID = "Lightricks/LTX-2"
|
| 41 |
+
DEFAULT_GEMMA_REPO_ID = "unsloth/gemma-3-12b-it-qat-bnb-4bit"
|
| 42 |
+
DEFAULT_CHECKPOINT_FILENAME = "ltx-2-19b-dev.safetensors"
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def get_hub_or_local_checkpoint(repo_id: str, filename: str):
|
| 46 |
+
"""Download from HuggingFace Hub."""
|
| 47 |
+
print(f"Downloading {filename} from {repo_id}...")
|
| 48 |
+
ckpt_path = hf_hub_download(repo_id=repo_id, filename=filename)
|
| 49 |
+
print(f"Downloaded to {ckpt_path}")
|
| 50 |
+
return ckpt_path
|
| 51 |
+
|
| 52 |
+
def download_gemma_model(repo_id: str):
|
| 53 |
+
"""Download the full Gemma model directory."""
|
| 54 |
+
print(f"Downloading Gemma model from {repo_id}...")
|
| 55 |
+
local_dir = snapshot_download(repo_id=repo_id)
|
| 56 |
+
print(f"Gemma model downloaded to {local_dir}")
|
| 57 |
+
return local_dir
|
| 58 |
+
|
| 59 |
+
# Initialize model ledger and text encoder at startup (load once, keep in memory)
|
| 60 |
+
print("=" * 80)
|
| 61 |
+
print("Loading Gemma Text Encoder...")
|
| 62 |
+
print("=" * 80)
|
| 63 |
+
|
| 64 |
+
checkpoint_path = get_hub_or_local_checkpoint(DEFAULT_REPO_ID, DEFAULT_CHECKPOINT_FILENAME)
|
| 65 |
+
gemma_local_path = download_gemma_model(DEFAULT_GEMMA_REPO_ID)
|
| 66 |
+
device = "cuda"
|
| 67 |
+
|
| 68 |
+
print(f"Initializing text encoder with:")
|
| 69 |
+
print(f" checkpoint_path={checkpoint_path}")
|
| 70 |
+
print(f" gemma_root={gemma_local_path}")
|
| 71 |
+
print(f" device={device}")
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
model_ledger = ModelLedger(
|
| 75 |
+
dtype=torch.bfloat16,
|
| 76 |
+
device=device,
|
| 77 |
+
checkpoint_path=checkpoint_path,
|
| 78 |
+
gemma_root_path=DEFAULT_GEMMA_REPO_ID,
|
| 79 |
+
local_files_only=False
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
# Load text encoder once and keep it in memory
|
| 84 |
+
text_encoder = model_ledger.text_encoder()
|
| 85 |
+
|
| 86 |
+
print("=" * 80)
|
| 87 |
+
print("Text encoder loaded and ready!")
|
| 88 |
+
print("=" * 80)
|
| 89 |
+
|
| 90 |
+
def encode_text_simple(text_encoder, prompt: str):
|
| 91 |
+
"""Simple text encoding without using pipeline_utils."""
|
| 92 |
+
v_context, a_context, _ = text_encoder(prompt)
|
| 93 |
+
return v_context, a_context
|
| 94 |
+
|
| 95 |
+
@spaces.GPU()
|
| 96 |
+
def encode_prompt(
|
| 97 |
+
prompt: str,
|
| 98 |
+
enhance_prompt: bool = True,
|
| 99 |
+
input_image = None,
|
| 100 |
+
seed: int = 42,
|
| 101 |
+
negative_prompt: str = ""
|
| 102 |
+
):
|
| 103 |
+
"""
|
| 104 |
+
Encode a text prompt using Gemma text encoder.
|
| 105 |
+
Args:
|
| 106 |
+
prompt: Text prompt to encode
|
| 107 |
+
enhance_prompt: Whether to use AI to enhance the prompt
|
| 108 |
+
input_image: Optional image for image-to-video enhancement
|
| 109 |
+
seed: Random seed for prompt enhancement
|
| 110 |
+
negative_prompt: Optional negative prompt for CFG (two-stage pipeline)
|
| 111 |
+
Returns:
|
| 112 |
+
tuple: (file_path, enhanced_prompt_text, status_message)
|
| 113 |
+
"""
|
| 114 |
+
start_time = time.time()
|
| 115 |
+
|
| 116 |
+
try:
|
| 117 |
+
# Enhance prompt if requested
|
| 118 |
+
final_prompt = prompt
|
| 119 |
+
if enhance_prompt:
|
| 120 |
+
if input_image is not None:
|
| 121 |
+
# Save image temporarily
|
| 122 |
+
temp_dir = Path("temp_images")
|
| 123 |
+
temp_dir.mkdir(exist_ok=True)
|
| 124 |
+
temp_image_path = temp_dir / f"temp_{int(time.time())}.jpg"
|
| 125 |
+
if hasattr(input_image, 'save'):
|
| 126 |
+
input_image.save(temp_image_path)
|
| 127 |
+
else:
|
| 128 |
+
temp_image_path = input_image
|
| 129 |
+
|
| 130 |
+
final_prompt = generate_enhanced_prompt(
|
| 131 |
+
text_encoder=text_encoder,
|
| 132 |
+
prompt=prompt,
|
| 133 |
+
image_path=str(temp_image_path),
|
| 134 |
+
seed=seed
|
| 135 |
+
)
|
| 136 |
+
else:
|
| 137 |
+
final_prompt = generate_enhanced_prompt(
|
| 138 |
+
text_encoder=text_encoder,
|
| 139 |
+
prompt=prompt,
|
| 140 |
+
image_path=None,
|
| 141 |
+
seed=seed
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
# Encode the positive prompt using the pre-loaded text encoder
|
| 145 |
+
video_context, audio_context = encode_text_simple(text_encoder, final_prompt)
|
| 146 |
+
|
| 147 |
+
# Encode negative prompt if provided
|
| 148 |
+
video_context_negative = None
|
| 149 |
+
audio_context_negative = None
|
| 150 |
+
if negative_prompt:
|
| 151 |
+
video_context_negative, audio_context_negative = encode_text_simple(text_encoder, negative_prompt)
|
| 152 |
+
|
| 153 |
+
# Save embeddings to file
|
| 154 |
+
output_dir = Path("embeddings")
|
| 155 |
+
output_dir.mkdir(exist_ok=True)
|
| 156 |
+
output_path = output_dir / f"embedding_{int(time.time())}.pt"
|
| 157 |
+
|
| 158 |
+
# Save embeddings (with negative contexts if provided)
|
| 159 |
+
embedding_data = {
|
| 160 |
+
'video_context': video_context.cpu(),
|
| 161 |
+
'audio_context': audio_context.cpu(),
|
| 162 |
+
'prompt': final_prompt,
|
| 163 |
+
'original_prompt': prompt if enhance_prompt else final_prompt,
|
| 164 |
+
}
|
| 165 |
+
|
| 166 |
+
# Add negative contexts if they were encoded
|
| 167 |
+
if video_context_negative is not None:
|
| 168 |
+
embedding_data['video_context_negative'] = video_context_negative.cpu()
|
| 169 |
+
embedding_data['audio_context_negative'] = audio_context_negative.cpu()
|
| 170 |
+
embedding_data['negative_prompt'] = negative_prompt
|
| 171 |
+
|
| 172 |
+
torch.save(embedding_data, output_path)
|
| 173 |
+
|
| 174 |
+
# Get memory stats
|
| 175 |
+
elapsed_time = time.time() - start_time
|
| 176 |
+
if torch.cuda.is_available():
|
| 177 |
+
allocated = torch.cuda.memory_allocated() / 1024**3
|
| 178 |
+
peak = torch.cuda.max_memory_allocated() / 1024**3
|
| 179 |
+
status = f"✓ Encoded in {elapsed_time:.2f}s | VRAM: {allocated:.2f}GB allocated, {peak:.2f}GB peak"
|
| 180 |
+
else:
|
| 181 |
+
status = f"✓ Encoded in {elapsed_time:.2f}s (CPU mode)"
|
| 182 |
+
|
| 183 |
+
return str(output_path), final_prompt, status
|
| 184 |
+
|
| 185 |
+
except Exception as e:
|
| 186 |
+
import traceback
|
| 187 |
+
error_msg = f"Error: {str(e)}\n{traceback.format_exc()}"
|
| 188 |
+
print(error_msg)
|
| 189 |
+
return None, prompt, error_msg
|
| 190 |
+
|
| 191 |
# Default prompt from docstring example
|
| 192 |
DEFAULT_PROMPT = "An astronaut hatches from a fragile egg on the surface of the Moon, the shell cracking and peeling apart in gentle low-gravity motion. Fine lunar dust lifts and drifts outward with each movement, floating in slow arcs before settling back onto the ground. The astronaut pushes free in a deliberate, weightless motion, small fragments of the egg tumbling and spinning through the air. In the background, the deep darkness of space subtly shifts as stars glide with the camera's movement, emphasizing vast depth and scale. The camera performs a smooth, cinematic slow push-in, with natural parallax between the foreground dust, the astronaut, and the distant starfield. Ultra-realistic detail, physically accurate low-gravity motion, cinematic lighting, and a breath-taking, movie-like shot."
|
| 193 |
|
|
|
|
| 244 |
# Initialize pipeline WITHOUT text encoder (gemma_root=None)
|
| 245 |
# Text encoding will be done by external space
|
| 246 |
pipeline = DistilledPipeline(
|
| 247 |
+
device=torch.device("cuda"),
|
| 248 |
checkpoint_path=checkpoint_path,
|
| 249 |
spatial_upsampler_path=spatial_upsampler_path,
|
| 250 |
gemma_root=None, # No text encoder in this space
|
|
|
|
| 252 |
fp8transformer=False,
|
| 253 |
local_files_only=False,
|
| 254 |
)
|
| 255 |
+
|
| 256 |
pipeline._video_encoder = pipeline.model_ledger.video_encoder()
|
| 257 |
pipeline._transformer = pipeline.model_ledger.transformer()
|
| 258 |
+
# pipeline.device = torch.device("cuda")
|
| 259 |
+
# pipeline.model_ledger.device = torch.device("cuda")
|
| 260 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 261 |
|
| 262 |
print("=" * 80)
|
| 263 |
print("Pipeline fully loaded and ready!")
|
| 264 |
print("=" * 80)
|
| 265 |
|
| 266 |
+
@spaces.GPU(duration=80)
|
| 267 |
def generate_video(
|
| 268 |
input_image,
|
| 269 |
prompt: str,
|
|
|
|
| 272 |
seed: int = 42,
|
| 273 |
randomize_seed: bool = True,
|
| 274 |
height: int = DEFAULT_1_STAGE_HEIGHT,
|
| 275 |
+
width: int = DEFAULT_1_STAGE_WIDTH ,
|
| 276 |
progress=gr.Progress(track_tqdm=True)
|
| 277 |
):
|
| 278 |
+
|
| 279 |
"""Generate a video based on the given parameters."""
|
| 280 |
try:
|
| 281 |
# Randomize seed if checkbox is enabled
|
|
|
|
| 308 |
# Get embeddings from text encoder space
|
| 309 |
print(f"Encoding prompt: {prompt}")
|
| 310 |
|
| 311 |
+
# Prepare image for upload if it exists
|
| 312 |
+
image_input = None
|
| 313 |
+
|
| 314 |
+
|
| 315 |
+
result = encode_prompt(
|
| 316 |
+
prompt=prompt,
|
| 317 |
+
enhance_prompt=enhance_prompt,
|
| 318 |
+
input_image=input_image,
|
| 319 |
+
seed=current_seed,
|
| 320 |
+
negative_prompt="",
|
| 321 |
+
)
|
| 322 |
+
embedding_path = result[0] # Path to .pt file
|
| 323 |
+
print(f"Embeddings received from: {embedding_path}")
|
| 324 |
+
|
| 325 |
+
# Load embeddings
|
| 326 |
+
embeddings = torch.load(embedding_path)
|
| 327 |
+
video_context = embeddings['video_context']
|
| 328 |
+
audio_context = embeddings['audio_context']
|
| 329 |
+
print("✓ Embeddings loaded successfully")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 330 |
|
| 331 |
# Run inference - progress automatically tracks tqdm from pipeline
|
| 332 |
pipeline(
|
|
|
|
| 462 |
.gradio-container .contain{max-width: 1200px !important; margin: 0 auto !important}
|
| 463 |
'''
|
| 464 |
if __name__ == "__main__":
|
| 465 |
+
demo.launch(theme=gr.themes.Citrus(), css=css)
|