File size: 1,293 Bytes
f440f03
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
"""Diffusion pipeline palīgklase."""

from __future__ import annotations

import logging
from typing import Any

from maris_core.utils.env import get_hf_model

logger = logging.getLogger(__name__)


class DiffusionPipeline:
    """Iesaiņo StableDiffusion/SDXL pipeline."""

    def __init__(self, model_id: str | None = None) -> None:
        self.model_id = model_id or get_hf_model("IMAGE_MODEL")
        self._pipe: Any = None

    def load(self) -> None:
        """Ielādē modeli."""
        try:
            import torch  # type: ignore
            from diffusers import StableDiffusionPipeline  # type: ignore

            self._pipe = StableDiffusionPipeline.from_pretrained(
                self.model_id, torch_dtype=torch.float16
            )
            self._pipe = self._pipe.to("cuda" if torch.cuda.is_available() else "cpu")
            logger.info("Ielādēts attēlu modelis: %s", self.model_id)
        except Exception as exc:  # noqa: BLE001
            logger.error("Nevar ielādēt diffusion modeli: %s", exc)

    def generate(self, prompt: str, **kwargs: Any) -> Any:
        """Ģenerē attēlu."""
        if self._pipe is None:
            self.load()
        if self._pipe is None:
            return None
        return self._pipe(prompt, **kwargs).images[0]