Spaces:
Runtime error
Runtime error
| from model import DesignModel | |
| from PIL import Image | |
| import numpy as np | |
| from typing import List | |
| import random | |
| import time | |
| import torch | |
| from diffusers import StableDiffusionImg2ImgPipeline | |
| from transformers import CLIPTokenizer | |
| import logging | |
| import os | |
| from datetime import datetime | |
| # Set up logging | |
| log_dir = "logs" | |
| os.makedirs(log_dir, exist_ok=True) | |
| log_file = os.path.join(log_dir, f"prod_model_{datetime.now().strftime('%Y%m%d')}.log") | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(levelname)s - %(message)s', | |
| handlers=[ | |
| logging.FileHandler(log_file), | |
| logging.StreamHandler() | |
| ] | |
| ) | |
| class ProductionDesignModel(DesignModel): | |
| def __init__(self): | |
| super().__init__() | |
| try: | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| logging.info(f"Using device: {self.device}") | |
| self.model_id = "stabilityai/stable-diffusion-2-1" | |
| logging.info(f"Loading model: {self.model_id}") | |
| # Initialize the pipeline with error handling | |
| try: | |
| self.pipe = StableDiffusionImg2ImgPipeline.from_pretrained( | |
| self.model_id, | |
| torch_dtype=torch.float16 if self.device == "cuda" else torch.float32, | |
| safety_checker=None # Disable safety checker for performance | |
| ).to(self.device) | |
| # Enable optimizations | |
| self.pipe.enable_attention_slicing() | |
| if self.device == "cuda": | |
| self.pipe.enable_model_cpu_offload() | |
| self.pipe.enable_vae_slicing() | |
| logging.info("Model loaded successfully") | |
| except Exception as e: | |
| logging.error(f"Error loading model: {e}") | |
| raise | |
| # Initialize tokenizer | |
| self.tokenizer = CLIPTokenizer.from_pretrained(self.model_id) | |
| # Set default prompts | |
| self.neg_prompt = "blurry, low quality, distorted, deformed, disfigured, watermark, text, bad proportions, duplicate, double, multiple, broken, cropped" | |
| self.additional_quality_suffix = "interior design, 4K, high resolution, photorealistic" | |
| except Exception as e: | |
| logging.error(f"Error in initialization: {e}") | |
| raise | |
| def _prepare_prompt(self, prompt: str) -> str: | |
| """Prepare the prompt by adding quality suffix and checking length""" | |
| try: | |
| full_prompt = f"{prompt}, {self.additional_quality_suffix}" | |
| tokens = self.tokenizer.tokenize(full_prompt) | |
| if len(tokens) > 77: | |
| logging.warning(f"Prompt too long ({len(tokens)} tokens). Truncating...") | |
| tokens = tokens[:77] | |
| full_prompt = self.tokenizer.convert_tokens_to_string(tokens) | |
| logging.info(f"Prepared prompt: {full_prompt}") | |
| return full_prompt | |
| except Exception as e: | |
| logging.error(f"Error preparing prompt: {e}") | |
| return prompt # Return original prompt if processing fails | |
| def generate_design(self, image: Image.Image, num_variations: int = 1, **kwargs) -> List[np.ndarray]: | |
| """Generate design variations with proper parameter handling""" | |
| generation_start = time.time() | |
| try: | |
| # Log input parameters | |
| logging.info(f"Generating {num_variations} variations with parameters: {kwargs}") | |
| # Get parameters from kwargs with defaults | |
| prompt = kwargs.get('prompt', '') | |
| num_steps = int(kwargs.get('num_steps', 50)) | |
| guidance_scale = float(kwargs.get('guidance_scale', 7.5)) | |
| strength = float(kwargs.get('strength', 0.75)) | |
| base_seed = kwargs.get('seed', int(time.time())) | |
| # Parameter validation | |
| num_steps = max(20, min(100, num_steps)) | |
| guidance_scale = max(1, min(20, guidance_scale)) | |
| strength = max(0.1, min(1.0, strength)) | |
| # Prepare the prompt | |
| full_prompt = self._prepare_prompt(prompt) | |
| # Generate distinct seeds | |
| seeds = [base_seed + i * 10000 for i in range(num_variations)] | |
| logging.info(f"Using seeds: {seeds}") | |
| # Prepare the input image | |
| if image.mode != "RGB": | |
| image = image.convert("RGB") | |
| # Generate variations | |
| variations = [] | |
| generator = torch.Generator(device=self.device) | |
| for i, seed in enumerate(seeds): | |
| try: | |
| variation_start = time.time() | |
| generator.manual_seed(seed) | |
| # Generate the image | |
| output = self.pipe( | |
| prompt=full_prompt, | |
| negative_prompt=self.neg_prompt, | |
| image=image, | |
| num_inference_steps=num_steps, | |
| guidance_scale=guidance_scale, | |
| strength=strength, | |
| generator=generator | |
| ).images[0] | |
| variations.append(np.array(output)) | |
| variation_time = time.time() - variation_start | |
| logging.info(f"Generated variation {i+1}/{num_variations} in {variation_time:.2f}s") | |
| except Exception as e: | |
| logging.error(f"Error generating variation {i+1}: {e}") | |
| if not variations: # If no successful variations yet | |
| variations.append(np.array(image.convert('RGB'))) | |
| total_time = time.time() - generation_start | |
| logging.info(f"Generation completed in {total_time:.2f}s") | |
| return variations | |
| except Exception as e: | |
| logging.error(f"Error in generate_design: {e}") | |
| import traceback | |
| logging.error(traceback.format_exc()) | |
| return [np.array(image.convert('RGB'))] | |
| finally: | |
| if self.device == "cuda": | |
| torch.cuda.empty_cache() | |
| logging.info("Cleared CUDA cache") | |
| def __del__(self): | |
| """Cleanup when the model is deleted""" | |
| try: | |
| if self.device == "cuda": | |
| torch.cuda.empty_cache() | |
| logging.info("Final CUDA cache cleanup") | |
| except: | |
| pass |