Spaces:
Sleeping
Sleeping
Fix AOTI: use correct imports (aoti_capture, aoti_compile, aoti_apply) with torch.compile fallback
Browse files
app.py
CHANGED
|
@@ -390,63 +390,167 @@ def _apply_torch_compile_optimizations():
|
|
| 390 |
print(" Continuing without compilation (slower but functional)\n")
|
| 391 |
|
| 392 |
|
| 393 |
-
# Enable torch.compile optimizations (timestep_embedding fixed!)
|
| 394 |
-
# Now works with fullgraph=False for compatibility with SAG
|
| 395 |
-
# FreeU now runs FFT on GPU to enable CUDAGraphs
|
| 396 |
-
# Skip on MPS (MacBooks) - torch.compile with MPS can cause issues
|
| 397 |
-
if not torch.backends.mps.is_available():
|
| 398 |
-
_apply_torch_compile_optimizations()
|
| 399 |
-
else:
|
| 400 |
-
print(
|
| 401 |
-
"ℹ️ torch.compile skipped on MPS (MacBook) - using fp32 optimizations instead"
|
| 402 |
-
)
|
| 403 |
-
|
| 404 |
-
|
| 405 |
# AOT Compilation with ZeroGPU for faster cold starts
|
| 406 |
-
# Runs once at startup to pre-compile models
|
|
|
|
| 407 |
@spaces.GPU(duration=1500) # Maximum allowed during startup
|
| 408 |
def compile_models_with_aoti():
|
| 409 |
"""
|
| 410 |
Pre-compile diffusion models using AOT compilation.
|
| 411 |
-
|
|
|
|
| 412 |
"""
|
| 413 |
-
print("\n🔧 Starting
|
| 414 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 415 |
|
| 416 |
try:
|
| 417 |
-
from spaces import
|
|
|
|
|
|
|
|
|
|
| 418 |
|
| 419 |
# ============================================================
|
| 420 |
-
# 1. Compile Standard Pipeline
|
| 421 |
# ============================================================
|
| 422 |
-
print("📦 [1/2]
|
| 423 |
standard_model = get_value_at_index(checkpointloadersimple_4, 0)
|
| 424 |
|
| 425 |
-
#
|
| 426 |
-
|
| 427 |
-
|
| 428 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 429 |
|
| 430 |
# ============================================================
|
| 431 |
-
# 2. Compile Artistic Pipeline
|
| 432 |
# ============================================================
|
| 433 |
-
print("📦 [2/2]
|
| 434 |
artistic_model = get_value_at_index(checkpointloadersimple_artistic, 0)
|
| 435 |
|
| 436 |
-
|
| 437 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 438 |
|
| 439 |
-
print("\n✅ AOT compilation
|
| 440 |
return True
|
| 441 |
|
| 442 |
-
except ImportError:
|
| 443 |
-
|
| 444 |
-
print("
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
|
| 449 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 450 |
|
| 451 |
|
| 452 |
@spaces.GPU(duration=60) # Reduced from 720s - AOTI compilation speeds up inference
|
|
|
|
| 390 |
print(" Continuing without compilation (slower but functional)\n")
|
| 391 |
|
| 392 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 393 |
# AOT Compilation with ZeroGPU for faster cold starts
|
| 394 |
+
# Runs once at startup to pre-compile models
|
| 395 |
+
# Falls back to torch.compile with warmup inference if AOTI unavailable
|
| 396 |
@spaces.GPU(duration=1500) # Maximum allowed during startup
|
| 397 |
def compile_models_with_aoti():
|
| 398 |
"""
|
| 399 |
Pre-compile diffusion models using AOT compilation.
|
| 400 |
+
If AOTI fails, falls back to torch.compile with warmup inference.
|
| 401 |
+
Uses the full 1500s GPU allocation to ensure models are compiled.
|
| 402 |
"""
|
| 403 |
+
print("\n🔧 Starting model compilation warmup...")
|
| 404 |
+
|
| 405 |
+
# Test parameters for warmup inference
|
| 406 |
+
TEST_PROMPT = "a beautiful landscape with mountains"
|
| 407 |
+
TEST_TEXT = "test.com"
|
| 408 |
+
TEST_SEED = 12345
|
| 409 |
|
| 410 |
try:
|
| 411 |
+
from spaces import aoti_capture, aoti_compile, aoti_apply
|
| 412 |
+
import torch.export
|
| 413 |
+
|
| 414 |
+
print(" Attempting AOT compilation...\n")
|
| 415 |
|
| 416 |
# ============================================================
|
| 417 |
+
# 1. Compile Standard Pipeline @ 512px
|
| 418 |
# ============================================================
|
| 419 |
+
print("📦 [1/2] AOT compiling standard pipeline (512px)...")
|
| 420 |
standard_model = get_value_at_index(checkpointloadersimple_4, 0)
|
| 421 |
|
| 422 |
+
# Capture example run
|
| 423 |
+
with aoti_capture(standard_model.model.diffusion_model) as call_standard:
|
| 424 |
+
list(_pipeline_standard(
|
| 425 |
+
prompt=TEST_PROMPT,
|
| 426 |
+
qr_text=TEST_TEXT,
|
| 427 |
+
input_type="URL",
|
| 428 |
+
image_size=512,
|
| 429 |
+
border_size=0,
|
| 430 |
+
error_correction="M",
|
| 431 |
+
module_size=1,
|
| 432 |
+
module_drawer="square",
|
| 433 |
+
seed=TEST_SEED,
|
| 434 |
+
enable_upscale=False,
|
| 435 |
+
controlnet_strength_first=1.5,
|
| 436 |
+
controlnet_strength_final=0.9,
|
| 437 |
+
))
|
| 438 |
+
|
| 439 |
+
# Export and compile
|
| 440 |
+
exported_standard = torch.export.export(
|
| 441 |
+
standard_model.model.diffusion_model,
|
| 442 |
+
args=call_standard.args,
|
| 443 |
+
kwargs=call_standard.kwargs,
|
| 444 |
+
)
|
| 445 |
+
compiled_standard = aoti_compile(exported_standard)
|
| 446 |
+
aoti_apply(compiled_standard, standard_model.model.diffusion_model)
|
| 447 |
+
print(" ✓ Standard pipeline compiled")
|
| 448 |
|
| 449 |
# ============================================================
|
| 450 |
+
# 2. Compile Artistic Pipeline @ 640px
|
| 451 |
# ============================================================
|
| 452 |
+
print("📦 [2/2] AOT compiling artistic pipeline (640px)...")
|
| 453 |
artistic_model = get_value_at_index(checkpointloadersimple_artistic, 0)
|
| 454 |
|
| 455 |
+
# Capture example run
|
| 456 |
+
with aoti_capture(artistic_model.model.diffusion_model) as call_artistic:
|
| 457 |
+
list(_pipeline_artistic(
|
| 458 |
+
prompt=TEST_PROMPT,
|
| 459 |
+
qr_text=TEST_TEXT,
|
| 460 |
+
input_type="URL",
|
| 461 |
+
image_size=640,
|
| 462 |
+
border_size=0,
|
| 463 |
+
error_correction="M",
|
| 464 |
+
module_size=1,
|
| 465 |
+
module_drawer="square",
|
| 466 |
+
seed=TEST_SEED,
|
| 467 |
+
enable_upscale=False,
|
| 468 |
+
controlnet_strength_first=1.5,
|
| 469 |
+
controlnet_strength_final=0.9,
|
| 470 |
+
freeu_b1=1.3,
|
| 471 |
+
freeu_b2=1.4,
|
| 472 |
+
freeu_s1=0.9,
|
| 473 |
+
freeu_s2=0.2,
|
| 474 |
+
enable_sag=True,
|
| 475 |
+
sag_scale=0.75,
|
| 476 |
+
sag_blur_sigma=2.0,
|
| 477 |
+
))
|
| 478 |
+
|
| 479 |
+
# Export and compile
|
| 480 |
+
exported_artistic = torch.export.export(
|
| 481 |
+
artistic_model.model.diffusion_model,
|
| 482 |
+
args=call_artistic.args,
|
| 483 |
+
kwargs=call_artistic.kwargs,
|
| 484 |
+
)
|
| 485 |
+
compiled_artistic = aoti_compile(exported_artistic)
|
| 486 |
+
aoti_apply(compiled_artistic, artistic_model.model.diffusion_model)
|
| 487 |
+
print(" ✓ Artistic pipeline compiled")
|
| 488 |
|
| 489 |
+
print("\n✅ AOT compilation complete! Models ready for fast inference.\n")
|
| 490 |
return True
|
| 491 |
|
| 492 |
+
except (ImportError, Exception) as e:
|
| 493 |
+
error_type = "not available" if isinstance(e, ImportError) else f"failed: {e}"
|
| 494 |
+
print(f"\n⚠️ AOT compilation {error_type}")
|
| 495 |
+
print(" Falling back to torch.compile with warmup inference...\n")
|
| 496 |
+
|
| 497 |
+
# Apply torch.compile optimizations
|
| 498 |
+
_apply_torch_compile_optimizations()
|
| 499 |
+
|
| 500 |
+
# Run warmup inference to trigger torch.compile compilation
|
| 501 |
+
print("🔥 Running warmup inference to compile models (this takes 2-3 minutes)...")
|
| 502 |
+
|
| 503 |
+
try:
|
| 504 |
+
# Warmup standard pipeline @ 512px
|
| 505 |
+
print(" [1/2] Warming up standard pipeline...")
|
| 506 |
+
list(_pipeline_standard(
|
| 507 |
+
prompt=TEST_PROMPT,
|
| 508 |
+
qr_text=TEST_TEXT,
|
| 509 |
+
input_type="URL",
|
| 510 |
+
image_size=512,
|
| 511 |
+
border_size=0,
|
| 512 |
+
error_correction="M",
|
| 513 |
+
module_size=1,
|
| 514 |
+
module_drawer="square",
|
| 515 |
+
seed=TEST_SEED,
|
| 516 |
+
enable_upscale=False,
|
| 517 |
+
controlnet_strength_first=1.5,
|
| 518 |
+
controlnet_strength_final=0.9,
|
| 519 |
+
))
|
| 520 |
+
print(" ✓ Standard pipeline compiled")
|
| 521 |
+
|
| 522 |
+
# Warmup artistic pipeline @ 640px
|
| 523 |
+
print(" [2/2] Warming up artistic pipeline...")
|
| 524 |
+
list(_pipeline_artistic(
|
| 525 |
+
prompt=TEST_PROMPT,
|
| 526 |
+
qr_text=TEST_TEXT,
|
| 527 |
+
input_type="URL",
|
| 528 |
+
image_size=640,
|
| 529 |
+
border_size=0,
|
| 530 |
+
error_correction="M",
|
| 531 |
+
module_size=1,
|
| 532 |
+
module_drawer="square",
|
| 533 |
+
seed=TEST_SEED,
|
| 534 |
+
enable_upscale=False,
|
| 535 |
+
controlnet_strength_first=1.5,
|
| 536 |
+
controlnet_strength_final=0.9,
|
| 537 |
+
freeu_b1=1.3,
|
| 538 |
+
freeu_b2=1.4,
|
| 539 |
+
freeu_s1=0.9,
|
| 540 |
+
freeu_s2=0.2,
|
| 541 |
+
enable_sag=True,
|
| 542 |
+
sag_scale=0.75,
|
| 543 |
+
sag_blur_sigma=2.0,
|
| 544 |
+
))
|
| 545 |
+
print(" ✓ Artistic pipeline compiled")
|
| 546 |
+
|
| 547 |
+
print("\n✅ torch.compile warmup complete! Models ready for fast inference.\n")
|
| 548 |
+
return True
|
| 549 |
+
|
| 550 |
+
except Exception as warmup_error:
|
| 551 |
+
print(f"\n⚠️ Warmup inference failed: {warmup_error}")
|
| 552 |
+
print(" Models will compile on first real inference (slower first run)\n")
|
| 553 |
+
return False
|
| 554 |
|
| 555 |
|
| 556 |
@spaces.GPU(duration=60) # Reduced from 720s - AOTI compilation speeds up inference
|