rahul7star commited on
Commit
07742f6
·
verified ·
1 Parent(s): c465c0d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -28
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import torch
2
  from diffusers import DiffusionPipeline
3
  import spaces
4
- from spaces.zero.torch.aoti import ZeroGPUCompiledModel, ZeroGPUWeights
5
  from time import perf_counter
6
  import argparse
7
 
@@ -10,8 +10,8 @@ CKPT_ID = "black-forest-labs/Flux.1-Dev"
10
  def get_pipe_kwargs():
11
  return {
12
  "prompt": "A cat holding a sign that says hello world",
13
- "height": 1024,
14
- "width": 1024,
15
  "guidance_scale": 3.5,
16
  "num_inference_steps": 50,
17
  "max_sequence_length": 512,
@@ -21,7 +21,7 @@ def get_pipe_kwargs():
21
  def load_pipeline():
22
  pipe = DiffusionPipeline.from_pretrained(
23
  CKPT_ID,
24
- torch_dtype=torch.float32, # CPU only
25
  device_map="cpu"
26
  )
27
  pipe.set_progress_bar_config(disable=True)
@@ -30,36 +30,32 @@ def load_pipeline():
30
  @torch.no_grad()
31
  def aot_compile_load(pipe, regional=False):
32
  prompt = "example prompt"
33
-
34
  torch.compiler.reset()
35
  with torch._inductor.utils.fresh_inductor_cache():
36
  if regional:
37
- # Compile individual transformer blocks
38
- for block in pipe.transformer.transformer_blocks:
39
- with spaces.aoti_capture(block) as call:
40
- pipe(prompt=prompt)
41
- exported = torch.export.export(block, args=call.args, kwargs=call.kwargs)
42
- compiled = spaces.aoti_compile(exported)
43
- weights = ZeroGPUWeights(block.state_dict())
44
- compiled_block = ZeroGPUCompiledModel(compiled.archive_file, weights)
45
- block.forward = compiled_block
46
-
47
- for block in pipe.transformer.single_transformer_blocks:
48
- with spaces.aoti_capture(block) as call:
49
- pipe(prompt=prompt)
50
- exported = torch.export.export(block, args=call.args, kwargs=call.kwargs)
51
- compiled = spaces.aoti_compile(exported)
52
- weights = ZeroGPUWeights(block.state_dict())
53
- compiled_block = ZeroGPUCompiledModel(compiled.archive_file, weights)
54
- block.forward = compiled_block
55
  else:
56
- # Compile the whole transformer
57
- with spaces.aoti_capture(pipe.transformer) as call:
58
  pipe(prompt=prompt)
59
  exported = torch.export.export(pipe.transformer, args=call.args, kwargs=call.kwargs)
60
- compiled = spaces.aoti_compile(exported)
61
- spaces.aoti_apply(compiled, pipe.transformer)
62
-
 
63
  return pipe
64
 
65
  def measure_compile_time(pipe, regional=False):
 
1
  import torch
2
  from diffusers import DiffusionPipeline
3
  import spaces
4
+ from spaces.zero.torch.aoti import ZeroGPUCompiledModel, ZeroGPUWeights, aoti_capture, aoti_compile, aoti_apply
5
  from time import perf_counter
6
  import argparse
7
 
 
10
  def get_pipe_kwargs():
11
  return {
12
  "prompt": "A cat holding a sign that says hello world",
13
+ "height": 512, # reduce memory usage
14
+ "width": 512,
15
  "guidance_scale": 3.5,
16
  "num_inference_steps": 50,
17
  "max_sequence_length": 512,
 
21
  def load_pipeline():
22
  pipe = DiffusionPipeline.from_pretrained(
23
  CKPT_ID,
24
+ torch_dtype=torch.float32, # CPU-only
25
  device_map="cpu"
26
  )
27
  pipe.set_progress_bar_config(disable=True)
 
30
  @torch.no_grad()
31
  def aot_compile_load(pipe, regional=False):
32
  prompt = "example prompt"
33
+
34
  torch.compiler.reset()
35
  with torch._inductor.utils.fresh_inductor_cache():
36
  if regional:
37
+ # Compile transformer blocks **one at a time** to save memory
38
+ for block_list in [pipe.transformer.transformer_blocks, pipe.transformer.single_transformer_blocks]:
39
+ for i, block in enumerate(block_list):
40
+ with aoti_capture(block) as call:
41
+ pipe(prompt=prompt)
42
+ exported = torch.export.export(block, args=call.args, kwargs=call.kwargs)
43
+ compiled = aoti_compile(exported)
44
+ weights = ZeroGPUWeights(block.state_dict())
45
+ compiled_block = ZeroGPUCompiledModel(compiled.archive_file, weights)
46
+ block.forward = compiled_block # replace forward with compiled block
47
+ # Free memory
48
+ del exported, compiled, weights, compiled_block, call
49
+ torch.cuda.empty_cache() if torch.cuda.is_available() else None
 
 
 
 
 
50
  else:
51
+ # Compile the whole transformer at once
52
+ with aoti_capture(pipe.transformer) as call:
53
  pipe(prompt=prompt)
54
  exported = torch.export.export(pipe.transformer, args=call.args, kwargs=call.kwargs)
55
+ compiled = aoti_compile(exported)
56
+ aoti_apply(compiled, pipe.transformer)
57
+ del exported, compiled, call
58
+
59
  return pipe
60
 
61
  def measure_compile_time(pipe, regional=False):