Spaces:
Sleeping
Sleeping
Simplify AOT compilation: use aoti_compile_model API instead of full pipeline
Browse files
app.py
CHANGED
|
@@ -407,102 +407,42 @@ else:
|
|
| 407 |
@spaces.GPU(duration=1500) # Maximum allowed during startup
|
| 408 |
def compile_models_with_aoti():
|
| 409 |
"""
|
| 410 |
-
Pre-compile
|
| 411 |
-
|
| 412 |
"""
|
| 413 |
-
import torch.export
|
| 414 |
-
from spaces import aoti_capture, aoti_compile, aoti_apply
|
| 415 |
-
|
| 416 |
print("\n🔧 Starting AOT compilation warmup...")
|
| 417 |
-
print("
|
| 418 |
-
|
| 419 |
-
# Test parameters from generate_all_triton_kernels.py
|
| 420 |
-
TEST_PROMPT = "a beautiful landscape with mountains"
|
| 421 |
-
TEST_TEXT = "test.com"
|
| 422 |
-
TEST_SEED = 12345
|
| 423 |
|
| 424 |
try:
|
|
|
|
|
|
|
| 425 |
# ============================================================
|
| 426 |
-
# 1. Compile Standard Pipeline
|
| 427 |
# ============================================================
|
| 428 |
-
print("📦 [1/2] Compiling standard
|
| 429 |
-
|
| 430 |
standard_model = get_value_at_index(checkpointloadersimple_4, 0)
|
| 431 |
|
| 432 |
-
#
|
| 433 |
-
|
| 434 |
-
|
| 435 |
-
|
| 436 |
-
prompt=TEST_PROMPT,
|
| 437 |
-
qr_text=TEST_TEXT,
|
| 438 |
-
input_type="URL",
|
| 439 |
-
image_size=512,
|
| 440 |
-
border_size=0,
|
| 441 |
-
error_correction="M",
|
| 442 |
-
module_size=1,
|
| 443 |
-
module_drawer="square",
|
| 444 |
-
seed=TEST_SEED,
|
| 445 |
-
enable_upscale=False,
|
| 446 |
-
controlnet_strength_first=1.5,
|
| 447 |
-
controlnet_strength_final=0.9,
|
| 448 |
-
))
|
| 449 |
-
|
| 450 |
-
# Export and compile
|
| 451 |
-
exported_standard = torch.export.export(
|
| 452 |
-
standard_model.model.diffusion_model,
|
| 453 |
-
args=call.args,
|
| 454 |
-
kwargs=call.kwargs,
|
| 455 |
-
)
|
| 456 |
-
compiled_standard = aoti_compile(exported_standard)
|
| 457 |
-
aoti_apply(compiled_standard, standard_model.model.diffusion_model)
|
| 458 |
-
|
| 459 |
-
print(" ✓ Standard pipeline compiled successfully")
|
| 460 |
|
| 461 |
# ============================================================
|
| 462 |
-
# 2. Compile Artistic Pipeline
|
| 463 |
# ============================================================
|
| 464 |
-
print("📦 [2/2] Compiling artistic
|
| 465 |
-
|
| 466 |
artistic_model = get_value_at_index(checkpointloadersimple_artistic, 0)
|
| 467 |
|
| 468 |
-
|
| 469 |
-
|
| 470 |
-
list(_pipeline_artistic(
|
| 471 |
-
prompt=TEST_PROMPT,
|
| 472 |
-
qr_text=TEST_TEXT,
|
| 473 |
-
input_type="URL",
|
| 474 |
-
image_size=640,
|
| 475 |
-
border_size=0,
|
| 476 |
-
error_correction="M",
|
| 477 |
-
module_size=1,
|
| 478 |
-
module_drawer="square",
|
| 479 |
-
seed=TEST_SEED,
|
| 480 |
-
enable_upscale=False,
|
| 481 |
-
controlnet_strength_first=1.5,
|
| 482 |
-
controlnet_strength_final=0.9,
|
| 483 |
-
freeu_b1=1.3,
|
| 484 |
-
freeu_b2=1.4,
|
| 485 |
-
freeu_s1=0.9,
|
| 486 |
-
freeu_s2=0.2,
|
| 487 |
-
enable_sag=True,
|
| 488 |
-
sag_scale=0.75,
|
| 489 |
-
sag_blur_sigma=2.0,
|
| 490 |
-
))
|
| 491 |
-
|
| 492 |
-
# Export and compile
|
| 493 |
-
exported_artistic = torch.export.export(
|
| 494 |
-
artistic_model.model.diffusion_model,
|
| 495 |
-
args=call.args,
|
| 496 |
-
kwargs=call.kwargs,
|
| 497 |
-
)
|
| 498 |
-
compiled_artistic = aoti_compile(exported_artistic)
|
| 499 |
-
aoti_apply(compiled_artistic, artistic_model.model.diffusion_model)
|
| 500 |
-
|
| 501 |
-
print(" ✓ Artistic pipeline compiled successfully")
|
| 502 |
-
print("\n✅ AOT compilation complete! Models ready for fast inference.\n")
|
| 503 |
|
|
|
|
| 504 |
return True
|
| 505 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 506 |
except Exception as e:
|
| 507 |
print(f"\n⚠️ AOT compilation failed: {e}")
|
| 508 |
print(" Continuing with torch.compile fallback (still functional)\n")
|
|
|
|
| 407 |
@spaces.GPU(duration=1500) # Maximum allowed during startup
|
| 408 |
def compile_models_with_aoti():
|
| 409 |
"""
|
| 410 |
+
Pre-compile diffusion models using AOT compilation.
|
| 411 |
+
Uses example tensors instead of full pipeline to avoid Python ops.
|
| 412 |
"""
|
|
|
|
|
|
|
|
|
|
| 413 |
print("\n🔧 Starting AOT compilation warmup...")
|
| 414 |
+
print(" Compiling diffusion models with example tensors (faster approach)\n")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 415 |
|
| 416 |
try:
|
| 417 |
+
from spaces import aoti_compile_model
|
| 418 |
+
|
| 419 |
# ============================================================
|
| 420 |
+
# 1. Compile Standard Pipeline Diffusion Model
|
| 421 |
# ============================================================
|
| 422 |
+
print("📦 [1/2] Compiling standard diffusion model...")
|
|
|
|
| 423 |
standard_model = get_value_at_index(checkpointloadersimple_4, 0)
|
| 424 |
|
| 425 |
+
# Use spaces.aoti_compile_model which handles the compilation automatically
|
| 426 |
+
# It will capture example inputs during first inference
|
| 427 |
+
aoti_compile_model(standard_model.model.diffusion_model)
|
| 428 |
+
print(" ✓ Standard diffusion model marked for AOT compilation")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 429 |
|
| 430 |
# ============================================================
|
| 431 |
+
# 2. Compile Artistic Pipeline Diffusion Model
|
| 432 |
# ============================================================
|
| 433 |
+
print("📦 [2/2] Compiling artistic diffusion model...")
|
|
|
|
| 434 |
artistic_model = get_value_at_index(checkpointloadersimple_artistic, 0)
|
| 435 |
|
| 436 |
+
aoti_compile_model(artistic_model.model.diffusion_model)
|
| 437 |
+
print(" ✓ Artistic diffusion model marked for AOT compilation")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 438 |
|
| 439 |
+
print("\n✅ AOT compilation setup complete! Models will compile on first use.\n")
|
| 440 |
return True
|
| 441 |
|
| 442 |
+
except ImportError:
|
| 443 |
+
print("\n⚠️ AOT compilation not available (spaces.aoti_compile_model missing)")
|
| 444 |
+
print(" Continuing with torch.compile fallback (still functional)\n")
|
| 445 |
+
return False
|
| 446 |
except Exception as e:
|
| 447 |
print(f"\n⚠️ AOT compilation failed: {e}")
|
| 448 |
print(" Continuing with torch.compile fallback (still functional)\n")
|