Shilpaj commited on
Commit
74fa5e8
·
verified ·
1 Parent(s): 2e09a45

Feat: App files

Browse files
Files changed (3) hide show
  1. app.py +308 -0
  2. requirements.txt +16 -0
  3. utils.py +268 -0
app.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Gradio Application for Stable Diffusion
4
+ Author: Shilpaj Bhalerao
5
+ Date: Feb 26, 2025
6
+ """
7
+
8
+ import os
9
+ import torch
10
+ import gradio as gr
11
+ from tqdm.auto import tqdm
12
+ import numpy as np
13
+ from PIL import Image
14
+ from utils import (
15
+ load_models, clear_gpu_memory, set_timesteps, latents_to_pil,
16
+ vignette_loss, get_concept_embedding, load_concept_library, image_grid
17
+ )
18
+
19
+ # Hugging Face Space configuration
20
+ # Use @space decorator to configure the Space
21
+ # This will set the Space to use zero GPU resources
22
+ @gr.Blocks.add_decorator
23
+ def space(demo, **kwargs):
24
+ demo.queue(concurrency_count=1, max_size=10)
25
+ return demo
26
+
27
+ # Set device
28
+ device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
29
+ if device == "mps":
30
+ os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = "1"
31
+
32
+ # Load models
33
+ vae, tokenizer, text_encoder, unet, scheduler, pipe = load_models(device)
34
+
35
+ # Load concept library
36
+ concept_embeds, concept_tokens = load_concept_library(pipe)
37
+
38
+ # Define art style concepts
39
+ art_concepts = {
40
+ "sketch_painting": "a sketch painting, pencil drawing, hand-drawn illustration",
41
+ "oil_painting": "an oil painting, textured canvas, painterly technique",
42
+ "watercolor": "a watercolor painting, fluid, soft edges",
43
+ "digital_art": "digital art, computer generated, precise details",
44
+ "comic_book": "comic book style, ink outlines, cel shading"
45
+ }
46
+
47
+ def generate_latents(prompt, seed, num_inference_steps, guidance_scale,
48
+ vignette_loss_scale, concept_style=None, concept_strength=0.5,
49
+ height=512, width=512):
50
+ """
51
+ Generate latents using the UNet model
52
+
53
+ Args:
54
+ prompt (str): Text prompt
55
+ seed (int): Random seed
56
+ num_inference_steps (int): Number of denoising steps
57
+ guidance_scale (float): Scale for classifier-free guidance
58
+ vignette_loss_scale (float): Scale for vignette loss
59
+ concept_style (str, optional): Style concept to use
60
+ concept_strength (float): Strength of concept influence (0.0-1.0)
61
+ height (int): Image height
62
+ width (int): Image width
63
+
64
+ Returns:
65
+ torch.Tensor: Generated latents
66
+ """
67
+ # Set the seed
68
+ generator = torch.manual_seed(seed)
69
+ batch_size = 1
70
+
71
+ # Clear GPU memory
72
+ clear_gpu_memory()
73
+
74
+ # Get concept embedding if specified
75
+ concept_embedding = None
76
+ if concept_style:
77
+ if concept_style in concept_tokens:
78
+ # Use pre-trained concept embedding
79
+ concept_embedding = concept_embeds[concept_style].unsqueeze(0).to(device)
80
+ elif concept_style in art_concepts:
81
+ # Generate concept embedding from text description
82
+ concept_text = art_concepts[concept_style]
83
+ concept_embedding = get_concept_embedding(concept_text, tokenizer, text_encoder, device)
84
+
85
+ # Prep text
86
+ text_input = tokenizer([prompt], padding="max_length", max_length=tokenizer.model_max_length,
87
+ truncation=True, return_tensors="pt")
88
+ with torch.no_grad():
89
+ text_embeddings = text_encoder(text_input.input_ids.to(device))[0]
90
+
91
+ # Apply concept embedding influence if provided
92
+ if concept_embedding is not None and concept_strength > 0:
93
+ # Fix the dimension mismatch by adding a batch dimension to concept_embedding if needed
94
+ if len(concept_embedding.shape) == 2 and len(text_embeddings.shape) == 3:
95
+ concept_embedding = concept_embedding.unsqueeze(0)
96
+
97
+ # Create weighted blend between original text embedding and concept
98
+ if text_embeddings.shape == concept_embedding.shape:
99
+ text_embeddings = (1 - concept_strength) * text_embeddings + concept_strength * concept_embedding
100
+
101
+ # Unconditional embedding for classifier-free guidance
102
+ max_length = text_input.input_ids.shape[-1]
103
+ uncond_input = tokenizer(
104
+ [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
105
+ )
106
+ with torch.no_grad():
107
+ uncond_embeddings = text_encoder(uncond_input.input_ids.to(device))[0]
108
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
109
+
110
+ # Prep Scheduler
111
+ set_timesteps(scheduler, num_inference_steps)
112
+
113
+ # Prep latents
114
+ latents = torch.randn(
115
+ (batch_size, unet.in_channels, height // 8, width // 8),
116
+ generator=generator,
117
+ )
118
+ latents = latents.to(device)
119
+ latents = latents * scheduler.init_noise_sigma
120
+
121
+ # Loop through diffusion process
122
+ for i, t in tqdm(enumerate(scheduler.timesteps), total=len(scheduler.timesteps)):
123
+ # Expand latents for classifier-free guidance
124
+ latent_model_input = torch.cat([latents] * 2)
125
+ sigma = scheduler.sigmas[i]
126
+ latent_model_input = scheduler.scale_model_input(latent_model_input, t)
127
+
128
+ # Predict the noise residual
129
+ with torch.no_grad():
130
+ noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"]
131
+
132
+ # Perform classifier-free guidance
133
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
134
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
135
+
136
+ # Apply additional guidance with vignette loss
137
+ if vignette_loss_scale > 0 and i % 5 == 0:
138
+ # Requires grad on the latents
139
+ latents = latents.detach().requires_grad_()
140
+
141
+ # Get the predicted x0
142
+ latents_x0 = latents - sigma * noise_pred
143
+
144
+ # Decode to image space
145
+ denoised_images = vae.decode((1 / 0.18215) * latents_x0).sample / 2 + 0.5 # range (0, 1)
146
+
147
+ # Calculate loss
148
+ loss = vignette_loss(denoised_images) * vignette_loss_scale
149
+
150
+ # Get gradient
151
+ cond_grad = torch.autograd.grad(loss, latents)[0]
152
+
153
+ # Modify the latents based on this gradient
154
+ latents = latents.detach() - cond_grad * sigma**2
155
+
156
+ # Step with scheduler
157
+ latents = scheduler.step(noise_pred, t, latents).prev_sample
158
+
159
+ return latents
160
+
161
+ def generate_image(prompt, seed=42, num_inference_steps=30, guidance_scale=7.5,
162
+ vignette_loss_scale=0.0, concept_style="none", concept_strength=0.5,
163
+ height=512, width=512):
164
+ """
165
+ Generate an image using Stable Diffusion
166
+
167
+ Args:
168
+ prompt (str): Text prompt
169
+ seed (int): Random seed
170
+ num_inference_steps (int): Number of denoising steps
171
+ guidance_scale (float): Scale for classifier-free guidance
172
+ vignette_loss_scale (float): Scale for vignette loss
173
+ concept_style (str): Style concept to use
174
+ concept_strength (float): Strength of concept influence (0.0-1.0)
175
+ height (int): Image height
176
+ width (int): Image width
177
+
178
+ Returns:
179
+ PIL.Image: Generated image
180
+ """
181
+ # Handle "none" concept style
182
+ if concept_style == "none":
183
+ concept_style = None
184
+
185
+ # Generate latents
186
+ latents = generate_latents(
187
+ prompt=prompt,
188
+ seed=seed,
189
+ num_inference_steps=num_inference_steps,
190
+ guidance_scale=guidance_scale,
191
+ vignette_loss_scale=vignette_loss_scale,
192
+ concept_style=concept_style,
193
+ concept_strength=concept_strength,
194
+ height=height,
195
+ width=width
196
+ )
197
+
198
+ # Convert latents to image
199
+ images = latents_to_pil(latents, vae)
200
+
201
+ return images[0]
202
+
203
+ def generate_style_grid(prompt, seed=42, num_inference_steps=30, guidance_scale=7.5,
204
+ vignette_loss_scale=0.0, concept_strength=0.5):
205
+ """
206
+ Generate a grid of images with different style concepts
207
+
208
+ Args:
209
+ prompt (str): Text prompt
210
+ seed (int): Random seed
211
+ num_inference_steps (int): Number of denoising steps
212
+ guidance_scale (float): Scale for classifier-free guidance
213
+ vignette_loss_scale (float): Scale for vignette loss
214
+ concept_strength (float): Strength of concept influence (0.0-1.0)
215
+
216
+ Returns:
217
+ PIL.Image: Grid of generated images
218
+ """
219
+ # List of styles to use
220
+ styles = list(art_concepts.keys())
221
+
222
+ # Generate images for each style
223
+ images = []
224
+ labels = []
225
+
226
+ for i, style in enumerate(styles):
227
+ # Generate image with this style
228
+ latents = generate_latents(
229
+ prompt=prompt,
230
+ seed=seed + i, # Use different seeds for variety
231
+ num_inference_steps=num_inference_steps,
232
+ guidance_scale=guidance_scale,
233
+ vignette_loss_scale=vignette_loss_scale,
234
+ concept_style=style,
235
+ concept_strength=concept_strength
236
+ )
237
+
238
+ # Convert latents to image
239
+ style_images = latents_to_pil(latents, vae)
240
+ images.append(style_images[0])
241
+ labels.append(style)
242
+
243
+ # Create grid
244
+ grid = image_grid(images, 1, len(styles), labels)
245
+
246
+ return grid
247
+
248
+ # Define Gradio interface
249
+ @space
250
+ def create_demo():
251
+ with gr.Blocks(title="Guided Stable Diffusion with Styles") as demo:
252
+ gr.Markdown("# Guided Stable Diffusion with Styles")
253
+
254
+ with gr.Tab("Single Image Generation"):
255
+ with gr.Row():
256
+ with gr.Column():
257
+ prompt = gr.Textbox(label="Prompt", placeholder="A cat sitting on a chair")
258
+ seed = gr.Slider(minimum=0, maximum=10000, step=1, label="Seed", value=42)
259
+ num_inference_steps = gr.Slider(minimum=10, maximum=100, step=1, label="Inference Steps", value=30)
260
+ guidance_scale = gr.Slider(minimum=1.0, maximum=15.0, step=0.1, label="Guidance Scale", value=7.5)
261
+ vignette_loss_scale = gr.Slider(minimum=0.0, maximum=100.0, step=1.0, label="Vignette Loss Scale", value=0.0)
262
+
263
+ # Combine SD concept library tokens and art concept descriptions
264
+ all_styles = ["none"] + concept_tokens + list(art_concepts.keys())
265
+ concept_style = gr.Dropdown(choices=all_styles, label="Style Concept", value="none")
266
+ concept_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label="Concept Strength", value=0.5)
267
+
268
+ generate_btn = gr.Button("Generate Image")
269
+
270
+ with gr.Column():
271
+ output_image = gr.Image(label="Generated Image", type="pil")
272
+
273
+ with gr.Tab("Style Grid"):
274
+ with gr.Row():
275
+ with gr.Column():
276
+ grid_prompt = gr.Textbox(label="Prompt", placeholder="A dog running in the park")
277
+ grid_seed = gr.Slider(minimum=0, maximum=10000, step=1, label="Base Seed", value=42)
278
+ grid_num_inference_steps = gr.Slider(minimum=10, maximum=100, step=1, label="Inference Steps", value=30)
279
+ grid_guidance_scale = gr.Slider(minimum=1.0, maximum=15.0, step=0.1, label="Guidance Scale", value=7.5)
280
+ grid_vignette_loss_scale = gr.Slider(minimum=0.0, maximum=100.0, step=1.0, label="Vignette Loss Scale", value=0.0)
281
+ grid_concept_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, label="Concept Strength", value=0.5)
282
+
283
+ grid_generate_btn = gr.Button("Generate Style Grid")
284
+
285
+ with gr.Column():
286
+ output_grid = gr.Image(label="Style Grid", type="pil")
287
+
288
+ # Set up event handlers
289
+ generate_btn.click(
290
+ generate_image,
291
+ inputs=[prompt, seed, num_inference_steps, guidance_scale,
292
+ vignette_loss_scale, concept_style, concept_strength],
293
+ outputs=output_image
294
+ )
295
+
296
+ grid_generate_btn.click(
297
+ generate_style_grid,
298
+ inputs=[grid_prompt, grid_seed, grid_num_inference_steps,
299
+ grid_guidance_scale, grid_vignette_loss_scale, grid_concept_strength],
300
+ outputs=output_grid
301
+ )
302
+
303
+ return demo
304
+
305
+ # Launch the app
306
+ if __name__ == "__main__":
307
+ demo = create_demo()
308
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Core dependencies
2
+ torch>=1.7.0
3
+ torchvision>=0.8.0
4
+ diffusers>=0.12.0
5
+ transformers>=4.25.1
6
+ accelerate>=0.16.0
7
+ ftfy>=6.1.1
8
+ gradio>=3.20.0
9
+ numpy>=1.22.0
10
+ Pillow>=9.0.0
11
+ tqdm>=4.64.0
12
+ huggingface-hub>=0.12.0
13
+
14
+ # Optional dependencies for better performance
15
+ scipy>=1.9.0
16
+ matplotlib>=3.5.0
utils.py ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ Utility functions for the application
4
+ Author: Shilpaj Bhalerao
5
+ Date: Feb 26, 2025
6
+ """
7
+
8
+ import torch
9
+ import gc
10
+ from PIL import Image, ImageDraw, ImageFont
11
+ from diffusers import StableDiffusionPipeline
12
+ from transformers import CLIPTokenizer, CLIPTextModel
13
+ import os
14
+
15
+ def load_models(device="cuda"):
16
+ """
17
+ Load the necessary models for stable diffusion
18
+
19
+ Args:
20
+ device (str): Device to load models on ('cuda', 'mps', or 'cpu')
21
+
22
+ Returns:
23
+ tuple: (vae, tokenizer, text_encoder, unet, scheduler, pipe)
24
+ """
25
+ from diffusers import AutoencoderKL, LMSDiscreteScheduler, UNet2DConditionModel
26
+
27
+ # Set device
28
+ if device == "cuda" and not torch.cuda.is_available():
29
+ device = "mps" if torch.backends.mps.is_available() else "cpu"
30
+ if device == "mps":
31
+ os.environ['PYTORCH_ENABLE_MPS_FALLBACK'] = "1"
32
+
33
+ print(f"Loading models on {device}...")
34
+
35
+ # Load the autoencoder model which will be used to decode the latents into image space
36
+ vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae")
37
+
38
+ # Load the tokenizer and text encoder to tokenize and encode the text
39
+ tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
40
+ text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
41
+
42
+ # The UNet model for generating the latents
43
+ unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet")
44
+
45
+ # The noise scheduler
46
+ scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
47
+
48
+ # Load the full pipeline for concept loading
49
+ pipe = StableDiffusionPipeline.from_pretrained(
50
+ "runwayml/stable-diffusion-v1-5",
51
+ torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
52
+ )
53
+
54
+ # Move models to device
55
+ vae = vae.to(device)
56
+ text_encoder = text_encoder.to(device)
57
+ unet = unet.to(device)
58
+ pipe = pipe.to(device)
59
+
60
+ return vae, tokenizer, text_encoder, unet, scheduler, pipe
61
+
62
+ def clear_gpu_memory():
63
+ """Clear GPU memory cache"""
64
+ torch.cuda.empty_cache()
65
+ gc.collect()
66
+ torch.cuda.empty_cache()
67
+
68
+ def set_timesteps(scheduler, num_inference_steps):
69
+ """Set timesteps for the scheduler with MPS compatibility fix"""
70
+ scheduler.set_timesteps(num_inference_steps)
71
+ scheduler.timesteps = scheduler.timesteps.to(torch.float32) # minor fix to ensure MPS compatibility
72
+
73
+ def pil_to_latent(input_im, vae, device):
74
+ """
75
+ Convert the image to latents
76
+
77
+ Args:
78
+ input_im: Input PIL image
79
+ vae: VAE model
80
+ device: Device to run on
81
+
82
+ Returns:
83
+ Latents from VAE's encoder
84
+ """
85
+ from torchvision import transforms as tfms
86
+
87
+ # Single image -> single latent in a batch (so size 1, 4, 64, 64)
88
+ with torch.no_grad():
89
+ latent = vae.encode(tfms.ToTensor()(input_im).unsqueeze(0).to(device)*2-1) # Note scaling
90
+ return 0.18215 * latent.latent_dist.sample()
91
+
92
+ def latents_to_pil(latents, vae):
93
+ """
94
+ Convert the latents to images
95
+
96
+ Args:
97
+ latents: Latent tensor
98
+ vae: VAE model
99
+
100
+ Returns:
101
+ list: PIL images
102
+ """
103
+ # batch of latents -> list of images
104
+ latents = (1 / 0.18215) * latents
105
+ with torch.no_grad():
106
+ image = vae.decode(latents).sample
107
+ image = (image / 2 + 0.5).clamp(0, 1)
108
+ image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
109
+ images = (image * 255).round().astype("uint8")
110
+ pil_images = [Image.fromarray(image) for image in images]
111
+ return pil_images
112
+
113
+ def image_grid(imgs, rows, cols, labels=None):
114
+ """
115
+ Create a grid of images with optional labels.
116
+
117
+ Args:
118
+ imgs (list): List of PIL images to be arranged in a grid
119
+ rows (int): Number of rows in the grid
120
+ cols (int): Number of columns in the grid
121
+ labels (list, optional): List of label strings for each image
122
+
123
+ Returns:
124
+ PIL.Image: A single image with all input images arranged in a grid and labeled
125
+ """
126
+ assert len(imgs) == rows*cols, f"Number of images ({len(imgs)}) must equal rows*cols ({rows*cols})"
127
+
128
+ w, h = imgs[0].size
129
+ grid = Image.new('RGB', size=(cols*w, rows*h + 30 if labels else rows*h))
130
+
131
+ # Add padding at the bottom for labels if they exist
132
+ label_height = 30 if labels else 0
133
+
134
+ # Paste images
135
+ for i, img in enumerate(imgs):
136
+ grid.paste(img, box=(i%cols*w, i//cols*h))
137
+
138
+ # Add labels if provided
139
+ if labels:
140
+ assert len(labels) == len(imgs), "Number of labels must match number of images"
141
+ draw = ImageDraw.Draw(grid)
142
+
143
+ # Try to use a standard font, fall back to default if not available
144
+ try:
145
+ font = ImageFont.truetype("arial.ttf", 14)
146
+ except IOError:
147
+ font = ImageFont.load_default()
148
+
149
+ for i, label in enumerate(labels):
150
+ # Position text under the image
151
+ x = (i % cols) * w + 10
152
+ y = (i // cols + 1) * h - 5
153
+
154
+ # Draw black text with white outline for visibility
155
+ # White outline (draw text in each direction)
156
+ for offset in [(1,1), (-1,-1), (1,-1), (-1,1)]:
157
+ draw.text((x+offset[0], y+offset[1]), label, fill=(255,255,255), font=font)
158
+
159
+ # Main text (black)
160
+ draw.text((x, y), label, fill=(0,0,0), font=font)
161
+
162
+ return grid
163
+
164
+ def vignette_loss(images, vignette_strength=3.0, color_shift=[1.0, 0.5, 0.0]):
165
+ """
166
+ Creates a strong vignette effect (dark corners) and color shift.
167
+
168
+ Args:
169
+ images: Batch of images from VAE decoder (range 0-1)
170
+ vignette_strength: How strong the darkening effect is (higher = more dramatic)
171
+ color_shift: RGB color to shift the center toward [r, g, b]
172
+
173
+ Returns:
174
+ torch.Tensor: Loss value
175
+ """
176
+ batch_size, channels, height, width = images.shape
177
+
178
+ # Create coordinate grid centered at 0 with range [-1, 1]
179
+ y = torch.linspace(-1, 1, height).view(-1, 1).repeat(1, width).to(images.device)
180
+ x = torch.linspace(-1, 1, width).view(1, -1).repeat(height, 1).to(images.device)
181
+
182
+ # Calculate radius from center (normalized [0,1])
183
+ radius = torch.sqrt(x.pow(2) + y.pow(2)) / 1.414
184
+
185
+ # Vignette mask: dark at edges, bright in center
186
+ vignette = torch.exp(-vignette_strength * radius)
187
+
188
+ # Color shift target: shift center toward specified color
189
+ color_tensor = torch.tensor(color_shift, dtype=torch.float32).view(1, 3, 1, 1).to(images.device)
190
+ center_mask = 1.0 - radius.unsqueeze(0).unsqueeze(0)
191
+ center_mask = torch.pow(center_mask, 2.0) # Make the transition more dramatic
192
+
193
+ # Target image with vignette and color shift
194
+ target = images.clone()
195
+
196
+ # Apply vignette (multiply all channels by vignette mask)
197
+ for c in range(channels):
198
+ target[:, c] = target[:, c] * vignette
199
+
200
+ # Apply color shift in center
201
+ for c in range(channels):
202
+ # Shift toward target color more in center, less at edges
203
+ color_offset = (color_tensor[:, c] - images[:, c]) * center_mask
204
+ target[:, c] = target[:, c] + color_offset.squeeze(1)
205
+
206
+ # Calculate loss - how different current image is from our target
207
+ return torch.pow(images - target, 2).mean()
208
+
209
+ def get_concept_embedding(concept_text, tokenizer, text_encoder, device):
210
+ """
211
+ Generate CLIP embedding for a concept described in text
212
+
213
+ Args:
214
+ concept_text (str): Text description of the concept (e.g., "sketch painting")
215
+ tokenizer: CLIP tokenizer
216
+ text_encoder: CLIP text encoder
217
+ device: Device to run on
218
+
219
+ Returns:
220
+ torch.Tensor: CLIP embedding for the concept
221
+ """
222
+ # Tokenize the concept text
223
+ concept_tokens = tokenizer(
224
+ concept_text,
225
+ padding="max_length",
226
+ max_length=tokenizer.model_max_length,
227
+ truncation=True,
228
+ return_tensors="pt"
229
+ ).input_ids.to(device)
230
+
231
+ # Generate the embedding using the text encoder
232
+ with torch.no_grad():
233
+ concept_embedding = text_encoder(concept_tokens)[0]
234
+
235
+ return concept_embedding
236
+
237
+ def load_concept_library(pipe):
238
+ """
239
+ Load textual inversion concepts from the SD concept library
240
+
241
+ Args:
242
+ pipe: StableDiffusionPipeline
243
+
244
+ Returns:
245
+ dict: Dictionary of token to embedding mappings
246
+ """
247
+ # Load textual inversion embeddings
248
+ pipe.load_textual_inversion("sd-concepts-library/dreams")
249
+ pipe.load_textual_inversion("sd-concepts-library/midjourney-style")
250
+ pipe.load_textual_inversion("sd-concepts-library/moebius")
251
+ pipe.load_textual_inversion("sd-concepts-library/style-of-marc-allante")
252
+ pipe.load_textual_inversion("sd-concepts-library/wlop-style")
253
+
254
+ # Extract the embeddings from the pipeline
255
+ tokens = ['<meeg>', '<midjourney-style>', '<moebius>', '<Marc_Allante>', '<wlop-style>']
256
+ token_ids = pipe.tokenizer.convert_tokens_to_ids(tokens)
257
+ embeddings = pipe.text_encoder.get_input_embeddings().weight[token_ids].detach().cpu()
258
+
259
+ # Create a dictionary with the embeddings
260
+ learned_embeds = {}
261
+ for i, token in enumerate(tokens):
262
+ learned_embeds[token] = embeddings[i]
263
+
264
+ # Save the embeddings for future use
265
+ torch.save(learned_embeds, "learned_embeds.bin")
266
+ print(f"Saved embeddings for tokens: {', '.join(tokens)}")
267
+
268
+ return learned_embeds, tokens