Spaces:
Running on Zero
Running on Zero
| from __future__ import annotations | |
| import base64 | |
| import glob | |
| import os | |
| import io | |
| import re | |
| import tempfile | |
| from src.AutoEncoders.taesd import decode_latents_to_images | |
| # Ensure we can import pipeline from this repo | |
| import sys | |
| import time | |
| from typing import Any, Dict, List, Optional | |
| from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect | |
| from fastapi.staticfiles import StaticFiles | |
| from pydantic import BaseModel | |
| from src.Device.ModelCache import get_model_cache | |
| from src.Core.Models.ModelFactory import list_available_models, list_available_controlnets | |
| from src.FileManaging.ImageSaver import pop_image_bytes | |
| sys.path.append(os.path.abspath(os.path.dirname(__file__))) | |
| # Logging setup | |
| import asyncio | |
| import logging | |
| import uuid | |
| from logging.handlers import RotatingFileHandler | |
| # Create a module-level logger with rotating file handler and request-id support | |
| class _RequestIdFilter(logging.Filter): | |
| def filter(self, record: logging.LogRecord) -> bool: # pragma: no cover - simple utility | |
| if not hasattr(record, "rid"): | |
| record.rid = "-" | |
| return True | |
| def _setup_logger() -> logging.Logger: | |
| os.makedirs("./logs", exist_ok=True) | |
| logger = logging.getLogger("lightdiffusion.server") | |
| if logger.handlers: | |
| return logger | |
| level_name = os.getenv("LD_SERVER_LOGLEVEL", "DEBUG").upper() | |
| try: | |
| level = getattr(logging, level_name, logging.DEBUG) | |
| except Exception: # pragma: no cover | |
| level = logging.DEBUG | |
| logger.setLevel(level) | |
| formatter = logging.Formatter( | |
| fmt="%(asctime)s | %(levelname)s | %(name)s | rid=%(rid)s | %(message)s", | |
| datefmt="%Y-%m-%d %H:%M:%S", | |
| ) | |
| file_handler = RotatingFileHandler( | |
| filename=os.path.join("./logs", "server.log"), | |
| maxBytes=5 * 1024 * 1024, | |
| backupCount=3, | |
| encoding="utf-8", | |
| ) | |
| file_handler.setFormatter(formatter) | |
| file_handler.addFilter(_RequestIdFilter()) | |
| logger.addHandler(file_handler) | |
| # Also log to stderr for interactive runs; avoid duplicate handlers if uvicorn config already propagates | |
| stream_handler = logging.StreamHandler() | |
| stream_handler.setFormatter(formatter) | |
| stream_handler.addFilter(_RequestIdFilter()) | |
| logger.addHandler(stream_handler) | |
| logger.propagate = False | |
| return logger | |
| logger = _setup_logger() | |
| logger.debug("server module loaded; cwd=%s", os.getcwd()) | |
| # Record server start time for telemetry | |
| SERVER_START_TS = time.time() | |
| try: | |
| # Import app_instance to control preview behavior during generation | |
| from src.user import app_instance as _app_instance | |
| from src.user.pipeline import pipeline | |
| except Exception as e: | |
| # Defer import error to runtime response for clarity | |
| pipeline = None # type: ignore | |
| _pipeline_import_error = e | |
| logger.exception("Failed to import pipeline: %s", e) | |
| else: | |
| _pipeline_import_error = None | |
| logger.info("Pipeline and app_instance imported successfully") | |
| class GenerateRequest(BaseModel): | |
| prompt: str | |
| negative_prompt: Optional[str] = "" | |
| width: int = 512 | |
| height: int = 512 | |
| num_images: int = 1 | |
| batch_size: int = 1 | |
| scheduler: str = "ays" | |
| sampler: str = "dpmpp_sde_cfgpp" | |
| steps: int = 20 | |
| hiresfix: bool = False | |
| adetailer: bool = False | |
| enhance_prompt: bool = False | |
| img2img_mode: bool = False | |
| img2img_image: Optional[str] = None | |
| img2img_denoise: float = 0.75 # Denoising strength: 0=keep original, 1=full generation | |
| stable_fast: bool = False | |
| reuse_seed: bool = False | |
| realistic_model: bool = False | |
| enable_multiscale: bool = False | |
| multiscale_preset: Optional[str] = "balanced" | |
| multiscale_intermittent: bool = True | |
| multiscale_factor: float = 0.5 | |
| multiscale_fullres_start: int = 10 | |
| multiscale_fullres_end: int = 8 | |
| keep_models_loaded: bool = True | |
| enable_preview: bool = False | |
| # Preview fidelity for this request: 'low' | 'balanced' | 'high' (default: balanced) | |
| preview_fidelity: str = "balanced" | |
| # CFG-free sampling parameters | |
| cfg_free_enabled: bool = False | |
| cfg_free_start_percent: float = 70.0 | |
| # Token Merging parameters | |
| tome_enabled: bool = False | |
| tome_ratio: float = 0.5 | |
| tome_max_downsample: int = 1 | |
| # Advanced CFG optimization parameters (batched_cfg enabled by default for 8% speedup) | |
| batched_cfg: bool = True | |
| dynamic_cfg_rescaling: bool = False | |
| dynamic_cfg_method: str = "variance" | |
| dynamic_cfg_percentile: float = 95.0 | |
| dynamic_cfg_target_scale: float = 7.0 | |
| adaptive_noise_enabled: bool = False | |
| adaptive_noise_method: str = "complexity" | |
| # Guidance | |
| cfg_scale: float = 7.0 | |
| guidance_scale: Optional[float] = None | |
| seed: Optional[int] = None # If provided >=0 we will reuse it | |
| # Model Selection | |
| model_path: Optional[str] = None | |
| refiner_model_path: Optional[str] = None | |
| refiner_switch_step: Optional[int] = None | |
| # ControlNet | |
| controlnet_enabled: bool = False | |
| controlnet_model: Optional[str] = None | |
| controlnet_strength: float = 1.0 | |
| controlnet_type: str = "canny" | |
| # torch.compile optimization (mutually exclusive with stable_fast) | |
| torch_compile: Optional[bool] = None | |
| vae_autotune: Optional[bool] = None | |
| # Weight quantization format: None, "fp8", or "nvfp4" | |
| weight_quantization: Optional[str] = None | |
| # FP8 inference (auto-gated to supported hardware: Ada Lovelace+) | |
| fp8_inference: bool = False | |
| class SettingsPreferencesRequest(BaseModel): | |
| torch_compile: bool = False | |
| vae_autotune: bool = False | |
| app = FastAPI(title="LightDiffusion Server", version="1.0.0") | |
| async def get_controlnets(): | |
| """List available ControlNet models.""" | |
| try: | |
| models = list_available_controlnets() | |
| return {"models": models} | |
| except Exception as e: | |
| logger.exception("Failed to list controlnets") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def startup_event(): | |
| """Capture event loop reference and start background worker.""" | |
| global _main_event_loop | |
| _main_event_loop = asyncio.get_running_loop() | |
| # Migrate legacy include/last_seed.txt into the JSON settings store on startup | |
| try: | |
| from src.Core.SettingsStore import migrate_from_last_seed_txt | |
| migrated_seed = migrate_from_last_seed_txt() | |
| if migrated_seed is not None: | |
| logger.info("Migrated legacy include/last_seed.txt -> last_seed=%s", migrated_seed) | |
| except Exception: | |
| logger.exception("Failed to migrate legacy last_seed.txt on startup") | |
| await _generation_buffer.start() | |
| logger.info("Server startup complete, event loop captured for preview broadcasting") | |
| # Helpful, user-friendly startup URL(s) so users know what to open in a browser. | |
| try: | |
| port = int(os.environ.get("PORT") or os.environ.get("UVICORN_PORT") or 7861) | |
| except Exception: | |
| port = 7861 | |
| try: | |
| import socket | |
| s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) | |
| s.connect(("8.8.8.8", 80)) | |
| local_ip = s.getsockname()[0] | |
| s.close() | |
| except Exception: | |
| local_ip = "127.0.0.1" | |
| logger.info("Open the UI in a browser: http://localhost:%d/ (or on your network: http://%s:%d/)", port, local_ip, port) | |
| # Batching buffer ----------------------------------------------------------- | |
| LD_MAX_BATCH_SIZE = int(os.getenv("LD_MAX_BATCH_SIZE", "4")) | |
| LD_BATCH_TIMEOUT = float(os.getenv("LD_BATCH_TIMEOUT", "0.5")) | |
| # If set to true (1/true), the worker will wait the coalescing timeout when | |
| # there is a single candidate in a chosen group; otherwise singletons are | |
| # processed immediately. Default is to process singletons immediately to | |
| # favor throughput and avoid perceived "stuck" behavior. | |
| LD_BATCH_WAIT_SINGLETONS = os.getenv("LD_BATCH_WAIT_SINGLETONS", "0").lower() in ("1", "true", "yes") | |
| # Limit total number of images we will process in a single pipeline run when | |
| # coalescing many requests into a group. If the sum of images across the group | |
| # is larger than this, we will split the group into smaller chunks and run the | |
| # pipeline sequentially to avoid memory pressure and downstream save failures. | |
| LD_MAX_IMAGES_PER_GROUP = int(os.getenv("LD_MAX_IMAGES_PER_GROUP", "256")) | |
| def _normalized_image_key(value: Optional[str]) -> str: | |
| """Return a stable image identity key for batching decisions.""" | |
| if not value: | |
| return "" | |
| if value.startswith("data:"): | |
| # Data URLs should already be normalized to a temp file before enqueue, | |
| # but keep a deterministic fallback in case this helper is called early. | |
| return value[:128] | |
| try: | |
| return os.path.abspath(os.path.realpath(value)) | |
| except Exception: | |
| return str(value) | |
| def _effective_guidance_scale(req: "GenerateRequest") -> float: | |
| """Normalize guidance scale for batch signatures and pipeline calls.""" | |
| return float(req.cfg_scale if req.guidance_scale is None else req.guidance_scale) | |
| def _has_running_loop() -> bool: | |
| try: | |
| asyncio.get_running_loop() | |
| return True | |
| except RuntimeError: | |
| return False | |
| class PendingRequest: | |
| def __init__(self, req: GenerateRequest, request_id: str): | |
| self.req = req | |
| self.request_id = request_id | |
| self.arrival = time.time() | |
| self.future: asyncio.Future = asyncio.get_running_loop().create_future() | |
| class GenerationBuffer: | |
| def __init__(self): | |
| self._pending: List[PendingRequest] = [] | |
| self._loop: Optional[asyncio.AbstractEventLoop] = None | |
| self._lock: asyncio.Lock | |
| self._new_request: asyncio.Event | |
| # Prefetching state | |
| self._prefetch_lock: asyncio.Lock | |
| self._prefetch_task: Optional[asyncio.Task] = None | |
| self._current_prefetch_path: Optional[str] = None | |
| # Statistics | |
| self._items_processed = 0 | |
| self._batches_processed = 0 | |
| self._requests_processed = 0 | |
| self._cumulative_wait_time = 0.0 | |
| self._last_batch_ts = 0.0 | |
| self._worker_task: Optional[asyncio.Task] = None | |
| self._reset_async_primitives(asyncio.get_running_loop() if _has_running_loop() else None) | |
| def _reset_async_primitives(self, loop: Optional[asyncio.AbstractEventLoop]) -> None: | |
| """Recreate loop-bound synchronization primitives. | |
| Test runs can start the in-process server multiple times on different | |
| event loops. The queue's Event/Lock objects must be recreated when the | |
| owning loop changes to avoid cross-loop RuntimeError during teardown. | |
| """ | |
| self._loop = loop | |
| self._lock = asyncio.Lock() | |
| self._new_request = asyncio.Event() | |
| self._prefetch_lock = asyncio.Lock() | |
| self._prefetch_task = None | |
| self._current_prefetch_path = None | |
| async def start(self): | |
| """Start the background worker task.""" | |
| current_loop = asyncio.get_running_loop() | |
| if self._loop is not current_loop: | |
| self._reset_async_primitives(current_loop) | |
| if self._worker_task is None or self._worker_task.done(): | |
| self._worker_task = asyncio.create_task(self._worker()) | |
| logger.info("GenerationBuffer worker task started") | |
| async def enqueue(self, pending: PendingRequest) -> dict: | |
| """Add a request to the queue and wait for completion.""" | |
| async with self._lock: | |
| self._pending.append(pending) | |
| self._new_request.set() | |
| # Wait for the worker to process this request | |
| return await pending.future | |
| async def _look_ahead_and_prefetch(self, current_batch_signature: tuple): | |
| """Analyze remaining queue and pre-load the next model if different.""" | |
| from src.user.pipeline import resolve_checkpoint_path | |
| async with self._lock: | |
| if not self._pending: | |
| return | |
| # Find the next group that has a different signature | |
| next_req = None | |
| for p in self._pending: | |
| sig = self._signature_for(p.req) | |
| if sig != current_batch_signature: | |
| next_req = p.req | |
| break | |
| if not next_req: | |
| return | |
| # Resolve the path for the next model | |
| target_path = resolve_checkpoint_path( | |
| realistic_model=next_req.realistic_model | |
| ) | |
| # Perform prefetch outside the queue lock | |
| async with self._prefetch_lock: | |
| # Skip if already prefetched or currently prefetching the same path | |
| if target_path == self._current_prefetch_path: | |
| return | |
| # Cancel existing prefetch if it's for a different model | |
| if self._prefetch_task and not self._prefetch_task.done(): | |
| self._prefetch_task.cancel() | |
| try: | |
| await self._prefetch_task | |
| except asyncio.CancelledError: | |
| pass | |
| self._current_prefetch_path = target_path | |
| async def prefetch_task(): | |
| try: | |
| logger.info("Prefetcher: Starting background load of %s", target_path) | |
| # Load to CPU RAM using the optimized util | |
| sd = await asyncio.to_thread(util.load_torch_file, target_path) | |
| # Store in cache | |
| get_model_cache().set_prefetched_model(target_path, sd) | |
| logger.info("Prefetcher: Successfully pre-loaded %s into RAM", target_path) | |
| except Exception as e: | |
| logger.warning("Prefetcher: Failed to pre-load %s: %s", target_path, e) | |
| finally: | |
| self._current_prefetch_path = None | |
| self._prefetch_task = asyncio.create_task(prefetch_task()) | |
| def _signature_for(self, req: GenerateRequest) -> tuple: | |
| # Grouping signature - requests must match these to be batched | |
| # Detect model type to determine if refiner is relevant | |
| from src.Core.Models.ModelFactory import detect_model_type | |
| is_sdxl = (detect_model_type(req.model_path) == "SDXL") | |
| guidance_scale = _effective_guidance_scale(req) | |
| normalized_img2img_image = _normalized_image_key(req.img2img_image) | |
| return ( | |
| str(req.model_path), # Model must match | |
| bool(req.realistic_model), | |
| int(req.width), | |
| int(req.height), | |
| int(max(1, req.batch_size)), | |
| bool(req.stable_fast), | |
| bool(req.torch_compile), | |
| bool(req.vae_autotune), | |
| bool(req.fp8_inference), | |
| str(req.weight_quantization), | |
| bool(req.img2img_mode), | |
| normalized_img2img_image, | |
| float(req.img2img_denoise), | |
| str(req.scheduler), | |
| str(req.sampler), | |
| int(req.steps), | |
| float(guidance_scale), | |
| bool(req.enhance_prompt), | |
| bool(req.reuse_seed), | |
| bool(req.enable_preview), | |
| str(req.preview_fidelity), | |
| # Treat multiscale options as batch-level — mixing them may | |
| # change the sampling schedule and therefore cannot be | |
| # safely combined into a single forward pass. | |
| bool(req.enable_multiscale), | |
| bool(req.multiscale_intermittent), | |
| float(req.multiscale_factor), | |
| int(req.multiscale_fullres_start), | |
| int(req.multiscale_fullres_end), | |
| bool(req.cfg_free_enabled), | |
| float(req.cfg_free_start_percent), | |
| bool(req.tome_enabled), | |
| float(req.tome_ratio), | |
| int(req.tome_max_downsample), | |
| bool(req.batched_cfg), | |
| bool(req.dynamic_cfg_rescaling), | |
| str(req.dynamic_cfg_method), | |
| float(req.dynamic_cfg_percentile), | |
| float(req.dynamic_cfg_target_scale), | |
| bool(req.adaptive_noise_enabled), | |
| str(req.adaptive_noise_method), | |
| # VRAM retention flags are also batch level | |
| bool(req.keep_models_loaded), | |
| # ControlNet (must match) | |
| bool(req.controlnet_enabled), | |
| str(req.controlnet_model), | |
| float(req.controlnet_strength), | |
| str(req.controlnet_type), | |
| # Refiner (must match only if it will actually be used) | |
| str(req.refiner_model_path) if is_sdxl else "", | |
| (int(req.refiner_switch_step) if req.refiner_switch_step is not None else -1) if is_sdxl else -1, | |
| # Note: hires_fix and adetailer remain intentionally NOT part of | |
| # this signature because they are executed per-sample. | |
| ) | |
| async def _worker(self): | |
| logger.info("Batching worker started; max_batch=%s timeout=%s", LD_MAX_BATCH_SIZE, LD_BATCH_TIMEOUT) | |
| while True: | |
| await self._new_request.wait() | |
| # Small throttle to coalesce multiple arrivals | |
| await asyncio.sleep(0) | |
| async with self._lock: | |
| if not self._pending: | |
| self._new_request.clear() | |
| continue | |
| # Group pending requests by signature | |
| groups: Dict[tuple, List[PendingRequest]] = {} | |
| for p in self._pending: | |
| sig = self._signature_for(p.req) | |
| groups.setdefault(sig, []).append(p) | |
| # Choose the group with the oldest request | |
| chosen_sig = None | |
| oldest_time = float("inf") | |
| for sig, arr in groups.items(): | |
| if arr and arr[0].arrival < oldest_time: | |
| chosen_sig = sig | |
| oldest_time = arr[0].arrival | |
| if chosen_sig is None: | |
| self._new_request.clear() | |
| continue | |
| candidates = groups[chosen_sig] | |
| # Sort by arrival time (oldest first) | |
| candidates.sort(key=lambda x: x.arrival) | |
| # Debug: show group sizes for observability | |
| try: | |
| group_summary = {str(sig): len(arr) for sig, arr in groups.items()} | |
| logger.debug("Batch worker: pending groups=%s chosen_sig=%s group_size=%d oldest_arrival=%.3f", | |
| group_summary, str(chosen_sig), len(candidates), candidates[0].arrival if candidates else 0.0) | |
| except Exception: | |
| pass | |
| # Determine whether to wait for coalescing when there's only a | |
| # single candidate. This is controlled by LD_BATCH_WAIT_SINGLETONS | |
| # so operators can toggle the behavior at runtime via env. | |
| if len(candidates) == 1: | |
| age = time.time() - candidates[0].arrival | |
| if LD_BATCH_WAIT_SINGLETONS and age < LD_BATCH_TIMEOUT: | |
| # Old behavior: wait a bit for more arrivals before | |
| # processing a singleton so we can form a larger batch. | |
| logger.debug("Singleton group for signature %s is too new (age=%.3fs < timeout=%.3fs). Sleeping to coalesce.", str(chosen_sig), age, LD_BATCH_TIMEOUT) | |
| self._new_request.clear() | |
| await asyncio.sleep(LD_BATCH_TIMEOUT) | |
| continue | |
| else: | |
| # Eager processing path (default): process singletons | |
| # immediately to avoid perceived "stuck" behavior. | |
| logger.debug("Processing singleton group for signature %s immediately (age=%.3fs). LD_BATCH_WAIT_SINGLETONS=%s", | |
| str(chosen_sig), age, LD_BATCH_WAIT_SINGLETONS) | |
| # Keep ControlNet requests singleton for now. Its image-conditioned | |
| # path has not been made batch-safe in the same way as text2img/img2img. | |
| max_group_size = 1 if candidates[0].req.controlnet_enabled else LD_MAX_BATCH_SIZE | |
| # Pick up to the allowed group size | |
| to_process = candidates[:max_group_size] | |
| # Remove selected items from pending list | |
| for p in to_process: | |
| try: | |
| self._pending.remove(p) | |
| except ValueError: | |
| pass | |
| if not self._pending: | |
| self._new_request.clear() | |
| # Trigger prefetching for the NEXT group while we process this one | |
| await self._look_ahead_and_prefetch(chosen_sig) | |
| # Process the selected group outside the lock | |
| try: | |
| try: | |
| logger.debug("Processing group chosen_sig=%s items=%d request_ids=%s", str(chosen_sig), len(to_process), [p.request_id for p in to_process]) | |
| except Exception: | |
| pass | |
| await self._process_group(to_process) | |
| # Update lightweight metrics only on success | |
| try: | |
| now_ts = time.time() | |
| self._batches_processed += 1 | |
| self._items_processed += sum( | |
| max(1, p.req.num_images) for p in to_process | |
| ) | |
| self._requests_processed += len(to_process) | |
| # Update cumulative wait time per-request | |
| wait_total = sum(now_ts - p.arrival for p in to_process) | |
| self._cumulative_wait_time += wait_total | |
| self._last_batch_ts = now_ts | |
| except Exception: | |
| # Metrics must never crash the worker loop | |
| logger.exception("Failed updating batch metrics") | |
| except Exception as e: | |
| logger.exception("Batch processing failed: %s", e) | |
| async def _process_group(self, items: List[PendingRequest]): | |
| # All items share a signature as enforced by the grouping logic. | |
| if not items: | |
| return | |
| first_req = items[0].req | |
| flat_samples: List[dict[str, Any]] = [] | |
| for p in items: | |
| for _ in range(max(1, p.req.num_images)): | |
| flat_samples.append( | |
| { | |
| "request_id": p.request_id, | |
| "filename_prefix": f"LD-REQ-{p.request_id}", | |
| "seed": p.req.seed if (p.req.seed is not None and p.req.seed >= 0) else None, | |
| "hires_fix": bool(p.req.hiresfix), | |
| "adetailer": bool(p.req.adetailer), | |
| "prompt": p.req.prompt, | |
| "negative_prompt": p.req.negative_prompt or "", | |
| } | |
| ) | |
| # Prepare pipeline kwargs based on the shared signature (take from first) | |
| # Unique ID for this generation run; sent with every preview message | |
| # so the frontend can discard stale previews from previous runs. | |
| _gen_id = uuid.uuid4().hex[:12] | |
| pipeline_kwargs = dict( | |
| prompt=[], | |
| w=first_req.width, | |
| h=first_req.height, | |
| number=0, | |
| batch=0, | |
| scheduler=first_req.scheduler, | |
| sampler=first_req.sampler, | |
| steps=first_req.steps, | |
| cfg_scale=_effective_guidance_scale(first_req), | |
| enhance_prompt=first_req.enhance_prompt, | |
| img2img=first_req.img2img_mode, | |
| img2img_denoise=first_req.img2img_denoise, | |
| stable_fast=first_req.stable_fast, | |
| reuse_seed=first_req.reuse_seed, | |
| autohdr=True, | |
| realistic_model=first_req.realistic_model, | |
| model_path=first_req.model_path, | |
| refiner_model_path=first_req.refiner_model_path, | |
| refiner_switch_step=first_req.refiner_switch_step, | |
| negative_prompt=[], | |
| multiscale_preset=first_req.multiscale_preset, | |
| enable_multiscale=first_req.enable_multiscale, | |
| multiscale_factor=first_req.multiscale_factor, | |
| multiscale_fullres_start=first_req.multiscale_fullres_start, | |
| multiscale_fullres_end=first_req.multiscale_fullres_end, | |
| multiscale_intermittent_fullres=first_req.multiscale_intermittent, | |
| img2img_image=first_req.img2img_image, | |
| request_filename_prefix=f"LD-REQ-{items[0].request_id}", | |
| per_sample_info=[], | |
| cfg_free_enabled=first_req.cfg_free_enabled, | |
| cfg_free_start_percent=first_req.cfg_free_start_percent, | |
| tome_enabled=first_req.tome_enabled, | |
| tome_ratio=first_req.tome_ratio, | |
| tome_max_downsample=first_req.tome_max_downsample, | |
| # Advanced CFG optimizations (batched_cfg always enabled) | |
| batched_cfg=first_req.batched_cfg, | |
| dynamic_cfg_rescaling=first_req.dynamic_cfg_rescaling, | |
| dynamic_cfg_method=first_req.dynamic_cfg_method, | |
| dynamic_cfg_percentile=first_req.dynamic_cfg_percentile, | |
| dynamic_cfg_target_scale=first_req.dynamic_cfg_target_scale, | |
| adaptive_noise_enabled=first_req.adaptive_noise_enabled, | |
| adaptive_noise_method=first_req.adaptive_noise_method, | |
| # ControlNet | |
| controlnet_model=first_req.controlnet_model if first_req.controlnet_enabled else None, | |
| controlnet_strength=first_req.controlnet_strength, | |
| controlnet_type=first_req.controlnet_type, | |
| # torch.compile | |
| torch_compile=first_req.torch_compile, | |
| vae_autotune=first_req.vae_autotune, | |
| # Weight quantization | |
| weight_quantization=first_req.weight_quantization, | |
| # FP8 inference | |
| fp8_inference=first_req.fp8_inference, | |
| # Add callback for WebSocket preview broadcasting | |
| callback=make_server_callback(first_req.steps, generation_id=_gen_id), | |
| ) | |
| # Notify clients that a new generation is starting so they can | |
| # discard stale previews from the previous run. | |
| sync_broadcast_preview( | |
| step=0, total_steps=first_req.steps, | |
| message_type="generation_start", | |
| generation_id=_gen_id, | |
| ) | |
| # Toggle preview state for the duration of the pipeline call | |
| prev_preview_state = None | |
| prev_keep_models_loaded = None | |
| prev_preview_settings = None | |
| try: | |
| try: | |
| prev_preview_state = _app_instance.app.previewer_var.get() | |
| _app_instance.app.previewer_var.set(bool(first_req.enable_preview)) | |
| except Exception: | |
| prev_preview_state = None | |
| # Apply per-request preview fidelity overrides (format / quality / sRGB) | |
| try: | |
| prev_preview_settings = _apply_preview_fidelity_to_app(first_req) | |
| except Exception: | |
| prev_preview_settings = None | |
| # Respect per-group model cache directive: toggle "keep loaded" | |
| # so the sampling pipeline sees the requested caching behavior. | |
| try: | |
| model_cache = get_model_cache() | |
| prev_keep_models_loaded = model_cache.get_keep_models_loaded() | |
| model_cache.set_keep_models_loaded(bool(first_req.keep_models_loaded)) | |
| except Exception: | |
| prev_keep_models_loaded = None | |
| saved_map: Dict[str, List[dict]] = {} | |
| total_images = len(flat_samples) | |
| # Respect ImageSaver.MAX_IMAGES_PER_SAVE and the requested batch size. | |
| # Multi-image runs always execute in deterministic chunks so that | |
| # `batch_size` means "images per sampling pass" and `num_images` | |
| # means "total outputs returned". | |
| try: | |
| from src.FileManaging import ImageSaver as _ImageSaver | |
| _max_save_limit = getattr(_ImageSaver, "MAX_IMAGES_PER_SAVE", LD_MAX_IMAGES_PER_GROUP) | |
| except Exception: | |
| _max_save_limit = LD_MAX_IMAGES_PER_GROUP | |
| max_save_limit = _max_save_limit if _max_save_limit and _max_save_limit > 0 else LD_MAX_IMAGES_PER_GROUP | |
| requested_batch_size = max(1, int(first_req.batch_size)) | |
| max_chunk_size = min(requested_batch_size, LD_MAX_IMAGES_PER_GROUP, max_save_limit) | |
| logger.info( | |
| "Processing group of %d request(s) -> %d image(s) with effective batch_size=%d across %d chunk(s)", | |
| len(items), | |
| total_images, | |
| max_chunk_size, | |
| (total_images + max_chunk_size - 1) // max_chunk_size if max_chunk_size > 0 else 0, | |
| ) | |
| chunks: list[list[dict[str, Any]]] = [ | |
| flat_samples[i : i + max_chunk_size] | |
| for i in range(0, total_images, max_chunk_size) | |
| ] | |
| try: | |
| for chunk in chunks: | |
| c_prompts = [entry["prompt"] for entry in chunk] | |
| c_negatives = [entry["negative_prompt"] for entry in chunk] | |
| c_per_sample_info = [ | |
| { | |
| "request_id": entry["request_id"], | |
| "filename_prefix": entry["filename_prefix"], | |
| "seed": entry["seed"], | |
| "hires_fix": entry["hires_fix"], | |
| "adetailer": entry["adetailer"], | |
| } | |
| for entry in chunk | |
| ] | |
| chunk_kwargs = dict(pipeline_kwargs) | |
| chunk_kwargs["prompt"] = c_prompts | |
| chunk_kwargs["negative_prompt"] = c_negatives | |
| chunk_kwargs["number"] = len(c_prompts) | |
| chunk_kwargs["batch"] = len(c_prompts) | |
| chunk_kwargs["per_sample_info"] = c_per_sample_info | |
| chunk_kwargs["request_filename_prefix"] = c_per_sample_info[0]["filename_prefix"] if c_per_sample_info else None | |
| chunk_start_ts = time.time() | |
| result = await asyncio.to_thread(pipeline, **chunk_kwargs) | |
| if isinstance(result, dict) and "batched_results" in result: | |
| for request_id, entries in result["batched_results"].items(): | |
| saved_map.setdefault(request_id, []).extend(entries) | |
| else: | |
| files = _find_images_since(chunk_start_ts) | |
| for f in files: | |
| name = os.path.basename(f) | |
| for entry in chunk: | |
| rid = entry["request_id"] | |
| if f"LD-REQ-{rid}" in name: | |
| saved_map.setdefault(rid, []).append({ | |
| "filename": name, | |
| "subfolder": os.path.relpath(os.path.dirname(f), "./output"), | |
| }) | |
| except InterruptedError: | |
| logger.info( | |
| "Generation interrupted for request_ids=%s", | |
| [p.request_id for p in items], | |
| ) | |
| sync_broadcast_preview( | |
| step=0, | |
| total_steps=first_req.steps, | |
| message_type="error", | |
| generation_id=_gen_id, | |
| ) | |
| for p in items: | |
| if not p.future.done(): | |
| p.future.set_exception(HTTPException(status_code=409, detail="Generation interrupted")) | |
| return | |
| # For each pending item, collect its images and set future result | |
| for p in items: | |
| imgs = saved_map.get(p.request_id, []) | |
| # Filter and select the first N images requested | |
| selected = imgs[: max(1, p.req.num_images)] | |
| if not selected: | |
| p.future.set_exception(HTTPException(status_code=500, detail="No images produced")) | |
| continue | |
| # Try to use in-memory byte buffer first (avoids disk I/O) | |
| buffered_images = pop_image_bytes(f"LD-REQ-{p.request_id}") | |
| b64_list = [] | |
| if buffered_images: | |
| # Use in-memory bytes directly - zero disk reads | |
| for buf_filename, buf_subfolder, png_bytes in buffered_images[:max(1, p.req.num_images)]: | |
| b64_data = base64.b64encode(png_bytes).decode("utf-8") | |
| mime_type = "image/png" | |
| if buf_filename.lower().endswith((".jpg", ".jpeg")): | |
| mime_type = "image/jpeg" | |
| elif buf_filename.lower().endswith(".webp"): | |
| mime_type = "image/webp" | |
| b64_list.append(f"data:{mime_type};base64,{b64_data}") | |
| else: | |
| # Fallback to disk reads | |
| for entry in selected: | |
| if isinstance(entry, list): | |
| # Safeguard against nested lists if any processor still returns them | |
| entry = entry[0] if entry else {} | |
| if not isinstance(entry, dict): | |
| continue | |
| filename = entry.get("filename", "") | |
| path = os.path.join("./output", entry.get("subfolder", ""), filename) | |
| try: | |
| b64_data = _encode_png_to_base64(path) | |
| mime_type = "image/png" | |
| if filename.lower().endswith(".jpg") or filename.lower().endswith(".jpeg"): | |
| mime_type = "image/jpeg" | |
| elif filename.lower().endswith(".webp"): | |
| mime_type = "image/webp" | |
| b64_list.append(f"data:{mime_type};base64,{b64_data}") | |
| except Exception as e: | |
| logger.exception("Failed to read image for request %s: %s", p.request_id, e) | |
| if len(b64_list) == 0: | |
| p.future.set_exception(HTTPException(status_code=500, detail="Failed to read generated images")) | |
| elif len(b64_list) == 1: | |
| p.future.set_result({"image": b64_list[0]}) | |
| else: | |
| p.future.set_result({"images": b64_list}) | |
| finally: | |
| try: | |
| if prev_preview_settings is not None: | |
| _restore_preview_settings(prev_preview_settings) | |
| except Exception: | |
| pass | |
| try: | |
| if prev_preview_state is not None: | |
| _app_instance.app.previewer_var.set(prev_preview_state) | |
| except Exception: | |
| pass | |
| try: | |
| # Restore previous model cache keep-loaded setting if we | |
| # changed it above. | |
| if prev_keep_models_loaded is not None: | |
| try: | |
| model_cache = get_model_cache() | |
| model_cache.set_keep_models_loaded(bool(prev_keep_models_loaded)) | |
| except Exception: | |
| pass | |
| except Exception: | |
| pass | |
| # Instantiate the buffer and start it on startup | |
| _generation_buffer = GenerationBuffer() | |
| async def _start_buffer(): | |
| await _generation_buffer.start() | |
| def health() -> Dict[str, str]: | |
| return {"status": "ok"} | |
| async def telemetry() -> Dict[str, Any]: | |
| """Return basic server and batching buffer telemetry. | |
| Fields: | |
| - uptime_seconds | |
| - pending_count | |
| - pending_by_signature (human-readable) | |
| - pending_preview (list of small pending request summaries) | |
| - worker_running | |
| - max_batch_size, batch_timeout | |
| - batches_processed, items_processed, last_batch_time | |
| - pipeline_import_ok and pipeline_import_error | |
| """ | |
| rid = uuid.uuid4().hex[:8] | |
| log = logging.LoggerAdapter(logger, {"rid": rid}) | |
| log.debug("telemetry requested") | |
| now = time.time() | |
| uptime = now - SERVER_START_TS | |
| # Build a small snapshot of queue state under the buffer lock | |
| async with _generation_buffer._lock: | |
| pending_count = len(_generation_buffer._pending) | |
| # Group pending requests by signature for visibility | |
| sig_counts: Dict[str, int] = {} | |
| pending_preview: List[Dict[str, Any]] = [] | |
| for p in _generation_buffer._pending: | |
| try: | |
| sig = _generation_buffer._signature_for(p.req) | |
| sig_key = str(sig) | |
| except Exception: | |
| sig_key = "<unknown>" | |
| sig_counts[sig_key] = sig_counts.get(sig_key, 0) + 1 | |
| # Keep preview small to avoid large payloads | |
| preview = { | |
| "request_id": p.request_id, | |
| "waiting_s": round(now - p.arrival, 3), | |
| "prompt_preview": (p.req.prompt[:120] + "…") if (p.req.prompt and len(p.req.prompt) > 120) else (p.req.prompt or ""), | |
| } | |
| pending_preview.append(preview) | |
| batches_processed = _generation_buffer._batches_processed | |
| items_processed = _generation_buffer._items_processed | |
| last_batch_ts = _generation_buffer._last_batch_ts | |
| worker_running = ( | |
| _generation_buffer._worker_task is not None | |
| and (not _generation_buffer._worker_task.done()) | |
| ) | |
| # Compute average wait times | |
| requests_processed = _generation_buffer._requests_processed | |
| cumulative_wait = _generation_buffer._cumulative_wait_time | |
| avg_processed_wait_s = ( | |
| (cumulative_wait / requests_processed) if requests_processed > 0 else None | |
| ) | |
| # Pending average wait (current queue) | |
| pending_avg_wait_s = ( | |
| (sum(now - p.arrival for p in _generation_buffer._pending) / pending_count) | |
| if pending_count > 0 | |
| else 0.0 | |
| ) | |
| # Model cache telemetry (memory and loaded models) | |
| memory_info_error = None | |
| try: | |
| model_cache = get_model_cache() | |
| memory_info = model_cache.get_memory_info() | |
| loaded_raw = model_cache.get_cached_sampling_models() | |
| loaded_models = [] | |
| for m in loaded_raw: | |
| try: | |
| name = getattr(m, "name", None) or getattr(m, "__class__", type(m)).__name__ | |
| except Exception: | |
| name = str(type(m)) | |
| loaded_models.append(name) | |
| loaded_models_count = len(loaded_models) | |
| except Exception as e: | |
| # Don't fail telemetry if model cache query fails. Capture a short | |
| # error string so callers can display a hint without exposing full | |
| # stack traces. Device-side CUDA asserts can leave the device in an | |
| # unusable state and will cause subsequent CUDA queries to fail; we | |
| # surface a concise message here instead of crashing the endpoint. | |
| try: | |
| # Prefer a succinct message | |
| memory_info_error = str(e) | |
| except Exception: | |
| memory_info_error = "unknown" | |
| logger.exception("Failed to fetch model cache telemetry: %s", memory_info_error) | |
| memory_info = None | |
| loaded_models = [] | |
| loaded_models_count = 0 | |
| return { | |
| "uptime_seconds": round(uptime, 3), | |
| "server_start_ts": SERVER_START_TS, | |
| "pending_count": pending_count, | |
| "pending_by_signature": sig_counts, | |
| "pending_preview": pending_preview[:20], | |
| "worker_running": worker_running, | |
| "max_batch_size": LD_MAX_BATCH_SIZE, | |
| "batch_timeout": LD_BATCH_TIMEOUT, | |
| "max_images_per_group": LD_MAX_IMAGES_PER_GROUP, | |
| "batches_processed": batches_processed, | |
| "items_processed": items_processed, | |
| "requests_processed": requests_processed, | |
| "last_batch_time": last_batch_ts, | |
| "avg_processed_wait_s": avg_processed_wait_s, | |
| "pending_avg_wait_s": pending_avg_wait_s, | |
| "memory_info": memory_info, | |
| "loaded_models_count": loaded_models_count, | |
| "loaded_models": loaded_models, | |
| "pipeline_import_ok": pipeline is not None, | |
| "pipeline_import_error": str(_pipeline_import_error) if _pipeline_import_error is not None else None, | |
| } | |
| # Settings API ------------------------------------------------------------ | |
| def _read_settings_preferences() -> Dict[str, bool]: | |
| from src.Core.SettingsStore import get_preferences | |
| return get_preferences() | |
| def _resolve_autotune_preferences(req: GenerateRequest) -> GenerateRequest: | |
| prefs = _read_settings_preferences() | |
| req.torch_compile = bool(prefs["torch_compile"] if req.torch_compile is None else req.torch_compile) | |
| req.vae_autotune = bool(prefs["vae_autotune"] if req.vae_autotune is None else req.vae_autotune) | |
| return req | |
| def _reset_autotune_runtime_state() -> None: | |
| """Clear runtime model state so changed autotune preferences take effect.""" | |
| from src.Core.Pipeline import reset_default_pipeline | |
| from src.Device.Device import clear_compiled_models | |
| from src.Device.ModelCache import clear_model_cache | |
| reset_default_pipeline() | |
| clear_model_cache() | |
| clear_compiled_models() | |
| async def api_get_settings_preferences(): | |
| """Return persisted server-wide generation preferences.""" | |
| try: | |
| return _read_settings_preferences() | |
| except Exception as e: | |
| logger.exception("Failed to read settings preferences: %s", e) | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def api_post_settings_preferences(body: SettingsPreferencesRequest): | |
| """Persist server-wide generation preferences and reset runtime caches if needed.""" | |
| try: | |
| from src.Core.SettingsStore import set_preferences | |
| current = _read_settings_preferences() | |
| incoming = { | |
| "torch_compile": bool(body.torch_compile), | |
| "vae_autotune": bool(body.vae_autotune), | |
| } | |
| stored = set_preferences(incoming) | |
| if stored != current: | |
| _reset_autotune_runtime_state() | |
| return stored | |
| except Exception as e: | |
| logger.exception("Failed to update settings preferences: %s", e) | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def api_get_last_settings(): | |
| """Return the last persisted seed (or null).""" | |
| try: | |
| from src.Core.SettingsStore import get_last_seed | |
| seed = get_last_seed() | |
| return {"seed": seed} | |
| except Exception as e: | |
| logger.exception("Failed to read last seed: %s", e) | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def api_get_settings_history(): | |
| """Return saved settings history (most-recent-first).""" | |
| try: | |
| from src.Core.SettingsStore import get_history | |
| return {"history": get_history()} | |
| except Exception as e: | |
| logger.exception("Failed to read settings history: %s", e) | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def api_post_settings_history(body: Dict[str, Any]): | |
| """Append a settings snapshot to history. | |
| Body: { settings: GenerationSettings, include_prompt: bool } | |
| By default `include_prompt` is False and prompt/negative_prompt are NOT persisted. | |
| """ | |
| try: | |
| settings = body.get("settings") | |
| if not settings: | |
| raise HTTPException(status_code=400, detail="Missing 'settings' in request body") | |
| include_prompt = bool(body.get("include_prompt", False)) | |
| if include_prompt: | |
| stored = dict(settings) | |
| else: | |
| # Default sanitized/parameter-only snapshot for privacy | |
| allowed = ["seed", "steps", "cfg_scale", "sampler", "scheduler", "model_path", "width", "height"] | |
| stored = {k: settings[k] for k in allowed if k in settings} | |
| from src.Core.SettingsStore import append_snapshot | |
| snap = append_snapshot({"settings": stored}) | |
| return {"snapshot": snap} | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.exception("Failed to append settings history: %s", e) | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def api_post_image_metadata(body: Dict[str, Any]): | |
| """Extract PNG metadata from a base64/data-URL image payload and return | |
| a normalized metadata dictionary suitable for re-applying to UI settings. | |
| Body: { image: "data:image/png;base64,..." } or { image: "<base64>" } | |
| Returns: { metadata: { seed, steps, cfg_scale, sampler, scheduler, model_path, width, height, prompt?, negative_prompt? } } | |
| """ | |
| try: | |
| image_b64 = body.get("image") | |
| if not image_b64: | |
| raise HTTPException(status_code=400, detail="Missing 'image' in request body") | |
| # Accept data URL or raw base64 | |
| b64_data = None | |
| if isinstance(image_b64, str) and image_b64.startswith("data:"): | |
| idx = image_b64.find("base64,") | |
| if idx != -1: | |
| b64_data = image_b64[idx + len("base64,"):] | |
| elif isinstance(image_b64, str): | |
| b64_data = image_b64.strip().replace("\n", "") | |
| if not b64_data: | |
| raise HTTPException(status_code=400, detail="Invalid image payload") | |
| decoded = base64.b64decode(b64_data) | |
| # Parse PNG metadata using PIL | |
| from PIL import Image | |
| img = Image.open(io.BytesIO(decoded)) | |
| info = img.info or {} | |
| def _to_int(v): | |
| try: | |
| return int(v) | |
| except Exception: | |
| return None | |
| def _to_float(v): | |
| try: | |
| return float(v) | |
| except Exception: | |
| return None | |
| meta: Dict[str, Any] = {} | |
| if "prompt" in info: | |
| meta["prompt"] = info.get("prompt") | |
| if "negative_prompt" in info: | |
| meta["negative_prompt"] = info.get("negative_prompt") | |
| if "seed" in info: | |
| meta["seed"] = _to_int(info.get("seed")) | |
| if "steps" in info: | |
| meta["steps"] = _to_int(info.get("steps")) | |
| # Context.build_metadata uses key 'cfg' for CFG value — map it to cfg_scale | |
| if "cfg" in info: | |
| meta["cfg_scale"] = _to_float(info.get("cfg")) | |
| if "sampler" in info: | |
| meta["sampler"] = info.get("sampler") | |
| if "scheduler" in info: | |
| meta["scheduler"] = info.get("scheduler") | |
| if "model_path" in info: | |
| meta["model_path"] = info.get("model_path") | |
| if "width" in info: | |
| meta["width"] = _to_int(info.get("width")) | |
| if "height" in info: | |
| meta["height"] = _to_int(info.get("height")) | |
| return {"metadata": meta} | |
| except HTTPException: | |
| raise | |
| except Exception as e: | |
| logger.exception("Failed to decode image metadata: %s", e) | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| def _encode_png_to_base64(path: str) -> str: | |
| # Retry a few times in case the file is still being finalized on disk | |
| last_err: Optional[Exception] = None | |
| for attempt in range(20): # up to ~2s total | |
| try: | |
| with open(path, "rb") as f: | |
| data = f.read() | |
| if attempt > 0: | |
| logger.debug("Read image after %d retries: %s", attempt, path) | |
| return base64.b64encode(data).decode("utf-8") | |
| except Exception as e: | |
| last_err = e | |
| time.sleep(0.1) | |
| # One last attempt or raise detailed error | |
| try: | |
| with open(path, "rb") as f: | |
| logger.debug("Final attempt succeeded reading: %s", path) | |
| return base64.b64encode(f.read()).decode("utf-8") | |
| except Exception as e: | |
| logger.error("Failed to read generated image %s: %s", path, e if e else last_err) | |
| raise HTTPException(status_code=500, detail=f"Failed to read generated image: {e if e else last_err}") | |
| def _save_img2img_image_to_file(value: Optional[str], max_size_bytes: int = 10 * 1024 * 1024) -> Optional[str]: | |
| """Ensure img2img_image is a local file path. | |
| Accepts either: | |
| - an existing filesystem path (returned unchanged), | |
| - a data URL (data:image/...;base64,...) which will be decoded and saved to the system temp directory, or | |
| - a bare base64 string which will be decoded and saved. | |
| Returns the path to the saved file, or None if no value was provided. | |
| Raises HTTPException on invalid data or if the decoded payload exceeds max_size_bytes. | |
| """ | |
| if not value: | |
| return None | |
| # If it's already a file path that exists, return as-is | |
| if os.path.exists(value) and os.path.isfile(value): | |
| return value | |
| # Try to parse as a data URL or bare base64 | |
| b64_data = None | |
| try: | |
| if isinstance(value, str) and value.startswith("data:"): | |
| # data:[<mediatype>][;base64],<data> | |
| m = re.match(r"^data:(?P<mime>image/[^;]+);base64,(?P<b64>.+)$", value, flags=re.DOTALL) | |
| if m: | |
| b64_data = m.group("b64") | |
| else: | |
| # Fallback: find 'base64,' and take the rest | |
| idx = value.find("base64,") | |
| if idx != -1: | |
| b64_data = value[idx + len("base64,"):] | |
| else: | |
| # Possibly a raw base64 string; strip whitespace/newlines | |
| s = re.sub(r"\s+", "", str(value)) | |
| if len(s) > 100 and re.fullmatch(r"[A-Za-z0-9+/=]+", s): | |
| b64_data = s | |
| if not b64_data: | |
| raise HTTPException(status_code=400, detail="img2img_image must be a file path, a data URL, or a base64-encoded image") | |
| decoded = base64.b64decode(b64_data) | |
| except HTTPException: | |
| raise | |
| except Exception: | |
| raise HTTPException(status_code=400, detail="Invalid base64 data for img2img_image") | |
| # Enforce size limit | |
| if len(decoded) > max_size_bytes: | |
| raise HTTPException(status_code=413, detail=f"img2img_image too large (max {max_size_bytes // 1024} KB)") | |
| # Try to detect format | |
| try: | |
| import imghdr | |
| fmt = imghdr.what(None, decoded) | |
| except Exception: | |
| fmt = None | |
| ext = None | |
| if fmt: | |
| ext = "jpg" if fmt == "jpeg" else fmt | |
| else: | |
| try: | |
| from PIL import Image | |
| img = Image.open(io.BytesIO(decoded)) | |
| fmt = img.format.lower() if img.format else "png" | |
| ext = "jpg" if fmt == "jpeg" else fmt | |
| except Exception: | |
| ext = "png" | |
| # Save to system temp directory | |
| tmp_dir = tempfile.gettempdir() | |
| os.makedirs(tmp_dir, exist_ok=True) | |
| fname = f"img2img-{uuid.uuid4().hex[:8]}.{ext}" | |
| path = os.path.join(tmp_dir, fname) | |
| try: | |
| with open(path, "wb") as f: | |
| f.write(decoded) | |
| except Exception as e: | |
| logger.exception("Failed to write img2img upload to %s: %s", path, e) | |
| raise HTTPException(status_code=500, detail="Failed to save img2img_image on server") | |
| # Don't log the incoming base64 content | |
| logger.info("Saved img2img image to %s", path) | |
| return path | |
| def _list_existing_images() -> List[str]: | |
| exts = ["*.png", "*.jpg", "*.jpeg", "*.webp"] | |
| files: List[str] = [] | |
| for ext in exts: | |
| files.extend(glob.glob(os.path.join("./output", "**", ext), recursive=True)) | |
| logger.debug("Found %d existing images", len(files)) | |
| return files | |
| def _find_images_since(start_ts: float) -> List[str]: | |
| """Return images whose mtime is at or after start_ts (with small grace).""" | |
| grace = 0.25 | |
| files = _list_existing_images() | |
| recent = [p for p in files if os.path.getmtime(p) >= (start_ts - grace)] | |
| recent.sort(key=lambda p: os.path.getmtime(p), reverse=True) | |
| logger.debug("%d images modified since %.3f", len(recent), start_ts) | |
| return recent | |
| # WebSocket preview endpoint for real-time streaming | |
| _preview_clients: List[WebSocket] = [] | |
| _main_event_loop: Optional[asyncio.AbstractEventLoop] = None | |
| def sync_broadcast_preview( | |
| step: int, | |
| total_steps: int, | |
| images: Optional[List[str]] = None, | |
| message_type: str = "preview", | |
| generation_id: Optional[str] = None, | |
| ): | |
| """Synchronous wrapper to broadcast preview from pipeline thread. | |
| This function can be called from the pipeline callback running in a | |
| thread pool executor. It schedules the async broadcast on the main | |
| event loop. | |
| """ | |
| global _main_event_loop | |
| if not _preview_clients: | |
| if step % 10 == 0: | |
| logger.debug("No preview clients connected, skipping broadcast") | |
| return | |
| if _main_event_loop is None: | |
| logger.error("Main event loop is None! Cannot broadcast preview.") | |
| return | |
| try: | |
| if step % 5 == 0 or step == total_steps - 1: | |
| logger.info(f"Broadcasting preview step {step}/{total_steps}") | |
| future = asyncio.run_coroutine_threadsafe( | |
| broadcast_preview(step, total_steps, images, message_type, generation_id=generation_id), | |
| _main_event_loop | |
| ) | |
| # Wait for broadcast to complete to ensure ordering | |
| try: | |
| future.result(timeout=0.5) | |
| except Exception: | |
| pass # Don't block generation on slow clients | |
| except Exception as e: | |
| logger.error(f"Preview broadcast failed: {e}") | |
| pass # Don't let preview errors affect generation | |
| def _apply_preview_fidelity_to_app(req): | |
| """Apply preview fidelity settings from a GenerateRequest into the global app. | |
| Returns a dict with previous settings so callers can restore them later. | |
| """ | |
| prev = {} | |
| try: | |
| # Only apply fidelity changes if previewing is enabled for this request. | |
| if not getattr(req, "enable_preview", False): | |
| return None | |
| prev["preview_srgb"] = getattr(_app_instance.app, "preview_srgb", True) | |
| prev["preview_format"] = getattr(_app_instance.app, "preview_format", "WEBP") | |
| prev["preview_quality"] = getattr(_app_instance.app, "preview_quality", 90) | |
| prev["preview_resample"] = getattr(_app_instance.app, "preview_resample", "LANCZOS") | |
| prev["preview_apply_fast_autohdr"] = getattr(_app_instance.app, "preview_apply_fast_autohdr", False) | |
| pfid = getattr(req, "preview_fidelity", "balanced") or "balanced" | |
| # Map to a few conservative presets | |
| if pfid == "low": | |
| _app_instance.app.preview_srgb = True | |
| _app_instance.app.preview_format = "WEBP" | |
| _app_instance.app.preview_quality = 70 | |
| elif pfid == "high": | |
| _app_instance.app.preview_srgb = True | |
| _app_instance.app.preview_format = "PNG" | |
| _app_instance.app.preview_quality = 100 | |
| else: | |
| # balanced | |
| _app_instance.app.preview_srgb = True | |
| _app_instance.app.preview_format = "WEBP" | |
| _app_instance.app.preview_quality = 90 | |
| return prev | |
| except Exception: | |
| return None | |
| def _restore_preview_settings(prev): | |
| if not prev: | |
| return | |
| try: | |
| _app_instance.app.preview_srgb = prev.get("preview_srgb", True) | |
| _app_instance.app.preview_format = prev.get("preview_format", "WEBP") | |
| _app_instance.app.preview_quality = prev.get("preview_quality", 90) | |
| _app_instance.app.preview_resample = prev.get("preview_resample", "LANCZOS") | |
| _app_instance.app.preview_apply_fast_autohdr = prev.get("preview_apply_fast_autohdr", False) | |
| except Exception: | |
| pass | |
| def make_server_callback(total_steps: int, generation_id: Optional[str] = None): | |
| """Create a pipeline callback that broadcasts progress via WebSocket. | |
| Args: | |
| total_steps: Total number of sampling steps | |
| generation_id: Unique ID for this generation run, sent with every | |
| preview message so the frontend can ignore stale previews. | |
| Returns: | |
| Callback function compatible with pipeline | |
| """ | |
| def callback(args): | |
| # Extract step info from args dict | |
| step = args.get("i", 0) | |
| curr_total_steps = args.get("total_steps", total_steps) | |
| # Only process images on broadcast steps to save compute | |
| # Broadcast every 5 steps or last step | |
| is_broadcast_step = (step % 5 == 0) or (step == curr_total_steps - 1) | |
| images_b64 = None | |
| if is_broadcast_step: | |
| try: | |
| # prefer denoised, fallback to x ONLY if early step | |
| latents_tensor = args.get("denoised") | |
| if latents_tensor is None and step < 5: | |
| latents_tensor = args.get("x") | |
| if latents_tensor is not None: | |
| # Detect flux from shape (Flux has 16 or 32 channels) | |
| # This is a heuristic, ideal would be to pass it in args | |
| is_flux = (latents_tensor.shape[1] == 16 or latents_tensor.shape[1] == 32) | |
| pil_images = decode_latents_to_images(latents_tensor, flux=is_flux) | |
| images_b64 = [] | |
| for img in pil_images: | |
| buffered = io.BytesIO() | |
| fmt = getattr(_app_instance.app, "preview_format", "WEBP") | |
| q = getattr(_app_instance.app, "preview_quality", 90) | |
| try: | |
| img.save(buffered, format=fmt, quality=q) | |
| mime = f"image/{fmt.lower()}" | |
| except Exception: | |
| # Fallback to JPEG if preferred format is unsupported | |
| buffered = io.BytesIO() | |
| img.save(buffered, format="JPEG", quality=max(70, q)) | |
| mime = "image/jpeg" | |
| img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") | |
| images_b64.append(f"data:{mime};base64,{img_str}") | |
| except Exception as e: | |
| logger.error(f"Preview generation failed: {e}") | |
| pass | |
| # Broadcast progress update with images | |
| sync_broadcast_preview(step, curr_total_steps, images=images_b64, message_type="preview" if images_b64 else "progress", generation_id=generation_id) | |
| return callback | |
| async def websocket_preview(websocket: WebSocket): | |
| """WebSocket endpoint for real-time preview streaming. | |
| Clients receive JSON messages with: | |
| - type: "preview" | "progress" | "complete" | "error" | |
| - step: Current step number | |
| - total_steps: Total number of steps | |
| - timestamp: Unix timestamp | |
| - images: List of base64 encoded preview images (for "preview" type) | |
| """ | |
| await websocket.accept() | |
| _preview_clients.append(websocket) | |
| logger.info("WebSocket client connected to /ws/preview (total: %d)", len(_preview_clients)) | |
| try: | |
| # Keep connection alive and listen for close | |
| while True: | |
| try: | |
| # Wait for any message (ping/pong or close) | |
| data = await asyncio.wait_for(websocket.receive_text(), timeout=30.0) | |
| # Echo back to confirm alive | |
| await websocket.send_json({"type": "pong", "timestamp": time.time()}) | |
| except asyncio.TimeoutError: | |
| # Send ping to keep connection alive | |
| try: | |
| await websocket.send_json({"type": "ping", "timestamp": time.time()}) | |
| except Exception: | |
| break | |
| except WebSocketDisconnect: | |
| pass | |
| except Exception as e: | |
| logger.debug("WebSocket connection error: %s", e) | |
| finally: | |
| if websocket in _preview_clients: | |
| _preview_clients.remove(websocket) | |
| logger.info("WebSocket client disconnected (remaining: %d)", len(_preview_clients)) | |
| async def broadcast_preview( | |
| step: int, | |
| total_steps: int, | |
| images: Optional[List[str]] = None, | |
| message_type: str = "preview", | |
| generation_id: Optional[str] = None, | |
| ): | |
| """Broadcast preview update to all connected WebSocket clients. | |
| Args: | |
| step: Current step number | |
| total_steps: Total number of steps | |
| images: Optional list of base64-encoded images | |
| message_type: Type of message (preview, progress, complete, error) | |
| generation_id: Unique ID for this generation run | |
| """ | |
| if not _preview_clients: | |
| return | |
| payload = { | |
| "type": message_type, | |
| "step": step, | |
| "total_steps": total_steps, | |
| "timestamp": time.time(), | |
| } | |
| if generation_id: | |
| payload["generation_id"] = generation_id | |
| if images: | |
| payload["images"] = images | |
| # Send to all clients, removing any that fail | |
| disconnected = [] | |
| for client in _preview_clients: | |
| try: | |
| await client.send_json(payload) | |
| except Exception: | |
| disconnected.append(client) | |
| for client in disconnected: | |
| if client in _preview_clients: | |
| _preview_clients.remove(client) | |
| async def generate(req: GenerateRequest) -> Dict[str, Any]: | |
| rid = uuid.uuid4().hex[:8] | |
| log = logging.LoggerAdapter(logger, {"rid": rid}) | |
| log.info("/api/generate called") | |
| # Validate pipeline import | |
| global pipeline, _pipeline_import_error | |
| if pipeline is None: | |
| log.error("Pipeline import error: %s", _pipeline_import_error) | |
| raise HTTPException(status_code=500, detail=f"Pipeline import error: {_pipeline_import_error}") | |
| # Optionally honor requested seed by persisting it in SettingsStore and enabling reuse | |
| reuse_seed = req.reuse_seed | |
| if req.seed is not None and req.seed >= 0: | |
| try: | |
| from src.Core.SettingsStore import set_last_seed | |
| set_last_seed(int(req.seed)) | |
| except Exception: | |
| logger.exception("Failed to persist last seed to SettingsStore") | |
| reuse_seed = True | |
| req = _resolve_autotune_preferences(req) | |
| # For buffered execution we pass request data into the queue; the | |
| # background worker will control how the prompt and img2img path are | |
| # consumed when invoking the pipeline. | |
| # Log request summary (avoid dumping huge strings) | |
| def _truncate(s: Optional[str], n: int = 200) -> str: | |
| if not s: | |
| return "" | |
| return s if len(s) <= n else s[:n] + "…" | |
| log.debug( | |
| "Request: w=%s h=%s num_images=%s batch=%s scheduler=%s sampler=%s steps=%s hiresfix=%s adetailer=%s enhance=%s img2img=%s stable_fast=%s torch_compile=%s vae_autotune=%s reuse_seed=%s realistic=%s multiscale=%s intermittent=%s factor=%s fullres=[%s,%s] keep_models_loaded=%s enable_preview=%s prompt='%s' neg='%s' img2img_image_present=%s", | |
| req.width, | |
| req.height, | |
| req.num_images, | |
| req.batch_size, | |
| req.scheduler, | |
| req.sampler, | |
| req.steps, | |
| req.hiresfix, | |
| req.adetailer, | |
| req.enhance_prompt, | |
| req.img2img_mode, | |
| req.stable_fast, | |
| req.torch_compile, | |
| req.vae_autotune, | |
| reuse_seed, | |
| req.realistic_model, | |
| req.enable_multiscale, | |
| req.multiscale_intermittent, | |
| req.multiscale_factor, | |
| req.multiscale_fullres_start, | |
| req.multiscale_fullres_end, | |
| req.keep_models_loaded, | |
| req.enable_preview, | |
| _truncate(req.prompt, 200), | |
| _truncate(req.negative_prompt or "", 200), | |
| bool(req.img2img_image), | |
| ) | |
| # If client provided an img2img image as a data URL or raw base64, decode and save | |
| if req.img2img_image: | |
| try: | |
| saved_path = _save_img2img_image_to_file(req.img2img_image, max_size_bytes=10 * 1024 * 1024) | |
| if saved_path and saved_path != req.img2img_image: | |
| log.info("Img2Img upload received and written to %s", saved_path) | |
| req.img2img_image = saved_path | |
| except HTTPException: | |
| # Propagate well-formed HTTP exceptions (bad payloads, too large, etc.) | |
| raise | |
| except Exception as e: | |
| log.exception("Failed processing img2img_image: %s", e) | |
| # Avoid echoing the raw base64 content into logs or responses | |
| raise HTTPException(status_code=400, detail="Invalid img2img_image payload") | |
| # Enqueue the request for batched processing. The background worker will | |
| # perform the actual pipeline invocation and will restore any preview | |
| # state toggles after generation completes. | |
| # Enqueue the request for batched processing. The background worker will | |
| # perform the actual pipeline invocation and will restore any preview | |
| # state toggles after generation completes. | |
| pending = PendingRequest(req, rid) | |
| result = await _generation_buffer.enqueue(pending) | |
| # Return the result produced by the background worker (dict with | |
| # either 'image' or 'images'). | |
| return result | |
| # Background worker will have returned the final result for this request. | |
| async def list_models() -> List[Dict[str, Any]]: | |
| """List available models with type detection and capabilities.""" | |
| try: | |
| from src.Core.Models.ModelFactory import list_available_models, detect_model_type, create_model | |
| models = list_available_models(return_mapping=True) | |
| results = [] | |
| for name, path in models: | |
| try: | |
| # We create a temporary instance to get capabilities without full loading | |
| # detect_model_type is fast | |
| mtype = detect_model_type(path) | |
| # Get capabilities from the model class | |
| # ModelFactory.create_model returns an uninitialized instance | |
| model_instance = create_model(model_path=path, model_type=mtype) | |
| caps = model_instance.capabilities | |
| # Convert capabilities dataclass to dict | |
| cap_dict = { | |
| "supports_hires_fix": caps.supports_hires_fix, | |
| "supports_img2img": caps.supports_img2img, | |
| "supports_controlnet": caps.supports_controlnet, | |
| "supports_inpainting": caps.supports_inpainting, | |
| "supports_stable_fast": caps.supports_stable_fast, | |
| "supports_deepcache": caps.supports_deepcache, | |
| "supports_tome": caps.supports_tome, | |
| "preferred_resolution": caps.preferred_resolution, | |
| } | |
| results.append({ | |
| "name": name, | |
| "path": path, | |
| "type": mtype, | |
| "capabilities": cap_dict | |
| }) | |
| except Exception as e: | |
| logger.warning(f"Failed to detect type/caps for {name}: {e}") | |
| results.append({ | |
| "name": name, | |
| "path": path, | |
| "type": "SD15", | |
| "capabilities": {} | |
| }) | |
| return results | |
| except Exception as e: | |
| logger.error(f"Failed to list models: {e}") | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def interrupt_generation(): | |
| """Interrupt current generation.""" | |
| # Logic to interrupt generation | |
| # We need to signal the pipeline to stop | |
| # The pipeline checks app_instance.app.interrupt_flag | |
| if _app_instance and hasattr(_app_instance, "app") and _app_instance.app: | |
| _app_instance.app.request_interrupt() | |
| logger.info("Interrupt requested via API") | |
| return {"status": "interrupted"} | |
| else: | |
| logger.error("Cannot interrupt: app_instance not available") | |
| raise HTTPException(status_code=503, detail="App instance not available") | |
| # Mount frontend if build exists | |
| frontend_dist = os.path.join(os.path.dirname(__file__), "frontend", "dist") | |
| if os.path.exists(frontend_dist): | |
| app.mount("/", StaticFiles(directory=frontend_dist, html=True), name="frontend") | |
| logger.info(f"Serving frontend from {frontend_dist}") | |
| else: | |
| logger.warning(f"Frontend build not found at {frontend_dist}. Run 'npm run build' in frontend directory.") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| import argparse | |
| import subprocess | |
| import signal | |
| parser = argparse.ArgumentParser(description="LightDiffusion Server") | |
| try: | |
| default_port = int(os.environ.get("PORT") or os.environ.get("UVICORN_PORT") or 7861) | |
| except Exception: | |
| default_port = 7861 | |
| parser.add_argument("--host", type=str, default=os.environ.get("HOST", "0.0.0.0"), help="Host to bind to") | |
| parser.add_argument("--port", type=int, default=default_port, help="Port to bind to") | |
| parser.add_argument("--frontend", action="store_true", help="Launch the frontend development server") | |
| args = parser.parse_args() | |
| frontend_proc = None | |
| if args.frontend: | |
| frontend_dir = os.path.join(os.path.dirname(__file__), "frontend") | |
| if os.path.exists(frontend_dir): | |
| logger.info("Launching frontend development server...") | |
| try: | |
| # Use shell=True for windows to find npm | |
| frontend_proc = subprocess.Popen( | |
| ["npm", "run", "dev"], | |
| cwd=frontend_dir, | |
| shell=True | |
| ) | |
| logger.info("Frontend development server launched") | |
| except Exception as e: | |
| logger.error(f"Failed to launch frontend: {e}") | |
| else: | |
| logger.warning(f"Frontend directory not found at {frontend_dir}") | |
| # Present helpful URL(s) to the user before starting uvicorn | |
| try: | |
| if args.host in ("0.0.0.0", "::", ""): | |
| try: | |
| import socket | |
| s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) | |
| s.connect(("8.8.8.8", 80)) | |
| host_ip = s.getsockname()[0] | |
| s.close() | |
| except Exception: | |
| host_ip = "127.0.0.1" | |
| logger.info("Open the UI in a browser: http://localhost:%d/ (or on your network: http://%s:%d/)", args.port, host_ip, args.port) | |
| else: | |
| logger.info("Open the UI in a browser: http://%s:%d/", args.host, args.port) | |
| uvicorn.run("server:app", host=args.host, port=args.port, reload=False, ws="websockets") | |
| finally: | |
| if frontend_proc: | |
| logger.info("Shutting down frontend development server...") | |
| if sys.platform == "win32": | |
| # On Windows, we need to kill the process tree because shell=True creates a cmd.exe wrapper | |
| subprocess.run(["taskkill", "/F", "/T", "/PID", str(frontend_proc.pid)], capture_output=True) | |
| else: | |
| frontend_proc.terminate() | |
| logger.info("Frontend development server shut down") | |