dagloop5 commited on
Commit
a5f8b2b
·
verified ·
1 Parent(s): 39d7936

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -15
app.py CHANGED
@@ -118,11 +118,12 @@ GEMMA_REPO = "Lightricks/gemma-3-12b-it-qat-q4_0-unquantized"
118
  # Custom HQ Pipeline with LoRA Cache Support
119
  # =============================================================================
120
 
121
- class HQPipelineWithCachedLoRA(TI2VidTwoStagesHQPipeline):
122
  """
123
- TI2VidTwoStagesHQPipeline modified to:
124
- 1. NOT accept or pass distilled_lora to parent init (enables preloading)
125
- 2. Handle ALL LoRAs via cached state (distilled + 12 custom)
 
126
  """
127
 
128
  def __init__(
@@ -130,23 +131,45 @@ class HQPipelineWithCachedLoRA(TI2VidTwoStagesHQPipeline):
130
  checkpoint_path: str,
131
  spatial_upsampler_path: str,
132
  gemma_root: str,
133
- loras: tuple = (),
134
  quantization: QuantizationPolicy | None = None,
135
  ):
136
- # Call parent WITHOUT distilled_lora parameters
137
- # We create minimal stage ledgers (no LoRAs) for preloading
138
- super().__init__(
 
 
 
 
 
 
 
 
 
139
  checkpoint_path=checkpoint_path,
140
- distilled_lora=[], # Empty - satisfies signature
141
- distilled_lora_strength_stage_1=0.0,
142
- distilled_lora_strength_stage_2=0.0,
143
  spatial_upsampler_path=spatial_upsampler_path,
144
- gemma_root=gemma_root,
145
- loras=loras,
146
  quantization=quantization,
147
  )
148
 
149
- # Storage for cached LoRA states for each stage
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
150
  self._cached_state_stage1 = None
151
  self._cached_state_stage2 = None
152
 
@@ -172,7 +195,16 @@ class HQPipelineWithCachedLoRA(TI2VidTwoStagesHQPipeline):
172
  tiling_config: TilingConfig | None = None,
173
  enhance_prompt: bool = False,
174
  ):
175
- # ... same as before ...
 
 
 
 
 
 
 
 
 
176
  assert_resolution(height=height, width=width, is_two_stage=True)
177
 
178
  device = self.device
@@ -180,6 +212,7 @@ class HQPipelineWithCachedLoRA(TI2VidTwoStagesHQPipeline):
180
  generator = torch.Generator(device=device).manual_seed(seed)
181
  noiser = GaussianNoiser(generator=generator)
182
 
 
183
  if self._cached_state_stage1 is not None:
184
  print("[LoRA] Applying cached state to stage 1 transformer...")
185
  t1 = self.stage_1_model_ledger.transformer()
@@ -283,6 +316,8 @@ class HQPipelineWithCachedLoRA(TI2VidTwoStagesHQPipeline):
283
  cleanup_memory()
284
 
285
  transformer = self.stage_2_model_ledger.transformer()
 
 
286
  distilled_sigmas = torch.tensor(STAGE_2_DISTILLED_SIGMA_VALUES, device=device)
287
 
288
  def second_stage_denoising_loop(sigmas, video_state, audio_state, stepper):
 
118
  # Custom HQ Pipeline with LoRA Cache Support
119
  # =============================================================================
120
 
121
+ class HQPipelineWithCachedLoRA:
122
  """
123
+ Custom HQ pipeline that:
124
+ 1. Creates ModelLedgers WITHOUT LoRAs (enables preloading)
125
+ 2. Handles ALL LoRAs via cached state (distilled + 12 custom)
126
+ 3. Supports CFG/negative prompts and guidance parameters
127
  """
128
 
129
  def __init__(
 
131
  checkpoint_path: str,
132
  spatial_upsampler_path: str,
133
  gemma_root: str,
 
134
  quantization: QuantizationPolicy | None = None,
135
  ):
136
+ from ltx_pipelines.utils import ModelLedger
137
+ from ltx_core.loader import LoraPathStrengthAndSDOps, LTXV_LORA_COMFY_RENAMING_MAP
138
+ from ltx_core.types import PipelineComponents
139
+
140
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
141
+ self.dtype = torch.bfloat16
142
+
143
+ # Create ModelLedgers WITHOUT LoRAs - this allows preloading
144
+ print(" Creating stage 1 ModelLedger (no LoRAs)...")
145
+ self.stage_1_model_ledger = ModelLedger(
146
+ dtype=self.dtype,
147
+ device=self.device,
148
  checkpoint_path=checkpoint_path,
149
+ gemma_root_path=gemma_root,
 
 
150
  spatial_upsampler_path=spatial_upsampler_path,
151
+ loras=(), # NO LoRAs - preloading works
 
152
  quantization=quantization,
153
  )
154
 
155
+ print(" Creating stage 2 ModelLedger (no LoRAs)...")
156
+ self.stage_2_model_ledger = ModelLedger(
157
+ dtype=self.dtype,
158
+ device=self.device,
159
+ checkpoint_path=checkpoint_path,
160
+ gemma_root_path=gemma_root,
161
+ spatial_upsampler_path=spatial_upsampler_path,
162
+ loras=(), # NO LoRAs - preloading works
163
+ quantization=quantization,
164
+ )
165
+
166
+ # Pipeline components (similar to parent)
167
+ self.pipeline_components = PipelineComponents(
168
+ dtype=self.dtype,
169
+ device=self.device,
170
+ )
171
+
172
+ # Storage for cached LoRA states
173
  self._cached_state_stage1 = None
174
  self._cached_state_stage2 = None
175
 
 
195
  tiling_config: TilingConfig | None = None,
196
  enhance_prompt: bool = False,
197
  ):
198
+ from ltx_pipelines.utils import assert_resolution, cleanup_memory, combined_image_conditionings, encode_prompts, res2s_audio_video_denoising_loop, multi_modal_guider_denoising_func, simple_denoising_func, denoise_audio_video
199
+ from ltx_core.tools import VideoLatentShape
200
+ from ltx_core.components.noisers import GaussianNoiser
201
+ from ltx_core.components.diffusion_steps import Res2sDiffusionStep
202
+ from ltx_core.components.schedulers import LTX2Scheduler
203
+ from ltx_core.types import VideoPixelShape
204
+ from ltx_core.model.upsampler import upsample_video
205
+ from ltx_core.model.video_vae import decode_video as vae_decode_video
206
+ from ltx_core.model.audio_vae import decode_audio as vae_decode_audio
207
+
208
  assert_resolution(height=height, width=width, is_two_stage=True)
209
 
210
  device = self.device
 
212
  generator = torch.Generator(device=device).manual_seed(seed)
213
  noiser = GaussianNoiser(generator=generator)
214
 
215
+ # Apply cached LoRA state if available
216
  if self._cached_state_stage1 is not None:
217
  print("[LoRA] Applying cached state to stage 1 transformer...")
218
  t1 = self.stage_1_model_ledger.transformer()
 
316
  cleanup_memory()
317
 
318
  transformer = self.stage_2_model_ledger.transformer()
319
+
320
+ from ltx_pipelines.utils.constants import STAGE_2_DISTILLED_SIGMA_VALUES
321
  distilled_sigmas = torch.tensor(STAGE_2_DISTILLED_SIGMA_VALUES, device=device)
322
 
323
  def second_stage_denoising_loop(sigmas, video_state, audio_state, stepper):