PixelForge / imageforge /backend /app /providers /diffusion_provider.py
Gregorfun's picture
Initial commit
32c5da4
from __future__ import annotations
import logging
import os
from pathlib import Path
from .interface import ProviderRequest, ProviderResult, ProviderUnavailableError
LOGGER = logging.getLogger(__name__)
# Lazy imports to avoid torch loading issues on Windows
torch = None
StableDiffusionImg2ImgPipeline = None
StableDiffusionPipeline = None
def _ensure_imports():
global torch, StableDiffusionImg2ImgPipeline, StableDiffusionPipeline
if torch is not None:
return
try:
import torch as _torch
from diffusers import StableDiffusionImg2ImgPipeline as _Img2Img
from diffusers import StableDiffusionPipeline as _Pipeline
torch = _torch
StableDiffusionImg2ImgPipeline = _Img2Img
StableDiffusionPipeline = _Pipeline
except Exception: # pragma: no cover - optional dependency
pass
class DiffusionProvider:
id = "diffusion"
name = "Stable Diffusion (local)"
description = "Uses diffusers for local Stable Diffusion generation"
def __init__(self, model_id: str = "segmind/tiny-sd") -> None:
self.model_id = os.getenv("IMAGEFORGE_DIFFUSION_MODEL", model_id)
self._pipe: StableDiffusionPipeline | None = None
self._img2img_pipe: StableDiffusionImg2ImgPipeline | None = None
def is_available(self) -> bool:
_ensure_imports()
return StableDiffusionPipeline is not None and torch is not None
def _ensure_pipeline(self) -> StableDiffusionPipeline:
_ensure_imports()
if StableDiffusionPipeline is None or torch is None:
raise ProviderUnavailableError(
"Diffusion dependencies missing. Install diffusers, torch, and transformers."
)
if self._pipe is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32
local_only = os.getenv("IMAGEFORGE_DIFFUSION_LOCAL_ONLY", "0") == "1"
LOGGER.info("Loading diffusion model '%s' on %s", self.model_id, device)
try:
pipe = StableDiffusionPipeline.from_pretrained(
self.model_id,
torch_dtype=dtype,
local_files_only=local_only,
)
except Exception as exc: # noqa: BLE001
mode_hint = "local cache only" if local_only else "online download"
raise ProviderUnavailableError(
f"Diffusion model '{self.model_id}' could not be loaded ({mode_hint}). "
"Set IMAGEFORGE_DIFFUSION_LOCAL_ONLY=0 to allow downloading models."
) from exc
if device == "cuda":
pipe = pipe.to(device)
if os.getenv("IMAGEFORGE_ENABLE_ATTENTION_SLICING", "1") == "1":
pipe.enable_attention_slicing()
self._pipe = pipe
return self._pipe
def _ensure_img2img_pipeline(self) -> StableDiffusionImg2ImgPipeline | None:
_ensure_imports()
if StableDiffusionImg2ImgPipeline is None or torch is None:
return None
if self._img2img_pipe is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32
local_only = os.getenv("IMAGEFORGE_DIFFUSION_LOCAL_ONLY", "0") == "1"
LOGGER.info("Loading diffusion img2img model '%s' on %s", self.model_id, device)
try:
img2img = StableDiffusionImg2ImgPipeline.from_pretrained(
self.model_id,
torch_dtype=dtype,
local_files_only=local_only,
)
except Exception as exc: # noqa: BLE001
mode_hint = "local cache only" if local_only else "online download"
raise ProviderUnavailableError(
f"Diffusion img2img model '{self.model_id}' could not be loaded ({mode_hint}). "
"Set IMAGEFORGE_DIFFUSION_LOCAL_ONLY=0 to allow downloading models."
) from exc
if device == "cuda":
img2img = img2img.to(device)
if os.getenv("IMAGEFORGE_ENABLE_ATTENTION_SLICING", "1") == "1":
img2img.enable_attention_slicing()
self._img2img_pipe = img2img
return self._img2img_pipe
def generate(self, request: ProviderRequest, output_dir: Path, progress, is_cancelled) -> ProviderResult:
if is_cancelled():
return ProviderResult(image_paths=[])
if request.model_variant and request.model_variant != self.model_id:
self.model_id = request.model_variant
self._pipe = None
self._img2img_pipe = None
progress(1, "Loading diffusion model")
pipe = self._ensure_pipeline()
output_dir.mkdir(parents=True, exist_ok=True)
image_paths: list[Path] = []
for idx in range(request.count):
if is_cancelled():
break
seed = request.seed + idx
generator = torch.Generator(device=pipe.device).manual_seed(seed)
def _callback(step: int, timestep: int, latents): # noqa: ANN001
if is_cancelled():
raise RuntimeError("Generation cancelled")
local_progress = int(((step + 1) / max(1, request.steps)) * 100)
progress(local_progress, f"Diffusion step {step + 1}/{request.steps} (image {idx + 1})")
if request.init_image_path:
from PIL import Image
img2img_pipe = self._ensure_img2img_pipeline()
if img2img_pipe is None:
raise ProviderUnavailableError("Img2Img requires diffusers img2img pipeline support")
init_image = Image.open(request.init_image_path).convert("RGB").resize((request.width, request.height))
result = img2img_pipe(
prompt=request.prompt,
negative_prompt=request.negative_prompt or None,
image=init_image,
num_inference_steps=request.steps,
guidance_scale=request.guidance,
strength=max(0.0, min(1.0, request.img2img_strength)),
generator=generator,
callback=_callback,
callback_steps=1,
)
else:
result = pipe(
prompt=request.prompt,
negative_prompt=request.negative_prompt or None,
width=request.width,
height=request.height,
num_inference_steps=request.steps,
guidance_scale=request.guidance,
generator=generator,
callback=_callback,
callback_steps=1,
)
image = result.images[0]
image_path = output_dir / f"image_{idx + 1:02d}.png"
image.save(image_path, format="PNG")
image_paths.append(image_path)
pct = int(((idx + 1) / request.count) * 100)
progress(pct, f"Diffusion image {idx + 1}/{request.count} complete")
return ProviderResult(image_paths=image_paths)