rahul7star commited on
Commit
569bc6c
·
verified ·
1 Parent(s): 07742f6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -54
app.py CHANGED
@@ -1,77 +1,66 @@
1
  import torch
2
  from diffusers import DiffusionPipeline
3
  import spaces
4
- from spaces.zero.torch.aoti import ZeroGPUCompiledModel, ZeroGPUWeights, aoti_capture, aoti_compile, aoti_apply
5
  from time import perf_counter
6
- import argparse
7
 
8
  CKPT_ID = "black-forest-labs/Flux.1-Dev"
9
 
10
- def get_pipe_kwargs():
11
- return {
12
- "prompt": "A cat holding a sign that says hello world",
13
- "height": 512, # reduce memory usage
14
- "width": 512,
15
- "guidance_scale": 3.5,
16
- "num_inference_steps": 50,
17
- "max_sequence_length": 512,
18
- "generator": torch.manual_seed(0)
19
- }
 
20
 
21
- def load_pipeline():
 
 
 
22
  pipe = DiffusionPipeline.from_pretrained(
23
  CKPT_ID,
24
- torch_dtype=torch.float32, # CPU-only
25
  device_map="cpu"
26
  )
27
  pipe.set_progress_bar_config(disable=True)
28
  return pipe
29
 
 
 
 
30
  @torch.no_grad()
31
- def aot_compile_load(pipe, regional=False):
32
- prompt = "example prompt"
33
-
34
- torch.compiler.reset()
35
  with torch._inductor.utils.fresh_inductor_cache():
36
- if regional:
37
- # Compile transformer blocks **one at a time** to save memory
38
- for block_list in [pipe.transformer.transformer_blocks, pipe.transformer.single_transformer_blocks]:
39
- for i, block in enumerate(block_list):
40
- with aoti_capture(block) as call:
41
- pipe(prompt=prompt)
42
- exported = torch.export.export(block, args=call.args, kwargs=call.kwargs)
43
- compiled = aoti_compile(exported)
44
- weights = ZeroGPUWeights(block.state_dict())
45
- compiled_block = ZeroGPUCompiledModel(compiled.archive_file, weights)
46
- block.forward = compiled_block # replace forward with compiled block
47
- # Free memory
48
- del exported, compiled, weights, compiled_block, call
49
- torch.cuda.empty_cache() if torch.cuda.is_available() else None
50
- else:
51
- # Compile the whole transformer at once
52
- with aoti_capture(pipe.transformer) as call:
53
- pipe(prompt=prompt)
54
- exported = torch.export.export(pipe.transformer, args=call.args, kwargs=call.kwargs)
55
- compiled = aoti_compile(exported)
56
- aoti_apply(compiled, pipe.transformer)
57
- del exported, compiled, call
58
-
59
  return pipe
60
 
61
- def measure_compile_time(pipe, regional=False):
 
 
 
 
62
  start = perf_counter()
63
- pipe = aot_compile_load(pipe, regional=regional)
64
  end = perf_counter()
65
- # Run inference to ensure it works
66
- image = pipe(**get_pipe_kwargs()).images[0]
67
- return end - start, image
68
 
 
 
 
69
  if __name__ == "__main__":
70
- parser = argparse.ArgumentParser()
71
- parser.add_argument("--regional", action="store_true")
72
- args = parser.parse_args()
73
-
74
- pipe = load_pipeline()
75
- latency, image = measure_compile_time(pipe, regional=args.regional)
76
- print(f"{args.regional=}, CPU compile + run latency: {latency:.2f} secs")
77
- image.save(f"regional@{args.regional}.png")
 
1
  import torch
2
  from diffusers import DiffusionPipeline
3
  import spaces
4
+ from spaces.zero.torch.aoti import aoti_capture, aoti_compile, aoti_apply
5
  from time import perf_counter
 
6
 
7
  CKPT_ID = "black-forest-labs/Flux.1-Dev"
8
 
9
+ # -----------------------------
10
+ # Pipeline arguments
11
+ # -----------------------------
12
+ PIPE_KWARGS = {
13
+ "prompt": "A cat holding a sign that says hello world",
14
+ "height": 256, # very small to reduce memory
15
+ "width": 256,
16
+ "guidance_scale": 3.5,
17
+ "num_inference_steps": 25, # fewer steps
18
+ "generator": torch.manual_seed(0)
19
+ }
20
 
21
+ # -----------------------------
22
+ # Load pipeline
23
+ # -----------------------------
24
+ def load_pipe():
25
  pipe = DiffusionPipeline.from_pretrained(
26
  CKPT_ID,
27
+ torch_dtype=torch.float32,
28
  device_map="cpu"
29
  )
30
  pipe.set_progress_bar_config(disable=True)
31
  return pipe
32
 
33
+ # -----------------------------
34
+ # Compile transformer using aoti (lightweight)
35
+ # -----------------------------
36
  @torch.no_grad()
37
+ def compile_pipe(pipe):
 
 
 
38
  with torch._inductor.utils.fresh_inductor_cache():
39
+ # Capture + compile transformer once
40
+ with aoti_capture(pipe.transformer) as call:
41
+ pipe(prompt="dummy")
42
+ exported = torch.export.export(pipe.transformer, args=call.args, kwargs=call.kwargs)
43
+ compiled = aoti_compile(exported)
44
+ aoti_apply(compiled, pipe.transformer)
45
+ del exported, compiled, call
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  return pipe
47
 
48
+ # -----------------------------
49
+ # Measure runtime
50
+ # -----------------------------
51
+ @torch.no_grad()
52
+ def run_pipe(pipe):
53
  start = perf_counter()
54
+ image = pipe(**PIPE_KWARGS).images[0]
55
  end = perf_counter()
56
+ return end-start, image
 
 
57
 
58
+ # -----------------------------
59
+ # Main
60
+ # -----------------------------
61
  if __name__ == "__main__":
62
+ pipe = load_pipe()
63
+ pipe = compile_pipe(pipe) # light aoti compile
64
+ latency, image = run_pipe(pipe)
65
+ print(f"Lightweight CPU + aoti latency: {latency:.2f}s")
66
+ image.save("cpu_lightweight.png")