Shilpaj commited on
Commit
179bc9a
·
verified ·
1 Parent(s): 5c6459f

Fix: App issue

Browse files
Files changed (2) hide show
  1. app.py +145 -207
  2. utils.py +39 -82
app.py CHANGED
@@ -4,254 +4,172 @@ Gradio Application for Stable Diffusion
4
  Author: Shilpaj Bhalerao
5
  Date: Feb 26, 2025
6
  """
7
-
8
  import os
9
  import torch
10
  import gradio as gr
11
- import spaces
12
  from tqdm.auto import tqdm
13
  from PIL import Image
14
  from utils import (
15
  load_models, clear_gpu_memory, set_timesteps, latents_to_pil,
16
- vignette_loss, get_concept_embedding, load_concept_library, image_grid
17
  )
18
  from diffusers import StableDiffusionPipeline
19
 
20
- # Set device
21
- device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
22
- if device == "mps":
23
- os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = "1"
24
-
25
- # Load model with proper caching
26
- @spaces.GPU
27
- def load_model():
28
- return StableDiffusionPipeline.from_pretrained(
29
- "runwayml/stable-diffusion-v1-5",
30
- torch_dtype=torch.float16,
31
- safety_checker=None
32
- ).to(device)
33
-
34
- @spaces.GPU
35
- def get_pipeline():
36
- return load_model()
37
-
38
- # Load concept library
39
- concept_embeds, concept_tokens = load_concept_library(get_pipeline())
40
-
41
- # Define art style concepts
42
- art_concepts = {
43
- "sketch_painting": "a sketch painting, pencil drawing, hand-drawn illustration",
44
- "oil_painting": "an oil painting, textured canvas, painterly technique",
45
- "watercolor": "a watercolor painting, fluid, soft edges",
46
- "digital_art": "digital art, computer generated, precise details",
47
- "comic_book": "comic book style, ink outlines, cel shading"
48
- }
49
-
50
- @spaces.GPU
51
- def generate_latents(prompt, seed, num_inference_steps, guidance_scale,
52
- vignette_loss_scale, concept_style=None, concept_strength=0.5,
53
- height=512, width=512):
54
  """
55
- Generate latents using the UNet model
56
-
57
- Args:
58
- prompt (str): Text prompt
59
- seed (int): Random seed
60
- num_inference_steps (int): Number of denoising steps
61
- guidance_scale (float): Scale for classifier-free guidance
62
- vignette_loss_scale (float): Scale for vignette loss
63
- concept_style (str, optional): Style concept to use
64
- concept_strength (float): Strength of concept influence (0.0-1.0)
65
- height (int): Image height
66
- width (int): Image width
67
-
68
- Returns:
69
- torch.Tensor: Generated latents
70
  """
 
 
 
 
 
71
  # Set the seed
72
  generator = torch.manual_seed(seed)
73
- batch_size = 1
74
-
75
- # Clear GPU memory
76
- clear_gpu_memory()
77
-
78
- # Get concept embedding if specified
79
- concept_embedding = None
80
- if concept_style:
81
- if concept_style in concept_tokens:
82
- # Use pre-trained concept embedding
83
- concept_embedding = concept_embeds[concept_style].unsqueeze(0).to(device)
84
- elif concept_style in art_concepts:
85
- # Generate concept embedding from text description
86
- concept_text = art_concepts[concept_style]
87
- concept_embedding = get_concept_embedding(concept_text, get_pipeline().tokenizer, get_pipeline().text_encoder, device)
88
-
89
  # Prep text
90
- text_input = get_pipeline().tokenizer([prompt], padding="max_length", max_length=get_pipeline().tokenizer.model_max_length,
91
- truncation=True, return_tensors="pt")
92
- with torch.inference_mode():
93
- text_embeddings = get_pipeline().text_encoder(text_input.input_ids.to(device))[0]
94
-
 
 
95
  # Apply concept embedding influence if provided
96
  if concept_embedding is not None and concept_strength > 0:
97
  # Fix the dimension mismatch by adding a batch dimension to concept_embedding if needed
98
  if len(concept_embedding.shape) == 2 and len(text_embeddings.shape) == 3:
 
99
  concept_embedding = concept_embedding.unsqueeze(0)
100
-
101
  # Create weighted blend between original text embedding and concept
102
  if text_embeddings.shape == concept_embedding.shape:
 
103
  text_embeddings = (1 - concept_strength) * text_embeddings + concept_strength * concept_embedding
104
-
105
- # Unconditional embedding for classifier-free guidance
 
 
 
106
  max_length = text_input.input_ids.shape[-1]
107
- uncond_input = get_pipeline().tokenizer(
108
  [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
109
  )
110
- with torch.inference_mode():
111
- uncond_embeddings = get_pipeline().text_encoder(uncond_input.input_ids.to(device))[0]
112
  text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
113
-
114
  # Prep Scheduler
115
- set_timesteps(get_pipeline().scheduler, num_inference_steps)
116
-
117
  # Prep latents
118
  latents = torch.randn(
119
- (batch_size, get_pipeline().unet.in_channels, height // 8, width // 8),
120
- generator=generator,
121
  )
122
  latents = latents.to(device)
123
- latents = latents * get_pipeline().scheduler.init_noise_sigma
124
-
125
- # Loop through diffusion process
126
- for i, t in tqdm(enumerate(get_pipeline().scheduler.timesteps), total=len(get_pipeline().scheduler.timesteps)):
127
- # Expand latents for classifier-free guidance
128
  latent_model_input = torch.cat([latents] * 2)
129
- sigma = get_pipeline().scheduler.sigmas[i]
130
- latent_model_input = get_pipeline().scheduler.scale_model_input(latent_model_input, t)
131
-
132
- # Predict the noise residual
133
- with torch.inference_mode():
134
- noise_pred = get_pipeline().unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"]
135
-
136
- # Perform classifier-free guidance
137
  noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
138
  noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
139
-
140
- # Apply additional guidance with vignette loss
141
- if vignette_loss_scale > 0 and i % 5 == 0:
142
  # Requires grad on the latents
143
  latents = latents.detach().requires_grad_()
144
-
145
- # Get the predicted x0
146
  latents_x0 = latents - sigma * noise_pred
147
-
 
148
  # Decode to image space
149
- denoised_images = get_pipeline().vae.decode((1 / 0.18215) * latents_x0).sample / 2 + 0.5 # range (0, 1)
150
-
151
  # Calculate loss
152
  loss = vignette_loss(denoised_images) * vignette_loss_scale
153
-
 
 
 
 
154
  # Get gradient
155
  cond_grad = torch.autograd.grad(loss, latents)[0]
156
-
157
  # Modify the latents based on this gradient
158
  latents = latents.detach() - cond_grad * sigma**2
159
-
160
- # Step with scheduler
161
- latents = get_pipeline().scheduler.step(noise_pred, t, latents).prev_sample
162
-
163
  return latents
164
 
165
- @spaces.GPU
166
  def generate_image(prompt, seed=42, num_inference_steps=30, guidance_scale=7.5,
167
- vignette_loss_scale=0.0, concept_style="none", concept_strength=0.5,
168
- height=512, width=512):
169
  """
170
- Generate an image using Stable Diffusion
171
-
172
- Args:
173
- prompt (str): Text prompt
174
- seed (int): Random seed
175
- num_inference_steps (int): Number of denoising steps
176
- guidance_scale (float): Scale for classifier-free guidance
177
- vignette_loss_scale (float): Scale for vignette loss
178
- concept_style (str): Style concept to use
179
- concept_strength (float): Strength of concept influence (0.0-1.0)
180
- height (int): Image height
181
- width (int): Image width
182
-
183
- Returns:
184
- PIL.Image: Generated image
185
  """
186
- # Handle "none" concept style
187
- if concept_style == "none":
188
- concept_style = None
189
-
190
- # Generate latents
191
- latents = generate_latents(
192
- prompt=prompt,
193
- seed=seed,
194
- num_inference_steps=num_inference_steps,
195
- guidance_scale=guidance_scale,
196
- vignette_loss_scale=vignette_loss_scale,
197
- concept_style=concept_style,
198
- concept_strength=concept_strength,
199
- height=height,
200
- width=width
201
- )
202
-
203
- # Convert latents to image
204
- images = latents_to_pil(latents, get_pipeline().vae)
205
-
206
- return images[0]
207
 
208
- def generate_style_grid(prompt, seed=42, num_inference_steps=30, guidance_scale=7.5,
209
- vignette_loss_scale=0.0, concept_strength=0.5):
 
210
  """
211
- Generate a grid of images with different style concepts
212
-
213
- Args:
214
- prompt (str): Text prompt
215
- seed (int): Random seed
216
- num_inference_steps (int): Number of denoising steps
217
- guidance_scale (float): Scale for classifier-free guidance
218
- vignette_loss_scale (float): Scale for vignette loss
219
- concept_strength (float): Strength of concept influence (0.0-1.0)
220
-
221
- Returns:
222
- PIL.Image: Grid of generated images
223
  """
224
- # List of styles to use
225
- styles = list(art_concepts.keys())
226
-
227
- # Generate images for each style
228
- images = []
229
- labels = []
230
-
231
- for i, style in enumerate(styles):
232
- # Generate image with this style
233
- latents = generate_latents(
234
- prompt=prompt,
235
- seed=seed + i, # Use different seeds for variety
236
- num_inference_steps=num_inference_steps,
237
- guidance_scale=guidance_scale,
238
- vignette_loss_scale=vignette_loss_scale,
239
- concept_style=style,
240
- concept_strength=concept_strength
241
- )
242
-
243
- # Convert latents to image
244
- style_images = latents_to_pil(latents, get_pipeline().vae)
245
- images.append(style_images[0])
246
- labels.append(style)
247
-
248
- # Create grid
249
- grid = image_grid(images, 1, len(styles), labels)
250
-
251
- return grid
252
 
253
  # Define Gradio interface
254
- @spaces.GPU(enable_queue=False)
255
  def create_demo():
256
  with gr.Blocks(title="Guided Stable Diffusion with Styles") as demo:
257
  gr.Markdown("# Guided Stable Diffusion with Styles")
@@ -259,15 +177,17 @@ def create_demo():
259
  with gr.Tab("Single Image Generation"):
260
  with gr.Row():
261
  with gr.Column():
 
 
262
  prompt = gr.Textbox(label="Prompt", placeholder="A cat sitting on a chair")
263
- seed = gr.Slider(minimum=0, maximum=10000, step=1, label="Seed", value=42)
264
- num_inference_steps = gr.Slider(minimum=10, maximum=100, step=1, label="Inference Steps", value=30)
265
- guidance_scale = gr.Slider(minimum=1.0, maximum=15.0, step=0.1, label="Guidance Scale", value=7.5)
266
- vignette_loss_scale = gr.Slider(minimum=0.0, maximum=100.0, step=1.0, label="Vignette Loss Scale", value=0.0)
267
-
268
- all_styles = ["none"] + concept_tokens + list(art_concepts.keys())
269
  concept_style = gr.Dropdown(choices=all_styles, label="Style Concept", value="none")
270
  concept_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label="Concept Strength", value=0.5)
 
 
 
 
 
271
 
272
  generate_btn = gr.Button("Generate Image")
273
 
@@ -278,10 +198,9 @@ def create_demo():
278
  with gr.Row():
279
  with gr.Column():
280
  grid_prompt = gr.Textbox(label="Prompt", placeholder="A dog running in the park")
281
- grid_seed = gr.Slider(minimum=0, maximum=10000, step=1, label="Base Seed", value=42)
282
  grid_num_inference_steps = gr.Slider(minimum=10, maximum=100, step=1, label="Inference Steps", value=30)
283
- grid_guidance_scale = gr.Slider(minimum=1.0, maximum=15.0, step=0.1, label="Guidance Scale", value=7.5)
284
- grid_vignette_loss_scale = gr.Slider(minimum=0.0, maximum=100.0, step=1.0, label="Vignette Loss Scale", value=0.0)
285
  grid_concept_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label="Concept Strength", value=0.5)
286
 
287
  grid_generate_btn = gr.Button("Generate Style Grid")
@@ -291,15 +210,15 @@ def create_demo():
291
 
292
  # Set up event handlers
293
  generate_btn.click(
294
- generate_latents,
295
  inputs=[prompt, seed, num_inference_steps, guidance_scale,
296
- vignette_loss_scale, concept_style, concept_strength],
297
  outputs=output_image
298
  )
299
 
300
  grid_generate_btn.click(
301
- generate_style_grid,
302
- inputs=[grid_prompt, grid_seed, grid_num_inference_steps,
303
  grid_guidance_scale, grid_vignette_loss_scale, grid_concept_strength],
304
  outputs=output_grid
305
  )
@@ -308,5 +227,24 @@ def create_demo():
308
 
309
  # Launch the app
310
  if __name__ == "__main__":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
311
  demo = create_demo()
312
- demo.launch(debug=False, show_error=True, server_name="0.0.0.0", server_port=7860, cache_examples=True)
 
4
  Author: Shilpaj Bhalerao
5
  Date: Feb 26, 2025
6
  """
7
+ import gc
8
  import os
9
  import torch
10
  import gradio as gr
11
+ # import spaces
12
  from tqdm.auto import tqdm
13
  from PIL import Image
14
  from utils import (
15
  load_models, clear_gpu_memory, set_timesteps, latents_to_pil,
16
+ vignette_loss, get_concept_embedding, image_grid
17
  )
18
  from diffusers import StableDiffusionPipeline
19
 
20
+
21
+ def generate_latents(prompt, seed, num_inference_steps, guidance_scale, vignette_loss_scale, concept, concept_strength, height, width):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  """
23
+ Function to generate latents from the UNet
24
+ :param seed_number: Seed
25
+ :param prompt: Text prompt
26
+ :param concept: Concept to influence generation (optional)
27
+ :param concept_strength: How strongly to apply the concept (0.0-1.0)
28
+ :return: Latents of the UNet. This will be passed to the VAE to generate the image
 
 
 
 
 
 
 
 
 
29
  """
30
+ global art_concepts
31
+
32
+ # Batch size
33
+ batch_size = 1
34
+
35
  # Set the seed
36
  generator = torch.manual_seed(seed)
37
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  # Prep text
39
+ text_input = tokenizer([prompt], padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
40
+ with torch.no_grad():
41
+ text_embeddings = text_encoder(text_input.input_ids.to(device))[0]
42
+
43
+ # Get the concept embedding
44
+ concept_embedding = art_concepts[concept]
45
+
46
  # Apply concept embedding influence if provided
47
  if concept_embedding is not None and concept_strength > 0:
48
  # Fix the dimension mismatch by adding a batch dimension to concept_embedding if needed
49
  if len(concept_embedding.shape) == 2 and len(text_embeddings.shape) == 3:
50
+ # Add batch dimension to concept_embedding to match text_embeddings
51
  concept_embedding = concept_embedding.unsqueeze(0)
52
+
53
  # Create weighted blend between original text embedding and concept
54
  if text_embeddings.shape == concept_embedding.shape:
55
+ # Interpolate between text embeddings and concept
56
  text_embeddings = (1 - concept_strength) * text_embeddings + concept_strength * concept_embedding
57
+ print(f"Successfully applied concept with strength {concept_strength}")
58
+ else:
59
+ print(f"Warning: Shapes still incompatible after adjustment. Concept: {concept_embedding.shape}, Text: {text_embeddings.shape}")
60
+
61
+ # And the uncond. input as before:
62
  max_length = text_input.input_ids.shape[-1]
63
+ uncond_input = tokenizer(
64
  [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
65
  )
66
+ with torch.no_grad():
67
+ uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0]
68
  text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
69
+
70
  # Prep Scheduler
71
+ set_timesteps(scheduler, num_inference_steps)
72
+
73
  # Prep latents
74
  latents = torch.randn(
75
+ (batch_size, unet.in_channels, height // 8, width // 8),
76
+ generator=generator,
77
  )
78
  latents = latents.to(device)
79
+ latents = latents * scheduler.init_noise_sigma
80
+
81
+ # Loop
82
+ for i, t in tqdm(enumerate(scheduler.timesteps), total=len(scheduler.timesteps)):
83
+ # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
84
  latent_model_input = torch.cat([latents] * 2)
85
+ sigma = scheduler.sigmas[i]
86
+ latent_model_input = scheduler.scale_model_input(latent_model_input, t)
87
+
88
+ # predict the noise residual
89
+ with torch.no_grad():
90
+ noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"]
91
+
92
+ # perform CFG
93
  noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
94
  noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
95
+
96
+ #### ADDITIONAL GUIDANCE ###
97
+ if i%5 == 0:
98
  # Requires grad on the latents
99
  latents = latents.detach().requires_grad_()
100
+
101
+ # Get the predicted x0:
102
  latents_x0 = latents - sigma * noise_pred
103
+ # latents_x0 = scheduler.step(noise_pred, t, latents).pred_original_sample
104
+
105
  # Decode to image space
106
+ denoised_images = vae.decode((1 / 0.18215) * latents_x0).sample / 2 + 0.5 # range (0, 1)
107
+
108
  # Calculate loss
109
  loss = vignette_loss(denoised_images) * vignette_loss_scale
110
+
111
+ # Occasionally print it out
112
+ if i%10==0:
113
+ print(i, 'loss:', loss.item())
114
+
115
  # Get gradient
116
  cond_grad = torch.autograd.grad(loss, latents)[0]
117
+
118
  # Modify the latents based on this gradient
119
  latents = latents.detach() - cond_grad * sigma**2
120
+
121
+ # Now step with scheduler
122
+ latents = scheduler.step(noise_pred, t, latents).prev_sample
 
123
  return latents
124
 
125
+
126
  def generate_image(prompt, seed=42, num_inference_steps=30, guidance_scale=7.5,
127
+ vignette_loss_scale=0.0, concept="none", concept_strength=0.5, height=512, width=512):
 
128
  """
129
+ Generate a single image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
130
  """
131
+ global vae
132
+ latents = generate_latents(prompt, seed, num_inference_steps, guidance_scale, vignette_loss_scale, concept, concept_strength, height, width)
133
+ generated_image = latents_to_pil(latents, vae)
134
+ return image_grid(generated_image, 1, 1, None)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
+
137
+ def generate_style_images(prompt, num_inference_steps=30, guidance_scale=7.5,
138
+ vignette_loss_scale=0.0, concept_strength=0.5, height=512, width=512):
139
  """
140
+ Function to generate images of all the styles
 
 
 
 
 
 
 
 
 
 
 
141
  """
142
+ global art_concepts, vae
143
+ seed_list = [2000, 1000, 500, 600, 100]
144
+
145
+ latents_collect = []
146
+ concept_labels = []
147
+
148
+ # Load and remove the "none" element
149
+ concepts_list = art_concepts.keys()
150
+ concepts_list.pop()
151
+
152
+ for seed_no, concept in zip(seed_list, concepts_list):
153
+ # Clear the CUDA cache
154
+ torch.cuda.empty_cache()
155
+ gc.collect()
156
+ torch.cuda.empty_cache()
157
+
158
+ print(f"Generating image with concept '{concept}' at strength {concept_strength}")
159
+
160
+ # Generate latents using the concept embedding
161
+ latents = generate_latents(prompt, seed_no, num_inference_steps, guidance_scale, vignette_loss_scale, concept, concept_strength, height, width)
162
+ latents_collect.append(latents)
163
+ concept_labels.append(f"{concept} ({concept_strength})")
164
+
165
+ # Show results
166
+ latents_collect = torch.vstack(latents_collect)
167
+ images = latents_to_pil(latents_collect, vae)
168
+ return image_grid(images, 1, len(seed_list), concept_labels)
169
+
170
 
171
  # Define Gradio interface
172
+ # @spaces.GPU(enable_queue=False)
173
  def create_demo():
174
  with gr.Blocks(title="Guided Stable Diffusion with Styles") as demo:
175
  gr.Markdown("# Guided Stable Diffusion with Styles")
 
177
  with gr.Tab("Single Image Generation"):
178
  with gr.Row():
179
  with gr.Column():
180
+ all_styles = ["none"] + list(art_concepts.keys())
181
+
182
  prompt = gr.Textbox(label="Prompt", placeholder="A cat sitting on a chair")
183
+ seed = gr.Slider(minimum=0, maximum=10000, step=1, label="Seed", value=1000)
 
 
 
 
 
184
  concept_style = gr.Dropdown(choices=all_styles, label="Style Concept", value="none")
185
  concept_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label="Concept Strength", value=0.5)
186
+ num_inference_steps = gr.Slider(minimum=10, maximum=100, step=1, label="Inference Steps", value=30)
187
+ height = gr.Slider(minimum=256, maximum=1024, step=1, label="Height", value=512)
188
+ width = gr.Slider(minimum=256, maximum=1024, step=1, label="Width", value=512)
189
+ guidance_scale = gr.Slider(minimum=1.0, maximum=15.0, step=0.1, label="Guidance Scale", value=8.0)
190
+ vignette_loss_scale = gr.Slider(minimum=0.0, maximum=100.0, step=1.0, label="Vignette Loss Scale", value=70.0)
191
 
192
  generate_btn = gr.Button("Generate Image")
193
 
 
198
  with gr.Row():
199
  with gr.Column():
200
  grid_prompt = gr.Textbox(label="Prompt", placeholder="A dog running in the park")
 
201
  grid_num_inference_steps = gr.Slider(minimum=10, maximum=100, step=1, label="Inference Steps", value=30)
202
+ grid_guidance_scale = gr.Slider(minimum=1.0, maximum=15.0, step=0.1, label="Guidance Scale", value=8.0)
203
+ grid_vignette_loss_scale = gr.Slider(minimum=0.0, maximum=100.0, step=1.0, label="Vignette Loss Scale", value=70.0)
204
  grid_concept_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label="Concept Strength", value=0.5)
205
 
206
  grid_generate_btn = gr.Button("Generate Style Grid")
 
210
 
211
  # Set up event handlers
212
  generate_btn.click(
213
+ generate_image,
214
  inputs=[prompt, seed, num_inference_steps, guidance_scale,
215
+ vignette_loss_scale, concept_style, concept_strength, height, width],
216
  outputs=output_image
217
  )
218
 
219
  grid_generate_btn.click(
220
+ generate_style_images,
221
+ inputs=[grid_prompt, grid_num_inference_steps,
222
  grid_guidance_scale, grid_vignette_loss_scale, grid_concept_strength],
223
  outputs=output_grid
224
  )
 
227
 
228
  # Launch the app
229
  if __name__ == "__main__":
230
+
231
+ # Set device
232
+ device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
233
+ if device == "mps":
234
+ os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = "1"
235
+
236
+ # Load models
237
+ vae, tokenizer, text_encoder, unet, scheduler, pipe = load_models(device=device)
238
+
239
+ # Define art style concepts
240
+ art_concepts = {
241
+ "sketch_painting": get_concept_embedding("a sketch painting, pencil drawing, hand-drawn illustration", tokenizer, text_encoder, device),
242
+ "oil_painting": get_concept_embedding("an oil painting, textured canvas, painterly technique", tokenizer, text_encoder, device),
243
+ "watercolor": get_concept_embedding("a watercolor painting, fluid, soft edges", tokenizer, text_encoder, device),
244
+ "digital_art": get_concept_embedding("digital art, computer generated, precise details", tokenizer, text_encoder, device),
245
+ "comic_book": get_concept_embedding("comic book style, ink outlines, cel shading", tokenizer, text_encoder, device),
246
+ "none": None
247
+ }
248
+
249
  demo = create_demo()
250
+ demo.launch(debug=True)
utils.py CHANGED
@@ -15,15 +15,12 @@ from transformers import CLIPTokenizer, CLIPTextModel
15
  # Disable HF transfer to avoid download issues
16
  os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "0"
17
 
 
18
  def load_models(device="cuda"):
19
  """
20
  Load the necessary models for stable diffusion
21
-
22
- Args:
23
- device (str): Device to load models on ('cuda', 'mps', or 'cpu')
24
-
25
- Returns:
26
- tuple: (vae, tokenizer, text_encoder, unet, scheduler, pipe)
27
  """
28
  from diffusers import AutoencoderKL, LMSDiscreteScheduler, UNet2DConditionModel
29
 
@@ -63,27 +60,32 @@ def load_models(device="cuda"):
63
 
64
  return vae, tokenizer, text_encoder, unet, scheduler, pipe
65
 
 
66
  def clear_gpu_memory():
67
- """Clear GPU memory cache"""
 
 
68
  torch.cuda.empty_cache()
69
  gc.collect()
70
 
 
71
  def set_timesteps(scheduler, num_inference_steps):
72
- """Set timesteps for the scheduler with MPS compatibility fix"""
 
 
 
 
73
  scheduler.set_timesteps(num_inference_steps)
74
  scheduler.timesteps = scheduler.timesteps.to(torch.float32)
75
 
 
76
  def pil_to_latent(input_im, vae, device):
77
  """
78
  Convert the image to latents
79
-
80
- Args:
81
- input_im: Input PIL image
82
- vae: VAE model
83
- device: Device to run on
84
-
85
- Returns:
86
- Latents from VAE's encoder
87
  """
88
  from torchvision import transforms as tfms
89
 
@@ -92,16 +94,13 @@ def pil_to_latent(input_im, vae, device):
92
  latent = vae.encode(tfms.ToTensor()(input_im).unsqueeze(0).to(device)*2-1) # Note scaling
93
  return 0.18215 * latent.latent_dist.sample()
94
 
 
95
  def latents_to_pil(latents, vae):
96
  """
97
  Convert the latents to images
98
-
99
- Args:
100
- latents: Latent tensor
101
- vae: VAE model
102
-
103
- Returns:
104
- list: PIL images
105
  """
106
  # batch of latents -> list of images
107
  latents = (1 / 0.18215) * latents
@@ -113,18 +112,15 @@ def latents_to_pil(latents, vae):
113
  pil_images = [Image.fromarray(image) for image in images]
114
  return pil_images
115
 
 
116
  def image_grid(imgs, rows, cols, labels=None):
117
  """
118
  Create a grid of images with optional labels.
119
-
120
- Args:
121
- imgs (list): List of PIL images to be arranged in a grid
122
- rows (int): Number of rows in the grid
123
- cols (int): Number of columns in the grid
124
- labels (list, optional): List of label strings for each image
125
-
126
- Returns:
127
- PIL.Image: A single image with all input images arranged in a grid and labeled
128
  """
129
  assert len(imgs) == rows*cols, f"Number of images ({len(imgs)}) must equal rows*cols ({rows*cols})"
130
 
@@ -164,17 +160,14 @@ def image_grid(imgs, rows, cols, labels=None):
164
 
165
  return grid
166
 
 
167
  def vignette_loss(images, vignette_strength=3.0, color_shift=[1.0, 0.5, 0.0]):
168
  """
169
  Creates a strong vignette effect (dark corners) and color shift.
170
-
171
- Args:
172
- images: Batch of images from VAE decoder (range 0-1)
173
- vignette_strength: How strong the darkening effect is (higher = more dramatic)
174
- color_shift: RGB color to shift the center toward [r, g, b]
175
-
176
- Returns:
177
- torch.Tensor: Loss value
178
  """
179
  batch_size, channels, height, width = images.shape
180
 
@@ -209,18 +202,15 @@ def vignette_loss(images, vignette_strength=3.0, color_shift=[1.0, 0.5, 0.0]):
209
  # Calculate loss - how different current image is from our target
210
  return torch.pow(images - target, 2).mean()
211
 
 
212
  def get_concept_embedding(concept_text, tokenizer, text_encoder, device):
213
  """
214
  Generate CLIP embedding for a concept described in text
215
-
216
- Args:
217
- concept_text (str): Text description of the concept (e.g., "sketch painting")
218
- tokenizer: CLIP tokenizer
219
- text_encoder: CLIP text encoder
220
- device: Device to run on
221
-
222
- Returns:
223
- torch.Tensor: CLIP embedding for the concept
224
  """
225
  # Tokenize the concept text
226
  concept_tokens = tokenizer(
@@ -236,36 +226,3 @@ def get_concept_embedding(concept_text, tokenizer, text_encoder, device):
236
  concept_embedding = text_encoder(concept_tokens)[0]
237
 
238
  return concept_embedding
239
-
240
- def load_concept_library(pipe):
241
- """
242
- Load textual inversion concepts from the SD concept library
243
-
244
- Args:
245
- pipe: StableDiffusionPipeline
246
-
247
- Returns:
248
- dict: Dictionary of token to embedding mappings
249
- """
250
- # Load textual inversion embeddings
251
- pipe.load_textual_inversion("sd-concepts-library/dreams")
252
- pipe.load_textual_inversion("sd-concepts-library/midjourney-style")
253
- pipe.load_textual_inversion("sd-concepts-library/moebius")
254
- pipe.load_textual_inversion("sd-concepts-library/style-of-marc-allante")
255
- pipe.load_textual_inversion("sd-concepts-library/wlop-style")
256
-
257
- # Extract the embeddings from the pipeline
258
- tokens = ['<meeg>', '<midjourney-style>', '<moebius>', '<Marc_Allante>', '<wlop-style>']
259
- token_ids = pipe.tokenizer.convert_tokens_to_ids(tokens)
260
- embeddings = pipe.text_encoder.get_input_embeddings().weight[token_ids].detach().cpu()
261
-
262
- # Create a dictionary with the embeddings
263
- learned_embeds = {}
264
- for i, token in enumerate(tokens):
265
- learned_embeds[token] = embeddings[i]
266
-
267
- # Save the embeddings for future use
268
- torch.save(learned_embeds, "learned_embeds.bin")
269
- print(f"Saved embeddings for tokens: {', '.join(tokens)}")
270
-
271
- return learned_embeds, tokens
 
15
  # Disable HF transfer to avoid download issues
16
  os.environ["HF_HUB_ENABLE_HF_TRANSFER"] = "0"
17
 
18
+
19
  def load_models(device="cuda"):
20
  """
21
  Load the necessary models for stable diffusion
22
+ :param device: (str) Device to load models on ('cuda', 'mps', or 'cpu')
23
+ :return: (tuple) (vae, tokenizer, text_encoder, unet, scheduler, pipe)
 
 
 
 
24
  """
25
  from diffusers import AutoencoderKL, LMSDiscreteScheduler, UNet2DConditionModel
26
 
 
60
 
61
  return vae, tokenizer, text_encoder, unet, scheduler, pipe
62
 
63
+
64
  def clear_gpu_memory():
65
+ """
66
+ Clear GPU memory cache
67
+ """
68
  torch.cuda.empty_cache()
69
  gc.collect()
70
 
71
+
72
  def set_timesteps(scheduler, num_inference_steps):
73
+ """
74
+ Set timesteps for the scheduler with MPS compatibility fix
75
+ :param scheduler: (Scheduler) Scheduler to set timesteps for
76
+ :param num_inference_steps: (int) Number of inference steps
77
+ """
78
  scheduler.set_timesteps(num_inference_steps)
79
  scheduler.timesteps = scheduler.timesteps.to(torch.float32)
80
 
81
+
82
  def pil_to_latent(input_im, vae, device):
83
  """
84
  Convert the image to latents
85
+ :param input_im: (PIL.Image) Input PIL image
86
+ :param vae: (VAE) VAE model
87
+ :param device: (str) Device to run on
88
+ :return: (torch.Tensor) Latents from VAE's encoder
 
 
 
 
89
  """
90
  from torchvision import transforms as tfms
91
 
 
94
  latent = vae.encode(tfms.ToTensor()(input_im).unsqueeze(0).to(device)*2-1) # Note scaling
95
  return 0.18215 * latent.latent_dist.sample()
96
 
97
+
98
  def latents_to_pil(latents, vae):
99
  """
100
  Convert the latents to images
101
+ :param latents: (torch.Tensor) Latent tensor
102
+ :param vae: (VAE) VAE model
103
+ :return: (list) PIL images
 
 
 
 
104
  """
105
  # batch of latents -> list of images
106
  latents = (1 / 0.18215) * latents
 
112
  pil_images = [Image.fromarray(image) for image in images]
113
  return pil_images
114
 
115
+
116
  def image_grid(imgs, rows, cols, labels=None):
117
  """
118
  Create a grid of images with optional labels.
119
+ :param imgs: (list) List of PIL images to be arranged in a grid
120
+ :param rows: (int) Number of rows in the grid
121
+ :param cols: (int) Number of columns in the grid
122
+ :param labels: (list, optional) List of label strings for each image
123
+ :return: (PIL.Image) A single image with all input images arranged in a grid and labeled
 
 
 
 
124
  """
125
  assert len(imgs) == rows*cols, f"Number of images ({len(imgs)}) must equal rows*cols ({rows*cols})"
126
 
 
160
 
161
  return grid
162
 
163
+
164
  def vignette_loss(images, vignette_strength=3.0, color_shift=[1.0, 0.5, 0.0]):
165
  """
166
  Creates a strong vignette effect (dark corners) and color shift.
167
+ :param images: (torch.Tensor) Batch of images from VAE decoder (range 0-1)
168
+ :param vignette_strength: (float) How strong the darkening effect is (higher = more dramatic)
169
+ :param color_shift: (list) RGB color to shift the center toward [r, g, b]
170
+ :return: (torch.Tensor) Loss value
 
 
 
 
171
  """
172
  batch_size, channels, height, width = images.shape
173
 
 
202
  # Calculate loss - how different current image is from our target
203
  return torch.pow(images - target, 2).mean()
204
 
205
+
206
  def get_concept_embedding(concept_text, tokenizer, text_encoder, device):
207
  """
208
  Generate CLIP embedding for a concept described in text
209
+ :param concept_text: (str) Text description of the concept (e.g., "sketch painting")
210
+ :param tokenizer: (CLIPTokenizer) CLIP tokenizer
211
+ :param text_encoder: (CLIPTextModel) CLIP text encoder
212
+ :param device: (str) Device to run on
213
+ :return: (torch.Tensor) CLIP embedding for the concept
 
 
 
 
214
  """
215
  # Tokenize the concept text
216
  concept_tokens = tokenizer(
 
226
  concept_embedding = text_encoder(concept_tokens)[0]
227
 
228
  return concept_embedding