dagloop5 commited on
Commit
39d7936
·
verified ·
1 Parent(s): e9ea7e7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +30 -8
app.py CHANGED
@@ -119,14 +119,39 @@ GEMMA_REPO = "Lightricks/gemma-3-12b-it-qat-q4_0-unquantized"
119
  # =============================================================================
120
 
121
  class HQPipelineWithCachedLoRA(TI2VidTwoStagesHQPipeline):
122
- """TI2VidTwoStagesHQPipeline with support for cached LoRA state."""
 
 
 
 
123
 
124
- def __init__(self, *args, **kwargs):
125
- super().__init__(*args, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
  self._cached_state_stage1 = None
127
  self._cached_state_stage2 = None
128
 
129
  def apply_cached_lora_state(self, state_dict_stage1, state_dict_stage2=None):
 
130
  self._cached_state_stage1 = state_dict_stage1
131
  self._cached_state_stage2 = state_dict_stage2 if state_dict_stage2 else state_dict_stage1
132
 
@@ -147,6 +172,7 @@ class HQPipelineWithCachedLoRA(TI2VidTwoStagesHQPipeline):
147
  tiling_config: TilingConfig | None = None,
148
  enhance_prompt: bool = False,
149
  ):
 
150
  assert_resolution(height=height, width=width, is_two_stage=True)
151
 
152
  device = self.device
@@ -154,7 +180,6 @@ class HQPipelineWithCachedLoRA(TI2VidTwoStagesHQPipeline):
154
  generator = torch.Generator(device=device).manual_seed(seed)
155
  noiser = GaussianNoiser(generator=generator)
156
 
157
- # Apply cached LoRA state if available
158
  if self._cached_state_stage1 is not None:
159
  print("[LoRA] Applying cached state to stage 1 transformer...")
160
  t1 = self.stage_1_model_ledger.transformer()
@@ -354,12 +379,9 @@ print("Initializing HQ Pipeline...")
354
 
355
  pipeline = HQPipelineWithCachedLoRA(
356
  checkpoint_path=checkpoint_path,
357
- distilled_lora=[], # No distilled LoRA at init - applied via cached state
358
- distilled_lora_strength_stage_1=0.0, # Not used since loras is empty
359
- distilled_lora_strength_stage_2=0.0, # Not used since loras is empty
360
  spatial_upsampler_path=spatial_upsampler_path,
361
  gemma_root=gemma_root,
362
- loras=(), # No LoRAs at init - preloading will work
363
  quantization=QuantizationPolicy.fp8_cast(),
364
  )
365
 
 
119
  # =============================================================================
120
 
121
  class HQPipelineWithCachedLoRA(TI2VidTwoStagesHQPipeline):
122
+ """
123
+ TI2VidTwoStagesHQPipeline modified to:
124
+ 1. NOT accept or pass distilled_lora to parent init (enables preloading)
125
+ 2. Handle ALL LoRAs via cached state (distilled + 12 custom)
126
+ """
127
 
128
+ def __init__(
129
+ self,
130
+ checkpoint_path: str,
131
+ spatial_upsampler_path: str,
132
+ gemma_root: str,
133
+ loras: tuple = (),
134
+ quantization: QuantizationPolicy | None = None,
135
+ ):
136
+ # Call parent WITHOUT distilled_lora parameters
137
+ # We create minimal stage ledgers (no LoRAs) for preloading
138
+ super().__init__(
139
+ checkpoint_path=checkpoint_path,
140
+ distilled_lora=[], # Empty - satisfies signature
141
+ distilled_lora_strength_stage_1=0.0,
142
+ distilled_lora_strength_stage_2=0.0,
143
+ spatial_upsampler_path=spatial_upsampler_path,
144
+ gemma_root=gemma_root,
145
+ loras=loras,
146
+ quantization=quantization,
147
+ )
148
+
149
+ # Storage for cached LoRA states for each stage
150
  self._cached_state_stage1 = None
151
  self._cached_state_stage2 = None
152
 
153
  def apply_cached_lora_state(self, state_dict_stage1, state_dict_stage2=None):
154
+ """Apply pre-cached LoRA state to both stage transformers."""
155
  self._cached_state_stage1 = state_dict_stage1
156
  self._cached_state_stage2 = state_dict_stage2 if state_dict_stage2 else state_dict_stage1
157
 
 
172
  tiling_config: TilingConfig | None = None,
173
  enhance_prompt: bool = False,
174
  ):
175
+ # ... same as before ...
176
  assert_resolution(height=height, width=width, is_two_stage=True)
177
 
178
  device = self.device
 
180
  generator = torch.Generator(device=device).manual_seed(seed)
181
  noiser = GaussianNoiser(generator=generator)
182
 
 
183
  if self._cached_state_stage1 is not None:
184
  print("[LoRA] Applying cached state to stage 1 transformer...")
185
  t1 = self.stage_1_model_ledger.transformer()
 
379
 
380
  pipeline = HQPipelineWithCachedLoRA(
381
  checkpoint_path=checkpoint_path,
 
 
 
382
  spatial_upsampler_path=spatial_upsampler_path,
383
  gemma_root=gemma_root,
384
+ loras=[], # No LoRAs at init - preloading works
385
  quantization=QuantizationPolicy.fp8_cast(),
386
  )
387