Oysiyl commited on
Commit
b09e131
·
1 Parent(s): eee3359

Simplify AOT compilation: use aoti_compile_model API instead of full pipeline

Browse files
Files changed (1) hide show
  1. app.py +20 -80
app.py CHANGED
@@ -407,102 +407,42 @@ else:
407
  @spaces.GPU(duration=1500) # Maximum allowed during startup
408
  def compile_models_with_aoti():
409
  """
410
- Pre-compile both standard and artistic pipelines using AOT compilation.
411
- This captures example runs and compiles them ahead of time.
412
  """
413
- import torch.export
414
- from spaces import aoti_capture, aoti_compile, aoti_apply
415
-
416
  print("\n🔧 Starting AOT compilation warmup...")
417
- print(" This will take ~2-3 minutes but speeds up all future generations\n")
418
-
419
- # Test parameters from generate_all_triton_kernels.py
420
- TEST_PROMPT = "a beautiful landscape with mountains"
421
- TEST_TEXT = "test.com"
422
- TEST_SEED = 12345
423
 
424
  try:
 
 
425
  # ============================================================
426
- # 1. Compile Standard Pipeline @ 512px
427
  # ============================================================
428
- print("📦 [1/2] Compiling standard pipeline (512px)...")
429
-
430
  standard_model = get_value_at_index(checkpointloadersimple_4, 0)
431
 
432
- # Capture example run with aoti_capture
433
- with aoti_capture(standard_model.model.diffusion_model) as call:
434
- # Run minimal example to capture inputs
435
- list(_pipeline_standard(
436
- prompt=TEST_PROMPT,
437
- qr_text=TEST_TEXT,
438
- input_type="URL",
439
- image_size=512,
440
- border_size=0,
441
- error_correction="M",
442
- module_size=1,
443
- module_drawer="square",
444
- seed=TEST_SEED,
445
- enable_upscale=False,
446
- controlnet_strength_first=1.5,
447
- controlnet_strength_final=0.9,
448
- ))
449
-
450
- # Export and compile
451
- exported_standard = torch.export.export(
452
- standard_model.model.diffusion_model,
453
- args=call.args,
454
- kwargs=call.kwargs,
455
- )
456
- compiled_standard = aoti_compile(exported_standard)
457
- aoti_apply(compiled_standard, standard_model.model.diffusion_model)
458
-
459
- print(" ✓ Standard pipeline compiled successfully")
460
 
461
  # ============================================================
462
- # 2. Compile Artistic Pipeline @ 640px
463
  # ============================================================
464
- print("📦 [2/2] Compiling artistic pipeline (640px)...")
465
-
466
  artistic_model = get_value_at_index(checkpointloadersimple_artistic, 0)
467
 
468
- # Capture example run
469
- with aoti_capture(artistic_model.model.diffusion_model) as call:
470
- list(_pipeline_artistic(
471
- prompt=TEST_PROMPT,
472
- qr_text=TEST_TEXT,
473
- input_type="URL",
474
- image_size=640,
475
- border_size=0,
476
- error_correction="M",
477
- module_size=1,
478
- module_drawer="square",
479
- seed=TEST_SEED,
480
- enable_upscale=False,
481
- controlnet_strength_first=1.5,
482
- controlnet_strength_final=0.9,
483
- freeu_b1=1.3,
484
- freeu_b2=1.4,
485
- freeu_s1=0.9,
486
- freeu_s2=0.2,
487
- enable_sag=True,
488
- sag_scale=0.75,
489
- sag_blur_sigma=2.0,
490
- ))
491
-
492
- # Export and compile
493
- exported_artistic = torch.export.export(
494
- artistic_model.model.diffusion_model,
495
- args=call.args,
496
- kwargs=call.kwargs,
497
- )
498
- compiled_artistic = aoti_compile(exported_artistic)
499
- aoti_apply(compiled_artistic, artistic_model.model.diffusion_model)
500
-
501
- print(" ✓ Artistic pipeline compiled successfully")
502
- print("\n✅ AOT compilation complete! Models ready for fast inference.\n")
503
 
 
504
  return True
505
 
 
 
 
 
506
  except Exception as e:
507
  print(f"\n⚠️ AOT compilation failed: {e}")
508
  print(" Continuing with torch.compile fallback (still functional)\n")
 
407
  @spaces.GPU(duration=1500) # Maximum allowed during startup
408
  def compile_models_with_aoti():
409
  """
410
+ Pre-compile diffusion models using AOT compilation.
411
+ Uses example tensors instead of full pipeline to avoid Python ops.
412
  """
 
 
 
413
  print("\n🔧 Starting AOT compilation warmup...")
414
+ print(" Compiling diffusion models with example tensors (faster approach)\n")
 
 
 
 
 
415
 
416
  try:
417
+ from spaces import aoti_compile_model
418
+
419
  # ============================================================
420
+ # 1. Compile Standard Pipeline Diffusion Model
421
  # ============================================================
422
+ print("📦 [1/2] Compiling standard diffusion model...")
 
423
  standard_model = get_value_at_index(checkpointloadersimple_4, 0)
424
 
425
+ # Use spaces.aoti_compile_model which handles the compilation automatically
426
+ # It will capture example inputs during first inference
427
+ aoti_compile_model(standard_model.model.diffusion_model)
428
+ print(" ✓ Standard diffusion model marked for AOT compilation")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
429
 
430
  # ============================================================
431
+ # 2. Compile Artistic Pipeline Diffusion Model
432
  # ============================================================
433
+ print("📦 [2/2] Compiling artistic diffusion model...")
 
434
  artistic_model = get_value_at_index(checkpointloadersimple_artistic, 0)
435
 
436
+ aoti_compile_model(artistic_model.model.diffusion_model)
437
+ print(" ✓ Artistic diffusion model marked for AOT compilation")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
438
 
439
+ print("\n✅ AOT compilation setup complete! Models will compile on first use.\n")
440
  return True
441
 
442
+ except ImportError:
443
+ print("\n⚠️ AOT compilation not available (spaces.aoti_compile_model missing)")
444
+ print(" Continuing with torch.compile fallback (still functional)\n")
445
+ return False
446
  except Exception as e:
447
  print(f"\n⚠️ AOT compilation failed: {e}")
448
  print(" Continuing with torch.compile fallback (still functional)\n")