Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -118,11 +118,12 @@ GEMMA_REPO = "Lightricks/gemma-3-12b-it-qat-q4_0-unquantized"
|
|
| 118 |
# Custom HQ Pipeline with LoRA Cache Support
|
| 119 |
# =============================================================================
|
| 120 |
|
| 121 |
-
class HQPipelineWithCachedLoRA
|
| 122 |
"""
|
| 123 |
-
|
| 124 |
-
1.
|
| 125 |
-
2.
|
|
|
|
| 126 |
"""
|
| 127 |
|
| 128 |
def __init__(
|
|
@@ -130,23 +131,45 @@ class HQPipelineWithCachedLoRA(TI2VidTwoStagesHQPipeline):
|
|
| 130 |
checkpoint_path: str,
|
| 131 |
spatial_upsampler_path: str,
|
| 132 |
gemma_root: str,
|
| 133 |
-
loras: tuple = (),
|
| 134 |
quantization: QuantizationPolicy | None = None,
|
| 135 |
):
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 139 |
checkpoint_path=checkpoint_path,
|
| 140 |
-
|
| 141 |
-
distilled_lora_strength_stage_1=0.0,
|
| 142 |
-
distilled_lora_strength_stage_2=0.0,
|
| 143 |
spatial_upsampler_path=spatial_upsampler_path,
|
| 144 |
-
|
| 145 |
-
loras=loras,
|
| 146 |
quantization=quantization,
|
| 147 |
)
|
| 148 |
|
| 149 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
self._cached_state_stage1 = None
|
| 151 |
self._cached_state_stage2 = None
|
| 152 |
|
|
@@ -172,7 +195,16 @@ class HQPipelineWithCachedLoRA(TI2VidTwoStagesHQPipeline):
|
|
| 172 |
tiling_config: TilingConfig | None = None,
|
| 173 |
enhance_prompt: bool = False,
|
| 174 |
):
|
| 175 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 176 |
assert_resolution(height=height, width=width, is_two_stage=True)
|
| 177 |
|
| 178 |
device = self.device
|
|
@@ -180,6 +212,7 @@ class HQPipelineWithCachedLoRA(TI2VidTwoStagesHQPipeline):
|
|
| 180 |
generator = torch.Generator(device=device).manual_seed(seed)
|
| 181 |
noiser = GaussianNoiser(generator=generator)
|
| 182 |
|
|
|
|
| 183 |
if self._cached_state_stage1 is not None:
|
| 184 |
print("[LoRA] Applying cached state to stage 1 transformer...")
|
| 185 |
t1 = self.stage_1_model_ledger.transformer()
|
|
@@ -283,6 +316,8 @@ class HQPipelineWithCachedLoRA(TI2VidTwoStagesHQPipeline):
|
|
| 283 |
cleanup_memory()
|
| 284 |
|
| 285 |
transformer = self.stage_2_model_ledger.transformer()
|
|
|
|
|
|
|
| 286 |
distilled_sigmas = torch.tensor(STAGE_2_DISTILLED_SIGMA_VALUES, device=device)
|
| 287 |
|
| 288 |
def second_stage_denoising_loop(sigmas, video_state, audio_state, stepper):
|
|
|
|
| 118 |
# Custom HQ Pipeline with LoRA Cache Support
|
| 119 |
# =============================================================================
|
| 120 |
|
| 121 |
+
class HQPipelineWithCachedLoRA:
|
| 122 |
"""
|
| 123 |
+
Custom HQ pipeline that:
|
| 124 |
+
1. Creates ModelLedgers WITHOUT LoRAs (enables preloading)
|
| 125 |
+
2. Handles ALL LoRAs via cached state (distilled + 12 custom)
|
| 126 |
+
3. Supports CFG/negative prompts and guidance parameters
|
| 127 |
"""
|
| 128 |
|
| 129 |
def __init__(
|
|
|
|
| 131 |
checkpoint_path: str,
|
| 132 |
spatial_upsampler_path: str,
|
| 133 |
gemma_root: str,
|
|
|
|
| 134 |
quantization: QuantizationPolicy | None = None,
|
| 135 |
):
|
| 136 |
+
from ltx_pipelines.utils import ModelLedger
|
| 137 |
+
from ltx_core.loader import LoraPathStrengthAndSDOps, LTXV_LORA_COMFY_RENAMING_MAP
|
| 138 |
+
from ltx_core.types import PipelineComponents
|
| 139 |
+
|
| 140 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 141 |
+
self.dtype = torch.bfloat16
|
| 142 |
+
|
| 143 |
+
# Create ModelLedgers WITHOUT LoRAs - this allows preloading
|
| 144 |
+
print(" Creating stage 1 ModelLedger (no LoRAs)...")
|
| 145 |
+
self.stage_1_model_ledger = ModelLedger(
|
| 146 |
+
dtype=self.dtype,
|
| 147 |
+
device=self.device,
|
| 148 |
checkpoint_path=checkpoint_path,
|
| 149 |
+
gemma_root_path=gemma_root,
|
|
|
|
|
|
|
| 150 |
spatial_upsampler_path=spatial_upsampler_path,
|
| 151 |
+
loras=(), # NO LoRAs - preloading works
|
|
|
|
| 152 |
quantization=quantization,
|
| 153 |
)
|
| 154 |
|
| 155 |
+
print(" Creating stage 2 ModelLedger (no LoRAs)...")
|
| 156 |
+
self.stage_2_model_ledger = ModelLedger(
|
| 157 |
+
dtype=self.dtype,
|
| 158 |
+
device=self.device,
|
| 159 |
+
checkpoint_path=checkpoint_path,
|
| 160 |
+
gemma_root_path=gemma_root,
|
| 161 |
+
spatial_upsampler_path=spatial_upsampler_path,
|
| 162 |
+
loras=(), # NO LoRAs - preloading works
|
| 163 |
+
quantization=quantization,
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
# Pipeline components (similar to parent)
|
| 167 |
+
self.pipeline_components = PipelineComponents(
|
| 168 |
+
dtype=self.dtype,
|
| 169 |
+
device=self.device,
|
| 170 |
+
)
|
| 171 |
+
|
| 172 |
+
# Storage for cached LoRA states
|
| 173 |
self._cached_state_stage1 = None
|
| 174 |
self._cached_state_stage2 = None
|
| 175 |
|
|
|
|
| 195 |
tiling_config: TilingConfig | None = None,
|
| 196 |
enhance_prompt: bool = False,
|
| 197 |
):
|
| 198 |
+
from ltx_pipelines.utils import assert_resolution, cleanup_memory, combined_image_conditionings, encode_prompts, res2s_audio_video_denoising_loop, multi_modal_guider_denoising_func, simple_denoising_func, denoise_audio_video
|
| 199 |
+
from ltx_core.tools import VideoLatentShape
|
| 200 |
+
from ltx_core.components.noisers import GaussianNoiser
|
| 201 |
+
from ltx_core.components.diffusion_steps import Res2sDiffusionStep
|
| 202 |
+
from ltx_core.components.schedulers import LTX2Scheduler
|
| 203 |
+
from ltx_core.types import VideoPixelShape
|
| 204 |
+
from ltx_core.model.upsampler import upsample_video
|
| 205 |
+
from ltx_core.model.video_vae import decode_video as vae_decode_video
|
| 206 |
+
from ltx_core.model.audio_vae import decode_audio as vae_decode_audio
|
| 207 |
+
|
| 208 |
assert_resolution(height=height, width=width, is_two_stage=True)
|
| 209 |
|
| 210 |
device = self.device
|
|
|
|
| 212 |
generator = torch.Generator(device=device).manual_seed(seed)
|
| 213 |
noiser = GaussianNoiser(generator=generator)
|
| 214 |
|
| 215 |
+
# Apply cached LoRA state if available
|
| 216 |
if self._cached_state_stage1 is not None:
|
| 217 |
print("[LoRA] Applying cached state to stage 1 transformer...")
|
| 218 |
t1 = self.stage_1_model_ledger.transformer()
|
|
|
|
| 316 |
cleanup_memory()
|
| 317 |
|
| 318 |
transformer = self.stage_2_model_ledger.transformer()
|
| 319 |
+
|
| 320 |
+
from ltx_pipelines.utils.constants import STAGE_2_DISTILLED_SIGMA_VALUES
|
| 321 |
distilled_sigmas = torch.tensor(STAGE_2_DISTILLED_SIGMA_VALUES, device=device)
|
| 322 |
|
| 323 |
def second_stage_denoising_loop(sigmas, video_state, audio_state, stepper):
|