EdgeDiffuse

Edge-deployable SD-Turbo via multi-stage compression: structural pruning β†’ distillation β†’ sensitivity-aware mixed-precision quantization (GPTQ) β†’ QLoRA recovery.

Code & paper-style writeup: github.com/SeanHe727/EdgeDiffusion


What's in this repo

File / dir What it is
unet/ Mixed-precision quantized UNet (GPTQ-applied). 152 of 192 Linear layers quantized to INT4 (45) / INT8 (107); the rest stay fp16. Fake-quantized: values rounded to int grid, stored as bf16.
text_encoder/, vae/, tokenizer/, scheduler/, model_index.json Standard stabilityai/sd-turbo components, unmodified
lora_adapter.pt (Optional) QLoRA recovery adapter trained on top of the quantized UNet. Improves LPIPS by ~8 % when applied. See "Advanced: QLoRA recovery" below.
mp_quant_metadata.json Per-layer bit-width assignment + GPTQ hyper-parameters for full reproducibility

Quick start

from diffusers import StableDiffusionPipeline
import torch

pipe = StableDiffusionPipeline.from_pretrained(
    "ChenHe727/EdgeDiffusion",
    torch_dtype=torch.bfloat16,        # required: INT4 layers use bf16 dtype
)
pipe = pipe.to("cuda")

image = pipe(
    "a photo of a tabby cat sitting on a wooden chair, sharp focus",
    num_inference_steps=2,             # 2-step is the sweet spot for SD-Turbo derivatives
    guidance_scale=0.0,                # SD-Turbo doesn't use CFG
).images[0]

image.save("output.png")

Why 2 inference steps?

SD-Turbo is fundamentally trained with adversarial diffusion distillation for 1-step generation. Empirically, 2 steps gives the best quality/speed trade-off for our compressed model: 28 % faster than 4 steps with marginally better LPIPS.


Results

Benchmark on RTX 5070 (Blackwell), 512 Γ— 512, 2-step inference:

Variant Params Latency VRAM LPIPS vs original SD-Turbo LPIPS vs fp16 baseline
stabilityai/sd-turbo (original) 860 M 0.146 s 3.05 GB 0 0.278
fp16 baseline (pruned + distilled) 642 M 0.142 s 2.64 GB 0.278 0
this repo (mp_quant PTQ) 642 M 0.145 s 2.64 GB 0.277 0.062
with LoRA adapter loaded 642 M + 9 MB 0.171 s 2.65 GB 0.278 0.057

Key takeaway: mixed-precision quantization adds essentially zero perceptual cost on top of the pruned + distilled baseline (LPIPS 0.062 vs fp16). The dominant quality cost in the pipeline is the pruning stage; quantization is "free".


Advanced: QLoRA recovery adapter

The included lora_adapter.pt was trained for 500 steps with step-wise teacher-student distillation to recover residual PTQ quality loss. It reduces the LPIPS gap from 0.062 to 0.057 (~8 % improvement).

import torch
from peft import LoraConfig, get_peft_model
from diffusers import StableDiffusionPipeline
from huggingface_hub import hf_hub_download
import json

# Load base pipeline
pipe = StableDiffusionPipeline.from_pretrained(
    "ChenHe727/EdgeDiffusion", torch_dtype=torch.bfloat16,
).to("cuda")

# Discover which layers were quantized (LoRA targets these)
meta_path = hf_hub_download("ChenHe727/EdgeDiffusion", "mp_quant_metadata.json")
with open(meta_path) as f:
    meta = json.load(f)
target_fqns = [fqn for fqn, bit in meta["quantization"]["assignment"].items() if bit != "fp16"]

# Re-attach LoRA structure and load adapter weights
lora_state = torch.load(hf_hub_download("ChenHe727/EdgeDiffusion", "lora_adapter.pt"),
                        weights_only=False, map_location="cuda")
sample_key = next(k for k in lora_state if "lora_A" in k)
rank = lora_state[sample_key].shape[0]

pipe.unet = get_peft_model(pipe.unet, LoraConfig(
    r=rank, lora_alpha=rank * 2, target_modules=target_fqns,
    lora_dropout=0.0, bias="none",
))
own = pipe.unet.state_dict()
for k, v in lora_state.items():
    if k in own:
        own[k].copy_(v.to(own[k].device, dtype=own[k].dtype))
pipe.unet.eval()

# Generate as usual
image = pipe("a cat", num_inference_steps=2, guidance_scale=0.0).images[0]

Pipeline overview

The model in this repo is the output of a three-stage compression pipeline applied to stabilityai/sd-turbo:

stabilityai/sd-turbo (860 M)
    ↓ structural pruning + step-wise distillation
ChenHe727/EdgeDiffusion_distilled_feat_attn (642 M, fp16)
    ↓ sensitivity-aware mixed-precision GPTQ (this repo's UNet)
    ↓ QLoRA recovery training (this repo's lora_adapter.pt)
ChenHe727/EdgeDiffusion (this repo)

Full design rationale, ablations, and reproducibility instructions: see the GitHub repo.


Limitations

  • Conv2d layers are not quantized in v1 β€” only nn.Linear (attention projections, FFN). Conv2d holds ~70 % of UNet parameters; full quantization is planned for v2.
  • Fake-quant storage: weights are rounded to INT4/INT8 grids but stored as bf16 (2 bytes/value). Real packed INT4/INT8 storage would shrink the file from 1.22 GB to ~900 MB but requires a separate packing step.
  • LPIPS vs original SD-Turbo β‰ˆ 0.28 mostly comes from the upstream pruning + distillation stage. The quantization stage itself adds only 0.005-0.062.
  • 2-step inference is the recommended default. 1-step works (faster) but quality drops noticeably; 4-step is slower and not better.

Acknowledgments

Downloads last month
55
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Model tree for ChenHe727/EdgeDiffusion

Finetuned
(6)
this model

Papers for ChenHe727/EdgeDiffusion