minakshi.mathpal commited on
Commit
9c9f1a4
·
1 Parent(s): 72c399a

Initial commit with Stable Diffusion color guidance app

Browse files
Files changed (2) hide show
  1. app.py +162 -0
  2. custom_stable_diffusion.py +640 -0
app.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import random
4
+ import time
5
+ from PIL import Image
6
+ from utils import StableDiffusionConfig, StableDiffusionModels, generate_image
7
+
8
+ # Set page config
9
+ st.set_page_config(
10
+ page_title="Butterfly Color Diffusion",
11
+ page_icon="🦋",
12
+ layout="wide",
13
+ initial_sidebar_state="expanded"
14
+ )
15
+
16
+ # Initialize session state for models
17
+ if 'models' not in st.session_state:
18
+ st.session_state.models = None
19
+ st.session_state.config = None
20
+
21
+ # Function to load models
22
+ @st.cache_resource
23
+ def load_models():
24
+ config = StableDiffusionConfig(
25
+ height=512,
26
+ width=512,
27
+ num_inference_steps=30,
28
+ guidance_scale=7.5,
29
+ seed=random.randint(1, 1000000),
30
+ batch_size=1,
31
+ device=None,
32
+ max_length=77
33
+ )
34
+ models = StableDiffusionModels(config)
35
+ with st.spinner("Loading Stable Diffusion models... This may take a minute."):
36
+ models.load_models()
37
+ models.set_timesteps()
38
+ return models, config
39
+
40
+ # Title and description
41
+ st.title("🦋 Butterfly Color Diffusion")
42
+ st.markdown("""
43
+ Generate beautiful butterfly images with Stable Diffusion and explore color guidance technology
44
+ that enhances yellow tones. Compare standard image generation with color-guided generation to see
45
+ how targeted color loss functions can transform your results.
46
+ """)
47
+
48
+ # Sidebar with controls
49
+ st.sidebar.title("Generation Settings")
50
+
51
+ # Common settings
52
+ prompt = st.sidebar.text_area(
53
+ "Prompt",
54
+ value="A detailed photograph of a colorful monarch butterfly with orange and black wings, resting on a purple flower in a lush garden with sunlight",
55
+ height=100
56
+ )
57
+
58
+ steps = st.sidebar.slider("Inference Steps", min_value=10, max_value=100, value=30, step=1)
59
+ guidance_scale = st.sidebar.slider("Guidance Scale", min_value=1.0, max_value=15.0, value=7.5, step=0.1)
60
+ seed = st.sidebar.number_input("Seed (0 for random)", min_value=0, max_value=1000000, value=0, step=1)
61
+
62
+ # Color guidance settings
63
+ st.sidebar.markdown("---")
64
+ st.sidebar.subheader("Color Guidance Settings")
65
+ yellow_strength = st.sidebar.slider("Yellow Strength", min_value=0, max_value=500, value=200, step=10)
66
+ guidance_interval = st.sidebar.slider("Guidance Interval", min_value=1, max_value=10, value=5, step=1)
67
+
68
+ # Create two columns for the images
69
+ col1, col2 = st.columns(2)
70
+
71
+ with col1:
72
+ st.subheader("Standard Stable Diffusion")
73
+ standard_button = st.button("Generate Standard Image", use_container_width=True)
74
+
75
+ with col2:
76
+ st.subheader("Color-Guided Stable Diffusion")
77
+ color_button = st.button("Generate Color-Guided Image", use_container_width=True)
78
+
79
+ # Load models when needed
80
+ if standard_button or color_button:
81
+ if st.session_state.models is None:
82
+ st.session_state.models, st.session_state.config = load_models()
83
+
84
+ # Update config with current settings
85
+ st.session_state.config.num_inference_steps = steps
86
+ st.session_state.config.guidance_scale = guidance_scale
87
+
88
+ # Set seed
89
+ if seed == 0:
90
+ seed = random.randint(1, 1000000)
91
+ st.session_state.config.seed = seed
92
+ st.sidebar.write(f"Using seed: {seed}")
93
+
94
+ # Generate standard image
95
+ if standard_button:
96
+ with col1:
97
+ with st.spinner("Generating standard image..."):
98
+ progress_bar = st.progress(0)
99
+ start_time = time.time()
100
+
101
+ image = generate_image(
102
+ models=st.session_state.models,
103
+ config=st.session_state.config,
104
+ prompt=prompt,
105
+ blue_loss_scale=0,
106
+ yellow_loss_scale=0,
107
+ progress_bar=progress_bar
108
+ )
109
+
110
+ end_time = time.time()
111
+ st.image(image, caption="Standard Stable Diffusion", use_column_width=True)
112
+ st.write(f"Generation time: {end_time - start_time:.2f} seconds")
113
+
114
+ # Generate color-guided image
115
+ if color_button:
116
+ with col2:
117
+ with st.spinner("Generating color-guided image..."):
118
+ progress_bar = st.progress(0)
119
+ start_time = time.time()
120
+
121
+ image = generate_image(
122
+ models=st.session_state.models,
123
+ config=st.session_state.config,
124
+ prompt=prompt,
125
+ blue_loss_scale=0,
126
+ yellow_loss_scale=yellow_strength,
127
+ guidance_interval=guidance_interval,
128
+ progress_bar=progress_bar
129
+ )
130
+
131
+ end_time = time.time()
132
+ st.image(image, caption="Color-Guided Stable Diffusion", use_column_width=True)
133
+ st.write(f"Generation time: {end_time - start_time:.2f} seconds")
134
+
135
+ # Explanation section
136
+ st.markdown("---")
137
+ st.header("How It Works")
138
+ st.markdown("""
139
+ ### Standard Stable Diffusion
140
+ The standard approach uses text-to-image generation with classifier-free guidance to create images based on your prompt.
141
+
142
+ ### Color-Guided Stable Diffusion
143
+ The color-guided approach adds a custom loss function during the diffusion process that encourages:
144
+ - Higher values in the red and green channels
145
+ - Lower values in the blue channel
146
+
147
+ This combination creates a yellow tone in the final image. The strength parameter controls how strongly this color guidance affects the generation process.
148
+
149
+ ### Technical Details
150
+ During each step of the diffusion process, we:
151
+ 1. Calculate the predicted image at that step
152
+ 2. Measure how far it is from our desired color profile
153
+ 3. Calculate the gradient of this loss with respect to the latents
154
+ 4. Adjust the latents to reduce the loss
155
+ 5. Continue with the standard diffusion process
156
+
157
+ This approach allows for targeted control of specific visual attributes while maintaining the overall quality and coherence of the generated image.
158
+ """)
159
+
160
+ # Footer
161
+ st.markdown("---")
162
+ st.markdown("Created with ❤️ using Stable Diffusion and Streamlit")
custom_stable_diffusion.py ADDED
@@ -0,0 +1,640 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ from base64 import b64encode
4
+ from pathlib import Path
5
+ from typing import List, Dict, Tuple, Optional, Union
6
+
7
+ import numpy as np
8
+ import torch
9
+ from diffusers import AutoencoderKL, LMSDiscreteScheduler, UNet2DConditionModel
10
+ from huggingface_hub import notebook_login, hf_hub_download
11
+ from IPython.display import HTML
12
+ from matplotlib import pyplot as plt
13
+ from PIL import Image
14
+ from torch import autocast
15
+ from torchvision import transforms as tfms
16
+ from tqdm.auto import tqdm
17
+ from transformers import CLIPTextModel, CLIPTokenizer, logging
18
+
19
+ class StableDiffusionConfig:
20
+ """
21
+ Configuration class for stable Diffusion parameters
22
+
23
+ """
24
+ def __init__(self, height: int=512,
25
+ width:int= 512,
26
+ num_inference_steps:int= 50,
27
+ guidance_scale:int=7.5,
28
+ seed:int=32,
29
+ batch_size:int=1,
30
+ device:str=None,
31
+ max_length:int=77):
32
+ self.height = height
33
+ self.width = width
34
+ self.num_inference_steps = num_inference_steps
35
+ self.guidance_scale = guidance_scale
36
+ self.seed = seed
37
+ self.batch_size = batch_size
38
+ self.max_length=max_length
39
+
40
+ # set device
41
+ if device is None:
42
+ self.device="cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
43
+ if "mps" ==self.device:
44
+ os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "TRUE"
45
+
46
+ else:
47
+ self.device=device
48
+
49
+ self.generator= torch.manual_seed(self.seed)
50
+
51
+ class StableDiffusionModels:
52
+ """
53
+ class to manage Stable Diffusion model components.
54
+ """
55
+ def __init__(self, config:StableDiffusionConfig):
56
+ self.config=config
57
+ self.vae= None
58
+ self.tokenizer= None
59
+ self.text_encoder= None
60
+ self.unet= None
61
+ self.scheduler= None
62
+
63
+ def load_models(self, model_version:str="CompVis/stable-diffusion-v1-4"):
64
+ """
65
+ Load all the required models for stable diffusion.
66
+ """
67
+ # Load VAE
68
+ self.vae = AutoencoderKL.from_pretrained(model_version, subfolder="vae")
69
+
70
+ # Load tokenizer and text encoder - IMPORTANT: Use the correct model
71
+ self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
72
+ self.text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
73
+
74
+ # Load UNet
75
+ self.unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet")
76
+
77
+ # Load scheduler
78
+ self.scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
79
+
80
+ self.vae = self.vae.to(self.config.device)
81
+ self.text_encoder = self.text_encoder.to(self.config.device)
82
+ self.unet = self.unet.to(self.config.device)
83
+ print(self.config.device)
84
+ return self
85
+
86
+ def set_timesteps(self, num_inference_steps:int=None):
87
+ """
88
+ Set the number of inference steps for the scheduler.
89
+ """
90
+ if num_inference_steps is None:
91
+ num_inference_steps= self.config.num_inference_steps
92
+ self.scheduler.set_timesteps(num_inference_steps)
93
+
94
+ # fix to ensure MPS compatibility
95
+ self.scheduler.timesteps= self.scheduler.timesteps.to(torch.float32)
96
+ return self
97
+
98
+ class ImageProcessor:
99
+ """Class to handle image processing operations."""
100
+
101
+ def __init__(self, models: StableDiffusionModels, config: StableDiffusionConfig):
102
+ self.models = models
103
+ self.config = config
104
+
105
+ def pil_to_latent(self, input_im: Image.Image) -> torch.Tensor:
106
+ """Convert a PIL image to latent space."""
107
+ with torch.no_grad():
108
+ # Scale to [-1, 1] and convert to tensor
109
+ image_tensor = tfms.ToTensor()(input_im).unsqueeze(0).to(self.config.device) * 2 - 1
110
+ # Encode to latent
111
+ latent = self.models.vae.encode(image_tensor)
112
+ return 0.18215 * latent.latent_dist.sample()
113
+
114
+ def latents_to_pil(self, latents: torch.Tensor) -> List[Image.Image]:
115
+ """Convert latents to PIL images."""
116
+ # Scale latents
117
+ latents = (1 / 0.18215) * latents
118
+
119
+ with torch.no_grad():
120
+ # Decode latents
121
+ image = self.models.vae.decode(latents).sample
122
+
123
+ # Process to PIL images
124
+ image = (image / 2 + 0.5).clamp(0, 1)
125
+ image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
126
+ images = (image * 255).round().astype("uint8")
127
+ pil_images = [Image.fromarray(image) for image in images]
128
+
129
+ return pil_images
130
+
131
+ class TextEmbeddingProcessor:
132
+ """Class to process and modify text embeddings."""
133
+ def __init__(self, models:StableDiffusionModels, config:StableDiffusionConfig,imageprocessor:ImageProcessor,prompt:str):
134
+ self.models=models
135
+ self.config=config
136
+ self.token_emb_layer= models.text_encoder.text_model.embeddings.token_embedding
137
+ self.pos_emb_layer= models.text_encoder.text_model.embeddings.position_embedding
138
+ self.position_ids= models.text_encoder.text_model.embeddings.position_ids[:,:77]
139
+ self.position_embeddings= self.pos_emb_layer(self.position_ids)
140
+ self.imageprocessor = imageprocessor
141
+ self.prompt=prompt
142
+
143
+ def load_embedding(self, concept_name:str) -> Tuple[str, torch.Tensor]:
144
+ """ Downlaod a textual inversion concept from hugging face"""
145
+ try:
146
+ # Download the file
147
+ file_path= hf_hub_download(
148
+ repo_id=f"sd-concepts-library/{concept_name}",
149
+ filename="learned_embeds.bin",
150
+ repo_type="model"
151
+ )
152
+ # load the embedding
153
+ embedding= torch.load(file_path)
154
+ return embedding
155
+ except Exception as e:
156
+ print(f"Error downloading concept {concept_name}: {e}")
157
+ return None, None
158
+
159
+ def tokenize_text(self, prompt=None) -> Tuple[torch.Tensor, int]:
160
+ """Tokenize text input."""
161
+ if prompt is None:
162
+ prompt = self.prompt
163
+
164
+ if isinstance(prompt, str):
165
+ text_input = self.models.tokenizer(
166
+ prompt,
167
+ padding="max_length",
168
+ truncation=True,
169
+ max_length=self.models.tokenizer.model_max_length,
170
+ return_tensors="pt"
171
+ )
172
+ position = text_input["input_ids"][0][4].item() # Get the position of the concept token
173
+
174
+ input_ids = text_input.input_ids.to(self.config.device)
175
+ return input_ids, position
176
+
177
+ def get_output_embeds(self,input_embeddings):
178
+ # CLIP's text model uses causal mask, so we prepare it here:
179
+ bsz, seq_len = input_embeddings.shape[:2]
180
+ causal_attention_mask = self.models.text_encoder.text_model._build_causal_attention_mask(bsz, seq_len, dtype=input_embeddings.dtype)
181
+
182
+ # Getting the output embeddings involves calling the model with passing output_hidden_states=True
183
+ # so that it doesn't just return the pooled final predictions:
184
+ encoder_outputs = self.models.text_encoder.text_model.encoder(
185
+ inputs_embeds=input_embeddings,
186
+ attention_mask=None, # We aren't using an attention mask so that can be None
187
+ causal_attention_mask=causal_attention_mask.to(self.config.device),
188
+ output_attentions=None,
189
+ output_hidden_states=True, # We want the output embs not the final output
190
+ return_dict=None,
191
+ )
192
+
193
+ # We're interested in the output hidden state only
194
+ output = encoder_outputs[0]
195
+
196
+ # There is a final layer norm we need to pass these through
197
+ output = self.models.text_encoder.text_model.final_layer_norm(output)
198
+
199
+ # And now they're ready!
200
+ return output
201
+
202
+ def generate_with_embs(self,text_embeddings,output_path=None, return_image=False):
203
+ height = self.config.height # default height of Stable Diffusion
204
+ width = self.config.width # default width of Stable Diffusion
205
+ num_inference_steps = self.config.num_inference_steps # Number of denoising steps
206
+ guidance_scale = self.config.guidance_scale # Scale for classifier-free guidance
207
+ generator = torch.manual_seed(self.config.seed) # Seed generator to create the inital latent noise
208
+ batch_size = 1
209
+
210
+ text_input= self.models.tokenizer(self.prompt, padding="max_length", truncation=True, max_length=self.models.tokenizer.model_max_length, return_tensors="pt")
211
+ max_length = text_input.input_ids.shape[-1]
212
+ uncond_input = self.models.tokenizer(
213
+ [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
214
+ )
215
+ with torch.no_grad():
216
+ uncond_embeddings = self.models.text_encoder(uncond_input.input_ids.to(self.config.device))[0]
217
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
218
+
219
+ # Prep Scheduler
220
+ self.models.set_timesteps(num_inference_steps)
221
+
222
+ # Prep latents
223
+ latents = torch.randn(
224
+ (batch_size, self.models.unet.config.in_channels, height // 8, width // 8),
225
+ generator=generator,
226
+ )
227
+ latents = latents.to(self.config.device)
228
+ latents = latents * self.models.scheduler.init_noise_sigma
229
+
230
+ # Loop
231
+ for i, t in tqdm(enumerate(self.models.scheduler.timesteps), total=len(self.models.scheduler.timesteps)):
232
+ # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
233
+ latent_model_input = torch.cat([latents] * 2)
234
+ sigma = self.models.scheduler.sigmas[i]
235
+ latent_model_input = self.models.scheduler.scale_model_input(latent_model_input, t)
236
+
237
+ # predict the noise residual
238
+ with torch.no_grad():
239
+ noise_pred = self.models.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"]
240
+
241
+ # perform guidance
242
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
243
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
244
+
245
+ # compute the previous noisy sample x_t -> x_t-1
246
+ latents = self.models.scheduler.step(noise_pred, t, latents).prev_sample
247
+
248
+ if output_path is not None:
249
+ # Ensure the output directory exists
250
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
251
+
252
+ # Make sure the output path has a file extension
253
+ if not os.path.splitext(output_path)[1]:
254
+ output_path = output_path + ".png"
255
+
256
+ self.imageprocessor.latents_to_pil(latents)[0].save(output_path)
257
+
258
+ if return_image:
259
+ return self.imageprocessor.latents_to_pil(latents)[0]
260
+
261
+ def prepare_embeddings_with_concepts(self, prompt, concept_name:str=None, output_path:str=None) -> None:
262
+ """Encode text input into embeddings and generate image with concept."""
263
+ input_ids, position = self.tokenize_text(self.prompt)
264
+ token_embeddings = self.token_emb_layer(input_ids)
265
+ embeddings = self.load_embedding(concept_name)
266
+
267
+ if embeddings is not None:
268
+ # embeddings = embeddings.to(self.config.device)
269
+ replacement_token_embedding = embeddings[next(iter(embeddings.keys()))].to(self.config.device)
270
+
271
+ # Get the position indices where the token appears
272
+ position_indices = torch.where(input_ids[0] == position)[0]
273
+
274
+ if len(position_indices) > 0:
275
+ # Get the shape of a single token embedding
276
+ single_token_shape = token_embeddings[0, position_indices[0]].shape
277
+
278
+ # Replace the token embedding at the specified position
279
+ if replacement_token_embedding.shape != single_token_shape:
280
+ print("Warning: Embedding dimensions don't match. This might not be the right embedding.")
281
+
282
+ # Reshape if needed
283
+ if replacement_token_embedding.shape[0] != single_token_shape[0]:
284
+ print(f"Reshaping embedding from {replacement_token_embedding.shape} to {single_token_shape}")
285
+ replacement_token_embedding = replacement_token_embedding[:single_token_shape[0]]
286
+
287
+ # Correctly index and replace the token embedding
288
+ for idx in position_indices:
289
+ token_embeddings[0, idx] = replacement_token_embedding.to(self.config.device)
290
+
291
+ # Combine with pos embs
292
+ input_embeddings = token_embeddings + self.position_embeddings
293
+ modified_output_embeddings = self.get_output_embeds(input_embeddings)
294
+ self.generate_with_embs(modified_output_embeddings, output_path=output_path)
295
+ else:
296
+ print(f"Token position {position} not found in input_ids")
297
+ else:
298
+ print(f"Failed to load concept: {concept_name}")
299
+
300
+ def generate_with_multiple_concepts(models, config, image_processor, prompt,concepts, output_dir="generated_images"):
301
+ """
302
+ Generate images using multiple concepts and save them in separate folders
303
+
304
+ """
305
+
306
+ os.makedirs(output_dir, exist_ok=True)
307
+
308
+ for concept in concepts:
309
+ concepts_dir= os.path.join(output_dir,concept)
310
+ os.makedirs(concepts_dir,exist_ok=True)
311
+
312
+ output_path = os.path.join(concepts_dir,f"{concept}.png")
313
+
314
+ text_processor = TextEmbeddingProcessor(models, config, image_processor, prompt)
315
+
316
+ text_processor.prepare_embeddings_with_concepts(prompt, concept_name= concept, output_path=output_path)
317
+ print(f"Saved iamge to {output_path}")
318
+
319
+ def channel_loss(images, channel_idx=2, target_value=0.9):
320
+ """
321
+ Calculate the mean absolute error between a specific color channel and a target value.
322
+
323
+ Args:
324
+ images (torch.Tensor): Batch of images with shape [batch_size, channels, height, width]
325
+ channel_idx (int): Index of the color channel to target (0=R, 1=G, 2=B)
326
+ target_value (float): Target value for the channel (0-1)
327
+
328
+ Returns:
329
+ torch.Tensor: Loss value
330
+ """
331
+ return torch.abs(images[:, channel_idx] - target_value).mean()
332
+
333
+ def blue_loss(images, target=0.9):
334
+ """Make images more blue by increasing the blue channel"""
335
+ return channel_loss(images, channel_idx=2, target_value=target)
336
+
337
+ def yellow_loss(images):
338
+ """
339
+ Make images more yellow by increasing red and green channels and decreasing blue
340
+ Yellow = high R + high G + low B
341
+ """
342
+ red_high = channel_loss(images, channel_idx=0, target_value=0.9)
343
+ green_high = channel_loss(images, channel_idx=1, target_value=0.9)
344
+ blue_low = channel_loss(images, channel_idx=2, target_value=0.1)
345
+ return (red_high + green_high + blue_low) / 3
346
+
347
+ def generate_with_concept_and_color(
348
+ models,
349
+ config,
350
+ image_processor,
351
+ prompt,
352
+ concept_name,
353
+ output_dir="concept_images",
354
+ blue_loss_scale=0,
355
+ yellow_loss_scale=400,
356
+ guidance_interval=3 # Changed from 5 to 3 to apply more frequently
357
+ ):
358
+ """
359
+ Generate images using a concept and color guidance, then save to specified directory
360
+ """
361
+ # Create output directory
362
+ concept_dir = os.path.join(output_dir, f"{concept_name}")
363
+ os.makedirs(concept_dir, exist_ok=True)
364
+
365
+ # Define output path with color info in filename
366
+ color_info = ""
367
+ if blue_loss_scale > 0:
368
+ color_info += f"_blue{blue_loss_scale}"
369
+ if yellow_loss_scale > 0:
370
+ color_info += f"_yellow{yellow_loss_scale}"
371
+
372
+ output_path = os.path.join(concept_dir, f"{concept_name}{color_info}.png")
373
+
374
+ # Create text processor
375
+ text_processor = TextEmbeddingProcessor(models, config, image_processor, prompt)
376
+
377
+ # Load concept embedding
378
+ embeddings = text_processor.load_embedding(concept_name)
379
+
380
+ if embeddings is None:
381
+ print(f"Failed to load concept: {concept_name}")
382
+ return
383
+
384
+ # Process text with concept
385
+ input_ids, position = text_processor.tokenize_text(prompt)
386
+ token_embeddings = text_processor.token_emb_layer(input_ids)
387
+
388
+ # Handle different embedding formats
389
+ if isinstance(embeddings, dict):
390
+ replacement_token_embedding = embeddings[next(iter(embeddings.keys()))].to(config.device)
391
+ elif isinstance(embeddings, tuple) and len(embeddings) >= 2:
392
+ replacement_token_embedding = embeddings[1].to(config.device)
393
+ elif isinstance(embeddings, torch.Tensor):
394
+ replacement_token_embedding = embeddings.to(config.device)
395
+ else:
396
+ print(f"Unsupported embedding format for concept: {concept_name}")
397
+ return
398
+
399
+ # Get the position indices where the token appears
400
+ position_indices = torch.where(input_ids[0] == position)[0]
401
+
402
+ if len(position_indices) == 0:
403
+ print(f"Token position {position} not found in input_ids")
404
+ return
405
+
406
+ # Get the shape of a single token embedding
407
+ single_token_shape = token_embeddings[0, position_indices[0]].shape
408
+
409
+ # Reshape if needed
410
+ if replacement_token_embedding.shape != single_token_shape:
411
+ print("Warning: Embedding dimensions don't match. This might not be the right embedding.")
412
+ if replacement_token_embedding.shape[0] != single_token_shape[0]:
413
+ print(f"Reshaping embedding from {replacement_token_embedding.shape} to {single_token_shape}")
414
+ replacement_token_embedding = replacement_token_embedding[:single_token_shape[0]]
415
+
416
+ # Replace the token embedding at the specified position
417
+ for idx in position_indices:
418
+ token_embeddings[0, idx] = replacement_token_embedding.to(config.device)
419
+
420
+ # Combine with position embeddings
421
+ input_embeddings = token_embeddings + text_processor.position_embeddings
422
+ text_embeddings = text_processor.get_output_embeds(input_embeddings)
423
+
424
+ # Get uncond embeddings
425
+ uncond_input = models.tokenizer(
426
+ [""], padding="max_length", max_length=77, return_tensors="pt"
427
+ )
428
+ with torch.no_grad():
429
+ uncond_embeddings = models.text_encoder(uncond_input.input_ids.to(config.device))[0]
430
+
431
+ # Concatenate for classifier-free guidance
432
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
433
+
434
+ # Set timesteps
435
+ models.set_timesteps(config.num_inference_steps)
436
+
437
+ # Prepare latents
438
+ height = config.height
439
+ width = config.width
440
+ batch_size = config.batch_size
441
+
442
+ # Create a generator on the same device as where the tensor will be created
443
+ if "cuda" in str(config.device):
444
+ generator = torch.Generator(device="cuda").manual_seed(config.seed)
445
+ else:
446
+ generator = torch.manual_seed(config.seed)
447
+
448
+ latents = torch.randn(
449
+ (batch_size, models.unet.config.in_channels, height // 8, width // 8),
450
+ generator=generator,
451
+ device=config.device
452
+ )
453
+ latents = latents * models.scheduler.init_noise_sigma
454
+
455
+ # Define color loss functions
456
+ def channel_loss(images, channel_idx=2, target_value=0.9):
457
+ return torch.abs(images[:, channel_idx] - target_value).mean()
458
+
459
+ def blue_loss(images, target=0.9):
460
+ return channel_loss(images, channel_idx=2, target_value=target)
461
+
462
+ def yellow_loss(images, red_target=0.95, green_target=0.95, blue_target=0.05):
463
+ """
464
+ Make images more yellow by increasing red and green channels and decreasing blue
465
+ Yellow = high R + high G + low B
466
+
467
+ Args:
468
+ images: The image tensor
469
+ red_target: Target value for red channel (higher = more red)
470
+ green_target: Target value for green channel (higher = more green)
471
+ blue_target: Target value for blue channel (lower = less blue)
472
+ """
473
+ red_high = torch.abs(images[:, 0] - red_target).mean()
474
+ green_high = torch.abs(images[:, 1] - green_target).mean()
475
+ blue_low = torch.abs(images[:, 2] - blue_target).mean()
476
+
477
+ # Weight the blue channel more heavily to really reduce blue
478
+ return (red_high + green_high + blue_low * 2) / 4
479
+
480
+ # Denoising loop
481
+ for i, t in tqdm(enumerate(models.scheduler.timesteps), total=len(models.scheduler.timesteps)):
482
+ # Expand latents for classifier-free guidance
483
+ latent_model_input = torch.cat([latents] * 2)
484
+ latent_model_input = models.scheduler.scale_model_input(latent_model_input, t)
485
+
486
+ # Predict noise
487
+ with torch.no_grad():
488
+ noise_pred = models.unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"]
489
+
490
+ # Perform guidance
491
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
492
+ noise_pred = noise_pred_uncond + config.guidance_scale * (noise_pred_text - noise_pred_uncond)
493
+
494
+ # Apply color guidance
495
+ if (blue_loss_scale > 0 or yellow_loss_scale > 0) and i % guidance_interval == 0:
496
+ # Get the current sigma value
497
+ sigma = models.scheduler.sigmas[i]
498
+
499
+ # Requires grad on the latents
500
+ latents = latents.detach().requires_grad_()
501
+
502
+ # Get the predicted x0 directly (like in the example code)
503
+ latents_x0 = latents - sigma * noise_pred
504
+
505
+ # Decode to image space
506
+ denoised_images = models.vae.decode((1 / 0.18215) * latents_x0).sample / 2 + 0.5
507
+
508
+ # Calculate combined loss
509
+ loss = 0
510
+ if blue_loss_scale > 0:
511
+ blue_loss_value = blue_loss(denoised_images) * blue_loss_scale
512
+ loss += blue_loss_value
513
+
514
+ if yellow_loss_scale > 0:
515
+ yellow_loss_value = yellow_loss(denoised_images) * yellow_loss_scale
516
+ loss += yellow_loss_value
517
+
518
+ # Print loss occasionally
519
+ if i % 10 == 0:
520
+ print(f"Step {i}, Loss: {loss.item()}")
521
+ if blue_loss_scale > 0 and yellow_loss_scale > 0:
522
+ print(f" Blue loss: {blue_loss_value.item()}, Yellow loss: {yellow_loss_value.item()}")
523
+
524
+ # Get gradient
525
+ cond_grad = torch.autograd.grad(loss, latents)[0]
526
+
527
+ # Modify the latents based on this gradient (using sigma squared like in the example)
528
+ latents = latents.detach() - cond_grad * sigma**2
529
+
530
+ # Step with scheduler
531
+ latents = models.scheduler.step(noise_pred, t, latents).prev_sample
532
+
533
+ # Decode the final image
534
+ with torch.no_grad():
535
+ decoded = models.vae.decode((1 / 0.18215) * latents).sample
536
+
537
+ # Convert to PIL image
538
+ image = (decoded / 2 + 0.5).clamp(0, 1)
539
+ image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
540
+ image = (image * 255).round().astype("uint8")[0]
541
+ pil_image = Image.fromarray(image)
542
+
543
+ # Save the image
544
+ os.makedirs(os.path.dirname(output_path), exist_ok=True)
545
+ pil_image.save(output_path)
546
+ print(f"Saved image to {output_path}")
547
+
548
+ return pil_image
549
+
550
+ # Function to generate multiple concepts with color guidance
551
+ def generate_with_multiple_concepts_and_color(
552
+ models,
553
+ config,
554
+ image_processor,
555
+ prompt,
556
+ concepts,
557
+ output_dir="concept_images",
558
+ blue_loss_scale=0,
559
+ yellow_loss_scale=0
560
+ ):
561
+ """
562
+ Generate images using multiple concepts and color guidance
563
+ """
564
+ os.makedirs(output_dir, exist_ok=True)
565
+
566
+ for concept in concepts:
567
+ print(f"Generating image for concept: {concept} and color guidance")
568
+ generate_with_concept_and_color(
569
+ models=models,
570
+ config=config,
571
+ image_processor=image_processor,
572
+ prompt=prompt,
573
+ concept_name=concept,
574
+ output_dir=output_dir,
575
+ blue_loss_scale=blue_loss_scale,
576
+ yellow_loss_scale=yellow_loss_scale
577
+ )
578
+
579
+
580
+ # Example usage
581
+ if __name__ == "__main__":
582
+ # Initialize configuration
583
+ config = StableDiffusionConfig(
584
+ height=512,
585
+ width=512,
586
+ num_inference_steps=30,
587
+ guidance_scale=7.5,
588
+ seed=42,
589
+ batch_size=1,
590
+ device=None,
591
+ max_length=77
592
+ )
593
+ if config.device is None:
594
+ device="cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
595
+ if "mps" ==config.device:
596
+ os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "TRUE"
597
+
598
+ else:
599
+ config.device=device
600
+ # Load models
601
+ models = StableDiffusionModels(config)
602
+ models.load_models()
603
+ models.set_timesteps()
604
+
605
+ # Create image processor
606
+ image_processor = ImageProcessor(models, config)
607
+
608
+ # Define base prompt and concepts
609
+ base_prompt = "A detailed photograph of a colorful monarch butterfly with orange and black wings, resting on a purple flower in a lush garden with sunlight"
610
+
611
+ # List of concepts to use (these should be available in the Hugging Face sd-concepts-library)
612
+ concepts = [
613
+ "concept-art-2-1",
614
+ "canna-lily-flowers102",
615
+ "arcane-style-jv",
616
+ "seismic-image",
617
+ "azalea-flowers102"
618
+ ]
619
+
620
+ # Generate images for all concepts
621
+ generate_with_multiple_concepts(
622
+ models=models,
623
+ config=config,
624
+ image_processor=image_processor,
625
+ prompt=base_prompt,
626
+ concepts=concepts,
627
+ output_dir="concept_images"
628
+ )
629
+
630
+ generate_with_multiple_concepts_and_color(
631
+ models=models,
632
+ config=config,
633
+ image_processor=image_processor,
634
+ prompt=base_prompt,
635
+ concepts=concepts,
636
+ output_dir="concept_images",
637
+ blue_loss_scale=0, # Set to 0 to disable blue guidance
638
+ yellow_loss_scale=200 # Set to 0 to disable yellow guidance
639
+ )
640
+