Oysiyl commited on
Commit
f205cc5
·
1 Parent(s): b09e131

Fix AOTI: use correct imports (aoti_capture, aoti_compile, aoti_apply) with torch.compile fallback

Browse files
Files changed (1) hide show
  1. app.py +140 -36
app.py CHANGED
@@ -390,63 +390,167 @@ def _apply_torch_compile_optimizations():
390
  print(" Continuing without compilation (slower but functional)\n")
391
 
392
 
393
- # Enable torch.compile optimizations (timestep_embedding fixed!)
394
- # Now works with fullgraph=False for compatibility with SAG
395
- # FreeU now runs FFT on GPU to enable CUDAGraphs
396
- # Skip on MPS (MacBooks) - torch.compile with MPS can cause issues
397
- if not torch.backends.mps.is_available():
398
- _apply_torch_compile_optimizations()
399
- else:
400
- print(
401
- "ℹ️ torch.compile skipped on MPS (MacBook) - using fp32 optimizations instead"
402
- )
403
-
404
-
405
  # AOT Compilation with ZeroGPU for faster cold starts
406
- # Runs once at startup to pre-compile models with example inputs
 
407
  @spaces.GPU(duration=1500) # Maximum allowed during startup
408
  def compile_models_with_aoti():
409
  """
410
  Pre-compile diffusion models using AOT compilation.
411
- Uses example tensors instead of full pipeline to avoid Python ops.
 
412
  """
413
- print("\n🔧 Starting AOT compilation warmup...")
414
- print(" Compiling diffusion models with example tensors (faster approach)\n")
 
 
 
 
415
 
416
  try:
417
- from spaces import aoti_compile_model
 
 
 
418
 
419
  # ============================================================
420
- # 1. Compile Standard Pipeline Diffusion Model
421
  # ============================================================
422
- print("📦 [1/2] Compiling standard diffusion model...")
423
  standard_model = get_value_at_index(checkpointloadersimple_4, 0)
424
 
425
- # Use spaces.aoti_compile_model which handles the compilation automatically
426
- # It will capture example inputs during first inference
427
- aoti_compile_model(standard_model.model.diffusion_model)
428
- print(" ✓ Standard diffusion model marked for AOT compilation")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
429
 
430
  # ============================================================
431
- # 2. Compile Artistic Pipeline Diffusion Model
432
  # ============================================================
433
- print("📦 [2/2] Compiling artistic diffusion model...")
434
  artistic_model = get_value_at_index(checkpointloadersimple_artistic, 0)
435
 
436
- aoti_compile_model(artistic_model.model.diffusion_model)
437
- print(" ✓ Artistic diffusion model marked for AOT compilation")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
438
 
439
- print("\n✅ AOT compilation setup complete! Models will compile on first use.\n")
440
  return True
441
 
442
- except ImportError:
443
- print("\n⚠️ AOT compilation not available (spaces.aoti_compile_model missing)")
444
- print(" Continuing with torch.compile fallback (still functional)\n")
445
- return False
446
- except Exception as e:
447
- print(f"\n⚠️ AOT compilation failed: {e}")
448
- print(" Continuing with torch.compile fallback (still functional)\n")
449
- return False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
450
 
451
 
452
  @spaces.GPU(duration=60) # Reduced from 720s - AOTI compilation speeds up inference
 
390
  print(" Continuing without compilation (slower but functional)\n")
391
 
392
 
 
 
 
 
 
 
 
 
 
 
 
 
393
  # AOT Compilation with ZeroGPU for faster cold starts
394
+ # Runs once at startup to pre-compile models
395
+ # Falls back to torch.compile with warmup inference if AOTI unavailable
396
  @spaces.GPU(duration=1500) # Maximum allowed during startup
397
  def compile_models_with_aoti():
398
  """
399
  Pre-compile diffusion models using AOT compilation.
400
+ If AOTI fails, falls back to torch.compile with warmup inference.
401
+ Uses the full 1500s GPU allocation to ensure models are compiled.
402
  """
403
+ print("\n🔧 Starting model compilation warmup...")
404
+
405
+ # Test parameters for warmup inference
406
+ TEST_PROMPT = "a beautiful landscape with mountains"
407
+ TEST_TEXT = "test.com"
408
+ TEST_SEED = 12345
409
 
410
  try:
411
+ from spaces import aoti_capture, aoti_compile, aoti_apply
412
+ import torch.export
413
+
414
+ print(" Attempting AOT compilation...\n")
415
 
416
  # ============================================================
417
+ # 1. Compile Standard Pipeline @ 512px
418
  # ============================================================
419
+ print("📦 [1/2] AOT compiling standard pipeline (512px)...")
420
  standard_model = get_value_at_index(checkpointloadersimple_4, 0)
421
 
422
+ # Capture example run
423
+ with aoti_capture(standard_model.model.diffusion_model) as call_standard:
424
+ list(_pipeline_standard(
425
+ prompt=TEST_PROMPT,
426
+ qr_text=TEST_TEXT,
427
+ input_type="URL",
428
+ image_size=512,
429
+ border_size=0,
430
+ error_correction="M",
431
+ module_size=1,
432
+ module_drawer="square",
433
+ seed=TEST_SEED,
434
+ enable_upscale=False,
435
+ controlnet_strength_first=1.5,
436
+ controlnet_strength_final=0.9,
437
+ ))
438
+
439
+ # Export and compile
440
+ exported_standard = torch.export.export(
441
+ standard_model.model.diffusion_model,
442
+ args=call_standard.args,
443
+ kwargs=call_standard.kwargs,
444
+ )
445
+ compiled_standard = aoti_compile(exported_standard)
446
+ aoti_apply(compiled_standard, standard_model.model.diffusion_model)
447
+ print(" ✓ Standard pipeline compiled")
448
 
449
  # ============================================================
450
+ # 2. Compile Artistic Pipeline @ 640px
451
  # ============================================================
452
+ print("📦 [2/2] AOT compiling artistic pipeline (640px)...")
453
  artistic_model = get_value_at_index(checkpointloadersimple_artistic, 0)
454
 
455
+ # Capture example run
456
+ with aoti_capture(artistic_model.model.diffusion_model) as call_artistic:
457
+ list(_pipeline_artistic(
458
+ prompt=TEST_PROMPT,
459
+ qr_text=TEST_TEXT,
460
+ input_type="URL",
461
+ image_size=640,
462
+ border_size=0,
463
+ error_correction="M",
464
+ module_size=1,
465
+ module_drawer="square",
466
+ seed=TEST_SEED,
467
+ enable_upscale=False,
468
+ controlnet_strength_first=1.5,
469
+ controlnet_strength_final=0.9,
470
+ freeu_b1=1.3,
471
+ freeu_b2=1.4,
472
+ freeu_s1=0.9,
473
+ freeu_s2=0.2,
474
+ enable_sag=True,
475
+ sag_scale=0.75,
476
+ sag_blur_sigma=2.0,
477
+ ))
478
+
479
+ # Export and compile
480
+ exported_artistic = torch.export.export(
481
+ artistic_model.model.diffusion_model,
482
+ args=call_artistic.args,
483
+ kwargs=call_artistic.kwargs,
484
+ )
485
+ compiled_artistic = aoti_compile(exported_artistic)
486
+ aoti_apply(compiled_artistic, artistic_model.model.diffusion_model)
487
+ print(" ✓ Artistic pipeline compiled")
488
 
489
+ print("\n✅ AOT compilation complete! Models ready for fast inference.\n")
490
  return True
491
 
492
+ except (ImportError, Exception) as e:
493
+ error_type = "not available" if isinstance(e, ImportError) else f"failed: {e}"
494
+ print(f"\n⚠️ AOT compilation {error_type}")
495
+ print(" Falling back to torch.compile with warmup inference...\n")
496
+
497
+ # Apply torch.compile optimizations
498
+ _apply_torch_compile_optimizations()
499
+
500
+ # Run warmup inference to trigger torch.compile compilation
501
+ print("🔥 Running warmup inference to compile models (this takes 2-3 minutes)...")
502
+
503
+ try:
504
+ # Warmup standard pipeline @ 512px
505
+ print(" [1/2] Warming up standard pipeline...")
506
+ list(_pipeline_standard(
507
+ prompt=TEST_PROMPT,
508
+ qr_text=TEST_TEXT,
509
+ input_type="URL",
510
+ image_size=512,
511
+ border_size=0,
512
+ error_correction="M",
513
+ module_size=1,
514
+ module_drawer="square",
515
+ seed=TEST_SEED,
516
+ enable_upscale=False,
517
+ controlnet_strength_first=1.5,
518
+ controlnet_strength_final=0.9,
519
+ ))
520
+ print(" ✓ Standard pipeline compiled")
521
+
522
+ # Warmup artistic pipeline @ 640px
523
+ print(" [2/2] Warming up artistic pipeline...")
524
+ list(_pipeline_artistic(
525
+ prompt=TEST_PROMPT,
526
+ qr_text=TEST_TEXT,
527
+ input_type="URL",
528
+ image_size=640,
529
+ border_size=0,
530
+ error_correction="M",
531
+ module_size=1,
532
+ module_drawer="square",
533
+ seed=TEST_SEED,
534
+ enable_upscale=False,
535
+ controlnet_strength_first=1.5,
536
+ controlnet_strength_final=0.9,
537
+ freeu_b1=1.3,
538
+ freeu_b2=1.4,
539
+ freeu_s1=0.9,
540
+ freeu_s2=0.2,
541
+ enable_sag=True,
542
+ sag_scale=0.75,
543
+ sag_blur_sigma=2.0,
544
+ ))
545
+ print(" ✓ Artistic pipeline compiled")
546
+
547
+ print("\n✅ torch.compile warmup complete! Models ready for fast inference.\n")
548
+ return True
549
+
550
+ except Exception as warmup_error:
551
+ print(f"\n⚠️ Warmup inference failed: {warmup_error}")
552
+ print(" Models will compile on first real inference (slower first run)\n")
553
+ return False
554
 
555
 
556
  @spaces.GPU(duration=60) # Reduced from 720s - AOTI compilation speeds up inference