Fix LoRA set_adapters when transformer compiled
Browse files
app.py
CHANGED
|
@@ -273,10 +273,37 @@ def attach_lora(pipeline: ZImagePipeline) -> Tuple[bool, str | None]:
|
|
| 273 |
def set_lora_scale(pipeline: ZImagePipeline, scale: float) -> None:
|
| 274 |
weight = max(float(scale), 0.0)
|
| 275 |
adapter = lora_adapter_name or "default"
|
|
|
|
| 276 |
try:
|
| 277 |
-
pipeline
|
| 278 |
-
|
| 279 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 280 |
|
| 281 |
|
| 282 |
def load_models() -> Tuple[ZImagePipeline, bool, str | None]:
|
|
@@ -498,10 +525,13 @@ def generate_image(
|
|
| 498 |
)
|
| 499 |
|
| 500 |
if lora_loaded:
|
| 501 |
-
|
| 502 |
-
|
| 503 |
-
|
| 504 |
-
|
|
|
|
|
|
|
|
|
|
| 505 |
|
| 506 |
with torch.inference_mode():
|
| 507 |
image = pipeline(
|
|
|
|
| 273 |
def set_lora_scale(pipeline: ZImagePipeline, scale: float) -> None:
|
| 274 |
weight = max(float(scale), 0.0)
|
| 275 |
adapter = lora_adapter_name or "default"
|
| 276 |
+
compiled_transformer = None
|
| 277 |
try:
|
| 278 |
+
transformer = getattr(pipeline, "transformer", None)
|
| 279 |
+
if transformer is not None and hasattr(transformer, "_orig_mod"):
|
| 280 |
+
compiled_transformer = transformer
|
| 281 |
+
pipeline.transformer = transformer._orig_mod
|
| 282 |
+
|
| 283 |
+
try:
|
| 284 |
+
pipeline.set_adapters([adapter], adapter_weights=[weight])
|
| 285 |
+
except TypeError:
|
| 286 |
+
pipeline.set_adapters([adapter], weights=[weight])
|
| 287 |
+
except ValueError as exc:
|
| 288 |
+
msg = str(exc)
|
| 289 |
+
present_match = re.search(r"present adapters:\s*(\{[^}]*\}|set\([^)]*\))", msg)
|
| 290 |
+
if present_match:
|
| 291 |
+
present_names = re.findall(r"'([^']+)'", present_match.group(1))
|
| 292 |
+
else:
|
| 293 |
+
present_names = []
|
| 294 |
+
if present_names:
|
| 295 |
+
global lora_adapter_name
|
| 296 |
+
lora_adapter_name = present_names[0]
|
| 297 |
+
adapter = lora_adapter_name
|
| 298 |
+
try:
|
| 299 |
+
pipeline.set_adapters([adapter], adapter_weights=[weight])
|
| 300 |
+
except TypeError:
|
| 301 |
+
pipeline.set_adapters([adapter], weights=[weight])
|
| 302 |
+
else:
|
| 303 |
+
raise
|
| 304 |
+
finally:
|
| 305 |
+
if compiled_transformer is not None:
|
| 306 |
+
pipeline.transformer = compiled_transformer
|
| 307 |
|
| 308 |
|
| 309 |
def load_models() -> Tuple[ZImagePipeline, bool, str | None]:
|
|
|
|
| 525 |
)
|
| 526 |
|
| 527 |
if lora_loaded:
|
| 528 |
+
try:
|
| 529 |
+
if use_lora:
|
| 530 |
+
set_lora_scale(pipeline, float(lora_scale))
|
| 531 |
+
else:
|
| 532 |
+
set_lora_scale(pipeline, 0.0)
|
| 533 |
+
except Exception as exc: # noqa: BLE001
|
| 534 |
+
print(f"LoRA scale update failed (continuing without changing LoRA state): {exc}")
|
| 535 |
|
| 536 |
with torch.inference_mode():
|
| 537 |
image = pipeline(
|