TTI / image_generator.py
Sam3838's picture
Upload 4 files
3b9f744 verified
raw
history blame
3.47 kB
import base64
import io
import os
import time
import logging
from typing import List
from PIL import Image
from huggingface_hub import InferenceClient
from config import config
from models import ImageGenerationRequest, ImageData, ResponseFormat
logger = logging.getLogger(__name__)
class ImageGenerator:
"""Text-to-image generator using Hugging Face InferenceClient"""
def __init__(self):
self.client = None
self._ensure_output_dir()
def _ensure_output_dir(self):
"""Ensure output directory exists"""
os.makedirs(config.OUTPUT_DIR, exist_ok=True)
def _get_client(self):
"""Get or create the InferenceClient"""
if self.client is None:
self.client = InferenceClient(
provider="replicate",
api_key=config.HF_TOKEN,
)
return self.client
def _image_to_base64(self, image: Image.Image) -> str:
"""Convert PIL Image to base64 string"""
buffer = io.BytesIO()
image.save(buffer, format="PNG")
img_str = base64.b64encode(buffer.getvalue()).decode()
return img_str
def _save_image(self, image: Image.Image, filename: str) -> str:
"""Save image and return URL"""
filepath = os.path.join(config.OUTPUT_DIR, filename)
image.save(filepath)
return f"{config.BASE_URL}/images/{filename}"
async def generate_images(self, request: ImageGenerationRequest) -> List[ImageData]:
"""Generate images based on the request"""
client = self._get_client()
# Generate images
results = []
for i in range(request.n):
try:
logger.info(f"Generating image {i+1}/{request.n} for prompt: {request.prompt[:50]}...")
# Generate the image using HuggingFace InferenceClient
image = client.text_to_image(
request.prompt,
model=config.DEFAULT_MODEL,
)
# Create response based on format
if request.response_format == ResponseFormat.B64_JSON:
image_data = ImageData(
b64_json=self._image_to_base64(image),
revised_prompt=request.prompt
)
else:
# Save image and return URL
timestamp = int(time.time())
filename = f"generated_{timestamp}_{i}.png"
url = self._save_image(image, filename)
image_data = ImageData(
url=url,
revised_prompt=request.prompt
)
results.append(image_data)
logger.info(f"Successfully generated image {i+1}/{request.n}")
except Exception as e:
logger.error(f"Failed to generate image {i+1}: {e}")
# Continue with other images
continue
if not results:
raise Exception("Failed to generate any images")
return results
def cleanup(self):
"""Cleanup resources"""
self.client = None
# Global instance
image_generator = ImageGenerator()