| import logging | |
| from PIL import Image | |
| from huggingface_hub import InferenceClient | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class DiffusionClient: | |
| def __init__( | |
| self, | |
| model_id: str = "black-forest-labs/FLUX.1-schnell", | |
| hf_token: str | None = None, | |
| provider: str = "auto", | |
| ): | |
| self.model_id = model_id | |
| _token = hf_token if hf_token else None | |
| self.client = InferenceClient(api_key=_token, provider=provider) | |
| self._ready = False | |
| def load_model(self): | |
| if self._ready: | |
| logger.info("Image API client already ready. Skipping.") | |
| return | |
| logger.info( | |
| "Image API client ready (model=%s, serverless inference).", self.model_id | |
| ) | |
| self._ready = True | |
| def gen_image( | |
| self, | |
| prompt: str, | |
| negative_prompt: str = "", | |
| num_inference_steps: int = 4, | |
| guidance_scale: float = 0.0, | |
| width: int = 768, | |
| height: int = 768, | |
| ) -> Image.Image | None: | |
| if not self._ready: | |
| self.load_model() | |
| try: | |
| image = self.client.text_to_image( | |
| prompt=prompt, | |
| model=self.model_id, | |
| negative_prompt=negative_prompt or None, | |
| num_inference_steps=num_inference_steps, | |
| guidance_scale=guidance_scale, | |
| width=width, | |
| height=height, | |
| ) | |
| return image | |
| except Exception: | |
| logger.exception("Image generation failed for prompt: %.120s", prompt) | |
| raise | |