CSSSC / server.py
AdarshJi's picture
Update server.py
a108652 verified
# perchance_server_with_pyvirtualdisplay.py
"""
Perchance Image-Generation Server v2.0
This variant adds optional pyvirtualdisplay support so the server can be
hosted on headless environments (Hugging Face Spaces etc.) while keeping
all original behaviour unchanged.
Behaviour:
- If ZD_HEADLESS is True, zendriver will run headless as before.
- If ZD_HEADLESS is False and USE_VIRTUAL_DISPLAY is True and a DISPLAY
is not present, we attempt to start a pyvirtualdisplay.Display (Xvfb)
automatically before launching browsers. If pyvirtualdisplay is not
installed or starting Xvfb fails, we log a warning and continue.
- If USE_VIRTUAL_DISPLAY is False we will NOT attempt to start a virtual
display β€” you must provide a DISPLAY yourself (or set ZD_HEADLESS=True)
if running on a headless host.
To run on Hugging Face Spaces, add `pyvirtualdisplay` to requirements.txt
and ensure `xvfb` is available in the runtime (HF Spaces typically provide it).
All original defaults and constants are preserved from the original file.
"""
import asyncio
import base64
import json
import logging
import os
import random
import string
import time
import uuid
from concurrent.futures import ThreadPoolExecutor
from contextlib import asynccontextmanager
from datetime import datetime
from functools import partial
from pathlib import Path
from typing import Any, Dict, List, Optional
import cloudscraper
from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse
from sse_starlette.sse import EventSourceResponse
import zendriver as zd
from zendriver import cdp
# Try to import pyvirtualdisplay (optional)
try:
from pyvirtualdisplay import Display
_HAS_PYVIRTUALDISPLAY = True
except Exception:
Display = None
_HAS_PYVIRTUALDISPLAY = False
# ═══════════════════════════════════════════════════════════════
# CONFIGURATION
# ═══════════════════════════════════════════════════════════════
# --- Perchance API ---
BASE_URL = "https://image-generation.perchance.org"
API_GENERATE = "/api/generate"
API_DOWNLOAD = "/api/downloadTemporaryImage"
API_AWAIT = "/api/awaitExistingGenerationRequest"
API_ACCESS_CODE = "/api/getAccessCodeForAdPoweredStuff"
# --- Browser automation (zendriver) ---
TARGET_URL = "https://perchance.org/ai-text-to-image-generator"
IMAGE_GEN_ORIGIN = "https://image-generation.perchance.org"
ZD_TIMEOUT = 90 # seconds for key-fetch attempt
ZD_HEADLESS = False # True β†’ hide browser window
CLICK_INTERVAL = 0.35
CLICK_JITTER = 8.0
KEY_PREFIX = "userKey"
# --- Virtual display toggle (new) ---
# If True the server will attempt to auto-start a pyvirtualdisplay (Xvfb)
# when no DISPLAY is present and ZD_HEADLESS is False. If False the server
# will not try to start a virtual display and you must provide a DISPLAY
# or set ZD_HEADLESS=True.
USE_VIRTUAL_DISPLAY = True
# --- HTTP / generation ---
HTTP_TIMEOUT = 30
MAX_DOWNLOAD_WAIT = 180
BACKOFF_INIT = 0.7
MAX_GEN_RETRIES = 6 # retries inside generate_one()
# --- Key-refresh policy ---
MAX_KEY_RETRIES = 3 # per-image retries when key is invalid
KEY_REFRESH_COOLDOWN = 30 # min seconds between two refreshes
MAX_REFRESH_FAILURES = 5 # consecutive failures β†’ stop auto-refresh
# --- Server ---
WORKER_COUNT = 3
MAX_QUEUE_SIZE = 1000
EXECUTOR_THREADS = 16
OUTPUT_DIR = Path("outputs")
OUTPUT_DIR.mkdir(exist_ok=True, parents=True)
# ═══════════════════════════════════════════════════════════════
# LOGGING
# ═══════════════════════════════════════════════════════════════
LOG_FMT = "%(asctime)s | %(levelname)-7s | %(message)s"
logging.basicConfig(level=logging.INFO, format=LOG_FMT)
log = logging.getLogger("perchance")
# ═══════════════════════════════════════════════════════════════
# GLOBAL STATE
# ═══════════════════════════════════════════════════════════════
USER_KEY: Optional[str] = None
# -- set in lifespan --
_key_lock: Optional[asyncio.Lock] = None # guard USER_KEY reads/writes
_key_valid: Optional[asyncio.Event] = None # cleared while refreshing
_key_refresh_lock: Optional[asyncio.Lock] = None # one refresh at a time
_key_last_ts: float = 0.0 # last successful refresh
_key_fail_count: int = 0 # consecutive refresh failures
JOB_QUEUE: Optional[asyncio.Queue] = None
TASKS: Dict[str, Dict[str, Any]] = {}
TASK_QUEUES: Dict[str, asyncio.Queue] = {} # SSE event queues
EXECUTOR = ThreadPoolExecutor(max_workers=EXECUTOR_THREADS)
SCRAPER = cloudscraper.create_scraper()
# pyvirtualdisplay handle (optional)
VDISPLAY: Optional[Display] = None
# ═══════════════════════════════════════════════════════════════
# SMALL HELPERS
# ═══════════════════════════════════════════════════════════════
def _safe(s: str) -> str:
"""Sanitise string for filenames."""
ok = set(string.ascii_letters + string.digits + "-_.()")
return "".join(c if c in ok else "_" for c in s)[:120]
def _sid() -> str:
return "".join(random.choices(string.ascii_lowercase + string.digits, k=8))
def _now() -> str:
return datetime.utcnow().isoformat(timespec="milliseconds") + "Z"
def _reqid() -> str:
return f"{time.time():.6f}-{_sid()}"
def _stamp() -> str:
return datetime.utcnow().strftime("%Y%m%dT%H%M%S")
# ═══════════════════════════════════════════════════════════════
# Virtual display helpers (pyvirtualdisplay)
# ═══════════════════════════════════════════════════════════════
def _start_virtual_display_if_needed(headless: bool):
"""
Start pyvirtualdisplay.Display() if we're running non-headless in an
environment without DISPLAY and USE_VIRTUAL_DISPLAY is True.
This function is synchronous and safe to be run in a thread executor.
"""
global VDISPLAY
if headless:
log.info("ZD_HEADLESS=True β†’ not starting virtual display")
return
if not USE_VIRTUAL_DISPLAY:
log.info("USE_VIRTUAL_DISPLAY=False β†’ not starting virtual display; expecting manual DISPLAY or headless mode.")
return
if os.environ.get("DISPLAY"):
log.info("DISPLAY already set: %s", os.environ.get("DISPLAY"))
return
if not _HAS_PYVIRTUALDISPLAY or Display is None:
log.warning(
"pyvirtualdisplay not installed β€” cannot create virtual DISPLAY. "
"Install pyvirtualdisplay in your environment to enable Xvfb.")
return
try:
VDISPLAY = Display(visible=0, size=(1280, 720))
VDISPLAY.start()
# pyvirtualdisplay sets DISPLAY env itself; log for visibility
log.info("Started virtual display via pyvirtualdisplay (DISPLAY=%s)", os.environ.get("DISPLAY"))
except Exception as exc:
VDISPLAY = None
log.exception("Failed to start virtual display: %s", exc)
def _stop_virtual_display_if_needed():
global VDISPLAY
if VDISPLAY is None:
return
try:
VDISPLAY.stop()
log.info("Stopped virtual display")
except Exception:
log.exception("Error while stopping virtual display")
finally:
VDISPLAY = None
# ═══════════════════════════════════════════════════════════════
# PERCHANCE HTTP CLIENT (blocking – runs in ThreadPoolExecutor)
# ═══════════════════════════════════════════════════════════════
class PerchanceClient:
"""All blocking HTTP work against the Perchance API."""
def __init__(self):
self.base = BASE_URL.rstrip("/")
self.s = SCRAPER
self.h = {
"Accept": "*/*",
"Content-Type": "application/json;charset=UTF-8",
"Origin": IMAGE_GEN_ORIGIN,
"Referer": f"{IMAGE_GEN_ORIGIN}/embed",
"User-Agent": (
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
"AppleWebKit/537.36 (KHTML, like Gecko) "
"Chrome/131.0.0.0 Safari/537.36"
),
}
# ---- low-level helpers ----
def get_ad_code(self) -> str:
try:
r = self.s.get(
f"{self.base}{API_ACCESS_CODE}",
timeout=HTTP_TIMEOUT, headers=self.h,
)
r.raise_for_status()
return r.text.strip()
except Exception:
return ""
def _post(self, body: dict, params: dict) -> dict:
try:
r = self.s.post(
f"{self.base}{API_GENERATE}",
json=body, params=params,
timeout=HTTP_TIMEOUT, headers=self.h,
)
r.raise_for_status()
try:
return r.json()
except Exception:
return {"status": "invalid_json", "raw": r.text}
except Exception as exc:
return {"status": "fetch_failure", "error": str(exc)}
def _await_prev(self, key: str):
try:
self.s.get(
f"{self.base}{API_AWAIT}",
params={"userKey": key, "__cacheBust": random.random()},
timeout=20, headers=self.h,
)
except Exception:
pass
# ---- generate one image ----
def generate_one(
self, *,
prompt: str,
negative_prompt: str = "",
seed: int = -1,
resolution: str = "512x768",
guidance_scale: float = 7.0,
channel: str = "ai-text-to-image-generator",
sub_channel: str = "private",
user_key: str = "",
ad_access_code: str = "",
request_id: str = "",
) -> dict:
"""
Returns ONE of:
{"imageId": ..., "seed": ...}
{"inline": ..., "seed": ...}
{"error": "invalid_key"} ← caller must refresh key
{"error": "<other>", ...}
"""
request_id = request_id or _reqid()
params = {
"userKey": user_key,
"requestId": request_id,
"adAccessCode": ad_access_code,
"__cacheBust": random.random(),
}
body = {
"prompt": prompt,
"negativePrompt": negative_prompt,
"seed": seed,
"resolution": resolution,
"guidanceScale": guidance_scale,
"channel": channel,
"subChannel": sub_channel,
"userKey": user_key,
"adAccessCode": ad_access_code,
"requestId": request_id,
}
ad_refreshed = False
for att in range(1, MAX_GEN_RETRIES + 1):
res = self._post(body, params)
st = res.get("status")
# ---- success ----
if st == "success":
iid = res.get("imageId")
urls = res.get("imageDataUrls")
if iid:
log.info("Got imageId: %s", iid)
return {"imageId": iid, "seed": res.get("seed")}
if urls:
return {"inline": urls[0], "seed": res.get("seed")}
log.error("success but empty payload: %s", str(res)[:300])
return {"error": "empty_success", "raw": res}
# ---- invalid key β†’ return immediately (do NOT retry here) ----
if st == "invalid_key":
log.warning("Server says invalid_key")
return {"error": "invalid_key"}
# ---- previous request still running ----
if st == "waiting_for_prev_request_to_finish":
log.info("Waiting for prev request to finish …")
self._await_prev(user_key)
time.sleep(0.3 + random.random() * 0.3)
continue
# ---- ad access code expired ----
if st == "invalid_ad_access_code" and not ad_refreshed:
code = self.get_ad_code()
if code:
ad_access_code = code
params["adAccessCode"] = code
body["adAccessCode"] = code
ad_refreshed = True
log.info("Refreshed ad code β†’ retry")
time.sleep(0.8)
continue
return {"error": "invalid_ad_access_code"}
# ---- transient gen failure ----
if st == "gen_failure" and res.get("type") == 1:
log.warning("gen_failure type 1 β†’ retry after 2.5 s")
time.sleep(2.5)
continue
# ---- network / stale ----
if st in (None, "fetch_failure", "invalid_json", "stale_request"):
log.info("Transient error (status=%s) attempt %d/%d", st, att, MAX_GEN_RETRIES)
time.sleep(1.0)
continue
# ---- anything else ----
log.error("Unhandled status '%s': %s", st, str(res)[:300])
return {"error": f"unhandled_{st}", "raw": res}
return {"error": "max_retries_exceeded"}
# ---- download ----
def download_image(self, image_id: str, prefix: str = "img") -> str:
"""Poll until the image is ready, save to OUTPUT_DIR, return path."""
url = f"{self.base}{API_DOWNLOAD}?imageId={image_id}"
t0 = time.time()
bk = BACKOFF_INIT
while True:
elapsed = time.time() - t0
if elapsed >= MAX_DOWNLOAD_WAIT:
raise TimeoutError(
f"Download timed out ({elapsed:.0f}s) for {image_id}"
)
try:
r = self.s.get(url, timeout=HTTP_TIMEOUT,
headers=self.h, stream=True)
if r.status_code == 200:
ct = r.headers.get("Content-Type", "")
ext = (
".png" if "png" in ct else
".webp" if "webp" in ct else ".jpg"
)
fn = _safe(f"{prefix}_{image_id[:12]}{ext}")
fp = str(OUTPUT_DIR / fn)
with open(fp, "wb") as f:
for chunk in r.iter_content(8192):
if chunk:
f.write(chunk)
log.info("Saved β†’ %s", fp)
return fp
except Exception:
pass
time.sleep(bk)
bk = min(bk * 1.8, 8.0)
CLIENT = PerchanceClient()
# ═══════════════════════════════════════════════════════════════
# ZENDRIVER – browser automation to extract userKey
# ═══════════════════════════════════════════════════════════════
async def _cdp_mouse(tab, typ, x, y, **kw):
await tab.send(
cdp.input_.dispatch_mouse_event(
type_=typ, x=float(x), y=float(y), **kw,
)
)
async def _viewport_center(tab):
try:
v = await tab.evaluate(
"(()=>({w:innerWidth,h:innerHeight}))()",
await_promise=False, return_by_value=True,
)
return (v["w"] / 2.0, v["h"] / 2.0)
except Exception:
return (600.0, 400.0)
async def _ls_get(tab, key):
try:
return await tab.evaluate(
f"localStorage&&localStorage.getItem({json.dumps(key)})",
await_promise=True, return_by_value=True,
)
except Exception:
return None
async def _clicker_loop(tab, stop: asyncio.Event):
"""Simulate steady centre-clicks on *tab* until *stop* is set."""
try:
await tab.evaluate(
"window.focus&&window.focus()",
await_promise=False, return_by_value=False,
)
except Exception:
pass
centre = await _viewport_center(tab)
centre_upd = time.time()
while not stop.is_set():
if time.time() - centre_upd > 2.5:
centre = await _viewport_center(tab)
centre_upd = time.time()
jx = random.uniform(-CLICK_JITTER, CLICK_JITTER)
jy = random.uniform(-CLICK_JITTER, CLICK_JITTER)
cx, cy = centre[0] + jx, centre[1] + jy
try:
await _cdp_mouse(tab, "mouseMoved", cx, cy, pointer_type="mouse")
await asyncio.sleep(random.uniform(0.02, 0.08))
await _cdp_mouse(
tab, "mousePressed", cx, cy,
button=cdp.input_.MouseButton.LEFT,
click_count=1, buttons=1,
)
await asyncio.sleep(random.uniform(0.03, 0.12))
await _cdp_mouse(
tab, "mouseReleased", cx, cy,
button=cdp.input_.MouseButton.LEFT,
click_count=1, buttons=0,
)
except Exception:
pass
# interruptible sleep
try:
await asyncio.wait_for(
stop.wait(),
timeout=CLICK_INTERVAL * random.uniform(0.85, 1.15),
)
break
except asyncio.TimeoutError:
pass
async def _poll_for_key(tab, stop: asyncio.Event, max_sec: int):
"""Poll localStorage every 250 ms for a userKey entry."""
t0 = time.time()
while not stop.is_set() and (time.time() - t0) < max_sec:
val = await _ls_get(tab, f"{KEY_PREFIX}-0")
if val:
return val
try:
keys = await tab.evaluate(
"Object.keys(localStorage||{}).filter(k=>k.includes('userKey'))",
await_promise=False, return_by_value=True,
)
for k in (keys or []):
v = await _ls_get(tab, k)
if v:
return v
except Exception:
pass
await asyncio.sleep(0.25)
return None
async def fetch_key_via_browser(
timeout: int = ZD_TIMEOUT,
headless: bool = ZD_HEADLESS,
) -> Optional[str]:
"""
Launch Chrome β†’ navigate to Perchance β†’ click to trigger
ad/verification β†’ read userKey from localStorage β†’ close browser.
Returns the key string or None.
"""
log.info(
"Launching browser for userKey (timeout=%ds, headless=%s)",
timeout, headless,
)
# If we're in non-headless mode on a display-less host, ensure a virtual
# DISPLAY is started first. This call is synchronous so we run it in the
# event loop's default executor when called from async code.
loop = asyncio.get_running_loop()
try:
await loop.run_in_executor(None, partial(_start_virtual_display_if_needed, headless))
except Exception:
log.exception("Error while attempting to start virtual display")
try:
browser = await zd.start(headless=headless)
except Exception as exc:
log.exception("Browser start failed: %s", exc)
return None
stop = asyncio.Event()
result = None
try:
page_tab = await browser.get(TARGET_URL)
log.info("Opened %s", TARGET_URL)
await asyncio.sleep(2.0)
origin_tab = await browser.get(IMAGE_GEN_ORIGIN, new_tab=True)
log.info("Opened %s", IMAGE_GEN_ORIGIN)
await asyncio.sleep(1.0)
await page_tab.bring_to_front()
await asyncio.sleep(0.5)
clicker = asyncio.create_task(_clicker_loop(page_tab, stop))
poller = asyncio.create_task(_poll_for_key(origin_tab, stop, timeout))
try:
done, _ = await asyncio.wait({poller}, timeout=timeout)
if poller in done:
result = poller.result()
finally:
stop.set()
if not clicker.done():
clicker.cancel()
try:
await clicker
except asyncio.CancelledError:
pass
for t in (origin_tab, page_tab):
try:
await t.close()
except Exception:
pass
finally:
try:
await browser.stop()
except Exception:
pass
if result:
log.info("Fetched userKey (len=%d)", len(result))
else:
log.warning("Could not fetch userKey within %ds", timeout)
return result
# ═══════════════════════════════════════════════════════════════
# KEY MANAGEMENT – coordinated refresh across workers
# ═══════════════════════════════════════════════════════════════
async def _broadcast(event: dict):
"""Push an event into every active task's SSE queue."""
for tid, q in TASK_QUEUES.items():
task = TASKS.get(tid)
if task and task["status"] in ("queued", "running"):
try:
q.put_nowait(event)
except asyncio.QueueFull:
pass
async def refresh_user_key() -> Optional[str]:
"""
Coordinate a single key refresh. If another coroutine is already
refreshing, we simply wait for it to finish and return the new key.
Returns the new key string, or None on failure.
"""
global USER_KEY, _key_last_ts, _key_fail_count
async with _key_refresh_lock:
# ── double-check: maybe another coroutine just refreshed ──
age = time.time() - _key_last_ts
if age < KEY_REFRESH_COOLDOWN and USER_KEY:
log.info(
"Key was refreshed %.1fs ago β†’ reusing existing key", age,
)
return USER_KEY
# ── too many consecutive failures? ──
if _key_fail_count >= MAX_REFRESH_FAILURES:
log.error(
"Key refresh disabled: %d consecutive failures. "
"Set key manually via POST /set_user_key",
_key_fail_count,
)
await _broadcast({
"type": "key_refresh_failed",
"time": _now(),
"message": (
f"Auto-refresh disabled after {_key_fail_count} failures. "
"Please set userKey manually via /set_user_key"
),
})
return None
# ── signal "key is being refreshed" ──
_key_valid.clear()
log.info("Starting userKey refresh via browser …")
await _broadcast({
"type": "key_refreshing",
"time": _now(),
"message": "UserKey expired β€” refreshing via browser automation …",
})
try:
new_key = await fetch_key_via_browser(
timeout=ZD_TIMEOUT, headless=ZD_HEADLESS,
)
if new_key:
async with _key_lock:
USER_KEY = new_key
_key_last_ts = time.time()
_key_fail_count = 0
log.info("UserKey refreshed OK (len=%d)", len(new_key))
await _broadcast({
"type": "key_refreshed",
"time": _now(),
"message": "UserKey refreshed – resuming generation.",
})
return new_key
# fetch returned None
_key_fail_count += 1
log.error(
"Key refresh returned nothing (failure #%d/%d)",
_key_fail_count, MAX_REFRESH_FAILURES,
)
await _broadcast({
"type": "key_refresh_failed",
"time": _now(),
"message": (
f"Key refresh failed (attempt {_key_fail_count}"
f"/{MAX_REFRESH_FAILURES})"
),
})
return None
except Exception as exc:
_key_fail_count += 1
log.exception(
"Key refresh error (failure #%d/%d): %s",
_key_fail_count, MAX_REFRESH_FAILURES, exc,
)
await _broadcast({
"type": "key_refresh_failed",
"time": _now(),
"message": f"Key refresh error: {exc}",
})
return None
finally:
# ALWAYS unblock waiters, even on failure
_key_valid.set()
# ═══════════════════════════════════════════════════════════════
# TASK MODEL
# ═══════════════════════════════════════════════════════════════
def create_task(
prompts: List[str],
count: int,
resolution: str,
guidance: float,
negative: str,
sub_channel: str,
) -> dict:
tid = str(uuid.uuid4())
task = {
"id": tid,
"prompts": prompts,
"count": count,
"resolution": resolution,
"guidance": guidance,
"negative": negative,
"sub_channel": sub_channel,
"created_at": _now(),
"status": "queued", # queued β†’ running β†’ done / failed
"total_images": len(prompts) * count,
"completed": 0,
"results": [],
"error": None,
}
TASKS[tid] = task
TASK_QUEUES[tid] = asyncio.Queue()
return task
# ═══════════════════════════════════════════════════════════════
# WORKER β€” image generation + key-refresh retry loop
# ═══════════════════════════════════════════════════════════════
async def _save_inline(data_url: str, prompt: str) -> str:
"""Decode base-64 data URL β†’ file. Returns path."""
loop = asyncio.get_running_loop()
header, b64 = (data_url.split(",", 1) + [""])[:2] if "," in data_url else ("", data_url)
ext = ".png" if "png" in header else ".jpg"
fn = _safe(f"{prompt[:30]}_{_stamp()}_{_sid()}{ext}")
fp = OUTPUT_DIR / fn
raw = base64.b64decode(b64)
await loop.run_in_executor(EXECUTOR, fp.write_bytes, raw)
log.info("Saved inline β†’ %s", fp)
return str(fp)
async def _download(image_id: str, prompt: str) -> str:
"""Download via PerchanceClient (blocking, in executor)."""
loop = asyncio.get_running_loop()
prefix = f"{_safe(prompt[:30])}_{_stamp()}_{_sid()}"
return await loop.run_in_executor(
EXECUTOR,
partial(CLIENT.download_image, image_id, prefix),
)
async def _generate_single(
prompt: str,
task: dict,
idx: int,
queue: asyncio.Queue,
ad_code: str,
) -> Optional[str]:
"""
Generate + save one image.
On 'invalid_key', triggers a coordinated key refresh and retries
up to MAX_KEY_RETRIES times. Returns the saved filepath or None.
"""
loop = asyncio.get_running_loop()
tid = task["id"]
for key_try in range(1, MAX_KEY_RETRIES + 1):
# ── wait if a refresh is in progress ──
await _key_valid.wait()
# ── read current key ──
async with _key_lock:
active_key = USER_KEY
if not active_key:
await queue.put({
"type": "error",
"time": _now(),
"task_id": tid,
"message": "No userKey available. Set via /set_user_key",
})
return None
# ── blocking generation in thread-pool ──
result = await loop.run_in_executor(
EXECUTOR,
partial(
CLIENT.generate_one,
prompt=prompt,
negative_prompt=task["negative"],
seed=-1,
resolution=task["resolution"],
guidance_scale=task["guidance"],
channel="ai-text-to-image-generator",
sub_channel=task["sub_channel"],
user_key=active_key,
ad_access_code=ad_code,
request_id=_reqid(),
),
)
# ── invalid_key β†’ refresh + retry ──
if result.get("error") == "invalid_key":
log.warning(
"invalid_key for task %s (key_try %d/%d) β†’ refreshing",
tid, key_try, MAX_KEY_RETRIES,
)
await queue.put({
"type": "key_invalid",
"time": _now(),
"task_id": tid,
"attempt": key_try,
"max_attempts": MAX_KEY_RETRIES,
"message": "UserKey invalid β€” refreshing …",
})
new_key = await refresh_user_key()
if new_key:
# also refresh ad code with fresh key
ad_code = await loop.run_in_executor(
EXECUTOR, CLIENT.get_ad_code,
)
continue # ← retry generation
else:
await queue.put({
"type": "error",
"time": _now(),
"task_id": tid,
"message": "Could not refresh userKey β€” aborting image",
})
return None
# ── other errors ──
if result.get("error"):
log.warning(
"Gen error task=%s prompt='%.40s': %s",
tid, prompt, result,
)
await queue.put({
"type": "gen_error",
"time": _now(),
"task_id": tid,
"prompt": prompt,
"index": idx,
"error": result,
})
return None
# ── success β†’ save ──
try:
if result.get("inline"):
fp = await _save_inline(result["inline"], prompt)
elif result.get("imageId"):
fp = await _download(result["imageId"], prompt)
else:
log.error("Unexpected result: %s", result)
return None
seed = result.get("seed")
task["completed"] += 1
task["results"].append({
"prompt": prompt,
"index": idx,
"path": fp,
"seed": seed,
})
await queue.put({
"type": "image_ready",
"time": _now(),
"task_id": tid,
"prompt": prompt,
"index": idx,
"path": fp,
"seed": seed,
"completed": task["completed"],
"total": task["total_images"],
})
return fp
except Exception as exc:
log.exception("Save/download error task=%s: %s", tid, exc)
await queue.put({
"type": "download_error",
"time": _now(),
"task_id": tid,
"prompt": prompt,
"index": idx,
"error": str(exc),
})
return None
# exhausted key retries
log.error("Exhausted key retries for task %s prompt='%.40s'", tid, prompt)
return None
async def worker_loop(worker_id: int, semaphore: asyncio.Semaphore):
"""Long-running coroutine: pull jobs β†’ generate images."""
log.info("Worker %d started", worker_id)
loop = asyncio.get_running_loop()
while True:
job = await JOB_QUEUE.get()
# shutdown sentinel
if job is None:
log.info("Worker %d shutting down", worker_id)
JOB_QUEUE.task_done()
break
task = job["task"]
tid = task["id"]
queue = TASK_QUEUES.get(tid)
log.info(
"Worker %d β†’ task %s (%d images)",
worker_id, tid, task["total_images"],
)
task["status"] = "running"
if queue:
await queue.put({
"type": "started",
"time": _now(),
"task_id": tid,
"total_images": task["total_images"],
})
# fetch ad code once per task
ad_code = await loop.run_in_executor(EXECUTOR, CLIENT.get_ad_code)
# heartbeat coroutine
async def _heartbeat():
while task["status"] == "running":
await asyncio.sleep(5.0)
if queue and task["status"] == "running":
try:
queue.put_nowait({
"type": "heartbeat",
"time": _now(),
"task_id": tid,
"completed": task["completed"],
"total": task["total_images"],
})
except asyncio.QueueFull:
pass
hb = asyncio.create_task(_heartbeat())
try:
for prompt in task["prompts"]:
for i in range(task["count"]):
async with semaphore:
await _generate_single(
prompt, task, i, queue, ad_code,
)
if task["status"] == "failed":
break
if task["status"] == "failed":
break
# decide final status
if task["status"] != "failed":
if task["completed"] == 0 and task["total_images"] > 0:
task["status"] = "failed"
task["error"] = "No images generated successfully"
else:
task["status"] = "done"
if queue:
await queue.put({
"type": task["status"], # "done" or "failed"
"time": _now(),
"task_id": tid,
"completed": task["completed"],
"total": task["total_images"],
"error": task.get("error"),
})
except Exception as exc:
log.exception("Worker %d task %s crashed: %s", worker_id, tid, exc)
task["status"] = "failed"
task["error"] = str(exc)
if queue:
await queue.put({
"type": "failed",
"time": _now(),
"task_id": tid,
"error": str(exc),
})
finally:
hb.cancel()
try:
await hb
except asyncio.CancelledError:
pass
if queue:
await queue.put({"type": "eof", "time": _now(), "task_id": tid})
JOB_QUEUE.task_done()
log.info(
"Worker %d task %s finished (%s, %d/%d)",
worker_id, tid, task["status"],
task["completed"], task["total_images"],
)
# ═══════════════════════════════════════════════════════════════
# FASTAPI – lifespan + app + endpoints
# ═══════════════════════════════════════════════════════════════
@asynccontextmanager
async def lifespan(app: FastAPI):
global USER_KEY, _key_lock, _key_valid, _key_refresh_lock
global _key_last_ts, _key_fail_count, JOB_QUEUE
# ── create asyncio primitives in uvicorn's loop ──
_key_lock = asyncio.Lock()
_key_valid = asyncio.Event()
_key_valid.set() # assume usable initially
_key_refresh_lock = asyncio.Lock()
JOB_QUEUE = asyncio.Queue(maxsize=MAX_QUEUE_SIZE)
# If configured to start virtual display, attempt to do so when needed.
loop = asyncio.get_running_loop()
try:
await loop.run_in_executor(None, partial(_start_virtual_display_if_needed, ZD_HEADLESS))
except Exception:
log.exception("Failed to ensure virtual display at startup")
# ── initial key fetch (skip if already set from __main__) ──
if USER_KEY:
log.info("Using pre-fetched userKey (len=%d)", len(USER_KEY))
_key_last_ts = time.time()
else:
skip = os.environ.get("NO_INITIAL_FETCH", "") in ("1", "true", "True")
if not skip:
try:
key = await fetch_key_via_browser(
timeout=ZD_TIMEOUT, headless=ZD_HEADLESS,
)
if key:
USER_KEY = key
_key_last_ts = time.time()
log.info("Fetched userKey at startup (len=%d)", len(key))
else:
log.warning(
"Startup key fetch failed. "
"Use /set_user_key or /fetch_user_key."
)
except Exception as exc:
log.exception("Startup key fetch error: %s", exc)
else:
log.info("NO_INITIAL_FETCH=1 β†’ skipping browser key fetch")
# ── launch workers ──
sem = asyncio.Semaphore(WORKER_COUNT)
workers = [
asyncio.create_task(worker_loop(i + 1, sem))
for i in range(WORKER_COUNT)
]
log.info("Launched %d workers", WORKER_COUNT)
# ---------- server is running ----------
yield
# ---------- shutdown begins ------------
log.info("Shutdown: sending stop sentinels to workers …")
for _ in range(WORKER_COUNT):
await JOB_QUEUE.put(None)
await asyncio.gather(*workers, return_exceptions=True)
try:
SCRAPER.close()
except Exception:
pass
EXECUTOR.shutdown(wait=True)
# Stop virtual display if we started one
try:
await loop.run_in_executor(None, _stop_virtual_display_if_needed)
except Exception:
log.exception("Failed to stop virtual display cleanly")
log.info("Shutdown complete")
# ── app ──
app = FastAPI(
title="Perchance Image Generation Server v2 (pyvirtualdisplay)",
lifespan=lifespan,
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# ───────────── endpoints ─────────────
@app.get("/health")
async def health():
async with _key_lock:
has_key = USER_KEY is not None
return {
"status": "ok",
"has_user_key": has_key,
"queue_size": JOB_QUEUE.qsize() if JOB_QUEUE else 0,
"active_tasks": sum(
1 for t in TASKS.values() if t["status"] in ("queued", "running")
),
}
@app.get("/user_key")
async def user_key_info():
async with _key_lock:
has = USER_KEY is not None
ln = len(USER_KEY) if has else 0
return {"has_user_key": has, "key_length": ln}
@app.post("/set_user_key")
async def set_user_key(payload: Dict[str, str]):
global USER_KEY, _key_last_ts, _key_fail_count
key = payload.get("userKey", "").strip()
if not key:
raise HTTPException(400, "userKey required")
async with _key_lock:
USER_KEY = key
_key_last_ts = time.time()
_key_fail_count = 0
_key_valid.set() # unblock any waiting workers
log.info("userKey set via API (len=%d)", len(key))
return {"status": "ok", "key_length": len(key)}
@app.post("/fetch_user_key")
async def fetch_user_key_endpoint():
"""Trigger a background browser-based key fetch."""
global _key_fail_count
async def _bg():
global _key_fail_count
_key_fail_count = 0 # reset so refresh is allowed
await refresh_user_key()
asyncio.create_task(_bg())
return {"status": "started", "note": "Browser key fetch running in background"}
@app.post("/generate")
async def submit_job(payload: Dict[str, Any]):
"""
POST /generate
Body:
{
"prompts": ["a cat in space", "sunset over mountains"],
"count": 2,
"resolution": "512x768",
"guidance": 7.0,
"negative": "",
"subChannel": "private"
}
Returns:
{ "task_id": "...", "stream_url": "/stream/...", "queue_position": N }
"""
prompts = payload.get("prompts") or payload.get("prompt") or []
if isinstance(prompts, str):
prompts = [prompts]
if not isinstance(prompts, list) or not prompts:
raise HTTPException(400, "prompts must be a non-empty list")
count = max(1, int(payload.get("count", 1)))
resolution = payload.get("resolution", "512x768")
guidance = float(payload.get("guidance", 7.0))
negative = payload.get("negative", "") or ""
sub_channel = payload.get("subChannel", "private")
task = create_task(prompts, count, resolution, guidance, negative, sub_channel)
try:
await JOB_QUEUE.put({"task": task})
except asyncio.QueueFull:
raise HTTPException(503, "Server queue full β€” try again later")
position = JOB_QUEUE.qsize()
q = TASK_QUEUES.get(task["id"])
if q:
await q.put({
"type": "queued",
"time": _now(),
"task_id": task["id"],
"queue_position": position,
"total_images": task["total_images"],
})
return {
"task_id": task["id"],
"stream_url": f"/stream/{task['id']}",
"queue_position": position,
}
@app.get("/stream/{task_id}")
async def stream_task(request: Request, task_id: str):
"""
SSE stream. Event types:
meta Β· queued Β· started Β· heartbeat Β· image_ready
key_invalid Β· key_refreshing Β· key_refreshed Β· key_refresh_failed
gen_error Β· download_error Β· done Β· failed Β· eof
"""
if task_id not in TASKS:
raise HTTPException(404, "unknown task id")
task = TASKS[task_id]
queue = TASK_QUEUES[task_id]
async def event_gen():
# ── initial snapshot ──
yield {
"event": "meta",
"data": json.dumps({
"task_id": task_id,
"status": task["status"],
"total_images": task["total_images"],
"created_at": task["created_at"],
}),
}
# ── if already finished, replay results + EOF ──
if task["status"] in ("done", "failed"):
for r in task["results"]:
yield {
"event": "image_ready",
"data": json.dumps({
"task_id": task_id,
"prompt": r["prompt"],
"index": r["index"],
"path": r["path"],
"seed": r["seed"],
"completed": task["completed"],
"total": task["total_images"],
}),
}
yield {
"event": task["status"],
"data": json.dumps({
"task_id": task_id,
"completed": task["completed"],
"total": task["total_images"],
"error": task.get("error"),
}),
}
yield {
"event": "eof",
"data": json.dumps({"task_id": task_id}),
}
return
# ── live stream ──
while True:
try:
ev = await asyncio.wait_for(queue.get(), timeout=30.0)
except asyncio.TimeoutError:
# keep-alive ping
yield {
"event": "ping",
"data": json.dumps({"time": _now()}),
}
if await request.is_disconnected():
log.info("SSE client disconnected (task %s)", task_id)
break
continue
yield {
"event": ev.get("type", "event"),
"data": json.dumps(ev),
}
if ev.get("type") == "eof":
break
return EventSourceResponse(event_gen())
@app.get("/status/{task_id}")
async def get_status(task_id: str):
task = TASKS.get(task_id)
if not task:
raise HTTPException(404, "unknown task id")
return {"task": task}
@app.get("/outputs/{filename}")
async def get_output(filename: str):
fp = OUTPUT_DIR / filename
if not fp.exists():
raise HTTPException(404, "file not found")
return FileResponse(fp, media_type="application/octet-stream", filename=filename)
# ═══════════════════════════════════════════════════════════════
# MAIN (when run directly)
# ═══════════════════════════════════════════════════════════════
if __name__ == "__main__":
import uvicorn
# Allow overriding via environment variables
ZD_HEADLESS = os.environ.get("ZD_HEADLESS", str(ZD_HEADLESS)) in ("1", "true", "True")
USE_VIRTUAL_DISPLAY = os.environ.get("USE_VIRTUAL_DISPLAY", str(USE_VIRTUAL_DISPLAY)) in ("1", "true", "True")
# If running locally and not headless, start virtual display if needed
try:
_start_virtual_display_if_needed(ZD_HEADLESS)
except Exception:
log.exception("Failed to ensure virtual display in __main__")
uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("PORT", 7860)))