Spaces:
Running on Zero
Running on Zero
Use proper spaces.aoti_capture + aoti_compile + aoti_apply for AOT compilation
Browse files- Use spaces.aoti_capture to capture real forward pass inputs
- Use spaces.aoti_compile with INDUCTOR_CONFIGS (including triton.cudagraphs)
- Use spaces.aoti_apply to apply compiled model
- Separate compilation (5 min one-time) and generation (2 min)
- First run: AOT compilation takes 3-5 minutes
- Subsequent runs: only 2 minutes for generation
This follows FLUX-Kontext-fp8 pattern exactly.
app.py
CHANGED
|
@@ -101,14 +101,9 @@ print("="*50)
|
|
| 101 |
# ============================================================
|
| 102 |
# AOT Optimization Configuration (from FLUX-Kontext-fp8)
|
| 103 |
# ============================================================
|
|
|
|
| 104 |
from torchao.quantization import quantize_, Float8DynamicActivationFloat8WeightConfig
|
| 105 |
|
| 106 |
-
# Fixed input dimensions for this model
|
| 107 |
-
# width=128, height=896 (7 chars × 128)
|
| 108 |
-
# After VAE packing: h=448, w=64, img_seq_len = 448*64 = 28672
|
| 109 |
-
IMG_SEQ_LEN = 28672 # Fixed: (896/2) * (128/2)
|
| 110 |
-
TXT_SEQ_LEN = 512 # Fixed: T5 max_length
|
| 111 |
-
|
| 112 |
# Inductor configuration for optimal performance
|
| 113 |
INDUCTOR_CONFIGS = {
|
| 114 |
'conv_1x1_as_mm': True,
|
|
@@ -120,102 +115,9 @@ INDUCTOR_CONFIGS = {
|
|
| 120 |
}
|
| 121 |
|
| 122 |
|
| 123 |
-
def create_sample_inputs(device="cuda", dtype=torch.float32):
|
| 124 |
-
"""
|
| 125 |
-
Create sample inputs with fixed dimensions for torch.export.
|
| 126 |
-
"""
|
| 127 |
-
batch_size = 1
|
| 128 |
-
hidden_size = 3072 # Flux hidden size
|
| 129 |
-
vec_dim = 768 # CLIP vec dim
|
| 130 |
-
cond_txt_dim = 896 # Condition text embedding dim
|
| 131 |
-
|
| 132 |
-
sample_inputs = {
|
| 133 |
-
'img': torch.randn(batch_size, IMG_SEQ_LEN, 64, device=device, dtype=dtype), # 64 = in_channels
|
| 134 |
-
'img_ids': torch.zeros(batch_size, IMG_SEQ_LEN, 3, device=device, dtype=dtype),
|
| 135 |
-
'txt': torch.randn(batch_size, TXT_SEQ_LEN, 4096, device=device, dtype=dtype), # 4096 = T5 dim
|
| 136 |
-
'txt_ids': torch.zeros(batch_size, TXT_SEQ_LEN, 3, device=device, dtype=dtype),
|
| 137 |
-
'y': torch.randn(batch_size, vec_dim, device=device, dtype=dtype),
|
| 138 |
-
'timesteps': torch.tensor([0.5], device=device, dtype=dtype),
|
| 139 |
-
'timesteps2': torch.tensor([0.5], device=device, dtype=dtype),
|
| 140 |
-
'cond_txt_latent': torch.randn(batch_size, 5, cond_txt_dim, device=device, dtype=dtype), # 5 cond tokens
|
| 141 |
-
'guidance': torch.tensor([3.5], device=device, dtype=dtype),
|
| 142 |
-
}
|
| 143 |
-
return sample_inputs
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
def apply_aot_optimization(model, device="cuda"):
|
| 147 |
-
"""
|
| 148 |
-
Apply Float8 quantization and AOT compilation with torch.export.
|
| 149 |
-
Based on FLUX-Kontext-fp8 optimization pattern.
|
| 150 |
-
"""
|
| 151 |
-
import torch._inductor.config as inductor_config
|
| 152 |
-
|
| 153 |
-
# Apply inductor configurations
|
| 154 |
-
for key, value in INDUCTOR_CONFIGS.items():
|
| 155 |
-
if hasattr(inductor_config, key):
|
| 156 |
-
setattr(inductor_config, key, value)
|
| 157 |
-
|
| 158 |
-
print("="*50)
|
| 159 |
-
print("Starting AOT optimization with fixed input shapes...")
|
| 160 |
-
print(f" img_seq_len: {IMG_SEQ_LEN}")
|
| 161 |
-
print(f" txt_seq_len: {TXT_SEQ_LEN}")
|
| 162 |
-
print("="*50)
|
| 163 |
-
|
| 164 |
-
# Step 1: Apply Float8 quantization
|
| 165 |
-
print("Applying Float8 quantization...")
|
| 166 |
-
quantize_(model, Float8DynamicActivationFloat8WeightConfig())
|
| 167 |
-
print("✓ Float8 quantization complete!")
|
| 168 |
-
|
| 169 |
-
# Step 2: Create sample inputs for export
|
| 170 |
-
print("Creating sample inputs for torch.export...")
|
| 171 |
-
sample_inputs = create_sample_inputs(device=device, dtype=torch.float32)
|
| 172 |
-
|
| 173 |
-
# Step 3: Export model with fixed shapes (no dynamic dims needed)
|
| 174 |
-
print("Exporting model with torch.export (fixed shapes)...")
|
| 175 |
-
try:
|
| 176 |
-
exported = torch.export.export(
|
| 177 |
-
model,
|
| 178 |
-
args=(),
|
| 179 |
-
kwargs=sample_inputs,
|
| 180 |
-
strict=False, # Allow some graph breaks if needed
|
| 181 |
-
)
|
| 182 |
-
print("✓ Model exported!")
|
| 183 |
-
|
| 184 |
-
# Step 4: AOT compile with inductor
|
| 185 |
-
print("AOT compiling with torch._inductor.aot_compile...")
|
| 186 |
-
compiled_path = torch._inductor.aot_compile(
|
| 187 |
-
exported.module(),
|
| 188 |
-
args=(),
|
| 189 |
-
kwargs=sample_inputs,
|
| 190 |
-
options=INDUCTOR_CONFIGS,
|
| 191 |
-
)
|
| 192 |
-
print(f"✓ AOT compiled to: {compiled_path}")
|
| 193 |
-
|
| 194 |
-
# Step 5: Load the compiled model
|
| 195 |
-
print("Loading AOT compiled model...")
|
| 196 |
-
compiled_model = torch._export.aot_load(compiled_path, device=device)
|
| 197 |
-
print("✓ AOT model loaded!")
|
| 198 |
-
|
| 199 |
-
return compiled_model
|
| 200 |
-
|
| 201 |
-
except Exception as e:
|
| 202 |
-
print(f"AOT compilation failed: {e}")
|
| 203 |
-
print("Falling back to torch.compile (JIT)...")
|
| 204 |
-
|
| 205 |
-
# Fallback to JIT compilation
|
| 206 |
-
compiled_model = torch.compile(
|
| 207 |
-
model,
|
| 208 |
-
mode="max-autotune",
|
| 209 |
-
backend="inductor",
|
| 210 |
-
fullgraph=False,
|
| 211 |
-
)
|
| 212 |
-
print("✓ torch.compile (JIT) applied!")
|
| 213 |
-
return compiled_model
|
| 214 |
-
|
| 215 |
-
|
| 216 |
def init_generator():
|
| 217 |
-
"""Initialize the generator
|
| 218 |
-
global generator, _cached_model_dir
|
| 219 |
|
| 220 |
if generator is None:
|
| 221 |
# Enable CUDA optimizations
|
|
@@ -247,21 +149,71 @@ def init_generator():
|
|
| 247 |
author_descriptions_path='dataset/calligraphy_styles_en.json',
|
| 248 |
use_deepspeed=False,
|
| 249 |
use_4bit_quantization=False,
|
| 250 |
-
use_float8_quantization=False,
|
| 251 |
-
use_torch_compile=False,
|
| 252 |
dtype="fp32",
|
| 253 |
)
|
| 254 |
-
|
| 255 |
-
# Apply Float8 quantization + AOT compilation (fixed input shapes)
|
| 256 |
-
if not _is_optimized:
|
| 257 |
-
print("Applying Float8 + AOT optimizations to transformer...")
|
| 258 |
-
generator.model = apply_aot_optimization(generator.model, device="cuda")
|
| 259 |
-
_is_optimized = True
|
| 260 |
-
print("✓ Transformer optimized with Float8 + AOT compilation!")
|
| 261 |
|
| 262 |
return generator
|
| 263 |
|
| 264 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 265 |
def update_font_choices(author: str):
|
| 266 |
"""
|
| 267 |
Update available font choices based on selected author
|
|
@@ -283,31 +235,13 @@ def parse_font_style(font_style: str) -> str:
|
|
| 283 |
return None
|
| 284 |
|
| 285 |
|
| 286 |
-
@spaces.GPU(duration=
|
| 287 |
-
def compile_and_warmup():
|
| 288 |
-
"""
|
| 289 |
-
Compile the model with Float8 + AOT optimization (first time only).
|
| 290 |
-
"""
|
| 291 |
-
print("="*50)
|
| 292 |
-
print("First-time compilation starting...")
|
| 293 |
-
print("="*50)
|
| 294 |
-
gen = init_generator()
|
| 295 |
-
# Warmup run to trigger JIT compilation
|
| 296 |
-
print("Running warmup generation...")
|
| 297 |
-
gen.generate(text="测", font_style="楷", author=None, num_steps=1, seed=42)
|
| 298 |
-
print("="*50)
|
| 299 |
-
print("Compilation and warmup complete!")
|
| 300 |
-
print("="*50)
|
| 301 |
-
return gen
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
@spaces.GPU(duration=120) # 2 minutes for normal generation (20s + 25steps * 4s = ~120s)
|
| 305 |
def run_generation(text, font, author, num_steps, start_seed, num_images):
|
| 306 |
"""
|
| 307 |
-
Run generation
|
| 308 |
-
Duration:
|
| 309 |
"""
|
| 310 |
-
gen = init_generator()
|
| 311 |
|
| 312 |
results = []
|
| 313 |
seeds_used = []
|
|
@@ -337,18 +271,10 @@ def interactive_session(
|
|
| 337 |
progress=gr.Progress()
|
| 338 |
):
|
| 339 |
"""
|
| 340 |
-
Interactive session
|
| 341 |
-
|
| 342 |
-
Args:
|
| 343 |
-
text: Input text (1-7 characters)
|
| 344 |
-
author_dropdown: Selected author
|
| 345 |
-
font_style: Font style
|
| 346 |
-
num_steps: Inference steps
|
| 347 |
-
start_seed: Starting seed
|
| 348 |
-
num_images: Number of images to generate (each with different seed)
|
| 349 |
|
| 350 |
-
|
| 351 |
-
|
| 352 |
"""
|
| 353 |
global _is_optimized
|
| 354 |
|
|
@@ -366,14 +292,16 @@ def interactive_session(
|
|
| 366 |
# Determine author
|
| 367 |
author = author_dropdown if author_dropdown != "None (Synthetic / 合成风格)" else None
|
| 368 |
|
| 369 |
-
# Step 1: Compile
|
| 370 |
if not _is_optimized:
|
| 371 |
-
yield "⏳ 首次运行,
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
|
|
|
|
|
|
| 377 |
progress(0.1, desc="生成中...")
|
| 378 |
|
| 379 |
results, seeds_used = run_generation(
|
|
@@ -382,7 +310,7 @@ def interactive_session(
|
|
| 382 |
|
| 383 |
progress(1.0, desc="完成!")
|
| 384 |
|
| 385 |
-
# Final
|
| 386 |
if num_images > 1:
|
| 387 |
final_status = f"✅ 全部完成!共 {num_images} 张 (Seeds: {seeds_used[0]}-{seeds_used[-1]})"
|
| 388 |
else:
|
|
|
|
| 101 |
# ============================================================
|
| 102 |
# AOT Optimization Configuration (from FLUX-Kontext-fp8)
|
| 103 |
# ============================================================
|
| 104 |
+
from torch.utils._pytree import tree_map_only
|
| 105 |
from torchao.quantization import quantize_, Float8DynamicActivationFloat8WeightConfig
|
| 106 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 107 |
# Inductor configuration for optimal performance
|
| 108 |
INDUCTOR_CONFIGS = {
|
| 109 |
'conv_1x1_as_mm': True,
|
|
|
|
| 115 |
}
|
| 116 |
|
| 117 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 118 |
def init_generator():
|
| 119 |
+
"""Initialize the generator (without optimization - that's done separately)"""
|
| 120 |
+
global generator, _cached_model_dir
|
| 121 |
|
| 122 |
if generator is None:
|
| 123 |
# Enable CUDA optimizations
|
|
|
|
| 149 |
author_descriptions_path='dataset/calligraphy_styles_en.json',
|
| 150 |
use_deepspeed=False,
|
| 151 |
use_4bit_quantization=False,
|
| 152 |
+
use_float8_quantization=False,
|
| 153 |
+
use_torch_compile=False,
|
| 154 |
dtype="fp32",
|
| 155 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 156 |
|
| 157 |
return generator
|
| 158 |
|
| 159 |
|
| 160 |
+
def optimize_transformer_(gen):
|
| 161 |
+
"""
|
| 162 |
+
Apply Float8 quantization + AOT compilation using spaces.aoti_capture.
|
| 163 |
+
Based on FLUX-Kontext-fp8 pattern.
|
| 164 |
+
"""
|
| 165 |
+
model = gen.model
|
| 166 |
+
|
| 167 |
+
@spaces.GPU(duration=300) # 5 minutes for compilation
|
| 168 |
+
def compile_transformer():
|
| 169 |
+
print("="*50)
|
| 170 |
+
print("Starting AOT compilation with spaces.aoti_capture...")
|
| 171 |
+
print("="*50)
|
| 172 |
+
|
| 173 |
+
# Step 1: Capture model forward during a real generation
|
| 174 |
+
print("Capturing model forward pass...")
|
| 175 |
+
with spaces.aoti_capture(model) as call:
|
| 176 |
+
gen.generate(
|
| 177 |
+
text="测试",
|
| 178 |
+
font_style="楷",
|
| 179 |
+
author=None,
|
| 180 |
+
num_steps=1,
|
| 181 |
+
seed=42,
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
# Step 2: Build dynamic shapes (we use fixed shapes, so set to None)
|
| 185 |
+
dynamic_shapes = tree_map_only((torch.Tensor, bool), lambda t: None, call.kwargs)
|
| 186 |
+
|
| 187 |
+
# Step 3: Apply Float8 quantization
|
| 188 |
+
print("Applying Float8 quantization...")
|
| 189 |
+
quantize_(model, Float8DynamicActivationFloat8WeightConfig())
|
| 190 |
+
print("✓ Float8 quantization complete!")
|
| 191 |
+
|
| 192 |
+
# Step 4: Export model
|
| 193 |
+
print("Exporting model with torch.export...")
|
| 194 |
+
exported = torch.export.export(
|
| 195 |
+
mod=model,
|
| 196 |
+
args=call.args,
|
| 197 |
+
kwargs=call.kwargs,
|
| 198 |
+
dynamic_shapes=dynamic_shapes,
|
| 199 |
+
)
|
| 200 |
+
print("✓ Model exported!")
|
| 201 |
+
|
| 202 |
+
# Step 5: AOT compile
|
| 203 |
+
print("AOT compiling with spaces.aoti_compile...")
|
| 204 |
+
print(f" Inductor configs: {INDUCTOR_CONFIGS}")
|
| 205 |
+
return spaces.aoti_compile(exported, INDUCTOR_CONFIGS)
|
| 206 |
+
|
| 207 |
+
# Run compilation and apply the result
|
| 208 |
+
print("="*50)
|
| 209 |
+
print("Running AOT compilation (this takes 3-5 minutes)...")
|
| 210 |
+
print("="*50)
|
| 211 |
+
spaces.aoti_apply(compile_transformer(), model)
|
| 212 |
+
print("="*50)
|
| 213 |
+
print("✓ AOT compilation complete! Model is now optimized.")
|
| 214 |
+
print("="*50)
|
| 215 |
+
|
| 216 |
+
|
| 217 |
def update_font_choices(author: str):
|
| 218 |
"""
|
| 219 |
Update available font choices based on selected author
|
|
|
|
| 235 |
return None
|
| 236 |
|
| 237 |
|
| 238 |
+
@spaces.GPU(duration=120) # 2 minutes for normal generation
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 239 |
def run_generation(text, font, author, num_steps, start_seed, num_images):
|
| 240 |
"""
|
| 241 |
+
Run generation with the optimized model.
|
| 242 |
+
Duration: 20s base + ~4s per step per image.
|
| 243 |
"""
|
| 244 |
+
gen = init_generator()
|
| 245 |
|
| 246 |
results = []
|
| 247 |
seeds_used = []
|
|
|
|
| 271 |
progress=gr.Progress()
|
| 272 |
):
|
| 273 |
"""
|
| 274 |
+
Interactive session with separate compilation and generation phases.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 275 |
|
| 276 |
+
- First time: 5 min for AOT compilation (one-time)
|
| 277 |
+
- After that: 2 min for generation
|
| 278 |
"""
|
| 279 |
global _is_optimized
|
| 280 |
|
|
|
|
| 292 |
# Determine author
|
| 293 |
author = author_dropdown if author_dropdown != "None (Synthetic / 合成风格)" else None
|
| 294 |
|
| 295 |
+
# Step 1: AOT Compile if not done yet (5 min, one-time only)
|
| 296 |
if not _is_optimized:
|
| 297 |
+
yield "⏳ 首次运行,需要编译优化模型(约3-5分钟,仅此一次)...", []
|
| 298 |
+
gen = init_generator()
|
| 299 |
+
optimize_transformer_(gen) # This uses @spaces.GPU(duration=300) internally
|
| 300 |
+
_is_optimized = True
|
| 301 |
+
yield "✅ 模型编译完成!后续生成将会很快。", []
|
| 302 |
+
|
| 303 |
+
# Step 2: Run generation (2 min)
|
| 304 |
+
yield f"🎨 开始生成 {num_images} 张图片...", []
|
| 305 |
progress(0.1, desc="生成中...")
|
| 306 |
|
| 307 |
results, seeds_used = run_generation(
|
|
|
|
| 310 |
|
| 311 |
progress(1.0, desc="完成!")
|
| 312 |
|
| 313 |
+
# Final status
|
| 314 |
if num_images > 1:
|
| 315 |
final_status = f"✅ 全部完成!共 {num_images} 张 (Seeds: {seeds_used[0]}-{seeds_used[-1]})"
|
| 316 |
else:
|