File size: 1,971 Bytes
c465c0d
 
f8bf429
569bc6c
f8bf429
 
 
 
569bc6c
 
 
 
 
 
 
 
 
 
 
f8bf429
569bc6c
 
 
 
c465c0d
 
569bc6c
c465c0d
 
f8bf429
 
 
569bc6c
 
 
f8bf429
569bc6c
f8bf429
569bc6c
 
 
 
 
 
 
f8bf429
 
569bc6c
 
 
 
 
f8bf429
569bc6c
f8bf429
569bc6c
f8bf429
569bc6c
 
 
f8bf429
569bc6c
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import torch
from diffusers import DiffusionPipeline
import spaces
from spaces.zero.torch.aoti import aoti_capture, aoti_compile, aoti_apply
from time import perf_counter

CKPT_ID = "black-forest-labs/Flux.1-Dev"

# -----------------------------
# Pipeline arguments
# -----------------------------
PIPE_KWARGS = {
    "prompt": "A cat holding a sign that says hello world",
    "height": 256,  # very small to reduce memory
    "width": 256,
    "guidance_scale": 3.5,
    "num_inference_steps": 25,  # fewer steps
    "generator": torch.manual_seed(0)
}

# -----------------------------
# Load pipeline
# -----------------------------
def load_pipe():
    pipe = DiffusionPipeline.from_pretrained(
        CKPT_ID,
        torch_dtype=torch.float32,
        device_map="cpu"
    )
    pipe.set_progress_bar_config(disable=True)
    return pipe

# -----------------------------
# Compile transformer using aoti (lightweight)
# -----------------------------
@torch.no_grad()
def compile_pipe(pipe):
    with torch._inductor.utils.fresh_inductor_cache():
        # Capture + compile transformer once
        with aoti_capture(pipe.transformer) as call:
            pipe(prompt="dummy")
        exported = torch.export.export(pipe.transformer, args=call.args, kwargs=call.kwargs)
        compiled = aoti_compile(exported)
        aoti_apply(compiled, pipe.transformer)
        del exported, compiled, call
    return pipe

# -----------------------------
# Measure runtime
# -----------------------------
@torch.no_grad()
def run_pipe(pipe):
    start = perf_counter()
    image = pipe(**PIPE_KWARGS).images[0]
    end = perf_counter()
    return end-start, image

# -----------------------------
# Main
# -----------------------------
if __name__ == "__main__":
    pipe = load_pipe()
    pipe = compile_pipe(pipe)  # light aoti compile
    latency, image = run_pipe(pipe)
    print(f"Lightweight CPU + aoti latency: {latency:.2f}s")
    image.save("cpu_lightweight.png")