File size: 5,001 Bytes
32c5da4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
from __future__ import annotations

import logging
import os
from dataclasses import dataclass

LOGGER = logging.getLogger(__name__)

# Lazy imports to avoid torch loading issues on Windows
torch = None
StableDiffusionImg2ImgPipeline = None
StableDiffusionPipeline = None

def _ensure_imports():
    global torch, StableDiffusionImg2ImgPipeline, StableDiffusionPipeline
    if torch is not None:
        return
    try:
        import torch as _torch
        from diffusers import StableDiffusionImg2ImgPipeline as _Img2Img
        from diffusers import StableDiffusionPipeline as _Pipeline
        torch = _torch
        StableDiffusionImg2ImgPipeline = _Img2Img
        StableDiffusionPipeline = _Pipeline
        LOGGER.info("✓ torch and diffusers imported successfully")
    except Exception as exc:  # pragma: no cover - optional dependency
        LOGGER.error("✗ Failed to import torch/diffusers: %s", exc, exc_info=True)
        pass


@dataclass(slots=True)
class LocalAIRequest:
    prompt: str
    negative_prompt: str
    width: int
    height: int
    steps: int
    guidance: float
    seed: int
    init_image_path: str | None = None
    strength: float = 0.45
    model_variant: str | None = None


class LocalAIEngine:
    """Self-hosted local generation engine; no external API calls required."""

    def __init__(self) -> None:
        self.model_id = os.getenv("IMAGEFORGE_LOCALAI_MODEL", "segmind/tiny-sd")
        self._pipe_t2i = None
        self._pipe_i2i = None

    def is_available(self) -> bool:
        _ensure_imports()
        return StableDiffusionPipeline is not None and torch is not None

    def _ensure(self):
        _ensure_imports()
        if not self.is_available():
            raise RuntimeError(
                "LocalAI dependencies missing. Install diffusers, torch, transformers, accelerate."
            )
        if self._pipe_t2i is None:
            device = "cuda" if torch.cuda.is_available() else "cpu"
            dtype = torch.float16 if device == "cuda" else torch.float32
            local_only = os.getenv("IMAGEFORGE_LOCALAI_LOCAL_ONLY", "0") == "1"
            LOGGER.info("Loading LocalAI model '%s' on %s", self.model_id, device)
            try:
                # FORCE local_files_only=False to allow download if needed
                pipe = StableDiffusionPipeline.from_pretrained(
                    self.model_id,
                    torch_dtype=dtype,
                    local_files_only=False,  # Always allow download
                    use_safetensors=True if "safetensors" in self.model_id else None,
                )
            except Exception as exc:  # noqa: BLE001
                LOGGER.error("Failed to load model '%s': %s", self.model_id, exc)
                raise RuntimeError(
                    f"LocalAI model '{self.model_id}' could not be loaded. Error: {exc}"
                ) from exc
            if device == "cuda":
                pipe = pipe.to(device)
                if os.getenv("IMAGEFORGE_ENABLE_ATTENTION_SLICING", "1") == "1":
                    pipe.enable_attention_slicing()
            self._pipe_t2i = pipe
            if StableDiffusionImg2ImgPipeline is not None:
                pipe_i2i = StableDiffusionImg2ImgPipeline.from_pretrained(
                    self.model_id,
                    torch_dtype=dtype,
                    local_files_only=local_only,
                )
                if device == "cuda":
                    pipe_i2i = pipe_i2i.to(device)
                    if os.getenv("IMAGEFORGE_ENABLE_ATTENTION_SLICING", "1") == "1":
                        pipe_i2i.enable_attention_slicing()
                self._pipe_i2i = pipe_i2i
        return self._pipe_t2i

    def generate(self, req: LocalAIRequest):
        from PIL import Image

        if getattr(req, "model_variant", None) and req.model_variant != self.model_id:
            self.model_id = req.model_variant
            self._pipe_t2i = None
            self._pipe_i2i = None
        pipe = self._ensure()
        generator = torch.Generator(device=pipe.device).manual_seed(req.seed)
        if req.init_image_path and self._pipe_i2i is not None:
            init_img = Image.open(req.init_image_path).convert("RGB").resize((req.width, req.height))
            out = self._pipe_i2i(
                prompt=req.prompt,
                negative_prompt=req.negative_prompt or None,
                image=init_img,
                guidance_scale=req.guidance,
                num_inference_steps=req.steps,
                strength=max(0.0, min(1.0, req.strength)),
                generator=generator,
            )
        else:
            out = pipe(
                prompt=req.prompt,
                negative_prompt=req.negative_prompt or None,
                width=req.width,
                height=req.height,
                guidance_scale=req.guidance,
                num_inference_steps=req.steps,
                generator=generator,
            )
        return out.images[0]