from __future__ import annotations import logging import os from pathlib import Path from .interface import ProviderRequest, ProviderResult, ProviderUnavailableError LOGGER = logging.getLogger(__name__) # Lazy imports to avoid torch loading issues on Windows torch = None StableDiffusionImg2ImgPipeline = None StableDiffusionPipeline = None def _ensure_imports(): global torch, StableDiffusionImg2ImgPipeline, StableDiffusionPipeline if torch is not None: return try: import torch as _torch from diffusers import StableDiffusionImg2ImgPipeline as _Img2Img from diffusers import StableDiffusionPipeline as _Pipeline torch = _torch StableDiffusionImg2ImgPipeline = _Img2Img StableDiffusionPipeline = _Pipeline except Exception: # pragma: no cover - optional dependency pass class DiffusionProvider: id = "diffusion" name = "Stable Diffusion (local)" description = "Uses diffusers for local Stable Diffusion generation" def __init__(self, model_id: str = "segmind/tiny-sd") -> None: self.model_id = os.getenv("IMAGEFORGE_DIFFUSION_MODEL", model_id) self._pipe: StableDiffusionPipeline | None = None self._img2img_pipe: StableDiffusionImg2ImgPipeline | None = None def is_available(self) -> bool: _ensure_imports() return StableDiffusionPipeline is not None and torch is not None def _ensure_pipeline(self) -> StableDiffusionPipeline: _ensure_imports() if StableDiffusionPipeline is None or torch is None: raise ProviderUnavailableError( "Diffusion dependencies missing. Install diffusers, torch, and transformers." ) if self._pipe is None: device = "cuda" if torch.cuda.is_available() else "cpu" dtype = torch.float16 if device == "cuda" else torch.float32 local_only = os.getenv("IMAGEFORGE_DIFFUSION_LOCAL_ONLY", "0") == "1" LOGGER.info("Loading diffusion model '%s' on %s", self.model_id, device) try: pipe = StableDiffusionPipeline.from_pretrained( self.model_id, torch_dtype=dtype, local_files_only=local_only, ) except Exception as exc: # noqa: BLE001 mode_hint = "local cache only" if local_only else "online download" raise ProviderUnavailableError( f"Diffusion model '{self.model_id}' could not be loaded ({mode_hint}). " "Set IMAGEFORGE_DIFFUSION_LOCAL_ONLY=0 to allow downloading models." ) from exc if device == "cuda": pipe = pipe.to(device) if os.getenv("IMAGEFORGE_ENABLE_ATTENTION_SLICING", "1") == "1": pipe.enable_attention_slicing() self._pipe = pipe return self._pipe def _ensure_img2img_pipeline(self) -> StableDiffusionImg2ImgPipeline | None: _ensure_imports() if StableDiffusionImg2ImgPipeline is None or torch is None: return None if self._img2img_pipe is None: device = "cuda" if torch.cuda.is_available() else "cpu" dtype = torch.float16 if device == "cuda" else torch.float32 local_only = os.getenv("IMAGEFORGE_DIFFUSION_LOCAL_ONLY", "0") == "1" LOGGER.info("Loading diffusion img2img model '%s' on %s", self.model_id, device) try: img2img = StableDiffusionImg2ImgPipeline.from_pretrained( self.model_id, torch_dtype=dtype, local_files_only=local_only, ) except Exception as exc: # noqa: BLE001 mode_hint = "local cache only" if local_only else "online download" raise ProviderUnavailableError( f"Diffusion img2img model '{self.model_id}' could not be loaded ({mode_hint}). " "Set IMAGEFORGE_DIFFUSION_LOCAL_ONLY=0 to allow downloading models." ) from exc if device == "cuda": img2img = img2img.to(device) if os.getenv("IMAGEFORGE_ENABLE_ATTENTION_SLICING", "1") == "1": img2img.enable_attention_slicing() self._img2img_pipe = img2img return self._img2img_pipe def generate(self, request: ProviderRequest, output_dir: Path, progress, is_cancelled) -> ProviderResult: if is_cancelled(): return ProviderResult(image_paths=[]) if request.model_variant and request.model_variant != self.model_id: self.model_id = request.model_variant self._pipe = None self._img2img_pipe = None progress(1, "Loading diffusion model") pipe = self._ensure_pipeline() output_dir.mkdir(parents=True, exist_ok=True) image_paths: list[Path] = [] for idx in range(request.count): if is_cancelled(): break seed = request.seed + idx generator = torch.Generator(device=pipe.device).manual_seed(seed) def _callback(step: int, timestep: int, latents): # noqa: ANN001 if is_cancelled(): raise RuntimeError("Generation cancelled") local_progress = int(((step + 1) / max(1, request.steps)) * 100) progress(local_progress, f"Diffusion step {step + 1}/{request.steps} (image {idx + 1})") if request.init_image_path: from PIL import Image img2img_pipe = self._ensure_img2img_pipeline() if img2img_pipe is None: raise ProviderUnavailableError("Img2Img requires diffusers img2img pipeline support") init_image = Image.open(request.init_image_path).convert("RGB").resize((request.width, request.height)) result = img2img_pipe( prompt=request.prompt, negative_prompt=request.negative_prompt or None, image=init_image, num_inference_steps=request.steps, guidance_scale=request.guidance, strength=max(0.0, min(1.0, request.img2img_strength)), generator=generator, callback=_callback, callback_steps=1, ) else: result = pipe( prompt=request.prompt, negative_prompt=request.negative_prompt or None, width=request.width, height=request.height, num_inference_steps=request.steps, guidance_scale=request.guidance, generator=generator, callback=_callback, callback_steps=1, ) image = result.images[0] image_path = output_dir / f"image_{idx + 1:02d}.png" image.save(image_path, format="PNG") image_paths.append(image_path) pct = int(((idx + 1) / request.count) * 100) progress(pct, f"Diffusion image {idx + 1}/{request.count} complete") return ProviderResult(image_paths=image_paths)