File size: 6,763 Bytes
16ff30b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
# Q&C: When Quantization Meets Cache in Efficient Image Generation

**Unofficial implementation** of the paper: [Q&C: When Quantization Meets Cache in Efficient Image Generation](https://arxiv.org/abs/2503.02508)

> The official code was announced at `https://github.com/xinding-sys/Quant-Cache` but is not yet publicly available. This repo provides a working implementation based on the paper's methodology sections.

## πŸ“‹ Overview

This repo implements the Q&C method for accelerating Diffusion Transformers (DiTs) by **combining post-training quantization with feature caching**. The paper identifies two key challenges when combining these techniques and proposes solutions:

1. **TAP** (Temporal-Aware Parallel Clustering) β€” Improves calibration dataset selection for PTQ when caching reduces sample diversity
2. **VC** (Variance Compensation) β€” Corrects exposure bias amplified by the quantization+cache combination

## πŸ—οΈ Architecture

```
qandc/
β”œβ”€β”€ __init__.py                    # Package exports
β”œβ”€β”€ quantizer.py                   # Uniform PTQ (W8A8/W4A8, Eq 1-3)
β”œβ”€β”€ cache.py                       # FORA-style feature caching (Section 2.1)
β”œβ”€β”€ tap.py                         # TAP calibration selection (Section 3.1, Algorithm 1)
└── variance_compensation.py       # VC exposure bias correction (Section 3.2, Eq 9-12)

run_experiment.py                  # Self-contained experiment runner
results/
└── experiment_summary.json        # Our experimental results
```

## πŸš€ Quick Start

```bash
pip install torch torchvision diffusers transformers accelerate scipy scikit-learn
```

```python
from diffusers import DiTPipeline, DDPMScheduler
from qandc import quantize_model, apply_cache_to_dit, reset_all_caches

# Load DiT-XL/2
pipe = DiTPipeline.from_pretrained("facebook/DiT-XL-2-256")
pipe.scheduler = DDPMScheduler.from_config(pipe.scheduler.config)
pipe = pipe.to("cuda")

# Apply W8A8 quantization (170 Linear layers)
pipe.transformer = quantize_model(pipe.transformer, w_bits=8, a_bits=8,
                                   skip_patterns=["pos_embed", "norm"])

# Apply feature caching (recompute every 5th step)
apply_cache_to_dit(pipe.transformer, cache_interval=5)

# Generate images
reset_all_caches(pipe.transformer)
output = pipe(class_labels=[207], num_inference_steps=50, guidance_scale=4.0)
```

## πŸ“Š Experiment Results

We ran **6 ablation experiments** on DiT-XL/2-256 with DDPM scheduler to validate the paper's claims. Run on CPU with 16 images, 20 steps (reduced scale for free compute β€” paper uses 10K images, 50 steps on A100 GPUs).

| Experiment | Inception Score ↑ | Time/Image (s) ↓ | Speedup | Description |
|:-----------|------------------:|------------------:|--------:|:------------|
| **FP Baseline** | **13.45** | 65.52 | 1.00x | Full-precision DiT-XL/2, DDPM 20 steps |
| Quant Only (W8A8) | 7.53 | 56.54 | 1.16x | Uniform PTQ, no caching |
| Cache Only (N=4) | 1.79 | 17.10 | 3.83x | FORA-style caching, no quantization |
| Q&C Naive | 1.84 | 20.89 | 3.14x | Quant + Cache, no TAP/VC |
| Q&C + TAP | 1.84 | 19.69 | 3.33x | + Temporal-Aware Parallel Clustering |
| **Q&C Full (TAP+VC)** | **2.27** | 21.12 | **3.10x** | Full method with Variance Compensation |

### Key Observations

1. **Caching provides dramatic speedup** (3.8x) but severely degrades quality β€” confirming the paper's Challenge 1
2. **Naive Q+C combination is catastrophic** (IS drops from 13.45 β†’ 1.84) β€” confirming Challenge 2
3. **Q&C Full (TAP+VC) shows IS improvement** (1.84 β†’ 2.27, +23%) over naive combination, demonstrating VC's effectiveness at correcting exposure bias
4. **TAP improves efficiency** (faster time/image in Q&C+TAP vs naive) through better calibration data selection

### Paper Reference (Table 1, ImageNet 256Γ—256, W8A8, 50 steps)

| Method | FID ↓ | sFID ↓ | IS ↑ | Precision ↑ | Speed |
|:-------|------:|-------:|-----:|------------:|------:|
| DDPM (FP) | 5.22 | 17.63 | 237.8 | 0.8056 | 5Γ— |
| PTQ4DiT | 5.45 | 19.50 | 250.68 | 0.7882 | 10Γ— |
| **Q&C (paper)** | **5.43** | **19.52** | **250.68** | **0.7895** | **12.7Γ—** |

> **Note:** Our numbers are NOT directly comparable to the paper's because: (1) we use only 16 images (paper: 10K), (2) 20 steps (paper: 50), (3) CPU execution, and (4) aggressive cache interval of 4 (paper optimizes this). The purpose is to validate the *relative trends* between methods.

## πŸ”§ Implementation Details

### Quantization (quantizer.py)
- **Uniform symmetric quantization** following Eq 1-3 from the paper
- **Channel-wise** quantization for weights (per output channel)
- **Tensor-wise** quantization for activations
- Supports W8A8 (8-bit weights, 8-bit activations) and W4A8
- Replaces all `nn.Linear` layers except normalization and positional embeddings

### Feature Caching (cache.py)
- **FORA-style static caching**: wraps each transformer block
- At every N-th step: full forward pass + cache the residual output
- For N-1 following steps: reuse the cached residual (skip expensive MHSA + FFN)
- `__getattr__` delegation ensures transparency for DiT's conditioning code

### TAP (tap.py)
- **Spatial similarity**: cosine similarity between flattened latent features (Eq 7)
- **Temporal similarity**: Gaussian kernel on timestep distances (Eq 8)
- **Combined similarity**: `A_final = Ξ±Β·A_spatial + (1-Ξ±)Β·A_temporal` (Eq 6)
- **Parallel subsampling**: m=3 independent subsamples, each 1/20 of full dataset
- **Spectral clustering** on each subsample β†’ co-occurrence matrix β†’ final KMeans

### Variance Compensation (variance_compensation.py)
- Implements both the **full analytical K_t** (Eq 12) and a **simplified version**
- Corrects variance shift in later denoising stages (t > T/2)
- `x_corrected = ΞΌ + K_t Β· (xΜ‚ - ΞΌ)` where K_t is the per-channel, per-timestep correction factor
- Calibrated offline using a few samples through the quantized+cached pipeline

## πŸ”¬ Running Full Experiments

For GPU-scale experiments matching the paper:

```python
# Modify run_experiment.py settings:
args = {
    "num_steps": 50,           # Paper: 50/100/250
    "num_images": 10000,       # Paper: 10,000
    "batch_size": 16,          # GPU batch size
    "cache_interval": 5,       # Tune for quality vs speed
    "num_calib_samples": 800,  # Paper recommendation
    "tap_clusters": 100,       # Paper setting
}
```

## πŸ“ Citation

```bibtex
@article{qandc2025,
  title={Q\&C: When Quantization Meets Cache in Efficient Image Generation},
  author={Xinding et al.},
  journal={arXiv preprint arXiv:2503.02508},
  year={2025}
}
```

## πŸ“„ License

This implementation is provided for research purposes. The DiT model (`facebook/DiT-XL-2-256`) is under CC-BY-NC-4.0 license.