Spaces:
Sleeping
Sleeping
File size: 5,001 Bytes
32c5da4 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 | from __future__ import annotations
import logging
import os
from dataclasses import dataclass
LOGGER = logging.getLogger(__name__)
# Lazy imports to avoid torch loading issues on Windows
torch = None
StableDiffusionImg2ImgPipeline = None
StableDiffusionPipeline = None
def _ensure_imports():
global torch, StableDiffusionImg2ImgPipeline, StableDiffusionPipeline
if torch is not None:
return
try:
import torch as _torch
from diffusers import StableDiffusionImg2ImgPipeline as _Img2Img
from diffusers import StableDiffusionPipeline as _Pipeline
torch = _torch
StableDiffusionImg2ImgPipeline = _Img2Img
StableDiffusionPipeline = _Pipeline
LOGGER.info("✓ torch and diffusers imported successfully")
except Exception as exc: # pragma: no cover - optional dependency
LOGGER.error("✗ Failed to import torch/diffusers: %s", exc, exc_info=True)
pass
@dataclass(slots=True)
class LocalAIRequest:
prompt: str
negative_prompt: str
width: int
height: int
steps: int
guidance: float
seed: int
init_image_path: str | None = None
strength: float = 0.45
model_variant: str | None = None
class LocalAIEngine:
"""Self-hosted local generation engine; no external API calls required."""
def __init__(self) -> None:
self.model_id = os.getenv("IMAGEFORGE_LOCALAI_MODEL", "segmind/tiny-sd")
self._pipe_t2i = None
self._pipe_i2i = None
def is_available(self) -> bool:
_ensure_imports()
return StableDiffusionPipeline is not None and torch is not None
def _ensure(self):
_ensure_imports()
if not self.is_available():
raise RuntimeError(
"LocalAI dependencies missing. Install diffusers, torch, transformers, accelerate."
)
if self._pipe_t2i is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32
local_only = os.getenv("IMAGEFORGE_LOCALAI_LOCAL_ONLY", "0") == "1"
LOGGER.info("Loading LocalAI model '%s' on %s", self.model_id, device)
try:
# FORCE local_files_only=False to allow download if needed
pipe = StableDiffusionPipeline.from_pretrained(
self.model_id,
torch_dtype=dtype,
local_files_only=False, # Always allow download
use_safetensors=True if "safetensors" in self.model_id else None,
)
except Exception as exc: # noqa: BLE001
LOGGER.error("Failed to load model '%s': %s", self.model_id, exc)
raise RuntimeError(
f"LocalAI model '{self.model_id}' could not be loaded. Error: {exc}"
) from exc
if device == "cuda":
pipe = pipe.to(device)
if os.getenv("IMAGEFORGE_ENABLE_ATTENTION_SLICING", "1") == "1":
pipe.enable_attention_slicing()
self._pipe_t2i = pipe
if StableDiffusionImg2ImgPipeline is not None:
pipe_i2i = StableDiffusionImg2ImgPipeline.from_pretrained(
self.model_id,
torch_dtype=dtype,
local_files_only=local_only,
)
if device == "cuda":
pipe_i2i = pipe_i2i.to(device)
if os.getenv("IMAGEFORGE_ENABLE_ATTENTION_SLICING", "1") == "1":
pipe_i2i.enable_attention_slicing()
self._pipe_i2i = pipe_i2i
return self._pipe_t2i
def generate(self, req: LocalAIRequest):
from PIL import Image
if getattr(req, "model_variant", None) and req.model_variant != self.model_id:
self.model_id = req.model_variant
self._pipe_t2i = None
self._pipe_i2i = None
pipe = self._ensure()
generator = torch.Generator(device=pipe.device).manual_seed(req.seed)
if req.init_image_path and self._pipe_i2i is not None:
init_img = Image.open(req.init_image_path).convert("RGB").resize((req.width, req.height))
out = self._pipe_i2i(
prompt=req.prompt,
negative_prompt=req.negative_prompt or None,
image=init_img,
guidance_scale=req.guidance,
num_inference_steps=req.steps,
strength=max(0.0, min(1.0, req.strength)),
generator=generator,
)
else:
out = pipe(
prompt=req.prompt,
negative_prompt=req.negative_prompt or None,
width=req.width,
height=req.height,
guidance_scale=req.guidance,
num_inference_steps=req.steps,
generator=generator,
)
return out.images[0]
|