File size: 6,020 Bytes
f06b6e7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 |
import base64
import io
import os
import sys
import time
import logging
import tempfile
import subprocess
from typing import List
from enum import Enum
# Install required packages
def install_packages():
"""Install required packages using pip"""
packages = [
"pillow",
"huggingface_hub",
"pydantic"
]
for package in packages:
try:
__import__(package.replace("-", "_"))
print(f"{package} already installed")
except ImportError:
print(f"Installing {package}...")
subprocess.check_call([sys.executable, "-m", "pip", "install", package])
# Install packages before importing
install_packages()
from PIL import Image
from huggingface_hub import InferenceClient
from pydantic import BaseModel
logger = logging.getLogger(__name__)
# Define models directly in the file
class ResponseFormat(str, Enum):
URL = "url"
B64_JSON = "b64_json"
class ImageGenerationRequest(BaseModel):
prompt: str
model: str = "black-forest-labs/flux-schnell"
n: int = 1
size: str = "1024x1024"
quality: str = "standard"
response_format: ResponseFormat = ResponseFormat.URL
class ImageData(BaseModel):
url: str = None
b64_json: str = None
revised_prompt: str = None
class ImageGenerator:
"""Text-to-image generator using Hugging Face InferenceClient"""
def __init__(self, hf_token: str = None):
self.client = None
self.hf_token = hf_token or os.getenv("HF_TOKEN")
self.output_dir = tempfile.mkdtemp(prefix="image_gen_")
self.base_url = "http://localhost:8000" # Default base URL
self.default_model = "black-forest-labs/flux-schnell"
self._ensure_output_dir()
def _ensure_output_dir(self):
"""Ensure output directory exists"""
os.makedirs(self.output_dir, exist_ok=True)
print(f"Using temporary directory: {self.output_dir}")
def _get_client(self):
"""Get or create the InferenceClient"""
if self.client is None:
if not self.hf_token:
raise ValueError("HuggingFace token is required. Set HF_TOKEN environment variable or pass it to constructor.")
self.client = InferenceClient(
token=self.hf_token,
)
return self.client
def _image_to_base64(self, image: Image.Image) -> str:
"""Convert PIL Image to base64 string"""
buffer = io.BytesIO()
image.save(buffer, format="PNG")
img_str = base64.b64encode(buffer.getvalue()).decode()
return img_str
def _save_image(self, image: Image.Image, filename: str) -> str:
"""Save image and return URL"""
filepath = os.path.join(self.output_dir, filename)
image.save(filepath)
return f"{self.base_url}/images/{filename}"
def set_config(self, hf_token: str = None, base_url: str = None, default_model: str = None):
"""Set configuration parameters"""
if hf_token:
self.hf_token = hf_token
self.client = None # Reset client to use new token
if base_url:
self.base_url = base_url
if default_model:
self.default_model = default_model
async def generate_images(self, request: ImageGenerationRequest) -> List[ImageData]:
"""Generate images based on the request"""
client = self._get_client()
# Generate images
results = []
for i in range(request.n):
try:
logger.info(f"Generating image {i+1}/{request.n} for prompt: {request.prompt[:50]}...")
# Generate the image using HuggingFace InferenceClient
image = client.text_to_image(
request.prompt,
model=request.model or self.default_model,
)
# Create response based on format
if request.response_format == ResponseFormat.B64_JSON:
image_data = ImageData(
b64_json=self._image_to_base64(image),
revised_prompt=request.prompt
)
else:
# Save image and return URL
timestamp = int(time.time())
filename = f"generated_{timestamp}_{i}.png"
url = self._save_image(image, filename)
image_data = ImageData(
url=url,
revised_prompt=request.prompt
)
results.append(image_data)
logger.info(f"Successfully generated image {i+1}/{request.n}")
except Exception as e:
logger.error(f"Failed to generate image {i+1}: {e}")
# Continue with other images
continue
if not results:
raise Exception("Failed to generate any images")
return results
def cleanup(self):
"""Cleanup resources and temporary directory"""
self.client = None
# Clean up temporary directory
import shutil
if os.path.exists(self.output_dir):
shutil.rmtree(self.output_dir)
print(f"Cleaned up temporary directory: {self.output_dir}")
# Example usage
if __name__ == "__main__":
# Create generator instance
generator = ImageGenerator()
# Set HuggingFace token (replace with your actual token)
generator.set_config(hf_token="your_hf_token_here")
# Example request
request = ImageGenerationRequest(
prompt="A beautiful sunset over mountains",
n=1,
response_format=ResponseFormat.URL
)
# Note: This would need to be run in an async context
# results = await generator.generate_images(request)
print("Image generator setup complete!") |