Shilpaj commited on
Commit
00ad3f0
·
verified ·
1 Parent(s): e2a7d7d

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -23
app.py CHANGED
@@ -16,6 +16,7 @@ from utils import (
16
  load_models, clear_gpu_memory, set_timesteps, latents_to_pil,
17
  vignette_loss, get_concept_embedding, load_concept_library, image_grid
18
  )
 
19
 
20
 
21
  # Set device
@@ -23,8 +24,24 @@ device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is
23
  if device == "mps":
24
  os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = "1"
25
 
26
- # Load models
27
- vae, tokenizer, text_encoder, unet, scheduler, pipe = load_models(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
  # Load concept library
30
  concept_embeds, concept_tokens = load_concept_library(pipe)
@@ -74,13 +91,13 @@ def generate_latents(prompt, seed, num_inference_steps, guidance_scale,
74
  elif concept_style in art_concepts:
75
  # Generate concept embedding from text description
76
  concept_text = art_concepts[concept_style]
77
- concept_embedding = get_concept_embedding(concept_text, tokenizer, text_encoder, device)
78
 
79
  # Prep text
80
- text_input = tokenizer([prompt], padding="max_length", max_length=tokenizer.model_max_length,
81
  truncation=True, return_tensors="pt")
82
- with torch.no_grad():
83
- text_embeddings = text_encoder(text_input.input_ids.to(device))[0]
84
 
85
  # Apply concept embedding influence if provided
86
  if concept_embedding is not None and concept_strength > 0:
@@ -94,34 +111,34 @@ def generate_latents(prompt, seed, num_inference_steps, guidance_scale,
94
 
95
  # Unconditional embedding for classifier-free guidance
96
  max_length = text_input.input_ids.shape[-1]
97
- uncond_input = tokenizer(
98
  [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
99
  )
100
- with torch.no_grad():
101
- uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0]
102
  text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
103
 
104
  # Prep Scheduler
105
- set_timesteps(scheduler, num_inference_steps)
106
 
107
  # Prep latents
108
  latents = torch.randn(
109
- (batch_size, unet.in_channels, height // 8, width // 8),
110
  generator=generator,
111
  )
112
  latents = latents.to(device)
113
- latents = latents * scheduler.init_noise_sigma
114
 
115
  # Loop through diffusion process
116
- for i, t in tqdm(enumerate(scheduler.timesteps), total=len(scheduler.timesteps)):
117
  # Expand latents for classifier-free guidance
118
  latent_model_input = torch.cat([latents] * 2)
119
- sigma = scheduler.sigmas[i]
120
- latent_model_input = scheduler.scale_model_input(latent_model_input, t)
121
 
122
  # Predict the noise residual
123
- with torch.no_grad():
124
- noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"]
125
 
126
  # Perform classifier-free guidance
127
  noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
@@ -136,7 +153,7 @@ def generate_latents(prompt, seed, num_inference_steps, guidance_scale,
136
  latents_x0 = latents - sigma * noise_pred
137
 
138
  # Decode to image space
139
- denoised_images = vae.decode((1 / 0.18215) * latents_x0).sample / 2 + 0.5 # range (0, 1)
140
 
141
  # Calculate loss
142
  loss = vignette_loss(denoised_images) * vignette_loss_scale
@@ -148,10 +165,11 @@ def generate_latents(prompt, seed, num_inference_steps, guidance_scale,
148
  latents = latents.detach() - cond_grad * sigma**2
149
 
150
  # Step with scheduler
151
- latents = scheduler.step(noise_pred, t, latents).prev_sample
152
 
153
  return latents
154
 
 
155
  def generate_image(prompt, seed=42, num_inference_steps=30, guidance_scale=7.5,
156
  vignette_loss_scale=0.0, concept_style="none", concept_strength=0.5,
157
  height=512, width=512):
@@ -190,7 +208,7 @@ def generate_image(prompt, seed=42, num_inference_steps=30, guidance_scale=7.5,
190
  )
191
 
192
  # Convert latents to image
193
- images = latents_to_pil(latents, vae)
194
 
195
  return images[0]
196
 
@@ -230,7 +248,7 @@ def generate_style_grid(prompt, seed=42, num_inference_steps=30, guidance_scale=
230
  )
231
 
232
  # Convert latents to image
233
- style_images = latents_to_pil(latents, vae)
234
  images.append(style_images[0])
235
  labels.append(style)
236
 
@@ -240,7 +258,7 @@ def generate_style_grid(prompt, seed=42, num_inference_steps=30, guidance_scale=
240
  return grid
241
 
242
  # Define Gradio interface
243
- @spaces.GPU(enable_queue=True)
244
  def create_demo():
245
  with gr.Blocks(title="Guided Stable Diffusion with Styles") as demo:
246
  gr.Markdown("# Guided Stable Diffusion with Styles")
@@ -299,4 +317,4 @@ def create_demo():
299
  # Launch the app
300
  if __name__ == "__main__":
301
  demo = create_demo()
302
- demo.launch()
 
16
  load_models, clear_gpu_memory, set_timesteps, latents_to_pil,
17
  vignette_loss, get_concept_embedding, load_concept_library, image_grid
18
  )
19
+ from diffusers import StableDiffusionPipeline
20
 
21
 
22
  # Set device
 
24
  if device == "mps":
25
  os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = "1"
26
 
27
+ # Load model once at startup using caching
28
+ @spaces.GPUCache
29
+ def load_models():
30
+ model_id = "runwayml/stable-diffusion-v1-5"
31
+
32
+ pipe = StableDiffusionPipeline.from_pretrained(
33
+ model_id,
34
+ torch_dtype=torch.float16,
35
+ safety_checker=None,
36
+ use_safetensors=True
37
+ ).to(device)
38
+
39
+ # Disable unnecessary progress bars
40
+ pipe.set_progress_bar_config(disable=True)
41
+ return pipe
42
+
43
+ # Initialize pipeline once
44
+ pipe = load_models()
45
 
46
  # Load concept library
47
  concept_embeds, concept_tokens = load_concept_library(pipe)
 
91
  elif concept_style in art_concepts:
92
  # Generate concept embedding from text description
93
  concept_text = art_concepts[concept_style]
94
+ concept_embedding = get_concept_embedding(concept_text, pipe.tokenizer, pipe.text_encoder, device)
95
 
96
  # Prep text
97
+ text_input = pipe.tokenizer([prompt], padding="max_length", max_length=pipe.tokenizer.model_max_length,
98
  truncation=True, return_tensors="pt")
99
+ with torch.inference_mode():
100
+ text_embeddings = pipe.text_encoder(text_input.input_ids.to(device))[0]
101
 
102
  # Apply concept embedding influence if provided
103
  if concept_embedding is not None and concept_strength > 0:
 
111
 
112
  # Unconditional embedding for classifier-free guidance
113
  max_length = text_input.input_ids.shape[-1]
114
+ uncond_input = pipe.tokenizer(
115
  [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
116
  )
117
+ with torch.inference_mode():
118
+ uncond_embeddings = pipe.text_encoder(uncond_input.input_ids.to(device))[0]
119
  text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
120
 
121
  # Prep Scheduler
122
+ set_timesteps(pipe.scheduler, num_inference_steps)
123
 
124
  # Prep latents
125
  latents = torch.randn(
126
+ (batch_size, pipe.unet.in_channels, height // 8, width // 8),
127
  generator=generator,
128
  )
129
  latents = latents.to(device)
130
+ latents = latents * pipe.scheduler.init_noise_sigma
131
 
132
  # Loop through diffusion process
133
+ for i, t in tqdm(enumerate(pipe.scheduler.timesteps), total=len(pipe.scheduler.timesteps)):
134
  # Expand latents for classifier-free guidance
135
  latent_model_input = torch.cat([latents] * 2)
136
+ sigma = pipe.scheduler.sigmas[i]
137
+ latent_model_input = pipe.scheduler.scale_model_input(latent_model_input, t)
138
 
139
  # Predict the noise residual
140
+ with torch.inference_mode():
141
+ noise_pred = pipe.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"]
142
 
143
  # Perform classifier-free guidance
144
  noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
 
153
  latents_x0 = latents - sigma * noise_pred
154
 
155
  # Decode to image space
156
+ denoised_images = pipe.vae.decode((1 / 0.18215) * latents_x0).sample / 2 + 0.5 # range (0, 1)
157
 
158
  # Calculate loss
159
  loss = vignette_loss(denoised_images) * vignette_loss_scale
 
165
  latents = latents.detach() - cond_grad * sigma**2
166
 
167
  # Step with scheduler
168
+ latents = pipe.scheduler.step(noise_pred, t, latents).prev_sample
169
 
170
  return latents
171
 
172
+ @spaces.GPU
173
  def generate_image(prompt, seed=42, num_inference_steps=30, guidance_scale=7.5,
174
  vignette_loss_scale=0.0, concept_style="none", concept_strength=0.5,
175
  height=512, width=512):
 
208
  )
209
 
210
  # Convert latents to image
211
+ images = latents_to_pil(latents, pipe.vae)
212
 
213
  return images[0]
214
 
 
248
  )
249
 
250
  # Convert latents to image
251
+ style_images = latents_to_pil(latents, pipe.vae)
252
  images.append(style_images[0])
253
  labels.append(style)
254
 
 
258
  return grid
259
 
260
  # Define Gradio interface
261
+ @spaces.GPU(enable_queue=False)
262
  def create_demo():
263
  with gr.Blocks(title="Guided Stable Diffusion with Styles") as demo:
264
  gr.Markdown("# Guided Stable Diffusion with Styles")
 
317
  # Launch the app
318
  if __name__ == "__main__":
319
  demo = create_demo()
320
+ demo.launch(debug=False, show_error=True, server_name="0.0.0.0", server_port=7860)