"""Diffusion pipeline palīgklase.""" from __future__ import annotations import logging from typing import Any from maris_core.utils.env import get_hf_model logger = logging.getLogger(__name__) class DiffusionPipeline: """Iesaiņo StableDiffusion/SDXL pipeline.""" def __init__(self, model_id: str | None = None) -> None: self.model_id = model_id or get_hf_model("IMAGE_MODEL") self._pipe: Any = None def load(self) -> None: """Ielādē modeli.""" try: import torch # type: ignore from diffusers import StableDiffusionPipeline # type: ignore self._pipe = StableDiffusionPipeline.from_pretrained( self.model_id, torch_dtype=torch.float16 ) self._pipe = self._pipe.to("cuda" if torch.cuda.is_available() else "cpu") logger.info("Ielādēts attēlu modelis: %s", self.model_id) except Exception as exc: # noqa: BLE001 logger.error("Nevar ielādēt diffusion modeli: %s", exc) def generate(self, prompt: str, **kwargs: Any) -> Any: """Ģenerē attēlu.""" if self._pipe is None: self.load() if self._pipe is None: return None return self._pipe(prompt, **kwargs).images[0]