File size: 1,638 Bytes
ee2e4d6 e7aa834 ee2e4d6 e7aa834 ee2e4d6 e7aa834 ee2e4d6 e7aa834 9977b3f e7aa834 ee2e4d6 e7aa834 ee2e4d6 e7aa834 ee2e4d6 e7aa834 ee2e4d6 e7aa834 3d4d1fb | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 | import logging
from PIL import Image
from huggingface_hub import InferenceClient
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class DiffusionClient:
def __init__(
self,
model_id: str = "black-forest-labs/FLUX.1-schnell",
hf_token: str | None = None,
provider: str = "auto",
):
self.model_id = model_id
_token = hf_token if hf_token else None
self.client = InferenceClient(api_key=_token, provider=provider)
self._ready = False
def load_model(self):
if self._ready:
logger.info("Image API client already ready. Skipping.")
return
logger.info(
"Image API client ready (model=%s, serverless inference).", self.model_id
)
self._ready = True
def gen_image(
self,
prompt: str,
negative_prompt: str = "",
num_inference_steps: int = 4,
guidance_scale: float = 0.0,
width: int = 768,
height: int = 768,
) -> Image.Image | None:
if not self._ready:
self.load_model()
try:
image = self.client.text_to_image(
prompt=prompt,
model=self.model_id,
negative_prompt=negative_prompt or None,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
width=width,
height=height,
)
return image
except Exception:
logger.exception("Image generation failed for prompt: %.120s", prompt)
raise
|