| """Diffusion pipeline palīgklase.""" |
|
|
| from __future__ import annotations |
|
|
| import logging |
| from typing import Any |
|
|
| from maris_core.utils.env import get_hf_model |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class DiffusionPipeline: |
| """Iesaiņo StableDiffusion/SDXL pipeline.""" |
|
|
| def __init__(self, model_id: str | None = None) -> None: |
| self.model_id = model_id or get_hf_model("IMAGE_MODEL") |
| self._pipe: Any = None |
|
|
| def load(self) -> None: |
| """Ielādē modeli.""" |
| try: |
| import torch |
| from diffusers import StableDiffusionPipeline |
|
|
| self._pipe = StableDiffusionPipeline.from_pretrained( |
| self.model_id, torch_dtype=torch.float16 |
| ) |
| self._pipe = self._pipe.to("cuda" if torch.cuda.is_available() else "cpu") |
| logger.info("Ielādēts attēlu modelis: %s", self.model_id) |
| except Exception as exc: |
| logger.error("Nevar ielādēt diffusion modeli: %s", exc) |
|
|
| def generate(self, prompt: str, **kwargs: Any) -> Any: |
| """Ģenerē attēlu.""" |
| if self._pipe is None: |
| self.load() |
| if self._pipe is None: |
| return None |
| return self._pipe(prompt, **kwargs).images[0] |
|
|