Spaces:
Running on Zero
Running on Zero
File size: 71,591 Bytes
b701455 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 1085 1086 1087 1088 1089 1090 1091 1092 1093 1094 1095 1096 1097 1098 1099 1100 1101 1102 1103 1104 1105 1106 1107 1108 1109 1110 1111 1112 1113 1114 1115 1116 1117 1118 1119 1120 1121 1122 1123 1124 1125 1126 1127 1128 1129 1130 1131 1132 1133 1134 1135 1136 1137 1138 1139 1140 1141 1142 1143 1144 1145 1146 1147 1148 1149 1150 1151 1152 1153 1154 1155 1156 1157 1158 1159 1160 1161 1162 1163 1164 1165 1166 1167 1168 1169 1170 1171 1172 1173 1174 1175 1176 1177 1178 1179 1180 1181 1182 1183 1184 1185 1186 1187 1188 1189 1190 1191 1192 1193 1194 1195 1196 1197 1198 1199 1200 1201 1202 1203 1204 1205 1206 1207 1208 1209 1210 1211 1212 1213 1214 1215 1216 1217 1218 1219 1220 1221 1222 1223 1224 1225 1226 1227 1228 1229 1230 1231 1232 1233 1234 1235 1236 1237 1238 1239 1240 1241 1242 1243 1244 1245 1246 1247 1248 1249 1250 1251 1252 1253 1254 1255 1256 1257 1258 1259 1260 1261 1262 1263 1264 1265 1266 1267 1268 1269 1270 1271 1272 1273 1274 1275 1276 1277 1278 1279 1280 1281 1282 1283 1284 1285 1286 1287 1288 1289 1290 1291 1292 1293 1294 1295 1296 1297 1298 1299 1300 1301 1302 1303 1304 1305 1306 1307 1308 1309 1310 1311 1312 1313 1314 1315 1316 1317 1318 1319 1320 1321 1322 1323 1324 1325 1326 1327 1328 1329 1330 1331 1332 1333 1334 1335 1336 1337 1338 1339 1340 1341 1342 1343 1344 1345 1346 1347 1348 1349 1350 1351 1352 1353 1354 1355 1356 1357 1358 1359 1360 1361 1362 1363 1364 1365 1366 1367 1368 1369 1370 1371 1372 1373 1374 1375 1376 1377 1378 1379 1380 1381 1382 1383 1384 1385 1386 1387 1388 1389 1390 1391 1392 1393 1394 1395 1396 1397 1398 1399 1400 1401 1402 1403 1404 1405 1406 1407 1408 1409 1410 1411 1412 1413 1414 1415 1416 1417 1418 1419 1420 1421 1422 1423 1424 1425 1426 1427 1428 1429 1430 1431 1432 1433 1434 1435 1436 1437 1438 1439 1440 1441 1442 1443 1444 1445 1446 1447 1448 1449 1450 1451 1452 1453 1454 1455 1456 1457 1458 1459 1460 1461 1462 1463 1464 1465 1466 1467 1468 1469 1470 1471 1472 1473 1474 1475 1476 1477 1478 1479 1480 1481 1482 1483 1484 1485 1486 1487 1488 1489 1490 1491 1492 1493 1494 1495 1496 1497 1498 1499 1500 1501 1502 1503 1504 1505 1506 1507 1508 1509 1510 1511 1512 1513 1514 1515 1516 1517 1518 1519 1520 1521 1522 1523 1524 1525 1526 1527 1528 1529 1530 1531 1532 1533 1534 1535 1536 1537 1538 1539 1540 1541 1542 1543 1544 1545 1546 1547 1548 1549 1550 1551 1552 1553 1554 1555 1556 1557 1558 1559 1560 1561 1562 1563 1564 1565 1566 1567 1568 1569 1570 1571 1572 1573 1574 1575 1576 1577 1578 1579 1580 1581 1582 1583 1584 1585 1586 1587 1588 1589 1590 1591 1592 1593 1594 1595 1596 1597 1598 1599 1600 1601 1602 1603 1604 1605 1606 1607 1608 1609 1610 1611 1612 1613 1614 1615 1616 1617 1618 1619 1620 1621 1622 1623 1624 1625 1626 1627 1628 1629 1630 1631 1632 1633 1634 1635 1636 1637 1638 1639 1640 1641 1642 1643 1644 1645 1646 1647 1648 1649 1650 1651 1652 1653 1654 1655 1656 1657 1658 1659 1660 1661 1662 1663 1664 1665 1666 1667 1668 1669 1670 1671 1672 1673 1674 1675 1676 1677 1678 1679 1680 1681 1682 1683 1684 1685 1686 1687 1688 1689 1690 1691 1692 1693 1694 1695 1696 1697 1698 1699 1700 1701 1702 1703 1704 1705 1706 1707 1708 1709 1710 1711 1712 1713 1714 1715 1716 1717 1718 1719 1720 1721 1722 1723 1724 1725 1726 1727 1728 1729 1730 1731 1732 1733 1734 1735 1736 1737 1738 1739 1740 1741 1742 1743 1744 1745 1746 1747 | from __future__ import annotations
import base64
import glob
import os
import io
import re
import tempfile
from src.AutoEncoders.taesd import decode_latents_to_images
# Ensure we can import pipeline from this repo
import sys
import time
from typing import Any, Dict, List, Optional
from fastapi import FastAPI, HTTPException, WebSocket, WebSocketDisconnect
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
from src.Device.ModelCache import get_model_cache
from src.Core.Models.ModelFactory import list_available_models, list_available_controlnets
from src.FileManaging.ImageSaver import pop_image_bytes
sys.path.append(os.path.abspath(os.path.dirname(__file__)))
# Logging setup
import asyncio
import logging
import uuid
from logging.handlers import RotatingFileHandler
# Create a module-level logger with rotating file handler and request-id support
class _RequestIdFilter(logging.Filter):
def filter(self, record: logging.LogRecord) -> bool: # pragma: no cover - simple utility
if not hasattr(record, "rid"):
record.rid = "-"
return True
def _setup_logger() -> logging.Logger:
os.makedirs("./logs", exist_ok=True)
logger = logging.getLogger("lightdiffusion.server")
if logger.handlers:
return logger
level_name = os.getenv("LD_SERVER_LOGLEVEL", "DEBUG").upper()
try:
level = getattr(logging, level_name, logging.DEBUG)
except Exception: # pragma: no cover
level = logging.DEBUG
logger.setLevel(level)
formatter = logging.Formatter(
fmt="%(asctime)s | %(levelname)s | %(name)s | rid=%(rid)s | %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
file_handler = RotatingFileHandler(
filename=os.path.join("./logs", "server.log"),
maxBytes=5 * 1024 * 1024,
backupCount=3,
encoding="utf-8",
)
file_handler.setFormatter(formatter)
file_handler.addFilter(_RequestIdFilter())
logger.addHandler(file_handler)
# Also log to stderr for interactive runs; avoid duplicate handlers if uvicorn config already propagates
stream_handler = logging.StreamHandler()
stream_handler.setFormatter(formatter)
stream_handler.addFilter(_RequestIdFilter())
logger.addHandler(stream_handler)
logger.propagate = False
return logger
logger = _setup_logger()
logger.debug("server module loaded; cwd=%s", os.getcwd())
# Record server start time for telemetry
SERVER_START_TS = time.time()
try:
# Import app_instance to control preview behavior during generation
from src.user import app_instance as _app_instance
from src.user.pipeline import pipeline
except Exception as e:
# Defer import error to runtime response for clarity
pipeline = None # type: ignore
_pipeline_import_error = e
logger.exception("Failed to import pipeline: %s", e)
else:
_pipeline_import_error = None
logger.info("Pipeline and app_instance imported successfully")
class GenerateRequest(BaseModel):
prompt: str
negative_prompt: Optional[str] = ""
width: int = 512
height: int = 512
num_images: int = 1
batch_size: int = 1
scheduler: str = "ays"
sampler: str = "dpmpp_sde_cfgpp"
steps: int = 20
hiresfix: bool = False
adetailer: bool = False
enhance_prompt: bool = False
img2img_mode: bool = False
img2img_image: Optional[str] = None
img2img_denoise: float = 0.75 # Denoising strength: 0=keep original, 1=full generation
stable_fast: bool = False
reuse_seed: bool = False
realistic_model: bool = False
enable_multiscale: bool = False
multiscale_preset: Optional[str] = "balanced"
multiscale_intermittent: bool = True
multiscale_factor: float = 0.5
multiscale_fullres_start: int = 10
multiscale_fullres_end: int = 8
keep_models_loaded: bool = True
enable_preview: bool = False
# Preview fidelity for this request: 'low' | 'balanced' | 'high' (default: balanced)
preview_fidelity: str = "balanced"
# CFG-free sampling parameters
cfg_free_enabled: bool = False
cfg_free_start_percent: float = 70.0
# Token Merging parameters
tome_enabled: bool = False
tome_ratio: float = 0.5
tome_max_downsample: int = 1
# Advanced CFG optimization parameters (batched_cfg enabled by default for 8% speedup)
batched_cfg: bool = True
dynamic_cfg_rescaling: bool = False
dynamic_cfg_method: str = "variance"
dynamic_cfg_percentile: float = 95.0
dynamic_cfg_target_scale: float = 7.0
adaptive_noise_enabled: bool = False
adaptive_noise_method: str = "complexity"
# Guidance
cfg_scale: float = 7.0
guidance_scale: Optional[float] = None
seed: Optional[int] = None # If provided >=0 we will reuse it
# Model Selection
model_path: Optional[str] = None
refiner_model_path: Optional[str] = None
refiner_switch_step: Optional[int] = None
# ControlNet
controlnet_enabled: bool = False
controlnet_model: Optional[str] = None
controlnet_strength: float = 1.0
controlnet_type: str = "canny"
# torch.compile optimization (mutually exclusive with stable_fast)
torch_compile: Optional[bool] = None
vae_autotune: Optional[bool] = None
# Weight quantization format: None, "fp8", or "nvfp4"
weight_quantization: Optional[str] = None
# FP8 inference (auto-gated to supported hardware: Ada Lovelace+)
fp8_inference: bool = False
class SettingsPreferencesRequest(BaseModel):
torch_compile: bool = False
vae_autotune: bool = False
app = FastAPI(title="LightDiffusion Server", version="1.0.0")
@app.get("/api/controlnets")
async def get_controlnets():
"""List available ControlNet models."""
try:
models = list_available_controlnets()
return {"models": models}
except Exception as e:
logger.exception("Failed to list controlnets")
raise HTTPException(status_code=500, detail=str(e))
@app.on_event("startup")
async def startup_event():
"""Capture event loop reference and start background worker."""
global _main_event_loop
_main_event_loop = asyncio.get_running_loop()
# Migrate legacy include/last_seed.txt into the JSON settings store on startup
try:
from src.Core.SettingsStore import migrate_from_last_seed_txt
migrated_seed = migrate_from_last_seed_txt()
if migrated_seed is not None:
logger.info("Migrated legacy include/last_seed.txt -> last_seed=%s", migrated_seed)
except Exception:
logger.exception("Failed to migrate legacy last_seed.txt on startup")
await _generation_buffer.start()
logger.info("Server startup complete, event loop captured for preview broadcasting")
# Helpful, user-friendly startup URL(s) so users know what to open in a browser.
try:
port = int(os.environ.get("PORT") or os.environ.get("UVICORN_PORT") or 7861)
except Exception:
port = 7861
try:
import socket
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
s.connect(("8.8.8.8", 80))
local_ip = s.getsockname()[0]
s.close()
except Exception:
local_ip = "127.0.0.1"
logger.info("Open the UI in a browser: http://localhost:%d/ (or on your network: http://%s:%d/)", port, local_ip, port)
# Batching buffer -----------------------------------------------------------
LD_MAX_BATCH_SIZE = int(os.getenv("LD_MAX_BATCH_SIZE", "4"))
LD_BATCH_TIMEOUT = float(os.getenv("LD_BATCH_TIMEOUT", "0.5"))
# If set to true (1/true), the worker will wait the coalescing timeout when
# there is a single candidate in a chosen group; otherwise singletons are
# processed immediately. Default is to process singletons immediately to
# favor throughput and avoid perceived "stuck" behavior.
LD_BATCH_WAIT_SINGLETONS = os.getenv("LD_BATCH_WAIT_SINGLETONS", "0").lower() in ("1", "true", "yes")
# Limit total number of images we will process in a single pipeline run when
# coalescing many requests into a group. If the sum of images across the group
# is larger than this, we will split the group into smaller chunks and run the
# pipeline sequentially to avoid memory pressure and downstream save failures.
LD_MAX_IMAGES_PER_GROUP = int(os.getenv("LD_MAX_IMAGES_PER_GROUP", "256"))
def _normalized_image_key(value: Optional[str]) -> str:
"""Return a stable image identity key for batching decisions."""
if not value:
return ""
if value.startswith("data:"):
# Data URLs should already be normalized to a temp file before enqueue,
# but keep a deterministic fallback in case this helper is called early.
return value[:128]
try:
return os.path.abspath(os.path.realpath(value))
except Exception:
return str(value)
def _effective_guidance_scale(req: "GenerateRequest") -> float:
"""Normalize guidance scale for batch signatures and pipeline calls."""
return float(req.cfg_scale if req.guidance_scale is None else req.guidance_scale)
def _has_running_loop() -> bool:
try:
asyncio.get_running_loop()
return True
except RuntimeError:
return False
class PendingRequest:
def __init__(self, req: GenerateRequest, request_id: str):
self.req = req
self.request_id = request_id
self.arrival = time.time()
self.future: asyncio.Future = asyncio.get_running_loop().create_future()
class GenerationBuffer:
def __init__(self):
self._pending: List[PendingRequest] = []
self._loop: Optional[asyncio.AbstractEventLoop] = None
self._lock: asyncio.Lock
self._new_request: asyncio.Event
# Prefetching state
self._prefetch_lock: asyncio.Lock
self._prefetch_task: Optional[asyncio.Task] = None
self._current_prefetch_path: Optional[str] = None
# Statistics
self._items_processed = 0
self._batches_processed = 0
self._requests_processed = 0
self._cumulative_wait_time = 0.0
self._last_batch_ts = 0.0
self._worker_task: Optional[asyncio.Task] = None
self._reset_async_primitives(asyncio.get_running_loop() if _has_running_loop() else None)
def _reset_async_primitives(self, loop: Optional[asyncio.AbstractEventLoop]) -> None:
"""Recreate loop-bound synchronization primitives.
Test runs can start the in-process server multiple times on different
event loops. The queue's Event/Lock objects must be recreated when the
owning loop changes to avoid cross-loop RuntimeError during teardown.
"""
self._loop = loop
self._lock = asyncio.Lock()
self._new_request = asyncio.Event()
self._prefetch_lock = asyncio.Lock()
self._prefetch_task = None
self._current_prefetch_path = None
async def start(self):
"""Start the background worker task."""
current_loop = asyncio.get_running_loop()
if self._loop is not current_loop:
self._reset_async_primitives(current_loop)
if self._worker_task is None or self._worker_task.done():
self._worker_task = asyncio.create_task(self._worker())
logger.info("GenerationBuffer worker task started")
async def enqueue(self, pending: PendingRequest) -> dict:
"""Add a request to the queue and wait for completion."""
async with self._lock:
self._pending.append(pending)
self._new_request.set()
# Wait for the worker to process this request
return await pending.future
async def _look_ahead_and_prefetch(self, current_batch_signature: tuple):
"""Analyze remaining queue and pre-load the next model if different."""
from src.user.pipeline import resolve_checkpoint_path
async with self._lock:
if not self._pending:
return
# Find the next group that has a different signature
next_req = None
for p in self._pending:
sig = self._signature_for(p.req)
if sig != current_batch_signature:
next_req = p.req
break
if not next_req:
return
# Resolve the path for the next model
target_path = resolve_checkpoint_path(
realistic_model=next_req.realistic_model
)
# Perform prefetch outside the queue lock
async with self._prefetch_lock:
# Skip if already prefetched or currently prefetching the same path
if target_path == self._current_prefetch_path:
return
# Cancel existing prefetch if it's for a different model
if self._prefetch_task and not self._prefetch_task.done():
self._prefetch_task.cancel()
try:
await self._prefetch_task
except asyncio.CancelledError:
pass
self._current_prefetch_path = target_path
async def prefetch_task():
try:
logger.info("Prefetcher: Starting background load of %s", target_path)
# Load to CPU RAM using the optimized util
sd = await asyncio.to_thread(util.load_torch_file, target_path)
# Store in cache
get_model_cache().set_prefetched_model(target_path, sd)
logger.info("Prefetcher: Successfully pre-loaded %s into RAM", target_path)
except Exception as e:
logger.warning("Prefetcher: Failed to pre-load %s: %s", target_path, e)
finally:
self._current_prefetch_path = None
self._prefetch_task = asyncio.create_task(prefetch_task())
def _signature_for(self, req: GenerateRequest) -> tuple:
# Grouping signature - requests must match these to be batched
# Detect model type to determine if refiner is relevant
from src.Core.Models.ModelFactory import detect_model_type
is_sdxl = (detect_model_type(req.model_path) == "SDXL")
guidance_scale = _effective_guidance_scale(req)
normalized_img2img_image = _normalized_image_key(req.img2img_image)
return (
str(req.model_path), # Model must match
bool(req.realistic_model),
int(req.width),
int(req.height),
int(max(1, req.batch_size)),
bool(req.stable_fast),
bool(req.torch_compile),
bool(req.vae_autotune),
bool(req.fp8_inference),
str(req.weight_quantization),
bool(req.img2img_mode),
normalized_img2img_image,
float(req.img2img_denoise),
str(req.scheduler),
str(req.sampler),
int(req.steps),
float(guidance_scale),
bool(req.enhance_prompt),
bool(req.reuse_seed),
bool(req.enable_preview),
str(req.preview_fidelity),
# Treat multiscale options as batch-level — mixing them may
# change the sampling schedule and therefore cannot be
# safely combined into a single forward pass.
bool(req.enable_multiscale),
bool(req.multiscale_intermittent),
float(req.multiscale_factor),
int(req.multiscale_fullres_start),
int(req.multiscale_fullres_end),
bool(req.cfg_free_enabled),
float(req.cfg_free_start_percent),
bool(req.tome_enabled),
float(req.tome_ratio),
int(req.tome_max_downsample),
bool(req.batched_cfg),
bool(req.dynamic_cfg_rescaling),
str(req.dynamic_cfg_method),
float(req.dynamic_cfg_percentile),
float(req.dynamic_cfg_target_scale),
bool(req.adaptive_noise_enabled),
str(req.adaptive_noise_method),
# VRAM retention flags are also batch level
bool(req.keep_models_loaded),
# ControlNet (must match)
bool(req.controlnet_enabled),
str(req.controlnet_model),
float(req.controlnet_strength),
str(req.controlnet_type),
# Refiner (must match only if it will actually be used)
str(req.refiner_model_path) if is_sdxl else "",
(int(req.refiner_switch_step) if req.refiner_switch_step is not None else -1) if is_sdxl else -1,
# Note: hires_fix and adetailer remain intentionally NOT part of
# this signature because they are executed per-sample.
)
async def _worker(self):
logger.info("Batching worker started; max_batch=%s timeout=%s", LD_MAX_BATCH_SIZE, LD_BATCH_TIMEOUT)
while True:
await self._new_request.wait()
# Small throttle to coalesce multiple arrivals
await asyncio.sleep(0)
async with self._lock:
if not self._pending:
self._new_request.clear()
continue
# Group pending requests by signature
groups: Dict[tuple, List[PendingRequest]] = {}
for p in self._pending:
sig = self._signature_for(p.req)
groups.setdefault(sig, []).append(p)
# Choose the group with the oldest request
chosen_sig = None
oldest_time = float("inf")
for sig, arr in groups.items():
if arr and arr[0].arrival < oldest_time:
chosen_sig = sig
oldest_time = arr[0].arrival
if chosen_sig is None:
self._new_request.clear()
continue
candidates = groups[chosen_sig]
# Sort by arrival time (oldest first)
candidates.sort(key=lambda x: x.arrival)
# Debug: show group sizes for observability
try:
group_summary = {str(sig): len(arr) for sig, arr in groups.items()}
logger.debug("Batch worker: pending groups=%s chosen_sig=%s group_size=%d oldest_arrival=%.3f",
group_summary, str(chosen_sig), len(candidates), candidates[0].arrival if candidates else 0.0)
except Exception:
pass
# Determine whether to wait for coalescing when there's only a
# single candidate. This is controlled by LD_BATCH_WAIT_SINGLETONS
# so operators can toggle the behavior at runtime via env.
if len(candidates) == 1:
age = time.time() - candidates[0].arrival
if LD_BATCH_WAIT_SINGLETONS and age < LD_BATCH_TIMEOUT:
# Old behavior: wait a bit for more arrivals before
# processing a singleton so we can form a larger batch.
logger.debug("Singleton group for signature %s is too new (age=%.3fs < timeout=%.3fs). Sleeping to coalesce.", str(chosen_sig), age, LD_BATCH_TIMEOUT)
self._new_request.clear()
await asyncio.sleep(LD_BATCH_TIMEOUT)
continue
else:
# Eager processing path (default): process singletons
# immediately to avoid perceived "stuck" behavior.
logger.debug("Processing singleton group for signature %s immediately (age=%.3fs). LD_BATCH_WAIT_SINGLETONS=%s",
str(chosen_sig), age, LD_BATCH_WAIT_SINGLETONS)
# Keep ControlNet requests singleton for now. Its image-conditioned
# path has not been made batch-safe in the same way as text2img/img2img.
max_group_size = 1 if candidates[0].req.controlnet_enabled else LD_MAX_BATCH_SIZE
# Pick up to the allowed group size
to_process = candidates[:max_group_size]
# Remove selected items from pending list
for p in to_process:
try:
self._pending.remove(p)
except ValueError:
pass
if not self._pending:
self._new_request.clear()
# Trigger prefetching for the NEXT group while we process this one
await self._look_ahead_and_prefetch(chosen_sig)
# Process the selected group outside the lock
try:
try:
logger.debug("Processing group chosen_sig=%s items=%d request_ids=%s", str(chosen_sig), len(to_process), [p.request_id for p in to_process])
except Exception:
pass
await self._process_group(to_process)
# Update lightweight metrics only on success
try:
now_ts = time.time()
self._batches_processed += 1
self._items_processed += sum(
max(1, p.req.num_images) for p in to_process
)
self._requests_processed += len(to_process)
# Update cumulative wait time per-request
wait_total = sum(now_ts - p.arrival for p in to_process)
self._cumulative_wait_time += wait_total
self._last_batch_ts = now_ts
except Exception:
# Metrics must never crash the worker loop
logger.exception("Failed updating batch metrics")
except Exception as e:
logger.exception("Batch processing failed: %s", e)
async def _process_group(self, items: List[PendingRequest]):
# All items share a signature as enforced by the grouping logic.
if not items:
return
first_req = items[0].req
flat_samples: List[dict[str, Any]] = []
for p in items:
for _ in range(max(1, p.req.num_images)):
flat_samples.append(
{
"request_id": p.request_id,
"filename_prefix": f"LD-REQ-{p.request_id}",
"seed": p.req.seed if (p.req.seed is not None and p.req.seed >= 0) else None,
"hires_fix": bool(p.req.hiresfix),
"adetailer": bool(p.req.adetailer),
"prompt": p.req.prompt,
"negative_prompt": p.req.negative_prompt or "",
}
)
# Prepare pipeline kwargs based on the shared signature (take from first)
# Unique ID for this generation run; sent with every preview message
# so the frontend can discard stale previews from previous runs.
_gen_id = uuid.uuid4().hex[:12]
pipeline_kwargs = dict(
prompt=[],
w=first_req.width,
h=first_req.height,
number=0,
batch=0,
scheduler=first_req.scheduler,
sampler=first_req.sampler,
steps=first_req.steps,
cfg_scale=_effective_guidance_scale(first_req),
enhance_prompt=first_req.enhance_prompt,
img2img=first_req.img2img_mode,
img2img_denoise=first_req.img2img_denoise,
stable_fast=first_req.stable_fast,
reuse_seed=first_req.reuse_seed,
autohdr=True,
realistic_model=first_req.realistic_model,
model_path=first_req.model_path,
refiner_model_path=first_req.refiner_model_path,
refiner_switch_step=first_req.refiner_switch_step,
negative_prompt=[],
multiscale_preset=first_req.multiscale_preset,
enable_multiscale=first_req.enable_multiscale,
multiscale_factor=first_req.multiscale_factor,
multiscale_fullres_start=first_req.multiscale_fullres_start,
multiscale_fullres_end=first_req.multiscale_fullres_end,
multiscale_intermittent_fullres=first_req.multiscale_intermittent,
img2img_image=first_req.img2img_image,
request_filename_prefix=f"LD-REQ-{items[0].request_id}",
per_sample_info=[],
cfg_free_enabled=first_req.cfg_free_enabled,
cfg_free_start_percent=first_req.cfg_free_start_percent,
tome_enabled=first_req.tome_enabled,
tome_ratio=first_req.tome_ratio,
tome_max_downsample=first_req.tome_max_downsample,
# Advanced CFG optimizations (batched_cfg always enabled)
batched_cfg=first_req.batched_cfg,
dynamic_cfg_rescaling=first_req.dynamic_cfg_rescaling,
dynamic_cfg_method=first_req.dynamic_cfg_method,
dynamic_cfg_percentile=first_req.dynamic_cfg_percentile,
dynamic_cfg_target_scale=first_req.dynamic_cfg_target_scale,
adaptive_noise_enabled=first_req.adaptive_noise_enabled,
adaptive_noise_method=first_req.adaptive_noise_method,
# ControlNet
controlnet_model=first_req.controlnet_model if first_req.controlnet_enabled else None,
controlnet_strength=first_req.controlnet_strength,
controlnet_type=first_req.controlnet_type,
# torch.compile
torch_compile=first_req.torch_compile,
vae_autotune=first_req.vae_autotune,
# Weight quantization
weight_quantization=first_req.weight_quantization,
# FP8 inference
fp8_inference=first_req.fp8_inference,
# Add callback for WebSocket preview broadcasting
callback=make_server_callback(first_req.steps, generation_id=_gen_id),
)
# Notify clients that a new generation is starting so they can
# discard stale previews from the previous run.
sync_broadcast_preview(
step=0, total_steps=first_req.steps,
message_type="generation_start",
generation_id=_gen_id,
)
# Toggle preview state for the duration of the pipeline call
prev_preview_state = None
prev_keep_models_loaded = None
prev_preview_settings = None
try:
try:
prev_preview_state = _app_instance.app.previewer_var.get()
_app_instance.app.previewer_var.set(bool(first_req.enable_preview))
except Exception:
prev_preview_state = None
# Apply per-request preview fidelity overrides (format / quality / sRGB)
try:
prev_preview_settings = _apply_preview_fidelity_to_app(first_req)
except Exception:
prev_preview_settings = None
# Respect per-group model cache directive: toggle "keep loaded"
# so the sampling pipeline sees the requested caching behavior.
try:
model_cache = get_model_cache()
prev_keep_models_loaded = model_cache.get_keep_models_loaded()
model_cache.set_keep_models_loaded(bool(first_req.keep_models_loaded))
except Exception:
prev_keep_models_loaded = None
saved_map: Dict[str, List[dict]] = {}
total_images = len(flat_samples)
# Respect ImageSaver.MAX_IMAGES_PER_SAVE and the requested batch size.
# Multi-image runs always execute in deterministic chunks so that
# `batch_size` means "images per sampling pass" and `num_images`
# means "total outputs returned".
try:
from src.FileManaging import ImageSaver as _ImageSaver
_max_save_limit = getattr(_ImageSaver, "MAX_IMAGES_PER_SAVE", LD_MAX_IMAGES_PER_GROUP)
except Exception:
_max_save_limit = LD_MAX_IMAGES_PER_GROUP
max_save_limit = _max_save_limit if _max_save_limit and _max_save_limit > 0 else LD_MAX_IMAGES_PER_GROUP
requested_batch_size = max(1, int(first_req.batch_size))
max_chunk_size = min(requested_batch_size, LD_MAX_IMAGES_PER_GROUP, max_save_limit)
logger.info(
"Processing group of %d request(s) -> %d image(s) with effective batch_size=%d across %d chunk(s)",
len(items),
total_images,
max_chunk_size,
(total_images + max_chunk_size - 1) // max_chunk_size if max_chunk_size > 0 else 0,
)
chunks: list[list[dict[str, Any]]] = [
flat_samples[i : i + max_chunk_size]
for i in range(0, total_images, max_chunk_size)
]
try:
for chunk in chunks:
c_prompts = [entry["prompt"] for entry in chunk]
c_negatives = [entry["negative_prompt"] for entry in chunk]
c_per_sample_info = [
{
"request_id": entry["request_id"],
"filename_prefix": entry["filename_prefix"],
"seed": entry["seed"],
"hires_fix": entry["hires_fix"],
"adetailer": entry["adetailer"],
}
for entry in chunk
]
chunk_kwargs = dict(pipeline_kwargs)
chunk_kwargs["prompt"] = c_prompts
chunk_kwargs["negative_prompt"] = c_negatives
chunk_kwargs["number"] = len(c_prompts)
chunk_kwargs["batch"] = len(c_prompts)
chunk_kwargs["per_sample_info"] = c_per_sample_info
chunk_kwargs["request_filename_prefix"] = c_per_sample_info[0]["filename_prefix"] if c_per_sample_info else None
chunk_start_ts = time.time()
result = await asyncio.to_thread(pipeline, **chunk_kwargs)
if isinstance(result, dict) and "batched_results" in result:
for request_id, entries in result["batched_results"].items():
saved_map.setdefault(request_id, []).extend(entries)
else:
files = _find_images_since(chunk_start_ts)
for f in files:
name = os.path.basename(f)
for entry in chunk:
rid = entry["request_id"]
if f"LD-REQ-{rid}" in name:
saved_map.setdefault(rid, []).append({
"filename": name,
"subfolder": os.path.relpath(os.path.dirname(f), "./output"),
})
except InterruptedError:
logger.info(
"Generation interrupted for request_ids=%s",
[p.request_id for p in items],
)
sync_broadcast_preview(
step=0,
total_steps=first_req.steps,
message_type="error",
generation_id=_gen_id,
)
for p in items:
if not p.future.done():
p.future.set_exception(HTTPException(status_code=409, detail="Generation interrupted"))
return
# For each pending item, collect its images and set future result
for p in items:
imgs = saved_map.get(p.request_id, [])
# Filter and select the first N images requested
selected = imgs[: max(1, p.req.num_images)]
if not selected:
p.future.set_exception(HTTPException(status_code=500, detail="No images produced"))
continue
# Try to use in-memory byte buffer first (avoids disk I/O)
buffered_images = pop_image_bytes(f"LD-REQ-{p.request_id}")
b64_list = []
if buffered_images:
# Use in-memory bytes directly - zero disk reads
for buf_filename, buf_subfolder, png_bytes in buffered_images[:max(1, p.req.num_images)]:
b64_data = base64.b64encode(png_bytes).decode("utf-8")
mime_type = "image/png"
if buf_filename.lower().endswith((".jpg", ".jpeg")):
mime_type = "image/jpeg"
elif buf_filename.lower().endswith(".webp"):
mime_type = "image/webp"
b64_list.append(f"data:{mime_type};base64,{b64_data}")
else:
# Fallback to disk reads
for entry in selected:
if isinstance(entry, list):
# Safeguard against nested lists if any processor still returns them
entry = entry[0] if entry else {}
if not isinstance(entry, dict):
continue
filename = entry.get("filename", "")
path = os.path.join("./output", entry.get("subfolder", ""), filename)
try:
b64_data = _encode_png_to_base64(path)
mime_type = "image/png"
if filename.lower().endswith(".jpg") or filename.lower().endswith(".jpeg"):
mime_type = "image/jpeg"
elif filename.lower().endswith(".webp"):
mime_type = "image/webp"
b64_list.append(f"data:{mime_type};base64,{b64_data}")
except Exception as e:
logger.exception("Failed to read image for request %s: %s", p.request_id, e)
if len(b64_list) == 0:
p.future.set_exception(HTTPException(status_code=500, detail="Failed to read generated images"))
elif len(b64_list) == 1:
p.future.set_result({"image": b64_list[0]})
else:
p.future.set_result({"images": b64_list})
finally:
try:
if prev_preview_settings is not None:
_restore_preview_settings(prev_preview_settings)
except Exception:
pass
try:
if prev_preview_state is not None:
_app_instance.app.previewer_var.set(prev_preview_state)
except Exception:
pass
try:
# Restore previous model cache keep-loaded setting if we
# changed it above.
if prev_keep_models_loaded is not None:
try:
model_cache = get_model_cache()
model_cache.set_keep_models_loaded(bool(prev_keep_models_loaded))
except Exception:
pass
except Exception:
pass
# Instantiate the buffer and start it on startup
_generation_buffer = GenerationBuffer()
@app.on_event("startup")
async def _start_buffer():
await _generation_buffer.start()
@app.get("/health")
def health() -> Dict[str, str]:
return {"status": "ok"}
@app.get("/api/telemetry")
async def telemetry() -> Dict[str, Any]:
"""Return basic server and batching buffer telemetry.
Fields:
- uptime_seconds
- pending_count
- pending_by_signature (human-readable)
- pending_preview (list of small pending request summaries)
- worker_running
- max_batch_size, batch_timeout
- batches_processed, items_processed, last_batch_time
- pipeline_import_ok and pipeline_import_error
"""
rid = uuid.uuid4().hex[:8]
log = logging.LoggerAdapter(logger, {"rid": rid})
log.debug("telemetry requested")
now = time.time()
uptime = now - SERVER_START_TS
# Build a small snapshot of queue state under the buffer lock
async with _generation_buffer._lock:
pending_count = len(_generation_buffer._pending)
# Group pending requests by signature for visibility
sig_counts: Dict[str, int] = {}
pending_preview: List[Dict[str, Any]] = []
for p in _generation_buffer._pending:
try:
sig = _generation_buffer._signature_for(p.req)
sig_key = str(sig)
except Exception:
sig_key = "<unknown>"
sig_counts[sig_key] = sig_counts.get(sig_key, 0) + 1
# Keep preview small to avoid large payloads
preview = {
"request_id": p.request_id,
"waiting_s": round(now - p.arrival, 3),
"prompt_preview": (p.req.prompt[:120] + "…") if (p.req.prompt and len(p.req.prompt) > 120) else (p.req.prompt or ""),
}
pending_preview.append(preview)
batches_processed = _generation_buffer._batches_processed
items_processed = _generation_buffer._items_processed
last_batch_ts = _generation_buffer._last_batch_ts
worker_running = (
_generation_buffer._worker_task is not None
and (not _generation_buffer._worker_task.done())
)
# Compute average wait times
requests_processed = _generation_buffer._requests_processed
cumulative_wait = _generation_buffer._cumulative_wait_time
avg_processed_wait_s = (
(cumulative_wait / requests_processed) if requests_processed > 0 else None
)
# Pending average wait (current queue)
pending_avg_wait_s = (
(sum(now - p.arrival for p in _generation_buffer._pending) / pending_count)
if pending_count > 0
else 0.0
)
# Model cache telemetry (memory and loaded models)
memory_info_error = None
try:
model_cache = get_model_cache()
memory_info = model_cache.get_memory_info()
loaded_raw = model_cache.get_cached_sampling_models()
loaded_models = []
for m in loaded_raw:
try:
name = getattr(m, "name", None) or getattr(m, "__class__", type(m)).__name__
except Exception:
name = str(type(m))
loaded_models.append(name)
loaded_models_count = len(loaded_models)
except Exception as e:
# Don't fail telemetry if model cache query fails. Capture a short
# error string so callers can display a hint without exposing full
# stack traces. Device-side CUDA asserts can leave the device in an
# unusable state and will cause subsequent CUDA queries to fail; we
# surface a concise message here instead of crashing the endpoint.
try:
# Prefer a succinct message
memory_info_error = str(e)
except Exception:
memory_info_error = "unknown"
logger.exception("Failed to fetch model cache telemetry: %s", memory_info_error)
memory_info = None
loaded_models = []
loaded_models_count = 0
return {
"uptime_seconds": round(uptime, 3),
"server_start_ts": SERVER_START_TS,
"pending_count": pending_count,
"pending_by_signature": sig_counts,
"pending_preview": pending_preview[:20],
"worker_running": worker_running,
"max_batch_size": LD_MAX_BATCH_SIZE,
"batch_timeout": LD_BATCH_TIMEOUT,
"max_images_per_group": LD_MAX_IMAGES_PER_GROUP,
"batches_processed": batches_processed,
"items_processed": items_processed,
"requests_processed": requests_processed,
"last_batch_time": last_batch_ts,
"avg_processed_wait_s": avg_processed_wait_s,
"pending_avg_wait_s": pending_avg_wait_s,
"memory_info": memory_info,
"loaded_models_count": loaded_models_count,
"loaded_models": loaded_models,
"pipeline_import_ok": pipeline is not None,
"pipeline_import_error": str(_pipeline_import_error) if _pipeline_import_error is not None else None,
}
# Settings API ------------------------------------------------------------
def _read_settings_preferences() -> Dict[str, bool]:
from src.Core.SettingsStore import get_preferences
return get_preferences()
def _resolve_autotune_preferences(req: GenerateRequest) -> GenerateRequest:
prefs = _read_settings_preferences()
req.torch_compile = bool(prefs["torch_compile"] if req.torch_compile is None else req.torch_compile)
req.vae_autotune = bool(prefs["vae_autotune"] if req.vae_autotune is None else req.vae_autotune)
return req
def _reset_autotune_runtime_state() -> None:
"""Clear runtime model state so changed autotune preferences take effect."""
from src.Core.Pipeline import reset_default_pipeline
from src.Device.Device import clear_compiled_models
from src.Device.ModelCache import clear_model_cache
reset_default_pipeline()
clear_model_cache()
clear_compiled_models()
@app.get("/api/settings/preferences")
async def api_get_settings_preferences():
"""Return persisted server-wide generation preferences."""
try:
return _read_settings_preferences()
except Exception as e:
logger.exception("Failed to read settings preferences: %s", e)
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/settings/preferences")
async def api_post_settings_preferences(body: SettingsPreferencesRequest):
"""Persist server-wide generation preferences and reset runtime caches if needed."""
try:
from src.Core.SettingsStore import set_preferences
current = _read_settings_preferences()
incoming = {
"torch_compile": bool(body.torch_compile),
"vae_autotune": bool(body.vae_autotune),
}
stored = set_preferences(incoming)
if stored != current:
_reset_autotune_runtime_state()
return stored
except Exception as e:
logger.exception("Failed to update settings preferences: %s", e)
raise HTTPException(status_code=500, detail=str(e))
@app.get("/api/settings/last")
async def api_get_last_settings():
"""Return the last persisted seed (or null)."""
try:
from src.Core.SettingsStore import get_last_seed
seed = get_last_seed()
return {"seed": seed}
except Exception as e:
logger.exception("Failed to read last seed: %s", e)
raise HTTPException(status_code=500, detail=str(e))
@app.get("/api/settings/history")
async def api_get_settings_history():
"""Return saved settings history (most-recent-first)."""
try:
from src.Core.SettingsStore import get_history
return {"history": get_history()}
except Exception as e:
logger.exception("Failed to read settings history: %s", e)
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/settings/history")
async def api_post_settings_history(body: Dict[str, Any]):
"""Append a settings snapshot to history.
Body: { settings: GenerationSettings, include_prompt: bool }
By default `include_prompt` is False and prompt/negative_prompt are NOT persisted.
"""
try:
settings = body.get("settings")
if not settings:
raise HTTPException(status_code=400, detail="Missing 'settings' in request body")
include_prompt = bool(body.get("include_prompt", False))
if include_prompt:
stored = dict(settings)
else:
# Default sanitized/parameter-only snapshot for privacy
allowed = ["seed", "steps", "cfg_scale", "sampler", "scheduler", "model_path", "width", "height"]
stored = {k: settings[k] for k in allowed if k in settings}
from src.Core.SettingsStore import append_snapshot
snap = append_snapshot({"settings": stored})
return {"snapshot": snap}
except HTTPException:
raise
except Exception as e:
logger.exception("Failed to append settings history: %s", e)
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/images/metadata")
async def api_post_image_metadata(body: Dict[str, Any]):
"""Extract PNG metadata from a base64/data-URL image payload and return
a normalized metadata dictionary suitable for re-applying to UI settings.
Body: { image: "data:image/png;base64,..." } or { image: "<base64>" }
Returns: { metadata: { seed, steps, cfg_scale, sampler, scheduler, model_path, width, height, prompt?, negative_prompt? } }
"""
try:
image_b64 = body.get("image")
if not image_b64:
raise HTTPException(status_code=400, detail="Missing 'image' in request body")
# Accept data URL or raw base64
b64_data = None
if isinstance(image_b64, str) and image_b64.startswith("data:"):
idx = image_b64.find("base64,")
if idx != -1:
b64_data = image_b64[idx + len("base64,"):]
elif isinstance(image_b64, str):
b64_data = image_b64.strip().replace("\n", "")
if not b64_data:
raise HTTPException(status_code=400, detail="Invalid image payload")
decoded = base64.b64decode(b64_data)
# Parse PNG metadata using PIL
from PIL import Image
img = Image.open(io.BytesIO(decoded))
info = img.info or {}
def _to_int(v):
try:
return int(v)
except Exception:
return None
def _to_float(v):
try:
return float(v)
except Exception:
return None
meta: Dict[str, Any] = {}
if "prompt" in info:
meta["prompt"] = info.get("prompt")
if "negative_prompt" in info:
meta["negative_prompt"] = info.get("negative_prompt")
if "seed" in info:
meta["seed"] = _to_int(info.get("seed"))
if "steps" in info:
meta["steps"] = _to_int(info.get("steps"))
# Context.build_metadata uses key 'cfg' for CFG value — map it to cfg_scale
if "cfg" in info:
meta["cfg_scale"] = _to_float(info.get("cfg"))
if "sampler" in info:
meta["sampler"] = info.get("sampler")
if "scheduler" in info:
meta["scheduler"] = info.get("scheduler")
if "model_path" in info:
meta["model_path"] = info.get("model_path")
if "width" in info:
meta["width"] = _to_int(info.get("width"))
if "height" in info:
meta["height"] = _to_int(info.get("height"))
return {"metadata": meta}
except HTTPException:
raise
except Exception as e:
logger.exception("Failed to decode image metadata: %s", e)
raise HTTPException(status_code=500, detail=str(e))
def _encode_png_to_base64(path: str) -> str:
# Retry a few times in case the file is still being finalized on disk
last_err: Optional[Exception] = None
for attempt in range(20): # up to ~2s total
try:
with open(path, "rb") as f:
data = f.read()
if attempt > 0:
logger.debug("Read image after %d retries: %s", attempt, path)
return base64.b64encode(data).decode("utf-8")
except Exception as e:
last_err = e
time.sleep(0.1)
# One last attempt or raise detailed error
try:
with open(path, "rb") as f:
logger.debug("Final attempt succeeded reading: %s", path)
return base64.b64encode(f.read()).decode("utf-8")
except Exception as e:
logger.error("Failed to read generated image %s: %s", path, e if e else last_err)
raise HTTPException(status_code=500, detail=f"Failed to read generated image: {e if e else last_err}")
def _save_img2img_image_to_file(value: Optional[str], max_size_bytes: int = 10 * 1024 * 1024) -> Optional[str]:
"""Ensure img2img_image is a local file path.
Accepts either:
- an existing filesystem path (returned unchanged),
- a data URL (data:image/...;base64,...) which will be decoded and saved to the system temp directory, or
- a bare base64 string which will be decoded and saved.
Returns the path to the saved file, or None if no value was provided.
Raises HTTPException on invalid data or if the decoded payload exceeds max_size_bytes.
"""
if not value:
return None
# If it's already a file path that exists, return as-is
if os.path.exists(value) and os.path.isfile(value):
return value
# Try to parse as a data URL or bare base64
b64_data = None
try:
if isinstance(value, str) and value.startswith("data:"):
# data:[<mediatype>][;base64],<data>
m = re.match(r"^data:(?P<mime>image/[^;]+);base64,(?P<b64>.+)$", value, flags=re.DOTALL)
if m:
b64_data = m.group("b64")
else:
# Fallback: find 'base64,' and take the rest
idx = value.find("base64,")
if idx != -1:
b64_data = value[idx + len("base64,"):]
else:
# Possibly a raw base64 string; strip whitespace/newlines
s = re.sub(r"\s+", "", str(value))
if len(s) > 100 and re.fullmatch(r"[A-Za-z0-9+/=]+", s):
b64_data = s
if not b64_data:
raise HTTPException(status_code=400, detail="img2img_image must be a file path, a data URL, or a base64-encoded image")
decoded = base64.b64decode(b64_data)
except HTTPException:
raise
except Exception:
raise HTTPException(status_code=400, detail="Invalid base64 data for img2img_image")
# Enforce size limit
if len(decoded) > max_size_bytes:
raise HTTPException(status_code=413, detail=f"img2img_image too large (max {max_size_bytes // 1024} KB)")
# Try to detect format
try:
import imghdr
fmt = imghdr.what(None, decoded)
except Exception:
fmt = None
ext = None
if fmt:
ext = "jpg" if fmt == "jpeg" else fmt
else:
try:
from PIL import Image
img = Image.open(io.BytesIO(decoded))
fmt = img.format.lower() if img.format else "png"
ext = "jpg" if fmt == "jpeg" else fmt
except Exception:
ext = "png"
# Save to system temp directory
tmp_dir = tempfile.gettempdir()
os.makedirs(tmp_dir, exist_ok=True)
fname = f"img2img-{uuid.uuid4().hex[:8]}.{ext}"
path = os.path.join(tmp_dir, fname)
try:
with open(path, "wb") as f:
f.write(decoded)
except Exception as e:
logger.exception("Failed to write img2img upload to %s: %s", path, e)
raise HTTPException(status_code=500, detail="Failed to save img2img_image on server")
# Don't log the incoming base64 content
logger.info("Saved img2img image to %s", path)
return path
def _list_existing_images() -> List[str]:
exts = ["*.png", "*.jpg", "*.jpeg", "*.webp"]
files: List[str] = []
for ext in exts:
files.extend(glob.glob(os.path.join("./output", "**", ext), recursive=True))
logger.debug("Found %d existing images", len(files))
return files
def _find_images_since(start_ts: float) -> List[str]:
"""Return images whose mtime is at or after start_ts (with small grace)."""
grace = 0.25
files = _list_existing_images()
recent = [p for p in files if os.path.getmtime(p) >= (start_ts - grace)]
recent.sort(key=lambda p: os.path.getmtime(p), reverse=True)
logger.debug("%d images modified since %.3f", len(recent), start_ts)
return recent
# WebSocket preview endpoint for real-time streaming
_preview_clients: List[WebSocket] = []
_main_event_loop: Optional[asyncio.AbstractEventLoop] = None
def sync_broadcast_preview(
step: int,
total_steps: int,
images: Optional[List[str]] = None,
message_type: str = "preview",
generation_id: Optional[str] = None,
):
"""Synchronous wrapper to broadcast preview from pipeline thread.
This function can be called from the pipeline callback running in a
thread pool executor. It schedules the async broadcast on the main
event loop.
"""
global _main_event_loop
if not _preview_clients:
if step % 10 == 0:
logger.debug("No preview clients connected, skipping broadcast")
return
if _main_event_loop is None:
logger.error("Main event loop is None! Cannot broadcast preview.")
return
try:
if step % 5 == 0 or step == total_steps - 1:
logger.info(f"Broadcasting preview step {step}/{total_steps}")
future = asyncio.run_coroutine_threadsafe(
broadcast_preview(step, total_steps, images, message_type, generation_id=generation_id),
_main_event_loop
)
# Wait for broadcast to complete to ensure ordering
try:
future.result(timeout=0.5)
except Exception:
pass # Don't block generation on slow clients
except Exception as e:
logger.error(f"Preview broadcast failed: {e}")
pass # Don't let preview errors affect generation
def _apply_preview_fidelity_to_app(req):
"""Apply preview fidelity settings from a GenerateRequest into the global app.
Returns a dict with previous settings so callers can restore them later.
"""
prev = {}
try:
# Only apply fidelity changes if previewing is enabled for this request.
if not getattr(req, "enable_preview", False):
return None
prev["preview_srgb"] = getattr(_app_instance.app, "preview_srgb", True)
prev["preview_format"] = getattr(_app_instance.app, "preview_format", "WEBP")
prev["preview_quality"] = getattr(_app_instance.app, "preview_quality", 90)
prev["preview_resample"] = getattr(_app_instance.app, "preview_resample", "LANCZOS")
prev["preview_apply_fast_autohdr"] = getattr(_app_instance.app, "preview_apply_fast_autohdr", False)
pfid = getattr(req, "preview_fidelity", "balanced") or "balanced"
# Map to a few conservative presets
if pfid == "low":
_app_instance.app.preview_srgb = True
_app_instance.app.preview_format = "WEBP"
_app_instance.app.preview_quality = 70
elif pfid == "high":
_app_instance.app.preview_srgb = True
_app_instance.app.preview_format = "PNG"
_app_instance.app.preview_quality = 100
else:
# balanced
_app_instance.app.preview_srgb = True
_app_instance.app.preview_format = "WEBP"
_app_instance.app.preview_quality = 90
return prev
except Exception:
return None
def _restore_preview_settings(prev):
if not prev:
return
try:
_app_instance.app.preview_srgb = prev.get("preview_srgb", True)
_app_instance.app.preview_format = prev.get("preview_format", "WEBP")
_app_instance.app.preview_quality = prev.get("preview_quality", 90)
_app_instance.app.preview_resample = prev.get("preview_resample", "LANCZOS")
_app_instance.app.preview_apply_fast_autohdr = prev.get("preview_apply_fast_autohdr", False)
except Exception:
pass
def make_server_callback(total_steps: int, generation_id: Optional[str] = None):
"""Create a pipeline callback that broadcasts progress via WebSocket.
Args:
total_steps: Total number of sampling steps
generation_id: Unique ID for this generation run, sent with every
preview message so the frontend can ignore stale previews.
Returns:
Callback function compatible with pipeline
"""
def callback(args):
# Extract step info from args dict
step = args.get("i", 0)
curr_total_steps = args.get("total_steps", total_steps)
# Only process images on broadcast steps to save compute
# Broadcast every 5 steps or last step
is_broadcast_step = (step % 5 == 0) or (step == curr_total_steps - 1)
images_b64 = None
if is_broadcast_step:
try:
# prefer denoised, fallback to x ONLY if early step
latents_tensor = args.get("denoised")
if latents_tensor is None and step < 5:
latents_tensor = args.get("x")
if latents_tensor is not None:
# Detect flux from shape (Flux has 16 or 32 channels)
# This is a heuristic, ideal would be to pass it in args
is_flux = (latents_tensor.shape[1] == 16 or latents_tensor.shape[1] == 32)
pil_images = decode_latents_to_images(latents_tensor, flux=is_flux)
images_b64 = []
for img in pil_images:
buffered = io.BytesIO()
fmt = getattr(_app_instance.app, "preview_format", "WEBP")
q = getattr(_app_instance.app, "preview_quality", 90)
try:
img.save(buffered, format=fmt, quality=q)
mime = f"image/{fmt.lower()}"
except Exception:
# Fallback to JPEG if preferred format is unsupported
buffered = io.BytesIO()
img.save(buffered, format="JPEG", quality=max(70, q))
mime = "image/jpeg"
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
images_b64.append(f"data:{mime};base64,{img_str}")
except Exception as e:
logger.error(f"Preview generation failed: {e}")
pass
# Broadcast progress update with images
sync_broadcast_preview(step, curr_total_steps, images=images_b64, message_type="preview" if images_b64 else "progress", generation_id=generation_id)
return callback
@app.websocket("/ws/preview")
async def websocket_preview(websocket: WebSocket):
"""WebSocket endpoint for real-time preview streaming.
Clients receive JSON messages with:
- type: "preview" | "progress" | "complete" | "error"
- step: Current step number
- total_steps: Total number of steps
- timestamp: Unix timestamp
- images: List of base64 encoded preview images (for "preview" type)
"""
await websocket.accept()
_preview_clients.append(websocket)
logger.info("WebSocket client connected to /ws/preview (total: %d)", len(_preview_clients))
try:
# Keep connection alive and listen for close
while True:
try:
# Wait for any message (ping/pong or close)
data = await asyncio.wait_for(websocket.receive_text(), timeout=30.0)
# Echo back to confirm alive
await websocket.send_json({"type": "pong", "timestamp": time.time()})
except asyncio.TimeoutError:
# Send ping to keep connection alive
try:
await websocket.send_json({"type": "ping", "timestamp": time.time()})
except Exception:
break
except WebSocketDisconnect:
pass
except Exception as e:
logger.debug("WebSocket connection error: %s", e)
finally:
if websocket in _preview_clients:
_preview_clients.remove(websocket)
logger.info("WebSocket client disconnected (remaining: %d)", len(_preview_clients))
async def broadcast_preview(
step: int,
total_steps: int,
images: Optional[List[str]] = None,
message_type: str = "preview",
generation_id: Optional[str] = None,
):
"""Broadcast preview update to all connected WebSocket clients.
Args:
step: Current step number
total_steps: Total number of steps
images: Optional list of base64-encoded images
message_type: Type of message (preview, progress, complete, error)
generation_id: Unique ID for this generation run
"""
if not _preview_clients:
return
payload = {
"type": message_type,
"step": step,
"total_steps": total_steps,
"timestamp": time.time(),
}
if generation_id:
payload["generation_id"] = generation_id
if images:
payload["images"] = images
# Send to all clients, removing any that fail
disconnected = []
for client in _preview_clients:
try:
await client.send_json(payload)
except Exception:
disconnected.append(client)
for client in disconnected:
if client in _preview_clients:
_preview_clients.remove(client)
@app.post("/api/generate")
async def generate(req: GenerateRequest) -> Dict[str, Any]:
rid = uuid.uuid4().hex[:8]
log = logging.LoggerAdapter(logger, {"rid": rid})
log.info("/api/generate called")
# Validate pipeline import
global pipeline, _pipeline_import_error
if pipeline is None:
log.error("Pipeline import error: %s", _pipeline_import_error)
raise HTTPException(status_code=500, detail=f"Pipeline import error: {_pipeline_import_error}")
# Optionally honor requested seed by persisting it in SettingsStore and enabling reuse
reuse_seed = req.reuse_seed
if req.seed is not None and req.seed >= 0:
try:
from src.Core.SettingsStore import set_last_seed
set_last_seed(int(req.seed))
except Exception:
logger.exception("Failed to persist last seed to SettingsStore")
reuse_seed = True
req = _resolve_autotune_preferences(req)
# For buffered execution we pass request data into the queue; the
# background worker will control how the prompt and img2img path are
# consumed when invoking the pipeline.
# Log request summary (avoid dumping huge strings)
def _truncate(s: Optional[str], n: int = 200) -> str:
if not s:
return ""
return s if len(s) <= n else s[:n] + "…"
log.debug(
"Request: w=%s h=%s num_images=%s batch=%s scheduler=%s sampler=%s steps=%s hiresfix=%s adetailer=%s enhance=%s img2img=%s stable_fast=%s torch_compile=%s vae_autotune=%s reuse_seed=%s realistic=%s multiscale=%s intermittent=%s factor=%s fullres=[%s,%s] keep_models_loaded=%s enable_preview=%s prompt='%s' neg='%s' img2img_image_present=%s",
req.width,
req.height,
req.num_images,
req.batch_size,
req.scheduler,
req.sampler,
req.steps,
req.hiresfix,
req.adetailer,
req.enhance_prompt,
req.img2img_mode,
req.stable_fast,
req.torch_compile,
req.vae_autotune,
reuse_seed,
req.realistic_model,
req.enable_multiscale,
req.multiscale_intermittent,
req.multiscale_factor,
req.multiscale_fullres_start,
req.multiscale_fullres_end,
req.keep_models_loaded,
req.enable_preview,
_truncate(req.prompt, 200),
_truncate(req.negative_prompt or "", 200),
bool(req.img2img_image),
)
# If client provided an img2img image as a data URL or raw base64, decode and save
if req.img2img_image:
try:
saved_path = _save_img2img_image_to_file(req.img2img_image, max_size_bytes=10 * 1024 * 1024)
if saved_path and saved_path != req.img2img_image:
log.info("Img2Img upload received and written to %s", saved_path)
req.img2img_image = saved_path
except HTTPException:
# Propagate well-formed HTTP exceptions (bad payloads, too large, etc.)
raise
except Exception as e:
log.exception("Failed processing img2img_image: %s", e)
# Avoid echoing the raw base64 content into logs or responses
raise HTTPException(status_code=400, detail="Invalid img2img_image payload")
# Enqueue the request for batched processing. The background worker will
# perform the actual pipeline invocation and will restore any preview
# state toggles after generation completes.
# Enqueue the request for batched processing. The background worker will
# perform the actual pipeline invocation and will restore any preview
# state toggles after generation completes.
pending = PendingRequest(req, rid)
result = await _generation_buffer.enqueue(pending)
# Return the result produced by the background worker (dict with
# either 'image' or 'images').
return result
# Background worker will have returned the final result for this request.
@app.get("/api/models")
async def list_models() -> List[Dict[str, Any]]:
"""List available models with type detection and capabilities."""
try:
from src.Core.Models.ModelFactory import list_available_models, detect_model_type, create_model
models = list_available_models(return_mapping=True)
results = []
for name, path in models:
try:
# We create a temporary instance to get capabilities without full loading
# detect_model_type is fast
mtype = detect_model_type(path)
# Get capabilities from the model class
# ModelFactory.create_model returns an uninitialized instance
model_instance = create_model(model_path=path, model_type=mtype)
caps = model_instance.capabilities
# Convert capabilities dataclass to dict
cap_dict = {
"supports_hires_fix": caps.supports_hires_fix,
"supports_img2img": caps.supports_img2img,
"supports_controlnet": caps.supports_controlnet,
"supports_inpainting": caps.supports_inpainting,
"supports_stable_fast": caps.supports_stable_fast,
"supports_deepcache": caps.supports_deepcache,
"supports_tome": caps.supports_tome,
"preferred_resolution": caps.preferred_resolution,
}
results.append({
"name": name,
"path": path,
"type": mtype,
"capabilities": cap_dict
})
except Exception as e:
logger.warning(f"Failed to detect type/caps for {name}: {e}")
results.append({
"name": name,
"path": path,
"type": "SD15",
"capabilities": {}
})
return results
except Exception as e:
logger.error(f"Failed to list models: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/api/interrupt")
async def interrupt_generation():
"""Interrupt current generation."""
# Logic to interrupt generation
# We need to signal the pipeline to stop
# The pipeline checks app_instance.app.interrupt_flag
if _app_instance and hasattr(_app_instance, "app") and _app_instance.app:
_app_instance.app.request_interrupt()
logger.info("Interrupt requested via API")
return {"status": "interrupted"}
else:
logger.error("Cannot interrupt: app_instance not available")
raise HTTPException(status_code=503, detail="App instance not available")
# Mount frontend if build exists
frontend_dist = os.path.join(os.path.dirname(__file__), "frontend", "dist")
if os.path.exists(frontend_dist):
app.mount("/", StaticFiles(directory=frontend_dist, html=True), name="frontend")
logger.info(f"Serving frontend from {frontend_dist}")
else:
logger.warning(f"Frontend build not found at {frontend_dist}. Run 'npm run build' in frontend directory.")
if __name__ == "__main__":
import uvicorn
import argparse
import subprocess
import signal
parser = argparse.ArgumentParser(description="LightDiffusion Server")
try:
default_port = int(os.environ.get("PORT") or os.environ.get("UVICORN_PORT") or 7861)
except Exception:
default_port = 7861
parser.add_argument("--host", type=str, default=os.environ.get("HOST", "0.0.0.0"), help="Host to bind to")
parser.add_argument("--port", type=int, default=default_port, help="Port to bind to")
parser.add_argument("--frontend", action="store_true", help="Launch the frontend development server")
args = parser.parse_args()
frontend_proc = None
if args.frontend:
frontend_dir = os.path.join(os.path.dirname(__file__), "frontend")
if os.path.exists(frontend_dir):
logger.info("Launching frontend development server...")
try:
# Use shell=True for windows to find npm
frontend_proc = subprocess.Popen(
["npm", "run", "dev"],
cwd=frontend_dir,
shell=True
)
logger.info("Frontend development server launched")
except Exception as e:
logger.error(f"Failed to launch frontend: {e}")
else:
logger.warning(f"Frontend directory not found at {frontend_dir}")
# Present helpful URL(s) to the user before starting uvicorn
try:
if args.host in ("0.0.0.0", "::", ""):
try:
import socket
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
s.connect(("8.8.8.8", 80))
host_ip = s.getsockname()[0]
s.close()
except Exception:
host_ip = "127.0.0.1"
logger.info("Open the UI in a browser: http://localhost:%d/ (or on your network: http://%s:%d/)", args.port, host_ip, args.port)
else:
logger.info("Open the UI in a browser: http://%s:%d/", args.host, args.port)
uvicorn.run("server:app", host=args.host, port=args.port, reload=False, ws="websockets")
finally:
if frontend_proc:
logger.info("Shutting down frontend development server...")
if sys.platform == "win32":
# On Windows, we need to kill the process tree because shell=True creates a cmd.exe wrapper
subprocess.run(["taskkill", "/F", "/T", "/PID", str(frontend_proc.pid)], capture_output=True)
else:
frontend_proc.terminate()
logger.info("Frontend development server shut down")
|