Spaces:
Sleeping
Sleeping
refactor: improve code readability and structure in OpenAI integration tests and services, update requirements for consistency
f5c3d9c | import base64 | |
| import os | |
| import uuid | |
| from openai import OpenAI | |
| import logging | |
| from typing import Optional, List | |
| from ..core.config import settings | |
| logger = logging.getLogger(__name__) | |
| class ImageGenerationService: | |
| """Service for handling OpenAI image generation""" | |
| def __init__(self): | |
| self.client = OpenAI( | |
| api_key=os.getenv("OPENAI_API_KEY"), | |
| timeout=60.0, # Increase timeout for Hugging Face environment | |
| max_retries=2, # Reduce retries to fail faster | |
| ) | |
| self.output_dir = "generated_images" | |
| self._ensure_output_directory() | |
| def _ensure_output_directory(self): | |
| """Ensure the output directory exists""" | |
| if not os.path.exists(self.output_dir): | |
| os.makedirs(self.output_dir) | |
| async def _fallback_to_dalle( | |
| self, prompt: str, size: str, n: int, model: str | |
| ) -> dict: | |
| """ | |
| Fallback to regular DALL-E when responses API is blocked | |
| This sacrifices reference image capability but ensures the app works on Hugging Face | |
| """ | |
| try: | |
| logger.info("Using DALL-E fallback (reference image will be ignored)") | |
| response = self.client.images.generate( | |
| model=model, | |
| prompt=prompt, | |
| n=n, | |
| size=size, | |
| response_format="b64_json", | |
| ) | |
| generated_filenames = [] | |
| for i, image_data in enumerate(response.data): | |
| try: | |
| image_bytes = base64.b64decode(image_data.b64_json) | |
| # Generate unique filename and save | |
| filename = f"{uuid.uuid4()}.png" | |
| filepath = os.path.join(self.output_dir, filename) | |
| with open(filepath, "wb") as f: | |
| f.write(image_bytes) | |
| generated_filenames.append(filename) | |
| logger.info(f"Fallback image {i+1} saved successfully: {filename}") | |
| except Exception as e: | |
| logger.warning(f"Failed to save fallback image {i+1}: {str(e)}") | |
| continue | |
| if generated_filenames: | |
| return { | |
| "success": True, | |
| "message": f"Generated {len(generated_filenames)}/{n} images using DALL-E fallback (reference image ignored due to network restrictions)", | |
| "filename": generated_filenames[0], | |
| "filenames": generated_filenames, | |
| "count": len(generated_filenames), | |
| } | |
| else: | |
| raise Exception("Fallback also failed to generate any images") | |
| except Exception as e: | |
| logger.error(f"Fallback to DALL-E also failed: {str(e)}") | |
| raise Exception(f"Both responses API and DALL-E fallback failed: {str(e)}") | |
| async def generate_image( | |
| self, | |
| prompt: str, | |
| size: str = "256x256", | |
| n: int = 1, | |
| model: str = "dall-e-3", | |
| reference_image: Optional[str] = None, | |
| ) -> dict: | |
| """ | |
| Generate image(s) using OpenAI, optionally using a reference image | |
| Args: | |
| prompt: Text prompt for image generation | |
| size: Image size (256x256, 512x512, 1024x1024) | |
| n: Number of images to generate | |
| model: Model to use for generation | |
| reference_image: Base64 encoded reference image (optional) | |
| Returns: | |
| dict: Result containing success status, message, and filename(s) | |
| """ | |
| try: | |
| logger.info(f"Generating {n} image(s) with prompt: {prompt}") | |
| if reference_image: | |
| # Use the newer responses API with image generation tools for reference images | |
| logger.info("Using reference image with responses API") | |
| generated_filenames = [] | |
| # Generate multiple images by making multiple requests | |
| for i in range(n): | |
| try: | |
| logger.info(f"Generating image {i+1}/{n}") | |
| content = [ | |
| {"type": "input_text", "text": prompt}, | |
| { | |
| "type": "input_image", | |
| "image_url": f"data:image/jpeg;base64,{reference_image}", | |
| }, | |
| ] | |
| response = self.client.responses.create( | |
| model="gpt-4.1", | |
| input=[ | |
| { | |
| "role": "user", | |
| "content": content, | |
| } | |
| ], | |
| tools=[{"type": "image_generation"}], | |
| ) | |
| # Extract image generation results | |
| image_generation_calls = [ | |
| output | |
| for output in response.output | |
| if output.type == "image_generation_call" | |
| ] | |
| if not image_generation_calls: | |
| logger.warning( | |
| f"No image generation calls found in response {i+1}, likely returned text instead" | |
| ) | |
| continue | |
| image_data = image_generation_calls[0].result | |
| if not image_data: | |
| logger.warning( | |
| f"No image data returned from generation {i+1}" | |
| ) | |
| continue | |
| # Decode base64 image | |
| image_bytes = base64.b64decode(image_data) | |
| # Generate unique filename and save | |
| filename = f"{uuid.uuid4()}.png" | |
| filepath = os.path.join(self.output_dir, filename) | |
| with open(filepath, "wb") as f: | |
| f.write(image_bytes) | |
| generated_filenames.append(filename) | |
| logger.info(f"Image {i+1} saved successfully: {filename}") | |
| except Exception as e: | |
| error_msg = str(e) | |
| logger.warning(f"Failed to generate image {i+1}: {error_msg}") | |
| # More specific error handling for network issues | |
| if ( | |
| "Connection error" in error_msg | |
| or "timeout" in error_msg.lower() | |
| ): | |
| logger.error( | |
| f"Network connectivity issue detected: {error_msg}" | |
| ) | |
| logger.error( | |
| "This might be due to Hugging Face network restrictions" | |
| ) | |
| elif ( | |
| "api_key" in error_msg.lower() | |
| or "unauthorized" in error_msg.lower() | |
| ): | |
| logger.error(f"API key issue detected: {error_msg}") | |
| elif "rate limit" in error_msg.lower(): | |
| logger.error(f"Rate limit issue detected: {error_msg}") | |
| continue | |
| if not generated_filenames: | |
| # If responses API failed due to network restrictions, try fallback to regular DALL-E | |
| logger.warning( | |
| "Responses API failed, attempting fallback to regular DALL-E" | |
| ) | |
| return await self._fallback_to_dalle(prompt, size, n, model) | |
| logger.info( | |
| f"Successfully generated {len(generated_filenames)}/{n} images" | |
| ) | |
| return { | |
| "success": True, | |
| "message": f"Generated {len(generated_filenames)}/{n} images successfully", | |
| "filename": generated_filenames[0] if generated_filenames else None, | |
| "filenames": generated_filenames, | |
| "count": len(generated_filenames), | |
| } | |
| else: | |
| # Use traditional DALL-E for text-only prompts | |
| logger.info("Using DALL-E for text-only generation") | |
| response = self.client.images.generate( | |
| model=model, | |
| prompt=prompt, | |
| n=n, | |
| size=size, | |
| response_format="b64_json", | |
| ) | |
| generated_filenames = [] | |
| for i, image_data in enumerate(response.data): | |
| try: | |
| image_bytes = base64.b64decode(image_data.b64_json) | |
| # Generate unique filename and save | |
| filename = f"{uuid.uuid4()}.png" | |
| filepath = os.path.join(self.output_dir, filename) | |
| with open(filepath, "wb") as f: | |
| f.write(image_bytes) | |
| generated_filenames.append(filename) | |
| logger.info(f"Image {i+1} saved successfully: {filename}") | |
| except Exception as e: | |
| logger.warning(f"Failed to save image {i+1}: {str(e)}") | |
| continue | |
| return { | |
| "success": True, | |
| "message": f"Generated {len(generated_filenames)}/{n} images successfully", | |
| "filename": generated_filenames[0] if generated_filenames else None, | |
| "filenames": generated_filenames, | |
| "count": len(generated_filenames), | |
| } | |
| except Exception as e: | |
| logger.error(f"Error generating image: {str(e)}") | |
| return { | |
| "success": False, | |
| "message": f"Failed to generate image: {str(e)}", | |
| "filename": None, | |
| "filenames": [], | |
| "count": 0, | |
| } | |
| # Create a singleton instance | |
| image_service = ImageGenerationService() | |