import logging from PIL import Image from huggingface_hub import InferenceClient logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class DiffusionClient: def __init__( self, model_id: str = "black-forest-labs/FLUX.1-schnell", hf_token: str | None = None, provider: str = "auto", ): self.model_id = model_id _token = hf_token if hf_token else None self.client = InferenceClient(api_key=_token, provider=provider) self._ready = False def load_model(self): if self._ready: logger.info("Image API client already ready. Skipping.") return logger.info( "Image API client ready (model=%s, serverless inference).", self.model_id ) self._ready = True def gen_image( self, prompt: str, negative_prompt: str = "", num_inference_steps: int = 4, guidance_scale: float = 0.0, width: int = 768, height: int = 768, ) -> Image.Image | None: if not self._ready: self.load_model() try: image = self.client.text_to_image( prompt=prompt, model=self.model_id, negative_prompt=negative_prompt or None, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale, width=width, height=height, ) return image except Exception: logger.exception("Image generation failed for prompt: %.120s", prompt) raise