|
|
import base64
|
|
|
import io
|
|
|
import os
|
|
|
import time
|
|
|
import logging
|
|
|
from typing import List
|
|
|
from PIL import Image
|
|
|
from huggingface_hub import InferenceClient
|
|
|
from config import config
|
|
|
from models import ImageGenerationRequest, ImageData, ResponseFormat
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
|
|
|
class ImageGenerator:
|
|
|
"""Text-to-image generator using Hugging Face InferenceClient"""
|
|
|
|
|
|
def __init__(self):
|
|
|
self.client = None
|
|
|
self._ensure_output_dir()
|
|
|
|
|
|
def _ensure_output_dir(self):
|
|
|
"""Ensure output directory exists"""
|
|
|
os.makedirs(config.OUTPUT_DIR, exist_ok=True)
|
|
|
|
|
|
def _get_client(self):
|
|
|
"""Get or create the InferenceClient"""
|
|
|
if self.client is None:
|
|
|
self.client = InferenceClient(
|
|
|
provider="replicate",
|
|
|
api_key=config.HF_TOKEN,
|
|
|
)
|
|
|
return self.client
|
|
|
|
|
|
def _image_to_base64(self, image: Image.Image) -> str:
|
|
|
"""Convert PIL Image to base64 string"""
|
|
|
buffer = io.BytesIO()
|
|
|
image.save(buffer, format="PNG")
|
|
|
img_str = base64.b64encode(buffer.getvalue()).decode()
|
|
|
return img_str
|
|
|
|
|
|
def _save_image(self, image: Image.Image, filename: str) -> str:
|
|
|
"""Save image and return URL"""
|
|
|
filepath = os.path.join(config.OUTPUT_DIR, filename)
|
|
|
image.save(filepath)
|
|
|
return f"{config.BASE_URL}/images/{filename}"
|
|
|
|
|
|
async def generate_images(self, request: ImageGenerationRequest) -> List[ImageData]:
|
|
|
"""Generate images based on the request"""
|
|
|
client = self._get_client()
|
|
|
|
|
|
|
|
|
results = []
|
|
|
|
|
|
for i in range(request.n):
|
|
|
try:
|
|
|
logger.info(f"Generating image {i+1}/{request.n} for prompt: {request.prompt[:50]}...")
|
|
|
|
|
|
|
|
|
image = client.text_to_image(
|
|
|
request.prompt,
|
|
|
model=config.DEFAULT_MODEL,
|
|
|
)
|
|
|
|
|
|
|
|
|
if request.response_format == ResponseFormat.B64_JSON:
|
|
|
image_data = ImageData(
|
|
|
b64_json=self._image_to_base64(image),
|
|
|
revised_prompt=request.prompt
|
|
|
)
|
|
|
else:
|
|
|
|
|
|
timestamp = int(time.time())
|
|
|
filename = f"generated_{timestamp}_{i}.png"
|
|
|
url = self._save_image(image, filename)
|
|
|
image_data = ImageData(
|
|
|
url=url,
|
|
|
revised_prompt=request.prompt
|
|
|
)
|
|
|
|
|
|
results.append(image_data)
|
|
|
logger.info(f"Successfully generated image {i+1}/{request.n}")
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.error(f"Failed to generate image {i+1}: {e}")
|
|
|
|
|
|
continue
|
|
|
|
|
|
if not results:
|
|
|
raise Exception("Failed to generate any images")
|
|
|
|
|
|
return results
|
|
|
|
|
|
def cleanup(self):
|
|
|
"""Cleanup resources"""
|
|
|
self.client = None
|
|
|
|
|
|
|
|
|
|
|
|
image_generator = ImageGenerator() |