maris-ai-master / core-python /maris_core /images /diffusion_pipeline.py
MarisUK's picture
Maris AI model sync
f440f03 verified
"""Diffusion pipeline palīgklase."""
from __future__ import annotations
import logging
from typing import Any
from maris_core.utils.env import get_hf_model
logger = logging.getLogger(__name__)
class DiffusionPipeline:
"""Iesaiņo StableDiffusion/SDXL pipeline."""
def __init__(self, model_id: str | None = None) -> None:
self.model_id = model_id or get_hf_model("IMAGE_MODEL")
self._pipe: Any = None
def load(self) -> None:
"""Ielādē modeli."""
try:
import torch # type: ignore
from diffusers import StableDiffusionPipeline # type: ignore
self._pipe = StableDiffusionPipeline.from_pretrained(
self.model_id, torch_dtype=torch.float16
)
self._pipe = self._pipe.to("cuda" if torch.cuda.is_available() else "cpu")
logger.info("Ielādēts attēlu modelis: %s", self.model_id)
except Exception as exc: # noqa: BLE001
logger.error("Nevar ielādēt diffusion modeli: %s", exc)
def generate(self, prompt: str, **kwargs: Any) -> Any:
"""Ģenerē attēlu."""
if self._pipe is None:
self.load()
if self._pipe is None:
return None
return self._pipe(prompt, **kwargs).images[0]