image_generation / backend /app /services /image_service.py
ChenyuRabbitLove's picture
refactor: improve code readability and structure in OpenAI integration tests and services, update requirements for consistency
f5c3d9c
import base64
import os
import uuid
from openai import OpenAI
import logging
from typing import Optional, List
from ..core.config import settings
logger = logging.getLogger(__name__)
class ImageGenerationService:
"""Service for handling OpenAI image generation"""
def __init__(self):
self.client = OpenAI(
api_key=os.getenv("OPENAI_API_KEY"),
timeout=60.0, # Increase timeout for Hugging Face environment
max_retries=2, # Reduce retries to fail faster
)
self.output_dir = "generated_images"
self._ensure_output_directory()
def _ensure_output_directory(self):
"""Ensure the output directory exists"""
if not os.path.exists(self.output_dir):
os.makedirs(self.output_dir)
async def _fallback_to_dalle(
self, prompt: str, size: str, n: int, model: str
) -> dict:
"""
Fallback to regular DALL-E when responses API is blocked
This sacrifices reference image capability but ensures the app works on Hugging Face
"""
try:
logger.info("Using DALL-E fallback (reference image will be ignored)")
response = self.client.images.generate(
model=model,
prompt=prompt,
n=n,
size=size,
response_format="b64_json",
)
generated_filenames = []
for i, image_data in enumerate(response.data):
try:
image_bytes = base64.b64decode(image_data.b64_json)
# Generate unique filename and save
filename = f"{uuid.uuid4()}.png"
filepath = os.path.join(self.output_dir, filename)
with open(filepath, "wb") as f:
f.write(image_bytes)
generated_filenames.append(filename)
logger.info(f"Fallback image {i+1} saved successfully: {filename}")
except Exception as e:
logger.warning(f"Failed to save fallback image {i+1}: {str(e)}")
continue
if generated_filenames:
return {
"success": True,
"message": f"Generated {len(generated_filenames)}/{n} images using DALL-E fallback (reference image ignored due to network restrictions)",
"filename": generated_filenames[0],
"filenames": generated_filenames,
"count": len(generated_filenames),
}
else:
raise Exception("Fallback also failed to generate any images")
except Exception as e:
logger.error(f"Fallback to DALL-E also failed: {str(e)}")
raise Exception(f"Both responses API and DALL-E fallback failed: {str(e)}")
async def generate_image(
self,
prompt: str,
size: str = "256x256",
n: int = 1,
model: str = "dall-e-3",
reference_image: Optional[str] = None,
) -> dict:
"""
Generate image(s) using OpenAI, optionally using a reference image
Args:
prompt: Text prompt for image generation
size: Image size (256x256, 512x512, 1024x1024)
n: Number of images to generate
model: Model to use for generation
reference_image: Base64 encoded reference image (optional)
Returns:
dict: Result containing success status, message, and filename(s)
"""
try:
logger.info(f"Generating {n} image(s) with prompt: {prompt}")
if reference_image:
# Use the newer responses API with image generation tools for reference images
logger.info("Using reference image with responses API")
generated_filenames = []
# Generate multiple images by making multiple requests
for i in range(n):
try:
logger.info(f"Generating image {i+1}/{n}")
content = [
{"type": "input_text", "text": prompt},
{
"type": "input_image",
"image_url": f"data:image/jpeg;base64,{reference_image}",
},
]
response = self.client.responses.create(
model="gpt-4.1",
input=[
{
"role": "user",
"content": content,
}
],
tools=[{"type": "image_generation"}],
)
# Extract image generation results
image_generation_calls = [
output
for output in response.output
if output.type == "image_generation_call"
]
if not image_generation_calls:
logger.warning(
f"No image generation calls found in response {i+1}, likely returned text instead"
)
continue
image_data = image_generation_calls[0].result
if not image_data:
logger.warning(
f"No image data returned from generation {i+1}"
)
continue
# Decode base64 image
image_bytes = base64.b64decode(image_data)
# Generate unique filename and save
filename = f"{uuid.uuid4()}.png"
filepath = os.path.join(self.output_dir, filename)
with open(filepath, "wb") as f:
f.write(image_bytes)
generated_filenames.append(filename)
logger.info(f"Image {i+1} saved successfully: {filename}")
except Exception as e:
error_msg = str(e)
logger.warning(f"Failed to generate image {i+1}: {error_msg}")
# More specific error handling for network issues
if (
"Connection error" in error_msg
or "timeout" in error_msg.lower()
):
logger.error(
f"Network connectivity issue detected: {error_msg}"
)
logger.error(
"This might be due to Hugging Face network restrictions"
)
elif (
"api_key" in error_msg.lower()
or "unauthorized" in error_msg.lower()
):
logger.error(f"API key issue detected: {error_msg}")
elif "rate limit" in error_msg.lower():
logger.error(f"Rate limit issue detected: {error_msg}")
continue
if not generated_filenames:
# If responses API failed due to network restrictions, try fallback to regular DALL-E
logger.warning(
"Responses API failed, attempting fallback to regular DALL-E"
)
return await self._fallback_to_dalle(prompt, size, n, model)
logger.info(
f"Successfully generated {len(generated_filenames)}/{n} images"
)
return {
"success": True,
"message": f"Generated {len(generated_filenames)}/{n} images successfully",
"filename": generated_filenames[0] if generated_filenames else None,
"filenames": generated_filenames,
"count": len(generated_filenames),
}
else:
# Use traditional DALL-E for text-only prompts
logger.info("Using DALL-E for text-only generation")
response = self.client.images.generate(
model=model,
prompt=prompt,
n=n,
size=size,
response_format="b64_json",
)
generated_filenames = []
for i, image_data in enumerate(response.data):
try:
image_bytes = base64.b64decode(image_data.b64_json)
# Generate unique filename and save
filename = f"{uuid.uuid4()}.png"
filepath = os.path.join(self.output_dir, filename)
with open(filepath, "wb") as f:
f.write(image_bytes)
generated_filenames.append(filename)
logger.info(f"Image {i+1} saved successfully: {filename}")
except Exception as e:
logger.warning(f"Failed to save image {i+1}: {str(e)}")
continue
return {
"success": True,
"message": f"Generated {len(generated_filenames)}/{n} images successfully",
"filename": generated_filenames[0] if generated_filenames else None,
"filenames": generated_filenames,
"count": len(generated_filenames),
}
except Exception as e:
logger.error(f"Error generating image: {str(e)}")
return {
"success": False,
"message": f"Failed to generate image: {str(e)}",
"filename": None,
"filenames": [],
"count": 0,
}
# Create a singleton instance
image_service = ImageGenerationService()