TSXu commited on
Commit
a1f5b88
·
1 Parent(s): afa3a21

Refactor to use Float8 + torch.compile from FLUX-Kontext-fp8

Browse files

- Apply Float8DynamicActivationFloat8WeightConfig via torchao
- Use torch.compile with max-autotune mode and inductor configs:
- conv_1x1_as_mm, coordinate_descent_tuning, max_autotune, triton.cudagraphs
- Optimizations applied in app.py init_generator() for ZeroGPU
- CLI mode in inference.py also supports --float8 and --compile flags

Files changed (3) hide show
  1. app.py +84 -44
  2. inference.py +30 -24
  3. optimization.py +162 -203
app.py CHANGED
@@ -1,7 +1,7 @@
1
  # -*- coding: utf-8 -*-
2
  """
3
  Gradio Demo for Chinese Calligraphy Generation - HuggingFace Space Version
4
- With interactive session mode to avoid model reloading
5
  """
6
 
7
  # IMPORTANT: import spaces first before any CUDA-related packages
@@ -11,6 +11,8 @@ import gradio as gr
11
  import json
12
  import csv
13
  import time
 
 
14
 
15
  # Load author and font mappings from CSV
16
  def load_author_fonts_from_csv(csv_path):
@@ -57,20 +59,22 @@ try:
57
  except:
58
  author_styles = {}
59
 
60
- # Initialize generator (will be done lazily on first generation)
61
  generator = None
 
 
62
 
63
- # Pre-download model files at startup (before user clicks)
 
 
64
  def preload_model_files():
65
  """Pre-download model files to cache at startup (no GPU needed)"""
66
- import os
67
  from huggingface_hub import snapshot_download
68
 
69
  hf_token = os.environ.get("HF_TOKEN", None)
70
  print("Pre-downloading model files to cache...")
71
 
72
  try:
73
- # Only download safetensors, embedding, and font files (not the .bin files)
74
  local_dir = snapshot_download(
75
  repo_id="TSXu/Unicalli_Pro",
76
  allow_patterns=[
@@ -79,9 +83,7 @@ def preload_model_files():
79
  "internvl_embedding/*",
80
  "*.ttf",
81
  ],
82
- ignore_patterns=[
83
- "*.bin", # Skip large .bin files
84
- ],
85
  token=hf_token
86
  )
87
  print(f"✓ Model files cached at: {local_dir}")
@@ -90,67 +92,105 @@ def preload_model_files():
90
  print(f"Warning: Could not pre-download model files: {e}")
91
  return None
92
 
93
- # Pre-download at startup
94
  print("="*50)
95
  print("Starting model pre-download...")
96
  _cached_model_dir = preload_model_files()
97
  print("="*50)
98
 
99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  def init_generator():
101
- """Initialize the generator (lazy loading)"""
102
- global generator, _cached_model_dir
103
 
104
  if generator is None:
105
- # Enable optimized SDPA attention backends for faster inference
106
- import torch
107
- import os
108
-
109
  try:
 
 
 
110
  torch.backends.cuda.enable_flash_sdp(True)
111
  torch.backends.cuda.enable_mem_efficient_sdp(True)
112
- torch.backends.cuda.enable_math_sdp(False) # Disable slow math backend
113
- print("✓ Enabled Flash Attention / Memory-Efficient SDPA backends")
114
  except Exception as e:
115
- print(f"Note: Could not configure SDPA backends: {e}")
116
-
117
- # Use pre-cached model directory
118
- if _cached_model_dir:
119
- intern_vlm_path = os.path.join(_cached_model_dir, "internvl_embedding")
120
- checkpoint_path = _cached_model_dir
121
- print(f"Using pre-cached model from: {_cached_model_dir}")
122
- else:
123
- # Fallback to HF Hub download if pre-download failed
124
- from huggingface_hub import snapshot_download
125
- hf_token = os.environ.get("HF_TOKEN", None)
126
- print("Downloading model from TSXu/Unicalli_Pro...")
127
- _cached_model_dir = snapshot_download(
128
- repo_id="TSXu/Unicalli_Pro",
129
- token=hf_token
130
- )
131
- intern_vlm_path = os.path.join(_cached_model_dir, "internvl_embedding")
132
- checkpoint_path = _cached_model_dir
133
 
 
 
 
134
  print(f"Using lightweight embedding from: {intern_vlm_path}")
135
 
136
- # Lazy import to avoid CUDA initialization at module load time
137
  from inference import CalligraphyGenerator
138
 
139
  generator = CalligraphyGenerator(
140
  model_name="flux-dev",
141
  device="cuda",
142
- offload=True, # Enable offload to save GPU memory
143
- intern_vlm_path=intern_vlm_path, # Lightweight embedding (~500MB vs ~2GB)
144
- checkpoint_path=checkpoint_path, # Use pre-cached model
145
  font_descriptions_path='dataset/chirography.json',
146
  author_descriptions_path='dataset/calligraphy_styles_en.json',
147
  use_deepspeed=False,
148
- use_4bit_quantization=False, # Full precision model
149
- use_float8_quantization=True, # Enable Float8 quantization for faster inference
150
- use_torch_compile=True, # Enable torch.compile with inductor optimizations
151
- compile_mode="reduce-overhead", # Best for inference speed
152
- dtype="fp32", # Use fp32 to avoid CUBLAS errors on ZeroGPU
153
  )
 
 
 
 
 
 
 
 
154
  return generator
155
 
156
 
 
1
  # -*- coding: utf-8 -*-
2
  """
3
  Gradio Demo for Chinese Calligraphy Generation - HuggingFace Space Version
4
+ With Float8 quantization and AOT compilation for faster inference
5
  """
6
 
7
  # IMPORTANT: import spaces first before any CUDA-related packages
 
11
  import json
12
  import csv
13
  import time
14
+ import torch
15
+ import os
16
 
17
  # Load author and font mappings from CSV
18
  def load_author_fonts_from_csv(csv_path):
 
59
  except:
60
  author_styles = {}
61
 
62
+ # Global generator instance
63
  generator = None
64
+ _cached_model_dir = None
65
+ _is_optimized = False
66
 
67
+ # ============================================================
68
+ # Pre-download model files at startup (no GPU needed)
69
+ # ============================================================
70
  def preload_model_files():
71
  """Pre-download model files to cache at startup (no GPU needed)"""
 
72
  from huggingface_hub import snapshot_download
73
 
74
  hf_token = os.environ.get("HF_TOKEN", None)
75
  print("Pre-downloading model files to cache...")
76
 
77
  try:
 
78
  local_dir = snapshot_download(
79
  repo_id="TSXu/Unicalli_Pro",
80
  allow_patterns=[
 
83
  "internvl_embedding/*",
84
  "*.ttf",
85
  ],
86
+ ignore_patterns=["*.bin"],
 
 
87
  token=hf_token
88
  )
89
  print(f"✓ Model files cached at: {local_dir}")
 
92
  print(f"Warning: Could not pre-download model files: {e}")
93
  return None
94
 
 
95
  print("="*50)
96
  print("Starting model pre-download...")
97
  _cached_model_dir = preload_model_files()
98
  print("="*50)
99
 
100
 
101
+ # ============================================================
102
+ # AOT Optimization Configuration (from FLUX-Kontext-fp8)
103
+ # ============================================================
104
+ from torch.utils._pytree import tree_map_only
105
+ from torchao.quantization import quantize_, Float8DynamicActivationFloat8WeightConfig
106
+
107
+ # Inductor configuration for optimal performance
108
+ INDUCTOR_CONFIGS = {
109
+ 'conv_1x1_as_mm': True,
110
+ 'epilogue_fusion': False,
111
+ 'coordinate_descent_tuning': True,
112
+ 'coordinate_descent_check_all_directions': True,
113
+ 'max_autotune': True,
114
+ 'triton.cudagraphs': True,
115
+ }
116
+
117
+
118
+ def apply_optimizations(model):
119
+ """
120
+ Apply Float8 quantization and torch.compile with inductor optimizations.
121
+ Based on FLUX-Kontext-fp8 optimization techniques.
122
+ """
123
+ import torch._inductor.config as inductor_config
124
+
125
+ # Apply inductor configurations
126
+ for key, value in INDUCTOR_CONFIGS.items():
127
+ if hasattr(inductor_config, key):
128
+ setattr(inductor_config, key, value)
129
+
130
+ print("="*50)
131
+ print("Applying Float8 quantization...")
132
+ quantize_(model, Float8DynamicActivationFloat8WeightConfig())
133
+ print("✓ Float8 quantization complete!")
134
+
135
+ print("Applying torch.compile with inductor optimizations...")
136
+ compiled_model = torch.compile(
137
+ model,
138
+ mode="max-autotune",
139
+ backend="inductor",
140
+ dynamic=True,
141
+ )
142
+ print("✓ torch.compile applied!")
143
+ print("="*50)
144
+
145
+ return compiled_model
146
+
147
+
148
  def init_generator():
149
+ """Initialize the generator with Float8 + torch.compile optimization"""
150
+ global generator, _cached_model_dir, _is_optimized
151
 
152
  if generator is None:
153
+ # Enable CUDA optimizations
 
 
 
154
  try:
155
+ torch.backends.cuda.matmul.allow_tf32 = True
156
+ torch.backends.cudnn.allow_tf32 = True
157
+ torch.backends.cudnn.benchmark = True
158
  torch.backends.cuda.enable_flash_sdp(True)
159
  torch.backends.cuda.enable_mem_efficient_sdp(True)
160
+ torch.backends.cuda.enable_math_sdp(False)
161
+ print("✓ CUDA optimizations enabled (TF32, cuDNN benchmark, Flash SDPA)")
162
  except Exception as e:
163
+ print(f"Note: Some CUDA optimizations failed: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
 
165
+ intern_vlm_path = os.path.join(_cached_model_dir, "internvl_embedding")
166
+ checkpoint_path = _cached_model_dir
167
+ print(f"Using pre-cached model from: {_cached_model_dir}")
168
  print(f"Using lightweight embedding from: {intern_vlm_path}")
169
 
 
170
  from inference import CalligraphyGenerator
171
 
172
  generator = CalligraphyGenerator(
173
  model_name="flux-dev",
174
  device="cuda",
175
+ offload=True,
176
+ intern_vlm_path=intern_vlm_path,
177
+ checkpoint_path=checkpoint_path,
178
  font_descriptions_path='dataset/chirography.json',
179
  author_descriptions_path='dataset/calligraphy_styles_en.json',
180
  use_deepspeed=False,
181
+ use_4bit_quantization=False,
182
+ use_float8_quantization=False, # Apply manually below
183
+ use_torch_compile=False, # Apply manually below
184
+ dtype="fp32",
 
185
  )
186
+
187
+ # Apply Float8 quantization + torch.compile
188
+ if not _is_optimized:
189
+ print("Applying optimizations to transformer...")
190
+ generator.model = apply_optimizations(generator.model)
191
+ _is_optimized = True
192
+ print("✓ Transformer optimized with Float8 + torch.compile!")
193
+
194
  return generator
195
 
196
 
inference.py CHANGED
@@ -18,7 +18,6 @@ from huggingface_hub import hf_hub_download, snapshot_download
18
  from src.flux.util import configs, load_ae, load_clip, load_t5
19
  from src.flux.model import Flux
20
  from src.flux.xflux_pipeline import XFluxSampler
21
- from optimization import optimize_model, enable_cuda_optimizations, check_optimization_support
22
 
23
 
24
  # HuggingFace Hub model IDs
@@ -188,9 +187,6 @@ class CalligraphyGenerator:
188
  self.use_torch_compile = use_torch_compile
189
  self.compile_mode = compile_mode
190
  self.forced_dtype = dtype # "fp16", "bf16", "fp32", or None for auto
191
-
192
- # Enable CUDA optimizations early
193
- enable_cuda_optimizations()
194
 
195
  # Load font and author style descriptions
196
  if os.path.exists(font_descriptions_path):
@@ -247,16 +243,9 @@ class CalligraphyGenerator:
247
  if self.use_deepspeed:
248
  self.model = self._init_deepspeed(self.model)
249
 
250
- # Apply Float8 quantization and torch.compile optimizations
251
- if not self.use_deepspeed and not self.use_4bit_quantization:
252
- if self.use_float8_quantization or self.use_torch_compile:
253
- self.model = optimize_model(
254
- self.model,
255
- device=str(self.device),
256
- use_float8=self.use_float8_quantization,
257
- use_compile=self.use_torch_compile,
258
- compile_mode=self.compile_mode
259
- )
260
 
261
  # Load VAE
262
  if self.use_deepspeed or offload:
@@ -1115,27 +1104,44 @@ if __name__ == "__main__":
1115
  parser.add_argument("--list-fonts", action="store_true", help="List available font styles")
1116
  parser.add_argument("--float8", action="store_true", help="Use Float8 quantization (torchao) for faster inference")
1117
  parser.add_argument("--compile", action="store_true", help="Use torch.compile for optimized inference")
1118
- parser.add_argument("--compile-mode", type=str, default="reduce-overhead",
1119
  choices=["reduce-overhead", "max-autotune", "default"],
1120
  help="torch.compile mode")
1121
- parser.add_argument("--check-optimizations", action="store_true", help="Check available optimization support")
1122
 
1123
  args = parser.parse_args()
1124
-
1125
- # Check optimization support if requested
1126
- if args.check_optimizations:
1127
- check_optimization_support()
1128
- exit(0)
1129
 
1130
  # Initialize generator
1131
  generator = CalligraphyGenerator(
1132
  model_name="flux-dev",
1133
  device=args.device,
1134
  checkpoint_path=args.checkpoint,
1135
- use_float8_quantization=args.float8,
1136
- use_torch_compile=args.compile,
1137
- compile_mode=args.compile_mode
1138
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1139
 
1140
  # List available options
1141
  if args.list_authors:
 
18
  from src.flux.util import configs, load_ae, load_clip, load_t5
19
  from src.flux.model import Flux
20
  from src.flux.xflux_pipeline import XFluxSampler
 
21
 
22
 
23
  # HuggingFace Hub model IDs
 
187
  self.use_torch_compile = use_torch_compile
188
  self.compile_mode = compile_mode
189
  self.forced_dtype = dtype # "fp16", "bf16", "fp32", or None for auto
 
 
 
190
 
191
  # Load font and author style descriptions
192
  if os.path.exists(font_descriptions_path):
 
243
  if self.use_deepspeed:
244
  self.model = self._init_deepspeed(self.model)
245
 
246
+ # Note: Float8 quantization and torch.compile optimizations
247
+ # are applied externally (e.g., in app.py) for better control
248
+ # over the optimization process with ZeroGPU AOT compilation.
 
 
 
 
 
 
 
249
 
250
  # Load VAE
251
  if self.use_deepspeed or offload:
 
1104
  parser.add_argument("--list-fonts", action="store_true", help="List available font styles")
1105
  parser.add_argument("--float8", action="store_true", help="Use Float8 quantization (torchao) for faster inference")
1106
  parser.add_argument("--compile", action="store_true", help="Use torch.compile for optimized inference")
1107
+ parser.add_argument("--compile-mode", type=str, default="max-autotune",
1108
  choices=["reduce-overhead", "max-autotune", "default"],
1109
  help="torch.compile mode")
 
1110
 
1111
  args = parser.parse_args()
 
 
 
 
 
1112
 
1113
  # Initialize generator
1114
  generator = CalligraphyGenerator(
1115
  model_name="flux-dev",
1116
  device=args.device,
1117
  checkpoint_path=args.checkpoint,
 
 
 
1118
  )
1119
+
1120
+ # Apply optimizations if requested (CLI mode)
1121
+ if args.float8 or args.compile:
1122
+ from torchao.quantization import quantize_, Float8DynamicActivationFloat8WeightConfig
1123
+ import torch._inductor.config as inductor_config
1124
+
1125
+ # Inductor configs from FLUX-Kontext-fp8
1126
+ inductor_config.conv_1x1_as_mm = True
1127
+ inductor_config.coordinate_descent_tuning = True
1128
+ inductor_config.coordinate_descent_check_all_directions = True
1129
+ inductor_config.max_autotune = True
1130
+
1131
+ if args.float8:
1132
+ print("Applying Float8 quantization...")
1133
+ quantize_(generator.model, Float8DynamicActivationFloat8WeightConfig())
1134
+ print("✓ Float8 quantization complete!")
1135
+
1136
+ if args.compile:
1137
+ print(f"Applying torch.compile (mode={args.compile_mode})...")
1138
+ generator.model = torch.compile(
1139
+ generator.model,
1140
+ mode=args.compile_mode,
1141
+ backend="inductor",
1142
+ dynamic=True,
1143
+ )
1144
+ print("✓ torch.compile applied!")
1145
 
1146
  # List available options
1147
  if args.list_authors:
optimization.py CHANGED
@@ -1,16 +1,34 @@
1
  """
2
- Model optimization utilities for faster inference using:
3
- - Float8 quantization via torchao
4
- - torch.compile with inductor optimizations
5
- - CUDA graph capture for reduced kernel launch overhead
6
 
7
- Inspired by FLUX-Kontext-fp8 optimization techniques.
 
 
 
8
  """
9
 
10
- from typing import Optional, Callable, Any
 
 
 
11
  import torch
12
  import torch.nn as nn
 
 
 
 
 
 
13
 
 
 
 
 
 
 
 
 
14
 
15
  # Inductor configuration for optimal performance
16
  INDUCTOR_CONFIGS = {
@@ -19,94 +37,114 @@ INDUCTOR_CONFIGS = {
19
  'coordinate_descent_tuning': True,
20
  'coordinate_descent_check_all_directions': True,
21
  'max_autotune': True,
 
22
  }
23
 
24
 
25
- def apply_float8_quantization(model: nn.Module, device: str = "cuda") -> nn.Module:
 
 
 
 
26
  """
27
- Apply Float8 dynamic activation and weight quantization using torchao.
28
 
29
- This provides significant speedup on GPUs with native FP8 support (H100, etc.)
30
- and reasonable speedup on other GPUs through reduced memory bandwidth.
 
 
 
 
31
 
32
  Args:
33
- model: PyTorch model to quantize
34
- device: Target device for the model
 
35
 
36
  Returns:
37
- Quantized model
38
  """
39
- try:
40
- from torchao.quantization import quantize_, Float8DynamicActivationFloat8WeightConfig
41
-
42
- print("Applying Float8 dynamic activation + Float8 weight quantization...")
 
 
43
 
44
- # Move model to device first if not already there
45
- if next(model.parameters()).device.type != device:
46
- model = model.to(device)
47
 
48
- # Apply float8 quantization
 
49
  quantize_(model, Float8DynamicActivationFloat8WeightConfig())
 
50
 
51
- print("Float8 quantization applied successfully!")
52
- return model
 
 
 
 
 
 
 
53
 
54
- except ImportError as e:
55
- print(f"torchao not available for Float8 quantization: {e}")
56
- print("Install with: pip install torchao")
57
- return model
58
- except Exception as e:
59
- print(f"Float8 quantization failed: {e}")
60
- print("Falling back to unquantized model")
61
- return model
 
 
 
 
 
 
62
 
63
 
64
- def apply_torch_compile(
65
- model: nn.Module,
66
- mode: str = "reduce-overhead",
67
- fullgraph: bool = False,
68
- dynamic: bool = True,
69
- backend: str = "inductor"
70
- ) -> nn.Module:
71
  """
72
- Apply torch.compile with optimized settings for inference.
 
73
 
74
  Args:
75
- model: PyTorch model to compile
76
- mode: Compilation mode - "reduce-overhead" (best for inference),
77
- "max-autotune" (slower compile, faster runtime), or "default"
78
- fullgraph: If True, requires entire forward to be capturable (faster but stricter)
79
- dynamic: If True, allows dynamic shapes (recommended for variable input sizes)
80
- backend: Compiler backend - "inductor" is recommended
81
-
82
- Returns:
83
- Compiled model
84
  """
85
- try:
86
- import torch._inductor.config as inductor_config
 
 
 
87
 
88
- # Apply inductor configurations
89
- for key, value in INDUCTOR_CONFIGS.items():
90
- if hasattr(inductor_config, key):
91
- setattr(inductor_config, key, value)
92
 
93
- print(f"Applying torch.compile with mode='{mode}', backend='{backend}'...")
 
 
94
 
95
- compiled_model = torch.compile(
96
- model,
97
- mode=mode,
98
- fullgraph=fullgraph,
99
- dynamic=dynamic,
100
- backend=backend
101
- )
102
 
103
- print("torch.compile applied successfully!")
104
- return compiled_model
 
 
 
 
105
 
106
- except Exception as e:
107
- print(f"torch.compile failed: {e}")
108
- print("Falling back to uncompiled model")
109
- return model
110
 
111
 
112
  def enable_cuda_optimizations():
@@ -136,6 +174,48 @@ def enable_cuda_optimizations():
136
  print(f"Some CUDA optimizations failed: {e}")
137
 
138
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  def optimize_model(
140
  model: nn.Module,
141
  device: str = "cuda",
@@ -144,42 +224,24 @@ def optimize_model(
144
  compile_mode: str = "reduce-overhead"
145
  ) -> nn.Module:
146
  """
147
- Apply all optimizations to the model for maximum inference speed.
148
-
149
- Optimizations applied:
150
- 1. CUDA backend optimizations (TF32, Flash SDPA, etc.)
151
- 2. Float8 quantization via torchao (if available)
152
- 3. torch.compile with inductor optimizations
153
-
154
- Args:
155
- model: PyTorch model to optimize
156
- device: Target device
157
- use_float8: Whether to apply Float8 quantization
158
- use_compile: Whether to apply torch.compile
159
- compile_mode: Mode for torch.compile
160
-
161
- Returns:
162
- Optimized model
163
  """
164
  print("=" * 50)
165
  print("Applying model optimizations...")
166
  print("=" * 50)
167
 
168
- # 1. Enable CUDA optimizations
169
  enable_cuda_optimizations()
170
 
171
- # 2. Move model to device
172
  if next(model.parameters()).device.type != device:
173
  print(f"Moving model to {device}...")
174
  model = model.to(device)
175
 
176
- # 3. Apply Float8 quantization
177
  if use_float8:
178
- model = apply_float8_quantization(model, device)
179
 
180
- # 4. Apply torch.compile
181
  if use_compile:
182
- model = apply_torch_compile(model, mode=compile_mode)
183
 
184
  print("=" * 50)
185
  print("Model optimization complete!")
@@ -188,132 +250,29 @@ def optimize_model(
188
  return model
189
 
190
 
191
- def warmup_model(
192
- model: nn.Module,
193
- warmup_fn: Callable[[], Any],
194
- num_warmup: int = 3
195
- ):
196
- """
197
- Warmup the compiled model to trigger JIT compilation.
198
-
199
- Args:
200
- model: The model (should already be compiled)
201
- warmup_fn: Function that runs a forward pass
202
- num_warmup: Number of warmup iterations
203
- """
204
- print(f"Warming up model with {num_warmup} iterations...")
205
-
206
- with torch.no_grad():
207
- for i in range(num_warmup):
208
- try:
209
- warmup_fn()
210
- print(f" Warmup {i+1}/{num_warmup} complete")
211
- except Exception as e:
212
- print(f" Warmup {i+1}/{num_warmup} failed: {e}")
213
-
214
- # Synchronize CUDA
215
- if torch.cuda.is_available():
216
- torch.cuda.synchronize()
217
-
218
- print("Model warmup complete!")
219
-
220
-
221
- class CUDAGraphWrapper(nn.Module):
222
- """
223
- Wrapper that captures and replays CUDA graphs for reduced kernel launch overhead.
224
-
225
- Note: CUDA graphs require static input shapes. Use this only if your input
226
- dimensions are fixed.
227
- """
228
-
229
- def __init__(self, model: nn.Module, warmup_fn: Callable[[], tuple]):
230
- super().__init__()
231
- self.model = model
232
- self.graph = None
233
- self.static_inputs = None
234
- self.static_outputs = None
235
- self._captured = False
236
-
237
- def capture(self, *sample_inputs):
238
- """
239
- Capture the CUDA graph with sample inputs.
240
-
241
- Args:
242
- *sample_inputs: Sample inputs with the exact shapes that will be used
243
- """
244
- if not torch.cuda.is_available():
245
- print("CUDA not available, skipping graph capture")
246
- return
247
-
248
- print("Capturing CUDA graph...")
249
-
250
- # Warmup
251
- with torch.no_grad():
252
- for _ in range(3):
253
- _ = self.model(*sample_inputs)
254
-
255
- torch.cuda.synchronize()
256
-
257
- # Capture
258
- self.graph = torch.cuda.CUDAGraph()
259
-
260
- # Create static tensors
261
- self.static_inputs = tuple(inp.clone() for inp in sample_inputs)
262
-
263
- with torch.cuda.graph(self.graph):
264
- self.static_outputs = self.model(*self.static_inputs)
265
-
266
- self._captured = True
267
- print("CUDA graph captured successfully!")
268
-
269
- def forward(self, *inputs):
270
- if not self._captured:
271
- return self.model(*inputs)
272
-
273
- # Copy inputs to static buffers
274
- for static_inp, inp in zip(self.static_inputs, inputs):
275
- static_inp.copy_(inp)
276
-
277
- # Replay graph
278
- self.graph.replay()
279
-
280
- return self.static_outputs
281
-
282
-
283
- # Utility function to check available optimizations
284
  def check_optimization_support():
285
- """
286
- Check which optimizations are available on the current system.
287
- """
288
  print("Checking optimization support...")
289
  print("-" * 40)
290
 
291
- # CUDA
292
  print(f"CUDA available: {torch.cuda.is_available()}")
293
  if torch.cuda.is_available():
294
  print(f" Device: {torch.cuda.get_device_name()}")
295
- print(f" Capability: {torch.cuda.get_device_capability()}")
296
 
297
- # torch.compile
298
- try:
299
- import torch._dynamo
300
- print(f"torch.compile available: True")
301
- except ImportError:
302
- print(f"torch.compile available: False")
303
-
304
- # torchao Float8
305
  try:
306
  from torchao.quantization import Float8DynamicActivationFloat8WeightConfig
307
- print(f"torchao Float8 available: True")
308
  except ImportError:
309
- print(f"torchao Float8 available: False")
310
 
311
- # Flash Attention
312
  try:
313
- flash_available = torch.backends.cuda.flash_sdp_enabled() if torch.cuda.is_available() else False
314
- print(f"Flash SDPA available: {flash_available}")
315
- except:
316
- print(f"Flash SDPA available: Unknown")
 
 
 
317
 
318
  print("-" * 40)
319
 
 
1
  """
2
+ Model optimization utilities using AOT compilation and Float8 quantization.
3
+ Based on FLUX-Kontext-fp8 optimization techniques.
 
 
4
 
5
+ Key optimizations:
6
+ 1. Float8 dynamic activation + Float8 weight quantization via torchao
7
+ 2. AOT (Ahead-of-Time) compilation via torch.export + spaces.aoti_compile
8
+ 3. CUDA graph capture for reduced kernel launch overhead
9
  """
10
 
11
+ from typing import Any, Callable, Optional
12
+ from typing import ParamSpec
13
+
14
+ import spaces
15
  import torch
16
  import torch.nn as nn
17
+ from torch.utils._pytree import tree_map_only
18
+ from torchao.quantization import quantize_, Float8DynamicActivationFloat8WeightConfig
19
+
20
+
21
+ P = ParamSpec('P')
22
+
23
 
24
+ # Dynamic shape specifications for the Flux transformer
25
+ # These allow variable sequence lengths during inference
26
+ TRANSFORMER_HIDDEN_DIM = torch.export.Dim('hidden', min=512, max=16384)
27
+
28
+ TRANSFORMER_DYNAMIC_SHAPES = {
29
+ 'img': {1: TRANSFORMER_HIDDEN_DIM},
30
+ 'img_ids': {1: TRANSFORMER_HIDDEN_DIM},
31
+ }
32
 
33
  # Inductor configuration for optimal performance
34
  INDUCTOR_CONFIGS = {
 
37
  'coordinate_descent_tuning': True,
38
  'coordinate_descent_check_all_directions': True,
39
  'max_autotune': True,
40
+ 'triton.cudagraphs': True,
41
  }
42
 
43
 
44
+ def optimize_flux_model_(
45
+ model: nn.Module,
46
+ sample_forward_fn: Callable[[], Any],
47
+ device: str = "cuda"
48
+ ):
49
  """
50
+ Optimize the Flux model using Float8 quantization and AOT compilation.
51
 
52
+ This function:
53
+ 1. Captures a sample forward pass to determine input shapes
54
+ 2. Applies Float8 quantization to the model
55
+ 3. Exports the model with dynamic shapes
56
+ 4. Compiles using AOT inductor with optimized configs
57
+ 5. Applies the compiled model back
58
 
59
  Args:
60
+ model: The Flux transformer model to optimize
61
+ sample_forward_fn: A function that runs a sample forward pass
62
+ device: Target device
63
 
64
  Returns:
65
+ Optimized model
66
  """
67
+
68
+ @spaces.GPU(duration=1500)
69
+ def compile_transformer():
70
+ # Step 1: Capture the forward pass to get input shapes
71
+ with spaces.aoti_capture(model) as call:
72
+ sample_forward_fn()
73
 
74
+ # Step 2: Build dynamic shapes from captured call
75
+ dynamic_shapes = tree_map_only((torch.Tensor, bool), lambda t: None, call.kwargs)
76
+ dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES
77
 
78
+ # Step 3: Apply Float8 quantization
79
+ print("Applying Float8 dynamic activation + Float8 weight quantization...")
80
  quantize_(model, Float8DynamicActivationFloat8WeightConfig())
81
+ print("Float8 quantization complete!")
82
 
83
+ # Step 4: Export the model with dynamic shapes
84
+ print("Exporting model with torch.export...")
85
+ exported = torch.export.export(
86
+ mod=model,
87
+ args=call.args,
88
+ kwargs=call.kwargs,
89
+ dynamic_shapes=dynamic_shapes,
90
+ )
91
+ print("Model exported successfully!")
92
 
93
+ # Step 5: AOT compile with inductor configs
94
+ print("AOT compiling with inductor optimizations...")
95
+ return spaces.aoti_compile(exported, INDUCTOR_CONFIGS)
96
+
97
+ # Apply the compiled model
98
+ print("=" * 50)
99
+ print("Starting AOT optimization pipeline...")
100
+ print("=" * 50)
101
+
102
+ spaces.aoti_apply(compile_transformer(), model)
103
+
104
+ print("=" * 50)
105
+ print("AOT optimization complete!")
106
+ print("=" * 50)
107
 
108
 
109
+ def optimize_pipeline_(
110
+ pipeline: Any,
111
+ sample_image: Any,
112
+ sample_prompt: str = "sample prompt"
113
+ ):
 
 
114
  """
115
+ Optimize a diffusers-style pipeline.
116
+ Compatible with FluxPipeline or similar pipelines.
117
 
118
  Args:
119
+ pipeline: The pipeline with a .transformer attribute
120
+ sample_image: Sample image for capturing input shapes
121
+ sample_prompt: Sample prompt text
 
 
 
 
 
 
122
  """
123
+
124
+ @spaces.GPU(duration=1500)
125
+ def compile_transformer():
126
+ with spaces.aoti_capture(pipeline.transformer) as call:
127
+ pipeline(image=sample_image, prompt=sample_prompt)
128
 
129
+ dynamic_shapes = tree_map_only((torch.Tensor, bool), lambda t: None, call.kwargs)
130
+ dynamic_shapes |= TRANSFORMER_DYNAMIC_SHAPES
 
 
131
 
132
+ # Fuse QKV projections if available (diffusers pipelines)
133
+ if hasattr(pipeline.transformer, 'fuse_qkv_projections'):
134
+ pipeline.transformer.fuse_qkv_projections()
135
 
136
+ quantize_(pipeline.transformer, Float8DynamicActivationFloat8WeightConfig())
 
 
 
 
 
 
137
 
138
+ exported = torch.export.export(
139
+ mod=pipeline.transformer,
140
+ args=call.args,
141
+ kwargs=call.kwargs,
142
+ dynamic_shapes=dynamic_shapes,
143
+ )
144
 
145
+ return spaces.aoti_compile(exported, INDUCTOR_CONFIGS)
146
+
147
+ spaces.aoti_apply(compile_transformer(), pipeline.transformer)
 
148
 
149
 
150
  def enable_cuda_optimizations():
 
174
  print(f"Some CUDA optimizations failed: {e}")
175
 
176
 
177
+ # Simpler optimization for non-ZeroGPU environments
178
+ def apply_float8_quantization_simple(model: nn.Module) -> nn.Module:
179
+ """
180
+ Apply Float8 quantization without AOT compilation.
181
+ Use this for non-ZeroGPU environments or when AOT is not available.
182
+ """
183
+ try:
184
+ print("Applying Float8 quantization...")
185
+ quantize_(model, Float8DynamicActivationFloat8WeightConfig())
186
+ print("Float8 quantization applied!")
187
+ return model
188
+ except Exception as e:
189
+ print(f"Float8 quantization failed: {e}")
190
+ return model
191
+
192
+
193
+ def apply_torch_compile_simple(
194
+ model: nn.Module,
195
+ mode: str = "reduce-overhead",
196
+ backend: str = "inductor"
197
+ ) -> nn.Module:
198
+ """
199
+ Apply torch.compile without AOT (JIT compilation).
200
+ Use this as a fallback when AOT compilation is not available.
201
+ """
202
+ try:
203
+ import torch._inductor.config as inductor_config
204
+
205
+ # Apply inductor configurations
206
+ for key, value in INDUCTOR_CONFIGS.items():
207
+ if hasattr(inductor_config, key):
208
+ setattr(inductor_config, key, value)
209
+
210
+ print(f"Applying torch.compile with mode='{mode}'...")
211
+ compiled_model = torch.compile(model, mode=mode, backend=backend)
212
+ print("torch.compile applied!")
213
+ return compiled_model
214
+ except Exception as e:
215
+ print(f"torch.compile failed: {e}")
216
+ return model
217
+
218
+
219
  def optimize_model(
220
  model: nn.Module,
221
  device: str = "cuda",
 
224
  compile_mode: str = "reduce-overhead"
225
  ) -> nn.Module:
226
  """
227
+ Simple optimization wrapper for non-ZeroGPU environments.
228
+ For ZeroGPU, use optimize_flux_model_ instead.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
  """
230
  print("=" * 50)
231
  print("Applying model optimizations...")
232
  print("=" * 50)
233
 
 
234
  enable_cuda_optimizations()
235
 
 
236
  if next(model.parameters()).device.type != device:
237
  print(f"Moving model to {device}...")
238
  model = model.to(device)
239
 
 
240
  if use_float8:
241
+ model = apply_float8_quantization_simple(model)
242
 
 
243
  if use_compile:
244
+ model = apply_torch_compile_simple(model, mode=compile_mode)
245
 
246
  print("=" * 50)
247
  print("Model optimization complete!")
 
250
  return model
251
 
252
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
253
  def check_optimization_support():
254
+ """Check which optimizations are available."""
 
 
255
  print("Checking optimization support...")
256
  print("-" * 40)
257
 
 
258
  print(f"CUDA available: {torch.cuda.is_available()}")
259
  if torch.cuda.is_available():
260
  print(f" Device: {torch.cuda.get_device_name()}")
 
261
 
 
 
 
 
 
 
 
 
262
  try:
263
  from torchao.quantization import Float8DynamicActivationFloat8WeightConfig
264
+ print(f"torchao Float8: Available")
265
  except ImportError:
266
+ print(f"torchao Float8: Not available")
267
 
 
268
  try:
269
+ import spaces
270
+ print(f"spaces (ZeroGPU AOT): Available")
271
+ print(f" aoti_capture: {hasattr(spaces, 'aoti_capture')}")
272
+ print(f" aoti_compile: {hasattr(spaces, 'aoti_compile')}")
273
+ print(f" aoti_apply: {hasattr(spaces, 'aoti_apply')}")
274
+ except ImportError:
275
+ print(f"spaces (ZeroGPU AOT): Not available")
276
 
277
  print("-" * 40)
278