Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import logging | |
| import os | |
| from pathlib import Path | |
| from .interface import ProviderRequest, ProviderResult, ProviderUnavailableError | |
| LOGGER = logging.getLogger(__name__) | |
| # Lazy imports to avoid torch loading issues on Windows | |
| torch = None | |
| StableDiffusionImg2ImgPipeline = None | |
| StableDiffusionPipeline = None | |
| def _ensure_imports(): | |
| global torch, StableDiffusionImg2ImgPipeline, StableDiffusionPipeline | |
| if torch is not None: | |
| return | |
| try: | |
| import torch as _torch | |
| from diffusers import StableDiffusionImg2ImgPipeline as _Img2Img | |
| from diffusers import StableDiffusionPipeline as _Pipeline | |
| torch = _torch | |
| StableDiffusionImg2ImgPipeline = _Img2Img | |
| StableDiffusionPipeline = _Pipeline | |
| except Exception: # pragma: no cover - optional dependency | |
| pass | |
| class DiffusionProvider: | |
| id = "diffusion" | |
| name = "Stable Diffusion (local)" | |
| description = "Uses diffusers for local Stable Diffusion generation" | |
| def __init__(self, model_id: str = "segmind/tiny-sd") -> None: | |
| self.model_id = os.getenv("IMAGEFORGE_DIFFUSION_MODEL", model_id) | |
| self._pipe: StableDiffusionPipeline | None = None | |
| self._img2img_pipe: StableDiffusionImg2ImgPipeline | None = None | |
| def is_available(self) -> bool: | |
| _ensure_imports() | |
| return StableDiffusionPipeline is not None and torch is not None | |
| def _ensure_pipeline(self) -> StableDiffusionPipeline: | |
| _ensure_imports() | |
| if StableDiffusionPipeline is None or torch is None: | |
| raise ProviderUnavailableError( | |
| "Diffusion dependencies missing. Install diffusers, torch, and transformers." | |
| ) | |
| if self._pipe is None: | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| dtype = torch.float16 if device == "cuda" else torch.float32 | |
| local_only = os.getenv("IMAGEFORGE_DIFFUSION_LOCAL_ONLY", "0") == "1" | |
| LOGGER.info("Loading diffusion model '%s' on %s", self.model_id, device) | |
| try: | |
| pipe = StableDiffusionPipeline.from_pretrained( | |
| self.model_id, | |
| torch_dtype=dtype, | |
| local_files_only=local_only, | |
| ) | |
| except Exception as exc: # noqa: BLE001 | |
| mode_hint = "local cache only" if local_only else "online download" | |
| raise ProviderUnavailableError( | |
| f"Diffusion model '{self.model_id}' could not be loaded ({mode_hint}). " | |
| "Set IMAGEFORGE_DIFFUSION_LOCAL_ONLY=0 to allow downloading models." | |
| ) from exc | |
| if device == "cuda": | |
| pipe = pipe.to(device) | |
| if os.getenv("IMAGEFORGE_ENABLE_ATTENTION_SLICING", "1") == "1": | |
| pipe.enable_attention_slicing() | |
| self._pipe = pipe | |
| return self._pipe | |
| def _ensure_img2img_pipeline(self) -> StableDiffusionImg2ImgPipeline | None: | |
| _ensure_imports() | |
| if StableDiffusionImg2ImgPipeline is None or torch is None: | |
| return None | |
| if self._img2img_pipe is None: | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| dtype = torch.float16 if device == "cuda" else torch.float32 | |
| local_only = os.getenv("IMAGEFORGE_DIFFUSION_LOCAL_ONLY", "0") == "1" | |
| LOGGER.info("Loading diffusion img2img model '%s' on %s", self.model_id, device) | |
| try: | |
| img2img = StableDiffusionImg2ImgPipeline.from_pretrained( | |
| self.model_id, | |
| torch_dtype=dtype, | |
| local_files_only=local_only, | |
| ) | |
| except Exception as exc: # noqa: BLE001 | |
| mode_hint = "local cache only" if local_only else "online download" | |
| raise ProviderUnavailableError( | |
| f"Diffusion img2img model '{self.model_id}' could not be loaded ({mode_hint}). " | |
| "Set IMAGEFORGE_DIFFUSION_LOCAL_ONLY=0 to allow downloading models." | |
| ) from exc | |
| if device == "cuda": | |
| img2img = img2img.to(device) | |
| if os.getenv("IMAGEFORGE_ENABLE_ATTENTION_SLICING", "1") == "1": | |
| img2img.enable_attention_slicing() | |
| self._img2img_pipe = img2img | |
| return self._img2img_pipe | |
| def generate(self, request: ProviderRequest, output_dir: Path, progress, is_cancelled) -> ProviderResult: | |
| if is_cancelled(): | |
| return ProviderResult(image_paths=[]) | |
| if request.model_variant and request.model_variant != self.model_id: | |
| self.model_id = request.model_variant | |
| self._pipe = None | |
| self._img2img_pipe = None | |
| progress(1, "Loading diffusion model") | |
| pipe = self._ensure_pipeline() | |
| output_dir.mkdir(parents=True, exist_ok=True) | |
| image_paths: list[Path] = [] | |
| for idx in range(request.count): | |
| if is_cancelled(): | |
| break | |
| seed = request.seed + idx | |
| generator = torch.Generator(device=pipe.device).manual_seed(seed) | |
| def _callback(step: int, timestep: int, latents): # noqa: ANN001 | |
| if is_cancelled(): | |
| raise RuntimeError("Generation cancelled") | |
| local_progress = int(((step + 1) / max(1, request.steps)) * 100) | |
| progress(local_progress, f"Diffusion step {step + 1}/{request.steps} (image {idx + 1})") | |
| if request.init_image_path: | |
| from PIL import Image | |
| img2img_pipe = self._ensure_img2img_pipeline() | |
| if img2img_pipe is None: | |
| raise ProviderUnavailableError("Img2Img requires diffusers img2img pipeline support") | |
| init_image = Image.open(request.init_image_path).convert("RGB").resize((request.width, request.height)) | |
| result = img2img_pipe( | |
| prompt=request.prompt, | |
| negative_prompt=request.negative_prompt or None, | |
| image=init_image, | |
| num_inference_steps=request.steps, | |
| guidance_scale=request.guidance, | |
| strength=max(0.0, min(1.0, request.img2img_strength)), | |
| generator=generator, | |
| callback=_callback, | |
| callback_steps=1, | |
| ) | |
| else: | |
| result = pipe( | |
| prompt=request.prompt, | |
| negative_prompt=request.negative_prompt or None, | |
| width=request.width, | |
| height=request.height, | |
| num_inference_steps=request.steps, | |
| guidance_scale=request.guidance, | |
| generator=generator, | |
| callback=_callback, | |
| callback_steps=1, | |
| ) | |
| image = result.images[0] | |
| image_path = output_dir / f"image_{idx + 1:02d}.png" | |
| image.save(image_path, format="PNG") | |
| image_paths.append(image_path) | |
| pct = int(((idx + 1) / request.count) * 100) | |
| progress(pct, f"Diffusion image {idx + 1}/{request.count} complete") | |
| return ProviderResult(image_paths=image_paths) | |