Spaces:
Sleeping
Sleeping
| """ | |
| Enhanced ImageDryer with better error handling and alternative API options. | |
| This module provides a more robust implementation of the ImageDryer class | |
| with improved error handling, retry logic, and alternative API endpoints. | |
| """ | |
| import os | |
| import time | |
| import io | |
| import base64 | |
| import requests | |
| import random | |
| from typing import Optional, List, Dict, Any | |
| from PIL import Image | |
| from dotenv import load_dotenv | |
| # Load environment variables | |
| load_dotenv() | |
| class EnhancedImageDryer: | |
| def __init__(self): | |
| """Initialize the EnhancedImageDryer with Stability AI API.""" | |
| self.api_key = os.getenv("STABILITY_API_KEY") | |
| self.api_host = "https://api.stability.ai" | |
| # List of available engines to try | |
| self.engines = [ | |
| "stable-diffusion-xl-1024-v1-0", | |
| "stable-diffusion-v1-5", | |
| "stable-diffusion-512-v2-1" | |
| ] | |
| self.current_engine = self.engines[0] | |
| self.max_retries = 3 | |
| self.retry_delay = 2 # seconds between retries | |
| def preprocess_image(self, image: Image.Image) -> Image.Image: | |
| """Preprocess the image to meet API requirements.""" | |
| # Convert to RGB if needed | |
| if image.mode != "RGB": | |
| image = image.convert("RGB") | |
| # Resize if larger than 1024x1024 | |
| max_size = 1024 | |
| if max(image.size) > max_size: | |
| ratio = max_size / max(image.size) | |
| new_size = tuple(int(dim * ratio) for dim in image.size) | |
| image = image.resize(new_size, Image.LANCZOS) | |
| return image | |
| def get_prompt_variations(self) -> List[Dict[str, Any]]: | |
| """Get different prompt variations to try for better results.""" | |
| prompt_variations = [ | |
| # Standard drying prompt | |
| [ | |
| { | |
| "text": "A completely dry version of this item, photorealistic, detailed texture, no water or moisture", | |
| "weight": 1 | |
| }, | |
| { | |
| "text": "wet, moist, damp, water droplets, puddles, stains", | |
| "weight": -1 | |
| } | |
| ], | |
| # Alternative drying prompt with more emphasis on dryness | |
| [ | |
| { | |
| "text": "Bone dry, completely dried out, arid, desert-like dryness, crisp texture, no moisture whatsoever", | |
| "weight": 1 | |
| }, | |
| { | |
| "text": "wet, damp, moist, humidity, water, liquid, droplets, condensation", | |
| "weight": -1 | |
| } | |
| ], | |
| # Focus on texture and detail | |
| [ | |
| { | |
| "text": "Dry texture, detailed fabric, no moisture, sun-dried appearance, crisp details", | |
| "weight": 1 | |
| }, | |
| { | |
| "text": "wet appearance, water stains, dampness, moisture", | |
| "weight": -1 | |
| } | |
| ] | |
| ] | |
| return prompt_variations | |
| def process_image(self, image: Image.Image) -> Optional[Image.Image]: | |
| """Process an image to make it appear dry using Stability AI API with robust error handling.""" | |
| if not self.api_key: | |
| print("Error: No Stability API key found in environment variables.") | |
| return None | |
| # Preprocess the image | |
| processed_image = self.preprocess_image(image) | |
| # Convert image to bytes | |
| buffered = io.BytesIO() | |
| processed_image.save(buffered, format="PNG") | |
| img_bytes = buffered.getvalue() | |
| # Try different engines and prompts | |
| prompt_variations = self.get_prompt_variations() | |
| for retry in range(self.max_retries): | |
| # Select engine based on retry count | |
| self.current_engine = self.engines[retry % len(self.engines)] | |
| # Select prompt variation | |
| prompts = prompt_variations[retry % len(prompt_variations)] | |
| print(f"Attempt {retry+1}/{self.max_retries} using engine: {self.current_engine}") | |
| try: | |
| # Prepare the API request | |
| url = f"{self.api_host}/v1/generation/{self.current_engine}/image-to-image" | |
| headers = { | |
| "Authorization": f"Bearer {self.api_key}" | |
| } | |
| # Adjust parameters based on the engine | |
| image_strength = 0.35 | |
| cfg_scale = 7 | |
| steps = 30 | |
| if "xl" not in self.current_engine: | |
| # Adjust parameters for non-XL models | |
| image_strength = 0.4 | |
| cfg_scale = 8 | |
| steps = 25 | |
| # Prepare files and data for multipart form request | |
| files = { | |
| "init_image": ("image.png", img_bytes, "image/png"), | |
| } | |
| data = { | |
| "image_strength": image_strength, | |
| "cfg_scale": cfg_scale, | |
| "samples": 1, | |
| "steps": steps | |
| } | |
| # Add text prompts | |
| for i, prompt in enumerate(prompts): | |
| data[f"text_prompts[{i}][text]"] = prompt["text"] | |
| data[f"text_prompts[{i}][weight]"] = prompt["weight"] | |
| # Make the API request | |
| print(f"Sending request to Stability AI API...") | |
| response = requests.post(url, headers=headers, files=files, data=data, timeout=60) | |
| if response.status_code == 200: | |
| # Process the response | |
| data = response.json() | |
| image_data = base64.b64decode(data["artifacts"][0]["base64"]) | |
| # Convert to PIL Image | |
| result = Image.open(io.BytesIO(image_data)) | |
| print(f"Successfully processed image with engine: {self.current_engine}") | |
| return result | |
| else: | |
| print(f"API request failed with status code {response.status_code}: {response.text}") | |
| # Check for rate limiting or server errors | |
| if response.status_code == 429: # Too Many Requests | |
| print("Rate limited. Waiting longer before retry...") | |
| time.sleep(self.retry_delay * 3) # Wait longer for rate limiting | |
| elif response.status_code >= 500: # Server errors | |
| print("Server error. Retrying with different engine...") | |
| else: | |
| print(f"Error: {response.text}") | |
| except Exception as e: | |
| print(f"Error during API request: {str(e)}") | |
| # Wait before retrying | |
| if retry < self.max_retries - 1: | |
| delay = self.retry_delay * (retry + 1) # Increase delay with each retry | |
| print(f"Retrying in {delay} seconds...") | |
| time.sleep(delay) | |
| print("All retry attempts failed.") | |
| return None | |
| def save_image(self, image: Image.Image, filename: str) -> None: | |
| """Save an image to a file.""" | |
| image.save(filename) | |
| def image_to_bytes(self, image: Image.Image) -> bytes: | |
| """Convert PIL Image to bytes.""" | |
| img_byte_arr = io.BytesIO() | |
| image.save(img_byte_arr, format='PNG') | |
| return img_byte_arr.getvalue() | |
| def apply_fallback_drying_effect(self, image: Image.Image) -> Image.Image: | |
| """Apply a simple drying effect as a fallback when API fails.""" | |
| print("Applying fallback drying effect...") | |
| # Convert to RGB if not already | |
| if image.mode != 'RGB': | |
| image = image.convert('RGB') | |
| # Create a copy to work with | |
| result = image.copy() | |
| # Apply brightness and contrast adjustments to simulate drying | |
| brightness_factor = 1.2 | |
| contrast_factor = 1.1 | |
| saturation_reduction = 0.8 | |
| # Apply adjustments | |
| # 1. Increase brightness | |
| result = result.point(lambda p: min(255, int(p * brightness_factor))) | |
| # 2. Increase contrast | |
| result = result.point(lambda p: min(255, int(128 + contrast_factor * (p - 128)))) | |
| # 3. Convert to HSV to reduce saturation | |
| from colorsys import rgb_to_hsv, hsv_to_rgb | |
| # Process each pixel | |
| width, height = result.size | |
| for x in range(width): | |
| for y in range(height): | |
| r, g, b = result.getpixel((x, y)) | |
| h, s, v = rgb_to_hsv(r/255, g/255, b/255) | |
| s *= saturation_reduction # Reduce saturation | |
| r, g, b = hsv_to_rgb(h, s, v) | |
| result.putpixel((x, y), (int(r*255), int(g*255), int(b*255))) | |
| return result | |