Spaces:
Paused
Paused
File size: 1,971 Bytes
c465c0d f8bf429 569bc6c f8bf429 569bc6c f8bf429 569bc6c c465c0d 569bc6c c465c0d f8bf429 569bc6c f8bf429 569bc6c f8bf429 569bc6c f8bf429 569bc6c f8bf429 569bc6c f8bf429 569bc6c f8bf429 569bc6c f8bf429 569bc6c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 | import torch
from diffusers import DiffusionPipeline
import spaces
from spaces.zero.torch.aoti import aoti_capture, aoti_compile, aoti_apply
from time import perf_counter
CKPT_ID = "black-forest-labs/Flux.1-Dev"
# -----------------------------
# Pipeline arguments
# -----------------------------
PIPE_KWARGS = {
"prompt": "A cat holding a sign that says hello world",
"height": 256, # very small to reduce memory
"width": 256,
"guidance_scale": 3.5,
"num_inference_steps": 25, # fewer steps
"generator": torch.manual_seed(0)
}
# -----------------------------
# Load pipeline
# -----------------------------
def load_pipe():
pipe = DiffusionPipeline.from_pretrained(
CKPT_ID,
torch_dtype=torch.float32,
device_map="cpu"
)
pipe.set_progress_bar_config(disable=True)
return pipe
# -----------------------------
# Compile transformer using aoti (lightweight)
# -----------------------------
@torch.no_grad()
def compile_pipe(pipe):
with torch._inductor.utils.fresh_inductor_cache():
# Capture + compile transformer once
with aoti_capture(pipe.transformer) as call:
pipe(prompt="dummy")
exported = torch.export.export(pipe.transformer, args=call.args, kwargs=call.kwargs)
compiled = aoti_compile(exported)
aoti_apply(compiled, pipe.transformer)
del exported, compiled, call
return pipe
# -----------------------------
# Measure runtime
# -----------------------------
@torch.no_grad()
def run_pipe(pipe):
start = perf_counter()
image = pipe(**PIPE_KWARGS).images[0]
end = perf_counter()
return end-start, image
# -----------------------------
# Main
# -----------------------------
if __name__ == "__main__":
pipe = load_pipe()
pipe = compile_pipe(pipe) # light aoti compile
latency, image = run_pipe(pipe)
print(f"Lightweight CPU + aoti latency: {latency:.2f}s")
image.save("cpu_lightweight.png")
|