Spaces:
Paused
Paused
| import torch | |
| from diffusers import DiffusionPipeline | |
| import spaces | |
| from spaces.zero.torch.aoti import aoti_capture, aoti_compile, aoti_apply | |
| from time import perf_counter | |
| CKPT_ID = "black-forest-labs/Flux.1-Dev" | |
| # ----------------------------- | |
| # Pipeline arguments | |
| # ----------------------------- | |
| PIPE_KWARGS = { | |
| "prompt": "A cat holding a sign that says hello world", | |
| "height": 256, # very small to reduce memory | |
| "width": 256, | |
| "guidance_scale": 3.5, | |
| "num_inference_steps": 25, # fewer steps | |
| "generator": torch.manual_seed(0) | |
| } | |
| # ----------------------------- | |
| # Load pipeline | |
| # ----------------------------- | |
| def load_pipe(): | |
| pipe = DiffusionPipeline.from_pretrained( | |
| CKPT_ID, | |
| torch_dtype=torch.float32, | |
| device_map="cpu" | |
| ) | |
| pipe.set_progress_bar_config(disable=True) | |
| return pipe | |
| # ----------------------------- | |
| # Compile transformer using aoti (lightweight) | |
| # ----------------------------- | |
| def compile_pipe(pipe): | |
| with torch._inductor.utils.fresh_inductor_cache(): | |
| # Capture + compile transformer once | |
| with aoti_capture(pipe.transformer) as call: | |
| pipe(prompt="dummy") | |
| exported = torch.export.export(pipe.transformer, args=call.args, kwargs=call.kwargs) | |
| compiled = aoti_compile(exported) | |
| aoti_apply(compiled, pipe.transformer) | |
| del exported, compiled, call | |
| return pipe | |
| # ----------------------------- | |
| # Measure runtime | |
| # ----------------------------- | |
| def run_pipe(pipe): | |
| start = perf_counter() | |
| image = pipe(**PIPE_KWARGS).images[0] | |
| end = perf_counter() | |
| return end-start, image | |
| # ----------------------------- | |
| # Main | |
| # ----------------------------- | |
| if __name__ == "__main__": | |
| pipe = load_pipe() | |
| pipe = compile_pipe(pipe) # light aoti compile | |
| latency, image = run_pipe(pipe) | |
| print(f"Lightweight CPU + aoti latency: {latency:.2f}s") | |
| image.save("cpu_lightweight.png") | |