Spaces:
Runtime error
Runtime error
anbucur
commited on
Commit
·
5d8e518
1
Parent(s):
25e8e9c
Refactor generate_design method in ProductionDesignModel for improved image handling and variation generation
Browse files- Updated the method to accept various image types (PIL Image, numpy array, torch tensor) and ensure proper conversion to RGB format.
- Enhanced parameter handling by consolidating the retrieval of prompt, number of variations, and other settings from kwargs.
- Implemented distinct seed generation for each variation to ensure diversity in outputs.
- Improved error handling and logging for better traceability during the design generation process.
- Cleared CUDA cache after each variation generation to optimize memory usage.
- prod_model.py +54 -55
prod_model.py
CHANGED
|
@@ -162,83 +162,82 @@ class ProductionDesignModel(DesignModel):
|
|
| 162 |
if torch.cuda.is_available():
|
| 163 |
torch.cuda.empty_cache()
|
| 164 |
|
| 165 |
-
def generate_design(self, image
|
| 166 |
-
"""
|
| 167 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 168 |
"""
|
| 169 |
try:
|
| 170 |
-
#
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 174 |
|
| 175 |
# Get parameters
|
| 176 |
-
|
| 177 |
-
guidance_scale = float(kwargs.get('guidance_scale', 10.0))
|
| 178 |
num_steps = int(kwargs.get('num_steps', 50))
|
|
|
|
| 179 |
strength = float(kwargs.get('strength', 0.9))
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
logging.info(f"Generating design with parameters: guidance_scale={guidance_scale}, "
|
| 183 |
-
f"num_steps={num_steps}, strength={strength}, img_size={img_size}")
|
| 184 |
-
|
| 185 |
-
# Prepare prompt
|
| 186 |
-
pos_prompt = f"{prompt}, {self.additional_quality_suffix}"
|
| 187 |
-
|
| 188 |
-
# Process input image
|
| 189 |
-
orig_size = image.size
|
| 190 |
-
input_image = self._resize_image(image, img_size)
|
| 191 |
|
| 192 |
-
#
|
| 193 |
-
|
|
|
|
| 194 |
|
| 195 |
-
# Generate segmentation
|
| 196 |
-
seg_map = self._segment_image(input_image)
|
| 197 |
-
|
| 198 |
-
# Generate IP-adapter reference image
|
| 199 |
-
self._flush()
|
| 200 |
-
ip_image = self.guide_pipe(
|
| 201 |
-
pos_prompt,
|
| 202 |
-
num_inference_steps=num_steps,
|
| 203 |
-
negative_prompt=self.neg_prompt,
|
| 204 |
-
generator=self.generator
|
| 205 |
-
).images[0]
|
| 206 |
-
|
| 207 |
-
# Generate variations
|
| 208 |
variations = []
|
| 209 |
for i in range(num_variations):
|
| 210 |
try:
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 215 |
num_inference_steps=num_steps,
|
| 216 |
-
strength=strength,
|
| 217 |
guidance_scale=guidance_scale,
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
control_image=[depth_map, seg_map],
|
| 222 |
-
controlnet_conditioning_scale=[0.5, 0.5]
|
| 223 |
).images[0]
|
| 224 |
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
variations.append(variation)
|
| 228 |
|
| 229 |
except Exception as e:
|
| 230 |
-
logging.error(f"Error generating variation {i}: {e}")
|
| 231 |
continue
|
| 232 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 233 |
if not variations:
|
| 234 |
logging.warning("No variations were generated successfully")
|
| 235 |
-
return [image] # Return original image if no variations
|
| 236 |
-
|
| 237 |
return variations
|
| 238 |
-
|
| 239 |
except Exception as e:
|
| 240 |
-
logging.error(f"Error in generate_design: {e}")
|
| 241 |
-
return [image] # Return original image
|
| 242 |
|
| 243 |
def __del__(self):
|
| 244 |
"""Cleanup when the model is deleted"""
|
|
|
|
| 162 |
if torch.cuda.is_available():
|
| 163 |
torch.cuda.empty_cache()
|
| 164 |
|
| 165 |
+
def generate_design(self, image, num_variations=1, **kwargs):
|
| 166 |
+
"""Generate design variations using the model.
|
| 167 |
+
|
| 168 |
+
Args:
|
| 169 |
+
image: Input image (PIL Image, numpy array, or torch tensor)
|
| 170 |
+
num_variations: Number of variations to generate
|
| 171 |
+
**kwargs: Additional parameters like prompt, num_steps, guidance_scale, strength
|
| 172 |
+
|
| 173 |
+
Returns:
|
| 174 |
+
List of generated images
|
| 175 |
"""
|
| 176 |
try:
|
| 177 |
+
# Convert image to PIL Image if needed
|
| 178 |
+
if isinstance(image, np.ndarray):
|
| 179 |
+
image = Image.fromarray(image)
|
| 180 |
+
elif isinstance(image, torch.Tensor):
|
| 181 |
+
# Convert tensor to numpy then PIL
|
| 182 |
+
image = Image.fromarray((image.cpu().numpy() * 255).astype(np.uint8))
|
| 183 |
+
|
| 184 |
+
if not isinstance(image, Image.Image):
|
| 185 |
+
raise ValueError(f"Unsupported image type: {type(image)}")
|
| 186 |
+
|
| 187 |
+
# Ensure image is RGB
|
| 188 |
+
if image.mode != "RGB":
|
| 189 |
+
image = image.convert("RGB")
|
| 190 |
|
| 191 |
# Get parameters
|
| 192 |
+
prompt = kwargs.get('prompt', '')
|
|
|
|
| 193 |
num_steps = int(kwargs.get('num_steps', 50))
|
| 194 |
+
guidance_scale = float(kwargs.get('guidance_scale', 10.0))
|
| 195 |
strength = float(kwargs.get('strength', 0.9))
|
| 196 |
+
seed_param = kwargs.get('seed')
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 197 |
|
| 198 |
+
# Handle seed
|
| 199 |
+
base_seed = int(time.time()) if seed_param is None else int(seed_param)
|
| 200 |
+
logging.info(f"Using base seed: {base_seed}")
|
| 201 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 202 |
variations = []
|
| 203 |
for i in range(num_variations):
|
| 204 |
try:
|
| 205 |
+
# Generate distinct seed for each variation
|
| 206 |
+
seed = base_seed + i
|
| 207 |
+
generator = torch.Generator(device=self.device).manual_seed(seed)
|
| 208 |
+
|
| 209 |
+
# Generate variation
|
| 210 |
+
output = self.pipe(
|
| 211 |
+
prompt=prompt,
|
| 212 |
+
image=image,
|
| 213 |
num_inference_steps=num_steps,
|
|
|
|
| 214 |
guidance_scale=guidance_scale,
|
| 215 |
+
strength=strength,
|
| 216 |
+
generator=generator,
|
| 217 |
+
negative_prompt=self.neg_prompt
|
|
|
|
|
|
|
| 218 |
).images[0]
|
| 219 |
|
| 220 |
+
variations.append(output)
|
| 221 |
+
logging.info(f"Successfully generated variation {i} with seed {seed}")
|
|
|
|
| 222 |
|
| 223 |
except Exception as e:
|
| 224 |
+
logging.error(f"Error generating variation {i}: {str(e)}")
|
| 225 |
continue
|
| 226 |
+
|
| 227 |
+
finally:
|
| 228 |
+
# Clear CUDA cache after each variation
|
| 229 |
+
if torch.cuda.is_available():
|
| 230 |
+
torch.cuda.empty_cache()
|
| 231 |
+
|
| 232 |
if not variations:
|
| 233 |
logging.warning("No variations were generated successfully")
|
| 234 |
+
return [image] # Return original image if no variations generated
|
| 235 |
+
|
| 236 |
return variations
|
| 237 |
+
|
| 238 |
except Exception as e:
|
| 239 |
+
logging.error(f"Error in generate_design: {str(e)}")
|
| 240 |
+
return [image] # Return original image on error
|
| 241 |
|
| 242 |
def __del__(self):
|
| 243 |
"""Cleanup when the model is deleted"""
|