File size: 3,471 Bytes
3b9f744
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import base64
import io
import os
import time
import logging
from typing import List
from PIL import Image
from huggingface_hub import InferenceClient
from config import config
from models import ImageGenerationRequest, ImageData, ResponseFormat

logger = logging.getLogger(__name__)


class ImageGenerator:
    """Text-to-image generator using Hugging Face InferenceClient"""
    
    def __init__(self):
        self.client = None
        self._ensure_output_dir()
        
    def _ensure_output_dir(self):
        """Ensure output directory exists"""
        os.makedirs(config.OUTPUT_DIR, exist_ok=True)
        
    def _get_client(self):
        """Get or create the InferenceClient"""
        if self.client is None:
            self.client = InferenceClient(
                provider="replicate",
                api_key=config.HF_TOKEN,
            )
        return self.client
    
    def _image_to_base64(self, image: Image.Image) -> str:
        """Convert PIL Image to base64 string"""
        buffer = io.BytesIO()
        image.save(buffer, format="PNG")
        img_str = base64.b64encode(buffer.getvalue()).decode()
        return img_str
    
    def _save_image(self, image: Image.Image, filename: str) -> str:
        """Save image and return URL"""
        filepath = os.path.join(config.OUTPUT_DIR, filename)
        image.save(filepath)
        return f"{config.BASE_URL}/images/{filename}"
    
    async def generate_images(self, request: ImageGenerationRequest) -> List[ImageData]:
        """Generate images based on the request"""
        client = self._get_client()
        
        # Generate images
        results = []
        
        for i in range(request.n):
            try:
                logger.info(f"Generating image {i+1}/{request.n} for prompt: {request.prompt[:50]}...")
                
                # Generate the image using HuggingFace InferenceClient
                image = client.text_to_image(
                    request.prompt,
                    model=config.DEFAULT_MODEL,
                )
                
                # Create response based on format
                if request.response_format == ResponseFormat.B64_JSON:
                    image_data = ImageData(
                        b64_json=self._image_to_base64(image),
                        revised_prompt=request.prompt
                    )
                else:
                    # Save image and return URL
                    timestamp = int(time.time())
                    filename = f"generated_{timestamp}_{i}.png"
                    url = self._save_image(image, filename)
                    image_data = ImageData(
                        url=url,
                        revised_prompt=request.prompt
                    )
                
                results.append(image_data)
                logger.info(f"Successfully generated image {i+1}/{request.n}")
                
            except Exception as e:
                logger.error(f"Failed to generate image {i+1}: {e}")
                # Continue with other images
                continue
        
        if not results:
            raise Exception("Failed to generate any images")
            
        return results
    
    def cleanup(self):
        """Cleanup resources"""
        self.client = None


# Global instance
image_generator = ImageGenerator()