File size: 6,020 Bytes
f06b6e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
import base64
import io
import os
import sys
import time
import logging
import tempfile
import subprocess
from typing import List
from enum import Enum

# Install required packages
def install_packages():
    """Install required packages using pip"""
    packages = [
        "pillow",
        "huggingface_hub",
        "pydantic"
    ]
    
    for package in packages:
        try:
            __import__(package.replace("-", "_"))
            print(f"{package} already installed")
        except ImportError:
            print(f"Installing {package}...")
            subprocess.check_call([sys.executable, "-m", "pip", "install", package])

# Install packages before importing
install_packages()

from PIL import Image
from huggingface_hub import InferenceClient
from pydantic import BaseModel

logger = logging.getLogger(__name__)

# Define models directly in the file
class ResponseFormat(str, Enum):
    URL = "url"
    B64_JSON = "b64_json"

class ImageGenerationRequest(BaseModel):
    prompt: str
    model: str = "black-forest-labs/flux-schnell"
    n: int = 1
    size: str = "1024x1024"
    quality: str = "standard"
    response_format: ResponseFormat = ResponseFormat.URL

class ImageData(BaseModel):
    url: str = None
    b64_json: str = None
    revised_prompt: str = None

class ImageGenerator:
    """Text-to-image generator using Hugging Face InferenceClient"""
    
    def __init__(self, hf_token: str = None):
        self.client = None
        self.hf_token = hf_token or os.getenv("HF_TOKEN")
        self.output_dir = tempfile.mkdtemp(prefix="image_gen_")
        self.base_url = "http://localhost:8000"  # Default base URL
        self.default_model = "black-forest-labs/flux-schnell"
        self._ensure_output_dir()
        
    def _ensure_output_dir(self):
        """Ensure output directory exists"""
        os.makedirs(self.output_dir, exist_ok=True)
        print(f"Using temporary directory: {self.output_dir}")
        
    def _get_client(self):
        """Get or create the InferenceClient"""
        if self.client is None:
            if not self.hf_token:
                raise ValueError("HuggingFace token is required. Set HF_TOKEN environment variable or pass it to constructor.")
            
            self.client = InferenceClient(
                token=self.hf_token,
            )
        return self.client
    
    def _image_to_base64(self, image: Image.Image) -> str:
        """Convert PIL Image to base64 string"""
        buffer = io.BytesIO()
        image.save(buffer, format="PNG")
        img_str = base64.b64encode(buffer.getvalue()).decode()
        return img_str
    
    def _save_image(self, image: Image.Image, filename: str) -> str:
        """Save image and return URL"""
        filepath = os.path.join(self.output_dir, filename)
        image.save(filepath)
        return f"{self.base_url}/images/{filename}"
    
    def set_config(self, hf_token: str = None, base_url: str = None, default_model: str = None):
        """Set configuration parameters"""
        if hf_token:
            self.hf_token = hf_token
            self.client = None  # Reset client to use new token
        if base_url:
            self.base_url = base_url
        if default_model:
            self.default_model = default_model
    
    async def generate_images(self, request: ImageGenerationRequest) -> List[ImageData]:
        """Generate images based on the request"""
        client = self._get_client()
        
        # Generate images
        results = []
        
        for i in range(request.n):
            try:
                logger.info(f"Generating image {i+1}/{request.n} for prompt: {request.prompt[:50]}...")
                
                # Generate the image using HuggingFace InferenceClient
                image = client.text_to_image(
                    request.prompt,
                    model=request.model or self.default_model,
                )
                
                # Create response based on format
                if request.response_format == ResponseFormat.B64_JSON:
                    image_data = ImageData(
                        b64_json=self._image_to_base64(image),
                        revised_prompt=request.prompt
                    )
                else:
                    # Save image and return URL
                    timestamp = int(time.time())
                    filename = f"generated_{timestamp}_{i}.png"
                    url = self._save_image(image, filename)
                    image_data = ImageData(
                        url=url,
                        revised_prompt=request.prompt
                    )
                
                results.append(image_data)
                logger.info(f"Successfully generated image {i+1}/{request.n}")
                
            except Exception as e:
                logger.error(f"Failed to generate image {i+1}: {e}")
                # Continue with other images
                continue
        
        if not results:
            raise Exception("Failed to generate any images")
            
        return results
    
    def cleanup(self):
        """Cleanup resources and temporary directory"""
        self.client = None
        # Clean up temporary directory
        import shutil
        if os.path.exists(self.output_dir):
            shutil.rmtree(self.output_dir)
            print(f"Cleaned up temporary directory: {self.output_dir}")

# Example usage
if __name__ == "__main__":
    # Create generator instance
    generator = ImageGenerator()
    
    # Set HuggingFace token (replace with your actual token)
    generator.set_config(hf_token="your_hf_token_here")
    
    # Example request
    request = ImageGenerationRequest(
        prompt="A beautiful sunset over mountains",
        n=1,
        response_format=ResponseFormat.URL
    )
    
    # Note: This would need to be run in an async context
    # results = await generator.generate_images(request)
    print("Image generator setup complete!")