TSXu commited on
Commit
a53108a
·
1 Parent(s): 3974489

Add Float8 quantization and torch.compile optimizations

Browse files

- Add optimization.py module with torchao Float8 quantization support
- Add torch.compile with inductor optimizations (max_autotune, cudagraphs, etc.)
- Enable CUDA optimizations (TF32, Flash SDPA, cuDNN benchmark)
- Add --float8, --compile, --compile-mode CLI arguments
- Update requirements.txt with torchao>=0.4.0 and torch>=2.4.0

Files changed (3) hide show
  1. inference.py +40 -1
  2. optimization.py +322 -0
  3. requirements.txt +2 -1
inference.py CHANGED
@@ -18,6 +18,7 @@ from huggingface_hub import hf_hub_download, snapshot_download
18
  from src.flux.util import configs, load_ae, load_clip, load_t5
19
  from src.flux.model import Flux
20
  from src.flux.xflux_pipeline import XFluxSampler
 
21
 
22
 
23
  # HuggingFace Hub model IDs
@@ -150,6 +151,9 @@ class CalligraphyGenerator:
150
  author_descriptions_path: str = "calligraphy_styles_en.json",
151
  use_deepspeed: bool = False,
152
  use_4bit_quantization: bool = False,
 
 
 
153
  deepspeed_config: Optional[str] = None,
154
  dtype: Optional[str] = None
155
  ):
@@ -166,6 +170,10 @@ class CalligraphyGenerator:
166
  font_descriptions_path: path to font style descriptions JSON
167
  author_descriptions_path: path to author style descriptions JSON
168
  use_deepspeed: whether to use DeepSpeed ZeRO for memory optimization
 
 
 
 
169
  deepspeed_config: path to DeepSpeed config JSON file
170
  dtype: force specific dtype for inference: "fp16", "bf16", "fp32", or None for auto
171
  """
@@ -176,7 +184,13 @@ class CalligraphyGenerator:
176
  self.use_deepspeed = use_deepspeed
177
  self.deepspeed_config = deepspeed_config
178
  self.use_4bit_quantization = use_4bit_quantization
 
 
 
179
  self.forced_dtype = dtype # "fp16", "bf16", "fp32", or None for auto
 
 
 
180
 
181
  # Load font and author style descriptions
182
  if os.path.exists(font_descriptions_path):
@@ -232,6 +246,17 @@ class CalligraphyGenerator:
232
  )
233
  if self.use_deepspeed:
234
  self.model = self._init_deepspeed(self.model)
 
 
 
 
 
 
 
 
 
 
 
235
 
236
  # Load VAE
237
  if self.use_deepspeed or offload:
@@ -1088,14 +1113,28 @@ if __name__ == "__main__":
1088
  parser.add_argument("--checkpoint", type=str, default=None, help="Checkpoint path")
1089
  parser.add_argument("--list-authors", action="store_true", help="List available authors")
1090
  parser.add_argument("--list-fonts", action="store_true", help="List available font styles")
 
 
 
 
 
 
1091
 
1092
  args = parser.parse_args()
 
 
 
 
 
1093
 
1094
  # Initialize generator
1095
  generator = CalligraphyGenerator(
1096
  model_name="flux-dev",
1097
  device=args.device,
1098
- checkpoint_path=args.checkpoint
 
 
 
1099
  )
1100
 
1101
  # List available options
 
18
  from src.flux.util import configs, load_ae, load_clip, load_t5
19
  from src.flux.model import Flux
20
  from src.flux.xflux_pipeline import XFluxSampler
21
+ from optimization import optimize_model, enable_cuda_optimizations, check_optimization_support
22
 
23
 
24
  # HuggingFace Hub model IDs
 
151
  author_descriptions_path: str = "calligraphy_styles_en.json",
152
  use_deepspeed: bool = False,
153
  use_4bit_quantization: bool = False,
154
+ use_float8_quantization: bool = False,
155
+ use_torch_compile: bool = False,
156
+ compile_mode: str = "reduce-overhead",
157
  deepspeed_config: Optional[str] = None,
158
  dtype: Optional[str] = None
159
  ):
 
170
  font_descriptions_path: path to font style descriptions JSON
171
  author_descriptions_path: path to author style descriptions JSON
172
  use_deepspeed: whether to use DeepSpeed ZeRO for memory optimization
173
+ use_4bit_quantization: whether to use 4-bit quantization (quanto/bitsandbytes)
174
+ use_float8_quantization: whether to use Float8 quantization (torchao) for faster inference
175
+ use_torch_compile: whether to use torch.compile for optimized inference
176
+ compile_mode: torch.compile mode - "reduce-overhead", "max-autotune", or "default"
177
  deepspeed_config: path to DeepSpeed config JSON file
178
  dtype: force specific dtype for inference: "fp16", "bf16", "fp32", or None for auto
179
  """
 
184
  self.use_deepspeed = use_deepspeed
185
  self.deepspeed_config = deepspeed_config
186
  self.use_4bit_quantization = use_4bit_quantization
187
+ self.use_float8_quantization = use_float8_quantization
188
+ self.use_torch_compile = use_torch_compile
189
+ self.compile_mode = compile_mode
190
  self.forced_dtype = dtype # "fp16", "bf16", "fp32", or None for auto
191
+
192
+ # Enable CUDA optimizations early
193
+ enable_cuda_optimizations()
194
 
195
  # Load font and author style descriptions
196
  if os.path.exists(font_descriptions_path):
 
246
  )
247
  if self.use_deepspeed:
248
  self.model = self._init_deepspeed(self.model)
249
+
250
+ # Apply Float8 quantization and torch.compile optimizations
251
+ if not self.use_deepspeed and not self.use_4bit_quantization:
252
+ if self.use_float8_quantization or self.use_torch_compile:
253
+ self.model = optimize_model(
254
+ self.model,
255
+ device=str(self.device),
256
+ use_float8=self.use_float8_quantization,
257
+ use_compile=self.use_torch_compile,
258
+ compile_mode=self.compile_mode
259
+ )
260
 
261
  # Load VAE
262
  if self.use_deepspeed or offload:
 
1113
  parser.add_argument("--checkpoint", type=str, default=None, help="Checkpoint path")
1114
  parser.add_argument("--list-authors", action="store_true", help="List available authors")
1115
  parser.add_argument("--list-fonts", action="store_true", help="List available font styles")
1116
+ parser.add_argument("--float8", action="store_true", help="Use Float8 quantization (torchao) for faster inference")
1117
+ parser.add_argument("--compile", action="store_true", help="Use torch.compile for optimized inference")
1118
+ parser.add_argument("--compile-mode", type=str, default="reduce-overhead",
1119
+ choices=["reduce-overhead", "max-autotune", "default"],
1120
+ help="torch.compile mode")
1121
+ parser.add_argument("--check-optimizations", action="store_true", help="Check available optimization support")
1122
 
1123
  args = parser.parse_args()
1124
+
1125
+ # Check optimization support if requested
1126
+ if args.check_optimizations:
1127
+ check_optimization_support()
1128
+ exit(0)
1129
 
1130
  # Initialize generator
1131
  generator = CalligraphyGenerator(
1132
  model_name="flux-dev",
1133
  device=args.device,
1134
+ checkpoint_path=args.checkpoint,
1135
+ use_float8_quantization=args.float8,
1136
+ use_torch_compile=args.compile,
1137
+ compile_mode=args.compile_mode
1138
  )
1139
 
1140
  # List available options
optimization.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Model optimization utilities for faster inference using:
3
+ - Float8 quantization via torchao
4
+ - torch.compile with inductor optimizations
5
+ - CUDA graph capture for reduced kernel launch overhead
6
+
7
+ Inspired by FLUX-Kontext-fp8 optimization techniques.
8
+ """
9
+
10
+ from typing import Optional, Callable, Any
11
+ import torch
12
+ import torch.nn as nn
13
+
14
+
15
+ # Inductor configuration for optimal performance
16
+ INDUCTOR_CONFIGS = {
17
+ 'conv_1x1_as_mm': True,
18
+ 'epilogue_fusion': False,
19
+ 'coordinate_descent_tuning': True,
20
+ 'coordinate_descent_check_all_directions': True,
21
+ 'max_autotune': True,
22
+ }
23
+
24
+
25
+ def apply_float8_quantization(model: nn.Module, device: str = "cuda") -> nn.Module:
26
+ """
27
+ Apply Float8 dynamic activation and weight quantization using torchao.
28
+
29
+ This provides significant speedup on GPUs with native FP8 support (H100, etc.)
30
+ and reasonable speedup on other GPUs through reduced memory bandwidth.
31
+
32
+ Args:
33
+ model: PyTorch model to quantize
34
+ device: Target device for the model
35
+
36
+ Returns:
37
+ Quantized model
38
+ """
39
+ try:
40
+ from torchao.quantization import quantize_, Float8DynamicActivationFloat8WeightConfig
41
+
42
+ print("Applying Float8 dynamic activation + Float8 weight quantization...")
43
+
44
+ # Move model to device first if not already there
45
+ if next(model.parameters()).device.type != device:
46
+ model = model.to(device)
47
+
48
+ # Apply float8 quantization
49
+ quantize_(model, Float8DynamicActivationFloat8WeightConfig())
50
+
51
+ print("Float8 quantization applied successfully!")
52
+ return model
53
+
54
+ except ImportError as e:
55
+ print(f"torchao not available for Float8 quantization: {e}")
56
+ print("Install with: pip install torchao")
57
+ return model
58
+ except Exception as e:
59
+ print(f"Float8 quantization failed: {e}")
60
+ print("Falling back to unquantized model")
61
+ return model
62
+
63
+
64
+ def apply_torch_compile(
65
+ model: nn.Module,
66
+ mode: str = "reduce-overhead",
67
+ fullgraph: bool = False,
68
+ dynamic: bool = True,
69
+ backend: str = "inductor"
70
+ ) -> nn.Module:
71
+ """
72
+ Apply torch.compile with optimized settings for inference.
73
+
74
+ Args:
75
+ model: PyTorch model to compile
76
+ mode: Compilation mode - "reduce-overhead" (best for inference),
77
+ "max-autotune" (slower compile, faster runtime), or "default"
78
+ fullgraph: If True, requires entire forward to be capturable (faster but stricter)
79
+ dynamic: If True, allows dynamic shapes (recommended for variable input sizes)
80
+ backend: Compiler backend - "inductor" is recommended
81
+
82
+ Returns:
83
+ Compiled model
84
+ """
85
+ try:
86
+ import torch._inductor.config as inductor_config
87
+
88
+ # Apply inductor configurations
89
+ for key, value in INDUCTOR_CONFIGS.items():
90
+ if hasattr(inductor_config, key):
91
+ setattr(inductor_config, key, value)
92
+
93
+ print(f"Applying torch.compile with mode='{mode}', backend='{backend}'...")
94
+
95
+ compiled_model = torch.compile(
96
+ model,
97
+ mode=mode,
98
+ fullgraph=fullgraph,
99
+ dynamic=dynamic,
100
+ backend=backend
101
+ )
102
+
103
+ print("torch.compile applied successfully!")
104
+ return compiled_model
105
+
106
+ except Exception as e:
107
+ print(f"torch.compile failed: {e}")
108
+ print("Falling back to uncompiled model")
109
+ return model
110
+
111
+
112
+ def enable_cuda_optimizations():
113
+ """
114
+ Enable various CUDA optimizations for better performance.
115
+ """
116
+ if not torch.cuda.is_available():
117
+ print("CUDA not available, skipping CUDA optimizations")
118
+ return
119
+
120
+ try:
121
+ # Enable TF32 for faster matmul on Ampere+ GPUs
122
+ torch.backends.cuda.matmul.allow_tf32 = True
123
+ torch.backends.cudnn.allow_tf32 = True
124
+
125
+ # Enable cuDNN benchmark mode for faster convolutions
126
+ torch.backends.cudnn.benchmark = True
127
+
128
+ # Enable flash/memory-efficient SDPA backends
129
+ torch.backends.cuda.enable_flash_sdp(True)
130
+ torch.backends.cuda.enable_mem_efficient_sdp(True)
131
+ torch.backends.cuda.enable_math_sdp(False) # Disable slower math backend
132
+
133
+ print("CUDA optimizations enabled (TF32, cuDNN benchmark, Flash SDPA)")
134
+
135
+ except Exception as e:
136
+ print(f"Some CUDA optimizations failed: {e}")
137
+
138
+
139
+ def optimize_model(
140
+ model: nn.Module,
141
+ device: str = "cuda",
142
+ use_float8: bool = True,
143
+ use_compile: bool = True,
144
+ compile_mode: str = "reduce-overhead"
145
+ ) -> nn.Module:
146
+ """
147
+ Apply all optimizations to the model for maximum inference speed.
148
+
149
+ Optimizations applied:
150
+ 1. CUDA backend optimizations (TF32, Flash SDPA, etc.)
151
+ 2. Float8 quantization via torchao (if available)
152
+ 3. torch.compile with inductor optimizations
153
+
154
+ Args:
155
+ model: PyTorch model to optimize
156
+ device: Target device
157
+ use_float8: Whether to apply Float8 quantization
158
+ use_compile: Whether to apply torch.compile
159
+ compile_mode: Mode for torch.compile
160
+
161
+ Returns:
162
+ Optimized model
163
+ """
164
+ print("=" * 50)
165
+ print("Applying model optimizations...")
166
+ print("=" * 50)
167
+
168
+ # 1. Enable CUDA optimizations
169
+ enable_cuda_optimizations()
170
+
171
+ # 2. Move model to device
172
+ if next(model.parameters()).device.type != device:
173
+ print(f"Moving model to {device}...")
174
+ model = model.to(device)
175
+
176
+ # 3. Apply Float8 quantization
177
+ if use_float8:
178
+ model = apply_float8_quantization(model, device)
179
+
180
+ # 4. Apply torch.compile
181
+ if use_compile:
182
+ model = apply_torch_compile(model, mode=compile_mode)
183
+
184
+ print("=" * 50)
185
+ print("Model optimization complete!")
186
+ print("=" * 50)
187
+
188
+ return model
189
+
190
+
191
+ def warmup_model(
192
+ model: nn.Module,
193
+ warmup_fn: Callable[[], Any],
194
+ num_warmup: int = 3
195
+ ):
196
+ """
197
+ Warmup the compiled model to trigger JIT compilation.
198
+
199
+ Args:
200
+ model: The model (should already be compiled)
201
+ warmup_fn: Function that runs a forward pass
202
+ num_warmup: Number of warmup iterations
203
+ """
204
+ print(f"Warming up model with {num_warmup} iterations...")
205
+
206
+ with torch.no_grad():
207
+ for i in range(num_warmup):
208
+ try:
209
+ warmup_fn()
210
+ print(f" Warmup {i+1}/{num_warmup} complete")
211
+ except Exception as e:
212
+ print(f" Warmup {i+1}/{num_warmup} failed: {e}")
213
+
214
+ # Synchronize CUDA
215
+ if torch.cuda.is_available():
216
+ torch.cuda.synchronize()
217
+
218
+ print("Model warmup complete!")
219
+
220
+
221
+ class CUDAGraphWrapper(nn.Module):
222
+ """
223
+ Wrapper that captures and replays CUDA graphs for reduced kernel launch overhead.
224
+
225
+ Note: CUDA graphs require static input shapes. Use this only if your input
226
+ dimensions are fixed.
227
+ """
228
+
229
+ def __init__(self, model: nn.Module, warmup_fn: Callable[[], tuple]):
230
+ super().__init__()
231
+ self.model = model
232
+ self.graph = None
233
+ self.static_inputs = None
234
+ self.static_outputs = None
235
+ self._captured = False
236
+
237
+ def capture(self, *sample_inputs):
238
+ """
239
+ Capture the CUDA graph with sample inputs.
240
+
241
+ Args:
242
+ *sample_inputs: Sample inputs with the exact shapes that will be used
243
+ """
244
+ if not torch.cuda.is_available():
245
+ print("CUDA not available, skipping graph capture")
246
+ return
247
+
248
+ print("Capturing CUDA graph...")
249
+
250
+ # Warmup
251
+ with torch.no_grad():
252
+ for _ in range(3):
253
+ _ = self.model(*sample_inputs)
254
+
255
+ torch.cuda.synchronize()
256
+
257
+ # Capture
258
+ self.graph = torch.cuda.CUDAGraph()
259
+
260
+ # Create static tensors
261
+ self.static_inputs = tuple(inp.clone() for inp in sample_inputs)
262
+
263
+ with torch.cuda.graph(self.graph):
264
+ self.static_outputs = self.model(*self.static_inputs)
265
+
266
+ self._captured = True
267
+ print("CUDA graph captured successfully!")
268
+
269
+ def forward(self, *inputs):
270
+ if not self._captured:
271
+ return self.model(*inputs)
272
+
273
+ # Copy inputs to static buffers
274
+ for static_inp, inp in zip(self.static_inputs, inputs):
275
+ static_inp.copy_(inp)
276
+
277
+ # Replay graph
278
+ self.graph.replay()
279
+
280
+ return self.static_outputs
281
+
282
+
283
+ # Utility function to check available optimizations
284
+ def check_optimization_support():
285
+ """
286
+ Check which optimizations are available on the current system.
287
+ """
288
+ print("Checking optimization support...")
289
+ print("-" * 40)
290
+
291
+ # CUDA
292
+ print(f"CUDA available: {torch.cuda.is_available()}")
293
+ if torch.cuda.is_available():
294
+ print(f" Device: {torch.cuda.get_device_name()}")
295
+ print(f" Capability: {torch.cuda.get_device_capability()}")
296
+
297
+ # torch.compile
298
+ try:
299
+ import torch._dynamo
300
+ print(f"torch.compile available: True")
301
+ except ImportError:
302
+ print(f"torch.compile available: False")
303
+
304
+ # torchao Float8
305
+ try:
306
+ from torchao.quantization import Float8DynamicActivationFloat8WeightConfig
307
+ print(f"torchao Float8 available: True")
308
+ except ImportError:
309
+ print(f"torchao Float8 available: False")
310
+
311
+ # Flash Attention
312
+ try:
313
+ flash_available = torch.backends.cuda.flash_sdp_enabled() if torch.cuda.is_available() else False
314
+ print(f"Flash SDPA available: {flash_available}")
315
+ except:
316
+ print(f"Flash SDPA available: Unknown")
317
+
318
+ print("-" * 40)
319
+
320
+
321
+ if __name__ == "__main__":
322
+ check_optimization_support()
requirements.txt CHANGED
@@ -8,8 +8,9 @@ safetensors>=0.4.0
8
  # Model and inference
9
  optimum-quanto
10
  bitsandbytes>=0.41.0
11
- torch
12
  torchvision
 
13
  timm
14
  sentencepiece
15
  diffusers
 
8
  # Model and inference
9
  optimum-quanto
10
  bitsandbytes>=0.41.0
11
+ torch>=2.4.0
12
  torchvision
13
+ torchao>=0.4.0
14
  timm
15
  sentencepiece
16
  diffusers