File size: 11,468 Bytes
b701455
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
# Stable-Fast Compilation

## Overview

Stable-Fast is a JIT compilation framework that optimizes Stable Diffusion UNet models by tracing execution, fusing operators and optionally capturing CUDA graphs. It can provide significant speedup for SD1.5/SDXL batch workflows with zero quality loss.

Unlike runtime attention optimizations (SageAttention, SpargeAttn), Stable-Fast performs **ahead-of-time compilation** on the first inference pass. The compiled model is cached and reused for subsequent generations with compatible shapes.

## How It Works

Stable-Fast applies three optimization layers:

### 1. TorchScript Tracing

The first forward pass through the UNet is recorded into a static computational graph:

```python
traced_model = torch.jit.trace(unet, example_inputs)
```

This eliminates Python interpreter overhead and enables downstream graph optimizations.

### 2. Operator Fusion

The traced graph undergoes pattern-based fusion:

- **Conv + BatchNorm fusion**: Merges normalization into convolution weights
- **Activation fusion**: Fuses ReLU/GELU/SiLU directly into linear/conv ops
- **Memory layout optimization**: Converts to channels-last format for faster conv execution
- **Triton kernels**: Replaces PyTorch ops with hand-tuned Triton implementations (if `enable_triton=True`)

Example fusion:

```python
# Before:
x = conv(input)
x = batch_norm(x)
x = relu(x)

# After:
x = fused_conv_bn_relu(input)  # Single kernel launch
```

### 3. CUDA Graph Capture (Optional)

When `enable_cuda_graph=True`, the entire forward pass is captured as a static CUDA graph:

- Kernel launches are recorded once and replayed on subsequent runs
- Eliminates CPU launch overhead (~10-15% speedup)
- Requires fixed input shapes and batch sizes

**Trade-off:** Higher VRAM usage (~500MB for graph buffers) and less flexibility.

## Installation

### Windows/Linux (Manual)

Follow the [official guide](https://github.com/chengzeyi/stable-fast?tab=readme-ov-file#installation):

```bash
# Install from PyPI (recommended)
pip install stable-fast

# Or build from source for latest features
git clone https://github.com/chengzeyi/stable-fast
cd stable-fast
pip install -e .
```

**Prerequisites:**
- PyTorch 2.0+ with CUDA support
- xformers (optional but recommended)
- Triton (optional for Triton kernel fusion)

### Docker

Stable-Fast is included in the Docker image when `INSTALL_STABLE_FAST=1`:

```bash
docker-compose build --build-arg INSTALL_STABLE_FAST=1
```

Default is `0` (disabled) to reduce image size and build time.

## Usage

### Streamlit UI

Enable in the **Performance** section of the sidebar:

1. Check **Stable Fast**
2. Generate images β€” the first run compiles the model (30-60s delay)
3. Subsequent generations reuse the cached compiled model

**Visual indicator:** The first generation shows "Compiling model..." in the progress bar.

### REST API

Pass `stable_fast: true` in the request payload:

```bash
curl -X POST http://localhost:7861/api/generate \
  -H "Content-Type: application/json" \
  -d '{
        "prompt": "a peaceful garden with cherry blossoms",
        "width": 768,
        "height": 512,
        "num_images": 1,
        "stable_fast": true
      }'
```

### Configuration

Stable-Fast behavior is controlled by `CompilationConfig`:

```python
from sfast.compilers.diffusion_pipeline_compiler import CompilationConfig

config = CompilationConfig.Default()
config.enable_xformers = True           # Use xformers attention
config.enable_cuda_graph = False        # CUDA graphs (set True for max speed)
config.enable_jit_freeze = True         # Freeze traced graph
config.enable_cnn_optimization = True   # Conv fusion
config.enable_triton = False            # Triton kernels (experimental)
config.memory_format = torch.channels_last  # Optimize memory layout
```

LightDiffusion-Next uses sensible defaults (CUDA graphs disabled by default for flexibility). To override:

```python
# In src/StableFast/StableFast.py
def gen_stable_fast_config(enable_cuda_graph=False):
    config = CompilationConfig.Default()
    config.enable_cuda_graph = enable_cuda_graph  # Pass True for max speed
    # ... rest of config
```

## Performance

### Speedup Benchmarks

Stable-Fast provides speedup through:
- **JIT compilation**: Eliminates Python overhead
- **Operator fusion**: Reduces kernel launches
- **CUDA graphs** (optional): Further reduces CPU overhead

Speedup varies significantly based on:
- GPU architecture
- Batch size and generation count
- Model size (SD1.5 vs SDXL)
- Whether CUDA graphs are enabled

**Note:** Performance benefits are most noticeable for batch operations (50+ images). For single 20-step generations, compilation overhead may exceed speedup gains.

### Compilation Time

First-run compilation overhead:

- **SD1.5 UNet**: ~30s (traced once per resolution/batch size)
- **SDXL UNet**: ~60s (larger model)
- **Subsequent runs**: <1s (cached)

Cached compiled models persist in `~/.cache/torch_extensions/`. Clear this directory to force recompilation.

## Stacking with Other Optimizations

Stable-Fast is **fully compatible** with SageAttention, SpargeAttn and WaveSpeed:

### Stable-Fast + SageAttention

```yaml
stable_fast: true
# SageAttention auto-detected
```

**Result:** 70% (Stable-Fast) + 15% (SageAttention) = **~2x total speedup**

### Stable-Fast + SpargeAttn

```yaml
stable_fast: true
# SpargeAttn auto-detected
```

**Result:** 70% (Stable-Fast) + 40% (SpargeAttn) = **~2.4x total speedup**

### Stable-Fast + SpargeAttn + DeepCache

```yaml
stable_fast: true
deepcache:
  enabled: true
  interval: 3
  depth: 2
# SpargeAttn auto-detected
```

**Result:** 70% Γ— 40% Γ— 150% (DeepCache 2-3x) = **~4-5x total speedup**

## Compatibility

### Compatible With

- βœ… Stable Diffusion 1.5
- βœ… Stable Diffusion 2.1
- βœ… SDXL
- βœ… All samplers (Euler, DPM++, etc.)
- βœ… LoRA adapters
- βœ… Textual inversion embeddings
- βœ… HiresFix
- βœ… ADetailer
- βœ… Img2Img (with fixed denoise strength)
- βœ… SageAttention/SpargeAttn
- βœ… WaveSpeed caching

### Not Compatible With

- ❌ Flux models (different architecture, no UNet)
- ❌ Dynamic resolution changes after compilation
- ❌ Dynamic batch size changes after compilation (with CUDA graphs)
- ⚠️ Frequent model switching (recompiles each time)

## Troubleshooting

### Slow First Run / Repeated Recompilation

**Symptom:** Every generation triggers compilation, even with identical settings.

**Causes:**
1. Cache directory not writable
2. System clock incorrect (invalidates timestamps)
3. Different model loaded (each model is cached separately)

**Fixes:**
```bash
# Check cache permissions
ls -la ~/.cache/torch_extensions

# Ensure stable timestamps
date  # Should be correct

# Mount cache in Docker to persist across container restarts
docker run -v ~/.cache/torch_extensions:/root/.cache/torch_extensions ...
```

### CUDA Out of Memory During Compilation

**Symptom:** OOM error on first run but not subsequent runs.

**Cause:** Compilation allocates temporary buffers for tracing.

**Fixes:**
1. Disable CUDA graphs: `enable_cuda_graph=False` (saves ~500MB)
2. Reduce batch size temporarily for first run
3. Clear other VRAM consumers (close other apps, disable model caching)

### Compilation Hangs or Crashes

**Symptom:** Process freezes during "Compiling model..." step.

**Causes:**
1. Triton compilation error (if `enable_triton=True`)
2. Driver incompatibility
3. Insufficient CPU RAM for graph analysis

**Fixes:**
```bash
# Disable Triton
# In src/StableFast/StableFast.py:
config.enable_triton = False

# Update NVIDIA driver
nvidia-smi  # Check version, upgrade if < 525.x

# Increase Docker memory limit
# In docker-compose.yml:
deploy:
  resources:
    limits:
      memory: 16G  # Increase from default
```

### Error: `torch.jit.trace` fails

**Symptom:** `RuntimeError: Could not trace model`

**Cause:** Dynamic control flow in model (if/else statements depending on runtime values).

**Fix:** This is rare with standard SD models. If it occurs:
1. Check for custom LoRA/embeddings with dynamic logic
2. Disable Stable-Fast for that specific generation
3. Report issue with model details

### Model Quality Degradation

**Symptom:** Compiled model produces different outputs than baseline.

**Cause:** Numeric precision differences from operator fusion (very rare).

**Fixes:**
```python
# Disable aggressive optimizations
config.enable_cnn_optimization = False
config.memory_format = None  # Use default layout
```

If issue persists, disable Stable-Fast and file a bug report.

## Advanced Configuration

### Custom Compilation Config

Override defaults in `src/StableFast/StableFast.py`:

```python
def gen_stable_fast_config(enable_cuda_graph=False):
    config = CompilationConfig.Default()
    
    # Maximum speed (higher VRAM usage)
    config.enable_cuda_graph = True
    config.enable_triton = True
    config.prefer_lowp_gemm = True  # Use FP16 matrix multiplies
    
    # Balanced (recommended)
    config.enable_cuda_graph = False
    config.enable_triton = False
    config.enable_cnn_optimization = True
    
    # Debug (no optimizations)
    config.enable_cuda_graph = False
    config.enable_jit_freeze = False
    config.enable_cnn_optimization = False
    
    return config
```

### Clear Cached Compilations

```bash
# Linux/Mac
rm -rf ~/.cache/torch_extensions

# Windows
del /s /q %USERPROFILE%\.cache\torch_extensions

# Docker (mount cache as volume)
docker run -v my_cache:/root/.cache/torch_extensions ...
docker volume rm my_cache  # Clear cache
```

### Profile Compilation

```bash
# Enable debug logging
export LD_SERVER_LOGLEVEL=DEBUG

# Run generation and check logs
cat logs/server.log | grep "Stable"
```

## Best Practices

### Production Deployments

1. **Pre-compile models** during startup with a warm-up request (only for batch/long-running services)
2. **Mount cache volume** to persist compilations across container restarts
3. **Disable CUDA graphs** if serving multiple batch sizes
4. **Enable CUDA graphs** for fixed-resolution APIs with consistent high-volume traffic
5. **Disable Stable-Fast entirely** for single-shot API endpoints (compilation overhead exceeds benefit)

Example warm-up:

```python
# In startup script
def warmup_stable_fast(model, width=768, height=512):
    """Pre-compile model with dummy input."""
    dummy_input = torch.randn(1, 4, height // 8, width // 8, device="cuda")
    dummy_timestep = torch.tensor([999], device="cuda")
    
    with torch.no_grad():
        model(dummy_input, dummy_timestep, c={})
    
    print("Stable-Fast compilation complete")
```

### Development Workflows

1. **Disable Stable-Fast** when experimenting with new models/LoRAs (avoids repeated recompilation)
2. **Enable for final testing** to verify production performance
3. **Clear cache** after upgrading PyTorch/CUDA drivers

## Citation

If you use Stable-Fast in your work:

```bibtex
@misc{stable-fast,
  author = {Cheng Zeyi},
  title = {stable-fast: Fast Inference for Stable Diffusion},
  year = {2023},
  publisher = {GitHub},
  url = {https://github.com/chengzeyi/stable-fast}
}
```

## Resources

- [Stable-Fast Repository](https://github.com/chengzeyi/stable-fast)
- [Installation Guide](https://github.com/chengzeyi/stable-fast?tab=readme-ov-file#installation)
- [TorchScript Documentation](https://pytorch.org/docs/stable/jit.html)
- [CUDA Graphs Guide](https://pytorch.org/blog/accelerating-pytorch-with-cuda-graphs/)