Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import logging | |
| import os | |
| from dataclasses import dataclass | |
| LOGGER = logging.getLogger(__name__) | |
| # Lazy imports to avoid torch loading issues on Windows | |
| torch = None | |
| StableDiffusionImg2ImgPipeline = None | |
| StableDiffusionPipeline = None | |
| def _ensure_imports(): | |
| global torch, StableDiffusionImg2ImgPipeline, StableDiffusionPipeline | |
| if torch is not None: | |
| return | |
| try: | |
| import torch as _torch | |
| from diffusers import StableDiffusionImg2ImgPipeline as _Img2Img | |
| from diffusers import StableDiffusionPipeline as _Pipeline | |
| torch = _torch | |
| StableDiffusionImg2ImgPipeline = _Img2Img | |
| StableDiffusionPipeline = _Pipeline | |
| LOGGER.info("✓ torch and diffusers imported successfully") | |
| except Exception as exc: # pragma: no cover - optional dependency | |
| LOGGER.error("✗ Failed to import torch/diffusers: %s", exc, exc_info=True) | |
| pass | |
| class LocalAIRequest: | |
| prompt: str | |
| negative_prompt: str | |
| width: int | |
| height: int | |
| steps: int | |
| guidance: float | |
| seed: int | |
| init_image_path: str | None = None | |
| strength: float = 0.45 | |
| model_variant: str | None = None | |
| class LocalAIEngine: | |
| """Self-hosted local generation engine; no external API calls required.""" | |
| def __init__(self) -> None: | |
| self.model_id = os.getenv("IMAGEFORGE_LOCALAI_MODEL", "segmind/tiny-sd") | |
| self._pipe_t2i = None | |
| self._pipe_i2i = None | |
| def is_available(self) -> bool: | |
| _ensure_imports() | |
| return StableDiffusionPipeline is not None and torch is not None | |
| def _ensure(self): | |
| _ensure_imports() | |
| if not self.is_available(): | |
| raise RuntimeError( | |
| "LocalAI dependencies missing. Install diffusers, torch, transformers, accelerate." | |
| ) | |
| if self._pipe_t2i is None: | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| dtype = torch.float16 if device == "cuda" else torch.float32 | |
| local_only = os.getenv("IMAGEFORGE_LOCALAI_LOCAL_ONLY", "0") == "1" | |
| LOGGER.info("Loading LocalAI model '%s' on %s", self.model_id, device) | |
| try: | |
| # FORCE local_files_only=False to allow download if needed | |
| pipe = StableDiffusionPipeline.from_pretrained( | |
| self.model_id, | |
| torch_dtype=dtype, | |
| local_files_only=False, # Always allow download | |
| use_safetensors=True if "safetensors" in self.model_id else None, | |
| ) | |
| except Exception as exc: # noqa: BLE001 | |
| LOGGER.error("Failed to load model '%s': %s", self.model_id, exc) | |
| raise RuntimeError( | |
| f"LocalAI model '{self.model_id}' could not be loaded. Error: {exc}" | |
| ) from exc | |
| if device == "cuda": | |
| pipe = pipe.to(device) | |
| if os.getenv("IMAGEFORGE_ENABLE_ATTENTION_SLICING", "1") == "1": | |
| pipe.enable_attention_slicing() | |
| self._pipe_t2i = pipe | |
| if StableDiffusionImg2ImgPipeline is not None: | |
| pipe_i2i = StableDiffusionImg2ImgPipeline.from_pretrained( | |
| self.model_id, | |
| torch_dtype=dtype, | |
| local_files_only=local_only, | |
| ) | |
| if device == "cuda": | |
| pipe_i2i = pipe_i2i.to(device) | |
| if os.getenv("IMAGEFORGE_ENABLE_ATTENTION_SLICING", "1") == "1": | |
| pipe_i2i.enable_attention_slicing() | |
| self._pipe_i2i = pipe_i2i | |
| return self._pipe_t2i | |
| def generate(self, req: LocalAIRequest): | |
| from PIL import Image | |
| if getattr(req, "model_variant", None) and req.model_variant != self.model_id: | |
| self.model_id = req.model_variant | |
| self._pipe_t2i = None | |
| self._pipe_i2i = None | |
| pipe = self._ensure() | |
| generator = torch.Generator(device=pipe.device).manual_seed(req.seed) | |
| if req.init_image_path and self._pipe_i2i is not None: | |
| init_img = Image.open(req.init_image_path).convert("RGB").resize((req.width, req.height)) | |
| out = self._pipe_i2i( | |
| prompt=req.prompt, | |
| negative_prompt=req.negative_prompt or None, | |
| image=init_img, | |
| guidance_scale=req.guidance, | |
| num_inference_steps=req.steps, | |
| strength=max(0.0, min(1.0, req.strength)), | |
| generator=generator, | |
| ) | |
| else: | |
| out = pipe( | |
| prompt=req.prompt, | |
| negative_prompt=req.negative_prompt or None, | |
| width=req.width, | |
| height=req.height, | |
| guidance_scale=req.guidance, | |
| num_inference_steps=req.steps, | |
| generator=generator, | |
| ) | |
| return out.images[0] | |