Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -119,14 +119,39 @@ GEMMA_REPO = "Lightricks/gemma-3-12b-it-qat-q4_0-unquantized"
|
|
| 119 |
# =============================================================================
|
| 120 |
|
| 121 |
class HQPipelineWithCachedLoRA(TI2VidTwoStagesHQPipeline):
|
| 122 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
|
| 124 |
-
def __init__(
|
| 125 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 126 |
self._cached_state_stage1 = None
|
| 127 |
self._cached_state_stage2 = None
|
| 128 |
|
| 129 |
def apply_cached_lora_state(self, state_dict_stage1, state_dict_stage2=None):
|
|
|
|
| 130 |
self._cached_state_stage1 = state_dict_stage1
|
| 131 |
self._cached_state_stage2 = state_dict_stage2 if state_dict_stage2 else state_dict_stage1
|
| 132 |
|
|
@@ -147,6 +172,7 @@ class HQPipelineWithCachedLoRA(TI2VidTwoStagesHQPipeline):
|
|
| 147 |
tiling_config: TilingConfig | None = None,
|
| 148 |
enhance_prompt: bool = False,
|
| 149 |
):
|
|
|
|
| 150 |
assert_resolution(height=height, width=width, is_two_stage=True)
|
| 151 |
|
| 152 |
device = self.device
|
|
@@ -154,7 +180,6 @@ class HQPipelineWithCachedLoRA(TI2VidTwoStagesHQPipeline):
|
|
| 154 |
generator = torch.Generator(device=device).manual_seed(seed)
|
| 155 |
noiser = GaussianNoiser(generator=generator)
|
| 156 |
|
| 157 |
-
# Apply cached LoRA state if available
|
| 158 |
if self._cached_state_stage1 is not None:
|
| 159 |
print("[LoRA] Applying cached state to stage 1 transformer...")
|
| 160 |
t1 = self.stage_1_model_ledger.transformer()
|
|
@@ -354,12 +379,9 @@ print("Initializing HQ Pipeline...")
|
|
| 354 |
|
| 355 |
pipeline = HQPipelineWithCachedLoRA(
|
| 356 |
checkpoint_path=checkpoint_path,
|
| 357 |
-
distilled_lora=[], # No distilled LoRA at init - applied via cached state
|
| 358 |
-
distilled_lora_strength_stage_1=0.0, # Not used since loras is empty
|
| 359 |
-
distilled_lora_strength_stage_2=0.0, # Not used since loras is empty
|
| 360 |
spatial_upsampler_path=spatial_upsampler_path,
|
| 361 |
gemma_root=gemma_root,
|
| 362 |
-
loras=
|
| 363 |
quantization=QuantizationPolicy.fp8_cast(),
|
| 364 |
)
|
| 365 |
|
|
|
|
| 119 |
# =============================================================================
|
| 120 |
|
| 121 |
class HQPipelineWithCachedLoRA(TI2VidTwoStagesHQPipeline):
|
| 122 |
+
"""
|
| 123 |
+
TI2VidTwoStagesHQPipeline modified to:
|
| 124 |
+
1. NOT accept or pass distilled_lora to parent init (enables preloading)
|
| 125 |
+
2. Handle ALL LoRAs via cached state (distilled + 12 custom)
|
| 126 |
+
"""
|
| 127 |
|
| 128 |
+
def __init__(
|
| 129 |
+
self,
|
| 130 |
+
checkpoint_path: str,
|
| 131 |
+
spatial_upsampler_path: str,
|
| 132 |
+
gemma_root: str,
|
| 133 |
+
loras: tuple = (),
|
| 134 |
+
quantization: QuantizationPolicy | None = None,
|
| 135 |
+
):
|
| 136 |
+
# Call parent WITHOUT distilled_lora parameters
|
| 137 |
+
# We create minimal stage ledgers (no LoRAs) for preloading
|
| 138 |
+
super().__init__(
|
| 139 |
+
checkpoint_path=checkpoint_path,
|
| 140 |
+
distilled_lora=[], # Empty - satisfies signature
|
| 141 |
+
distilled_lora_strength_stage_1=0.0,
|
| 142 |
+
distilled_lora_strength_stage_2=0.0,
|
| 143 |
+
spatial_upsampler_path=spatial_upsampler_path,
|
| 144 |
+
gemma_root=gemma_root,
|
| 145 |
+
loras=loras,
|
| 146 |
+
quantization=quantization,
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
# Storage for cached LoRA states for each stage
|
| 150 |
self._cached_state_stage1 = None
|
| 151 |
self._cached_state_stage2 = None
|
| 152 |
|
| 153 |
def apply_cached_lora_state(self, state_dict_stage1, state_dict_stage2=None):
|
| 154 |
+
"""Apply pre-cached LoRA state to both stage transformers."""
|
| 155 |
self._cached_state_stage1 = state_dict_stage1
|
| 156 |
self._cached_state_stage2 = state_dict_stage2 if state_dict_stage2 else state_dict_stage1
|
| 157 |
|
|
|
|
| 172 |
tiling_config: TilingConfig | None = None,
|
| 173 |
enhance_prompt: bool = False,
|
| 174 |
):
|
| 175 |
+
# ... same as before ...
|
| 176 |
assert_resolution(height=height, width=width, is_two_stage=True)
|
| 177 |
|
| 178 |
device = self.device
|
|
|
|
| 180 |
generator = torch.Generator(device=device).manual_seed(seed)
|
| 181 |
noiser = GaussianNoiser(generator=generator)
|
| 182 |
|
|
|
|
| 183 |
if self._cached_state_stage1 is not None:
|
| 184 |
print("[LoRA] Applying cached state to stage 1 transformer...")
|
| 185 |
t1 = self.stage_1_model_ledger.transformer()
|
|
|
|
| 379 |
|
| 380 |
pipeline = HQPipelineWithCachedLoRA(
|
| 381 |
checkpoint_path=checkpoint_path,
|
|
|
|
|
|
|
|
|
|
| 382 |
spatial_upsampler_path=spatial_upsampler_path,
|
| 383 |
gemma_root=gemma_root,
|
| 384 |
+
loras=[], # No LoRAs at init - preloading works
|
| 385 |
quantization=QuantizationPolicy.fp8_cast(),
|
| 386 |
)
|
| 387 |
|