Spaces:
Running on Zero
Running on Zero
Vicente Alvarez commited on
Commit ·
b1a127d
1
Parent(s): 824f9f7
Switch to DistilledPipeline with pre-distilled sulphur_distil_bf16 checkpoint
Browse files
app.py
CHANGED
|
@@ -61,8 +61,7 @@ import gradio as gr
|
|
| 61 |
import numpy as np
|
| 62 |
from huggingface_hub import hf_hub_download, snapshot_download
|
| 63 |
|
| 64 |
-
from
|
| 65 |
-
from ltx_pipelines.ti2vid_two_stages_hq import TI2VidTwoStagesHQPipeline
|
| 66 |
from ltx_pipelines.utils.args import ImageConditioningInput
|
| 67 |
from ltx_pipelines.utils.media_io import encode_video
|
| 68 |
|
|
@@ -111,21 +110,17 @@ RESOLUTIONS = {
|
|
| 111 |
|
| 112 |
# Model repos
|
| 113 |
CHECKPOINT_REPO = "SulphurAI/Sulphur-2-base"
|
| 114 |
-
DISTILL_LORA_REPO = "SulphurAI/Sulphur-2-base"
|
| 115 |
LTX_MODEL_REPO = "Lightricks/LTX-2.3"
|
| 116 |
GEMMA_REPO = "Lightricks/gemma-3-12b-it-qat-q4_0-unquantized"
|
| 117 |
|
| 118 |
# Download model checkpoints in parallel for speed
|
| 119 |
print("=" * 80)
|
| 120 |
-
print("Downloading Element-16
|
| 121 |
print("=" * 80)
|
| 122 |
|
| 123 |
def download_checkpoint():
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
def download_lora():
|
| 127 |
-
# Skip distill LoRA for fp8 - not compatible with mxfp8mixed format
|
| 128 |
-
return None
|
| 129 |
|
| 130 |
def download_upsampler():
|
| 131 |
return hf_hub_download(repo_id=LTX_MODEL_REPO, filename="ltx-2.3-spatial-upscaler-x2-1.0.safetensors")
|
|
@@ -145,58 +140,27 @@ with ThreadPoolExecutor(max_workers=3) as executor:
|
|
| 145 |
print(f"Checkpoint: {checkpoint_path}")
|
| 146 |
print(f"Spatial upsampler: {spatial_upsampler_path}")
|
| 147 |
print(f"Gemma root: {gemma_root}")
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
checkpoint_path=checkpoint_path,
|
| 153 |
-
distilled_lora=[],
|
| 154 |
-
distilled_lora_strength_stage_1=0.0,
|
| 155 |
-
distilled_lora_strength_stage_2=0.0,
|
| 156 |
spatial_upsampler_path=spatial_upsampler_path,
|
| 157 |
gemma_root=gemma_root,
|
| 158 |
loras=(),
|
| 159 |
)
|
| 160 |
|
| 161 |
-
# Preload all models for ZeroGPU tensor packing
|
| 162 |
-
print("Preloading all
|
| 163 |
-
|
| 164 |
-
#
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
_text_encoder = stage_1_ledger.text_encoder()
|
| 174 |
-
_embeddings_processor = stage_1_ledger.gemma_embeddings_processor()
|
| 175 |
-
|
| 176 |
-
stage_1_ledger.transformer = lambda: _transformer
|
| 177 |
-
stage_1_ledger.video_encoder = lambda: _video_encoder
|
| 178 |
-
stage_1_ledger.video_decoder = lambda: _video_decoder
|
| 179 |
-
stage_1_ledger.audio_encoder = lambda: _audio_encoder
|
| 180 |
-
stage_1_ledger.audio_decoder = lambda: _audio_decoder
|
| 181 |
-
stage_1_ledger.vocoder = lambda: _vocoder
|
| 182 |
-
stage_1_ledger.spatial_upsampler = lambda: _spatial_upsampler_1
|
| 183 |
-
stage_1_ledger.text_encoder = lambda: _text_encoder
|
| 184 |
-
stage_1_ledger.gemma_embeddings_processor = lambda: _embeddings_processor
|
| 185 |
-
|
| 186 |
-
# Stage 2 models (critical - spatial upsampler is used here!)
|
| 187 |
-
print("Preloading stage 2 models...")
|
| 188 |
-
stage_2_ledger = pipeline.stage_2_model_ledger
|
| 189 |
-
_spatial_upsampler_2 = stage_2_ledger.spatial_upsampler()
|
| 190 |
-
_transformer_2 = stage_2_ledger.transformer()
|
| 191 |
-
_video_encoder_2 = stage_2_ledger.video_encoder()
|
| 192 |
-
_video_decoder_2 = stage_2_ledger.video_decoder()
|
| 193 |
-
|
| 194 |
-
stage_2_ledger.spatial_upsampler = lambda: _spatial_upsampler_2
|
| 195 |
-
stage_2_ledger.transformer = lambda: _transformer_2
|
| 196 |
-
stage_2_ledger.video_encoder = lambda: _video_encoder_2
|
| 197 |
-
stage_2_ledger.video_decoder = lambda: _video_decoder_2
|
| 198 |
-
|
| 199 |
-
print("All models preloaded (stage 1 + stage 2)!")
|
| 200 |
|
| 201 |
print("=" * 80)
|
| 202 |
print("Pipeline ready!")
|
|
@@ -244,7 +208,7 @@ def on_highres_toggle(first_image, last_image, high_res):
|
|
| 244 |
DEFAULT_NEGATIVE_PROMPT = "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走, blurry, glasses, deformed, subtitles, text, captions, worst quality, low quality, inconsistent motion, jittery, distorted"
|
| 245 |
|
| 246 |
|
| 247 |
-
@spaces.GPU(duration=
|
| 248 |
@torch.inference_mode()
|
| 249 |
def generate_video(
|
| 250 |
first_image,
|
|
@@ -291,7 +255,6 @@ def generate_video(
|
|
| 291 |
temp_last_path = Path(last_image)
|
| 292 |
images.append(ImageConditioningInput(path=str(temp_last_path), frame_idx=num_frames - 1, strength=1.0))
|
| 293 |
|
| 294 |
-
from ltx_core.components.guiders import MultiModalGuiderParams
|
| 295 |
from ltx_core.model.video_vae import TilingConfig, get_video_chunks_number
|
| 296 |
|
| 297 |
tiling_config = TilingConfig.default()
|
|
@@ -299,38 +262,16 @@ def generate_video(
|
|
| 299 |
|
| 300 |
log_memory("before pipeline call")
|
| 301 |
|
| 302 |
-
#
|
| 303 |
-
video_guider_params = MultiModalGuiderParams(
|
| 304 |
-
cfg_scale=3.0,
|
| 305 |
-
stg_scale=0.0,
|
| 306 |
-
rescale_scale=0.45,
|
| 307 |
-
modality_scale=3.0,
|
| 308 |
-
skip_step=0,
|
| 309 |
-
stg_blocks=[],
|
| 310 |
-
)
|
| 311 |
-
|
| 312 |
-
audio_guider_params = MultiModalGuiderParams(
|
| 313 |
-
cfg_scale=7.0,
|
| 314 |
-
stg_scale=0.0,
|
| 315 |
-
rescale_scale=1.0,
|
| 316 |
-
modality_scale=3.0,
|
| 317 |
-
skip_step=0,
|
| 318 |
-
stg_blocks=[],
|
| 319 |
-
)
|
| 320 |
-
|
| 321 |
-
# Run inference - returns (video_frames_iter, audio)
|
| 322 |
video_frames_iter, audio = pipeline(
|
| 323 |
prompt=prompt,
|
| 324 |
-
negative_prompt=negative_prompt,
|
| 325 |
seed=current_seed,
|
| 326 |
height=int(height),
|
| 327 |
width=int(width),
|
| 328 |
num_frames=num_frames,
|
| 329 |
frame_rate=frame_rate,
|
| 330 |
-
num_inference_steps=30, # More steps needed without distill LoRA
|
| 331 |
-
video_guider_params=video_guider_params,
|
| 332 |
-
audio_guider_params=audio_guider_params,
|
| 333 |
images=images,
|
|
|
|
| 334 |
)
|
| 335 |
|
| 336 |
# Collect video frames
|
|
|
|
| 61 |
import numpy as np
|
| 62 |
from huggingface_hub import hf_hub_download, snapshot_download
|
| 63 |
|
| 64 |
+
from ltx_pipelines.distilled import DistilledPipeline
|
|
|
|
| 65 |
from ltx_pipelines.utils.args import ImageConditioningInput
|
| 66 |
from ltx_pipelines.utils.media_io import encode_video
|
| 67 |
|
|
|
|
| 110 |
|
| 111 |
# Model repos
|
| 112 |
CHECKPOINT_REPO = "SulphurAI/Sulphur-2-base"
|
|
|
|
| 113 |
LTX_MODEL_REPO = "Lightricks/LTX-2.3"
|
| 114 |
GEMMA_REPO = "Lightricks/gemma-3-12b-it-qat-q4_0-unquantized"
|
| 115 |
|
| 116 |
# Download model checkpoints in parallel for speed
|
| 117 |
print("=" * 80)
|
| 118 |
+
print("Downloading Element-16 (pre-distilled) + Gemma (parallel)...")
|
| 119 |
print("=" * 80)
|
| 120 |
|
| 121 |
def download_checkpoint():
|
| 122 |
+
# Use pre-distilled checkpoint - no LoRA needed
|
| 123 |
+
return hf_hub_download(repo_id=CHECKPOINT_REPO, filename="sulphur_distil_bf16.safetensors")
|
|
|
|
|
|
|
|
|
|
| 124 |
|
| 125 |
def download_upsampler():
|
| 126 |
return hf_hub_download(repo_id=LTX_MODEL_REPO, filename="ltx-2.3-spatial-upscaler-x2-1.0.safetensors")
|
|
|
|
| 140 |
print(f"Checkpoint: {checkpoint_path}")
|
| 141 |
print(f"Spatial upsampler: {spatial_upsampler_path}")
|
| 142 |
print(f"Gemma root: {gemma_root}")
|
| 143 |
+
|
| 144 |
+
# Initialize pipeline with pre-distilled checkpoint (no LoRA needed)
|
| 145 |
+
pipeline = DistilledPipeline(
|
| 146 |
+
distilled_checkpoint_path=checkpoint_path,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
spatial_upsampler_path=spatial_upsampler_path,
|
| 148 |
gemma_root=gemma_root,
|
| 149 |
loras=(),
|
| 150 |
)
|
| 151 |
|
| 152 |
+
# Preload all models for ZeroGPU tensor packing
|
| 153 |
+
print("Preloading all pipeline components...")
|
| 154 |
+
|
| 155 |
+
# DistilledPipeline components are already instantiated, just access them to ensure loaded
|
| 156 |
+
_ = pipeline.prompt_encoder
|
| 157 |
+
_ = pipeline.image_conditioner
|
| 158 |
+
_ = pipeline.stage
|
| 159 |
+
_ = pipeline.upsampler
|
| 160 |
+
_ = pipeline.video_decoder
|
| 161 |
+
_ = pipeline.audio_decoder
|
| 162 |
+
|
| 163 |
+
print("All models preloaded!")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
|
| 165 |
print("=" * 80)
|
| 166 |
print("Pipeline ready!")
|
|
|
|
| 208 |
DEFAULT_NEGATIVE_PROMPT = "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走, blurry, glasses, deformed, subtitles, text, captions, worst quality, low quality, inconsistent motion, jittery, distorted"
|
| 209 |
|
| 210 |
|
| 211 |
+
@spaces.GPU(duration=90)
|
| 212 |
@torch.inference_mode()
|
| 213 |
def generate_video(
|
| 214 |
first_image,
|
|
|
|
| 255 |
temp_last_path = Path(last_image)
|
| 256 |
images.append(ImageConditioningInput(path=str(temp_last_path), frame_idx=num_frames - 1, strength=1.0))
|
| 257 |
|
|
|
|
| 258 |
from ltx_core.model.video_vae import TilingConfig, get_video_chunks_number
|
| 259 |
|
| 260 |
tiling_config = TilingConfig.default()
|
|
|
|
| 262 |
|
| 263 |
log_memory("before pipeline call")
|
| 264 |
|
| 265 |
+
# Run inference - DistilledPipeline has simpler API
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 266 |
video_frames_iter, audio = pipeline(
|
| 267 |
prompt=prompt,
|
|
|
|
| 268 |
seed=current_seed,
|
| 269 |
height=int(height),
|
| 270 |
width=int(width),
|
| 271 |
num_frames=num_frames,
|
| 272 |
frame_rate=frame_rate,
|
|
|
|
|
|
|
|
|
|
| 273 |
images=images,
|
| 274 |
+
enhance_prompt=enhance_prompt,
|
| 275 |
)
|
| 276 |
|
| 277 |
# Collect video frames
|