Files changed (1) hide show
  1. app.py +137 -15
app.py CHANGED
@@ -1,10 +1,28 @@
1
  import gradio as gr
2
  import numpy as np
3
  import random
4
- from diffusers import DiffusionPipeline
5
  import torch
 
6
  from huggingface_hub import login
7
  import os
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  device = "cuda" if torch.cuda.is_available() else "cpu"
10
 
@@ -16,16 +34,111 @@ login(token=HUGGINGFACE_TOKEN)
16
  base_model_repo = "stabilityai/stable-diffusion-3-medium-diffusers"
17
  lora_weights_path = "./pytorch_lora_weights.safetensors"
18
 
19
- # Load the base model
20
- pipeline = DiffusionPipeline.from_pretrained(
21
- base_model_repo,
22
- torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
23
- use_auth_token=HUGGINGFACE_TOKEN
 
 
 
 
 
24
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  pipeline.load_lora_weights(lora_weights_path)
26
- #pipeline.enable_sequential_cpu_offload() # Efficient memory usage
27
- #pipeline.enable_xformers_memory_efficient_attention() # Enable xformers memory efficient attention
28
- pipeline = pipeline.to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
 
30
  MAX_SEED = np.iinfo(np.int32).max
31
  MAX_IMAGE_SIZE = 768 # Reduce max image size to fit within memory constraints
@@ -45,13 +158,15 @@ def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance
45
  height=height,
46
  generator=generator
47
  ).images[0]
 
 
48
 
49
- return image
50
 
51
  examples = [
52
- "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
53
- "An astronaut riding a green horse",
54
- "A delicious ceviche cheesecake slice",
55
  ]
56
 
57
  css = """
@@ -59,6 +174,12 @@ css = """
59
  margin: 0 auto;
60
  max-width: 520px;
61
  }
 
 
 
 
 
 
62
  """
63
 
64
  if torch.cuda.is_available():
@@ -85,6 +206,7 @@ with gr.Blocks(css=css) as demo:
85
  run_button = gr.Button("Run", scale=0)
86
 
87
  result = gr.Image(label="Result", show_label=False)
 
88
 
89
  with gr.Accordion("Advanced Settings", open=False):
90
  negative_prompt = gr.Textbox(
@@ -146,7 +268,7 @@ with gr.Blocks(css=css) as demo:
146
  run_button.click(
147
  fn=infer,
148
  inputs=[prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
149
- outputs=[result]
150
  )
151
 
152
- demo.queue().launch()
 
1
  import gradio as gr
2
  import numpy as np
3
  import random
4
+ from diffusers import StableDiffusion3Pipeline, DiffusionPipeline
5
  import torch
6
+ from transformers import T5EncoderModel
7
  from huggingface_hub import login
8
  import os
9
+ import gc
10
+ import psutil
11
+
12
+ def flush():
13
+ gc.collect()
14
+ torch.cuda.empty_cache()
15
+
16
+ def bytes_to_giga_bytes(bytes):
17
+ return bytes / 1024 / 1024 / 1024
18
+
19
+ def get_memory_usage():
20
+ process = psutil.Process(os.getpid())
21
+ mem_info = process.memory_info()
22
+ return f"{mem_info.rss / (1024 ** 2):.2f} MB"
23
+
24
+ def log_memory(step):
25
+ memory_log.append(f"{step}: {get_memory_usage()}")
26
 
27
  device = "cuda" if torch.cuda.is_available() else "cpu"
28
 
 
34
  base_model_repo = "stabilityai/stable-diffusion-3-medium-diffusers"
35
  lora_weights_path = "./pytorch_lora_weights.safetensors"
36
 
37
+ memory_log = []
38
+
39
+ log_memory("Before loading the model")
40
+
41
+ # Load text encoder in 8-bit
42
+ text_encoder = T5EncoderModel.from_pretrained(
43
+ base_model_repo,
44
+ subfolder="text_encoder_3",
45
+ load_in_8bit=True,
46
+ device_map="auto"
47
  )
48
+
49
+ # Load the pipeline with 8-bit text encoder
50
+ pipeline = StableDiffusion3Pipeline.from_pretrained(
51
+ base_model_repo,
52
+ text_encoder_3=text_encoder,
53
+ transformer=None,
54
+ vae=None,
55
+ device_map="balanced",
56
+ )
57
+
58
+ log_memory("After loading the pipeline")
59
+
60
+ # Load and apply the LoRA weights
61
+ pipeline.load_lora_weights(lora_weights_path)
62
+ log_memory("After loading LoRA weights")
63
+
64
+ with torch.no_grad():
65
+ for _ in range(3):
66
+ prompt = "a photo of a cat"
67
+ (
68
+ prompt_embeds,
69
+ negative_prompt_embeds,
70
+ pooled_prompt_embeds,
71
+ negative_pooled_prompt_embeds,
72
+ ) = pipeline.encode_prompt(prompt=prompt, prompt_2=None, prompt_3=None)
73
+ start = time.time()
74
+ for _ in range(10):
75
+ (
76
+ prompt_embeds,
77
+ negative_prompt_embeds,
78
+ pooled_prompt_embeds,
79
+ negative_pooled_prompt_embeds,
80
+ ) = pipeline.encode_prompt(prompt=prompt, prompt_2=None, prompt_3=None)
81
+ end = time.time()
82
+ avg_prompt_encoding_time = (end - start) / 10
83
+
84
+ del text_encoder
85
+ del pipeline
86
+ flush()
87
+
88
+ pipeline = StableDiffusion3Pipeline.from_pretrained(
89
+ base_model_repo,
90
+ text_encoder=None,
91
+ text_encoder_2=None,
92
+ text_encoder_3=None,
93
+ tokenizer=None,
94
+ tokenizer_2=None,
95
+ tokenizer_3=None,
96
+ torch_dtype=torch.float16
97
+ ).to("cuda")
98
+ pipeline.set_progress_bar_config(disable=True)
99
+
100
+ log_memory("After reloading the pipeline without text encoder")
101
+
102
+ # Load and apply the LoRA weights again for the reloaded pipeline
103
  pipeline.load_lora_weights(lora_weights_path)
104
+ log_memory("After reloading LoRA weights for inference")
105
+
106
+ for _ in range(3):
107
+ _ = pipeline(
108
+ prompt_embeds=prompt_embeds.half(),
109
+ negative_prompt_embeds=negative_prompt_embeds.half(),
110
+ pooled_prompt_embeds=pooled_prompt_embeds.half(),
111
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds.half(),
112
+ )
113
+ start = time.time()
114
+ for _ in range(10):
115
+ _ = pipeline(
116
+ prompt_embeds=prompt_embeds.half(),
117
+ negative_prompt_embeds=negative_prompt_embeds.half(),
118
+ pooled_prompt_embeds=pooled_prompt_embeds.half(),
119
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds.half(),
120
+ )
121
+ end = time.time()
122
+ avg_inference_time = (end - start) / 10
123
+
124
+ log_memory("After inference")
125
+
126
+ print(f"Average prompt encoding time: {avg_prompt_encoding_time:.3f} seconds.")
127
+ print(f"Average inference time: {avg_inference_time:.3f} seconds.")
128
+ print(f"Total time: {(avg_prompt_encoding_time + avg_inference_time):.3f} seconds.")
129
+ print(
130
+ f"Max memory allocated: {bytes_to_giga_bytes(torch.cuda.max_memory_allocated())} GB"
131
+ )
132
+
133
+ image = pipeline(
134
+ prompt_embeds=prompt_embeds.half(),
135
+ negative_prompt_embeds=negative_prompt_embeds.half(),
136
+ pooled_prompt_embeds=pooled_prompt_embeds.half(),
137
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds.half(),
138
+ ).images[0]
139
+ image.save("output_8bit.png")
140
+
141
+ log_memory("After saving the image")
142
 
143
  MAX_SEED = np.iinfo(np.int32).max
144
  MAX_IMAGE_SIZE = 768 # Reduce max image size to fit within memory constraints
 
158
  height=height,
159
  generator=generator
160
  ).images[0]
161
+
162
+ log_memory("After inference")
163
 
164
+ return image, "\n".join(memory_log)
165
 
166
  examples = [
167
+ ["Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"],
168
+ ["An astronaut riding a green horse"],
169
+ ["A delicious ceviche cheesecake slice"],
170
  ]
171
 
172
  css = """
 
174
  margin: 0 auto;
175
  max-width: 520px;
176
  }
177
+ #memory-log {
178
+ white-space: pre-wrap;
179
+ background: #f8f9fa;
180
+ padding: 10px;
181
+ border-radius: 5px;
182
+ }
183
  """
184
 
185
  if torch.cuda.is_available():
 
206
  run_button = gr.Button("Run", scale=0)
207
 
208
  result = gr.Image(label="Result", show_label=False)
209
+ memory_log_output = gr.Textbox(label="Memory Log", elem_id="memory-log", lines=10, interactive=False)
210
 
211
  with gr.Accordion("Advanced Settings", open=False):
212
  negative_prompt = gr.Textbox(
 
268
  run_button.click(
269
  fn=infer,
270
  inputs=[prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
271
+ outputs=[result, memory_log_output]
272
  )
273
 
274
+ demo.queue().launch()