minakshi.mathpal commited on
Commit
e875d47
·
1 Parent(s): 8a23def

refactored app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -3
app.py CHANGED
@@ -2,6 +2,7 @@ import streamlit as st
2
  import torch
3
  import random
4
  import time
 
5
  from PIL import Image
6
  from custom_stable_diffusion import StableDiffusionConfig, StableDiffusionModels,ImageProcessor, generate_with_multiple_concepts,generate_with_multiple_concepts_and_color
7
  st.set_page_config(
@@ -123,16 +124,39 @@ if standard_button:
123
  progress_bar = st.progress(0)
124
  start_time = time.time()
125
 
126
- image = generate_with_multiple_concepts(
 
127
  models=st.session_state.models,
128
  config=st.session_state.config,
129
  image_processor=st.session_state.image_processor,
130
  prompt=prompt,
131
- concepts=[concept_name] if concept_name else [], # Pass the selected concept
132
  output_dir="concept_images"
133
  )
134
 
135
  end_time = time.time()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
  caption = f"Standard Stable Diffusion"
137
  if concept_name:
138
  caption += f" with {concept_name} concept"
@@ -146,7 +170,8 @@ if color_button:
146
  progress_bar = st.progress(0)
147
  start_time = time.time()
148
 
149
- image = generate_with_multiple_concepts_and_color(
 
150
  models=st.session_state.models,
151
  config=st.session_state.config,
152
  image_processor=st.session_state.image_processor,
@@ -158,6 +183,30 @@ if color_button:
158
  )
159
 
160
  end_time = time.time()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
  caption = f"Color-Guided Stable Diffusion"
162
  if concept_name:
163
  caption += f" with {concept_name} concept"
 
2
  import torch
3
  import random
4
  import time
5
+ import os
6
  from PIL import Image
7
  from custom_stable_diffusion import StableDiffusionConfig, StableDiffusionModels,ImageProcessor, generate_with_multiple_concepts,generate_with_multiple_concepts_and_color
8
  st.set_page_config(
 
124
  progress_bar = st.progress(0)
125
  start_time = time.time()
126
 
127
+ # Call the generation function
128
+ result = generate_with_multiple_concepts(
129
  models=st.session_state.models,
130
  config=st.session_state.config,
131
  image_processor=st.session_state.image_processor,
132
  prompt=prompt,
133
+ concepts=[concept_name] if concept_name else [],
134
  output_dir="concept_images"
135
  )
136
 
137
  end_time = time.time()
138
+
139
+ # Check if we got a valid image back
140
+ if result is not None and hasattr(result, 'format'):
141
+ # It's a PIL Image object
142
+ image = result
143
+ else:
144
+ # Try to load the image from the expected output path
145
+ try:
146
+ if concept_name:
147
+ image_path = f"concept_images/{concept_name}/{concept_name}.png"
148
+ else:
149
+ image_path = "concept_images/standard_image.png"
150
+
151
+ if os.path.exists(image_path):
152
+ image = Image.open(image_path)
153
+ else:
154
+ st.error(f"Could not find generated image at {image_path}")
155
+ continue
156
+ except Exception as e:
157
+ st.error(f"Error loading image: {str(e)}")
158
+ continue
159
+
160
  caption = f"Standard Stable Diffusion"
161
  if concept_name:
162
  caption += f" with {concept_name} concept"
 
170
  progress_bar = st.progress(0)
171
  start_time = time.time()
172
 
173
+ # Call the generation function
174
+ result = generate_with_multiple_concepts_and_color(
175
  models=st.session_state.models,
176
  config=st.session_state.config,
177
  image_processor=st.session_state.image_processor,
 
183
  )
184
 
185
  end_time = time.time()
186
+
187
+ # Check if we got a valid image back
188
+ if result is not None and hasattr(result, 'format'):
189
+ # It's a PIL Image object
190
+ image = result
191
+ else:
192
+ # Try to load the image from the expected output path
193
+ try:
194
+ if concept_name:
195
+ # Determine the filename based on color guidance
196
+ color_info = f"_yellow{yellow_strength}" if yellow_strength > 0 else ""
197
+ image_path = f"concept_images/{concept_name}/{concept_name}{color_info}.png"
198
+ else:
199
+ image_path = "concept_images/color_guided_image.png"
200
+
201
+ if os.path.exists(image_path):
202
+ image = Image.open(image_path)
203
+ else:
204
+ st.error(f"Could not find generated image at {image_path}")
205
+ continue
206
+ except Exception as e:
207
+ st.error(f"Error loading image: {str(e)}")
208
+ continue
209
+
210
  caption = f"Color-Guided Stable Diffusion"
211
  if concept_name:
212
  caption += f" with {concept_name} concept"