minakshi.mathpal commited on
Commit
5ac005e
·
1 Parent(s): 7f152f4

changes made to all the files

Browse files
Files changed (2) hide show
  1. app.py +19 -2
  2. custom_stable_diffusion.py +77 -38
app.py CHANGED
@@ -5,6 +5,9 @@ import time
5
  import os
6
  from PIL import Image
7
  from custom_stable_diffusion import StableDiffusionConfig, StableDiffusionModels,ImageProcessor, generate_with_multiple_concepts,generate_with_multiple_concepts_and_color
 
 
 
8
  st.set_page_config(
9
  page_title="Butterfly Color Diffusion",
10
  page_icon="🦋",
@@ -17,9 +20,19 @@ if 'models' not in st.session_state:
17
  st.session_state.models = None
18
  st.session_state.config = None
19
 
 
 
 
20
  # Function to load models
21
  @st.cache_resource
22
  def load_models():
 
 
 
 
 
 
 
23
  config = StableDiffusionConfig(
24
  height=512,
25
  width=512,
@@ -35,6 +48,10 @@ def load_models():
35
  with st.spinner("Loading Stable Diffusion models... This may take a minute."):
36
  models.load_models()
37
  models.set_timesteps()
 
 
 
 
38
  return models, config, image_processor
39
 
40
  # Title and description
@@ -160,7 +177,7 @@ if standard_button:
160
  caption = f"Standard Stable Diffusion"
161
  if concept_name:
162
  caption += f" with {concept_name} concept"
163
- st.image(image, caption=caption, use_container_width=True)
164
  st.write(f"Generation time: {end_time - start_time:.2f} seconds")
165
 
166
  # Generate color-guided image
@@ -210,7 +227,7 @@ if color_button:
210
  caption = f"Color-Guided Stable Diffusion"
211
  if concept_name:
212
  caption += f" with {concept_name} concept"
213
- st.image(image, caption=caption, use_container_width=True)
214
  st.write(f"Generation time: {end_time - start_time:.2f} seconds")
215
 
216
  # Explanation section
 
5
  import os
6
  from PIL import Image
7
  from custom_stable_diffusion import StableDiffusionConfig, StableDiffusionModels,ImageProcessor, generate_with_multiple_concepts,generate_with_multiple_concepts_and_color
8
+ import sys
9
+ import transformers
10
+ import diffusers
11
  st.set_page_config(
12
  page_title="Butterfly Color Diffusion",
13
  page_icon="🦋",
 
20
  st.session_state.models = None
21
  st.session_state.config = None
22
 
23
+ # Add this near the top of your app.py
24
+ debug_mode = st.sidebar.checkbox("Debug Mode", value=True)
25
+
26
  # Function to load models
27
  @st.cache_resource
28
  def load_models():
29
+ if debug_mode:
30
+ st.write("Debug: Starting model loading")
31
+ st.write(f"Debug: Python version: {sys.version}")
32
+ st.write(f"Debug: Torch version: {torch.__version__}")
33
+ st.write(f"Debug: Transformers version: {transformers.__version__}")
34
+ st.write(f"Debug: Diffusers version: {diffusers.__version__}")
35
+
36
  config = StableDiffusionConfig(
37
  height=512,
38
  width=512,
 
48
  with st.spinner("Loading Stable Diffusion models... This may take a minute."):
49
  models.load_models()
50
  models.set_timesteps()
51
+
52
+ if debug_mode:
53
+ st.write(f"Debug: Models loaded successfully. Device: {config.device}")
54
+
55
  return models, config, image_processor
56
 
57
  # Title and description
 
177
  caption = f"Standard Stable Diffusion"
178
  if concept_name:
179
  caption += f" with {concept_name} concept"
180
+ st.image(image, caption=caption, use_column_width=True)
181
  st.write(f"Generation time: {end_time - start_time:.2f} seconds")
182
 
183
  # Generate color-guided image
 
227
  caption = f"Color-Guided Stable Diffusion"
228
  if concept_name:
229
  caption += f" with {concept_name} concept"
230
+ st.image(image, caption=caption, use_column_width=True)
231
  st.write(f"Generation time: {end_time - start_time:.2f} seconds")
232
 
233
  # Explanation section
custom_stable_diffusion.py CHANGED
@@ -60,27 +60,65 @@ class StableDiffusionModels:
60
  self.scheduler= None
61
 
62
  def load_models(self, model_version:str="CompVis/stable-diffusion-v1-4"):
63
- """
64
- Load all the required models for stable diffusion.
65
- """
66
- # Load VAE
67
- self.vae = AutoencoderKL.from_pretrained(model_version, subfolder="vae")
68
-
69
- # Load tokenizer and text encoder - IMPORTANT: Use the correct model
70
- self.tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
71
- self.text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
- # Load UNet
74
- self.unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet")
75
-
76
- # Load scheduler
77
- self.scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
78
-
79
- self.vae = self.vae.to(self.config.device)
80
- self.text_encoder = self.text_encoder.to(self.config.device)
81
- self.unet = self.unet.to(self.config.device)
82
- print(self.config.device)
83
- return self
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
  def set_timesteps(self, num_inference_steps:int=None):
86
  """
@@ -296,24 +334,25 @@ class TextEmbeddingProcessor:
296
  else:
297
  print(f"Failed to load concept: {concept_name}")
298
 
299
- def generate_with_multiple_concepts(models, config, image_processor, prompt,concepts, output_dir="generated_images"):
300
- """
301
- Generate images using multiple concepts and save them in separate folders
302
-
303
- """
304
-
305
- os.makedirs(output_dir, exist_ok=True)
306
-
307
- for concept in concepts:
308
- concepts_dir= os.path.join(output_dir,concept)
309
- os.makedirs(concepts_dir,exist_ok=True)
310
-
311
- output_path = os.path.join(concepts_dir,f"{concept}.png")
312
-
313
- text_processor = TextEmbeddingProcessor(models, config, image_processor, prompt)
314
-
315
- text_processor.prepare_embeddings_with_concepts(prompt, concept_name= concept, output_path=output_path)
316
- print(f"Saved iamge to {output_path}")
 
317
 
318
  def channel_loss(images, channel_idx=2, target_value=0.9):
319
  """
 
60
  self.scheduler= None
61
 
62
  def load_models(self, model_version:str="CompVis/stable-diffusion-v1-4"):
63
+ """
64
+ Load all the required models for stable diffusion.
65
+ """
66
+ try:
67
+ # Add cache directory to ensure files are saved in a writable location
68
+ cache_dir = "./model_cache"
69
+ os.makedirs(cache_dir, exist_ok=True)
70
+
71
+ # Load VAE
72
+ self.vae = AutoencoderKL.from_pretrained(
73
+ model_version,
74
+ subfolder="vae",
75
+ cache_dir=cache_dir,
76
+ local_files_only=False
77
+ )
78
+
79
+ # Load tokenizer and text encoder with explicit cache directory
80
+ self.tokenizer = CLIPTokenizer.from_pretrained(
81
+ "openai/clip-vit-large-patch14",
82
+ cache_dir=cache_dir,
83
+ local_files_only=False
84
+ )
85
+
86
+ self.text_encoder = CLIPTextModel.from_pretrained(
87
+ "openai/clip-vit-large-patch14",
88
+ cache_dir=cache_dir,
89
+ local_files_only=False
90
+ )
91
 
92
+ # Load UNet
93
+ self.unet = UNet2DConditionModel.from_pretrained(
94
+ model_version,
95
+ subfolder="unet",
96
+ cache_dir=cache_dir,
97
+ local_files_only=False
98
+ )
99
+
100
+ # Load scheduler
101
+ self.scheduler = LMSDiscreteScheduler(
102
+ beta_start=0.00085,
103
+ beta_end=0.012,
104
+ beta_schedule="scaled_linear",
105
+ num_train_timesteps=1000
106
+ )
107
+
108
+ # Move models to device
109
+ self.vae = self.vae.to(self.config.device)
110
+ self.text_encoder = self.text_encoder.to(self.config.device)
111
+ self.unet = self.unet.to(self.config.device)
112
+
113
+ print(f"Using device: {self.config.device}")
114
+ return self
115
+
116
+ except Exception as e:
117
+ print(f"Error loading models: {str(e)}")
118
+ # Add more detailed error information
119
+ import traceback
120
+ traceback.print_exc()
121
+ raise
122
 
123
  def set_timesteps(self, num_inference_steps:int=None):
124
  """
 
334
  else:
335
  print(f"Failed to load concept: {concept_name}")
336
 
337
+ def generate_with_multiple_concepts(models, config, image_processor, prompt, concepts, output_dir="concept_images"):
338
+ """
339
+ Generate images using multiple concepts
340
+ """
341
+ os.makedirs(output_dir, exist_ok=True)
342
+
343
+ if not concepts:
344
+ # Handle the case with no concept
345
+ # ... your existing code ...
346
+ # Make sure to return the PIL Image object
347
+ return pil_image
348
+
349
+ for concept in concepts:
350
+ # ... your existing code ...
351
+ # Make sure to return the PIL Image object
352
+ return pil_image
353
+
354
+ # If we get here, return None
355
+ return None
356
 
357
  def channel_loss(images, channel_idx=2, target_value=0.9):
358
  """