CSSSC / server.py
AdarshJi's picture
Create server.py
560edb5 verified
raw
history blame
44.9 kB
# perchance_server.py
"""
Perchance Image-Generation Server v2.0
Changes from v1:
β€’ Auto-detects 'invalid_key' and re-fetches userKey via zendriver
without manual intervention.
β€’ Only ONE browser launch at startup (no duplicate fetch).
β€’ Uses FastAPI *lifespan* context-manager β†’ zero deprecation warnings.
β€’ Coordinated key refresh: one refresh at a time; other workers wait.
β€’ SSE events: key_refreshing / key_refreshed / key_refresh_failed
so the client knows exactly what's happening.
β€’ Cleaner separation of concerns, better error handling.
"""
# ═══════════════════════════════════════════════════════════════
# IMPORTS
# ═══════════════════════════════════════════════════════════════
import asyncio
import base64
import json
import logging
import os
import random
import string
import time
import uuid
from concurrent.futures import ThreadPoolExecutor
from contextlib import asynccontextmanager
from datetime import datetime
from functools import partial
from pathlib import Path
from typing import Any, Dict, List, Optional
import cloudscraper
from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse
from sse_starlette.sse import EventSourceResponse
import zendriver as zd
from zendriver import cdp
# ═══════════════════════════════════════════════════════════════
# CONFIGURATION
# ═══════════════════════════════════════════════════════════════
# --- Perchance API ---
BASE_URL = "https://image-generation.perchance.org"
API_GENERATE = "/api/generate"
API_DOWNLOAD = "/api/downloadTemporaryImage"
API_AWAIT = "/api/awaitExistingGenerationRequest"
API_ACCESS_CODE = "/api/getAccessCodeForAdPoweredStuff"
# --- Browser automation (zendriver) ---
TARGET_URL = "https://perchance.org/ai-text-to-image-generator"
IMAGE_GEN_ORIGIN = "https://image-generation.perchance.org"
ZD_TIMEOUT = 90 # seconds for key-fetch attempt
ZD_HEADLESS = False # True β†’ hide browser window
CLICK_INTERVAL = 0.35
CLICK_JITTER = 8.0
KEY_PREFIX = "userKey"
# --- HTTP / generation ---
HTTP_TIMEOUT = 30
MAX_DOWNLOAD_WAIT = 180
BACKOFF_INIT = 0.7
MAX_GEN_RETRIES = 6 # retries inside generate_one()
# --- Key-refresh policy ---
MAX_KEY_RETRIES = 3 # per-image retries when key is invalid
KEY_REFRESH_COOLDOWN = 30 # min seconds between two refreshes
MAX_REFRESH_FAILURES = 5 # consecutive failures β†’ stop auto-refresh
# --- Server ---
WORKER_COUNT = 3
MAX_QUEUE_SIZE = 1000
EXECUTOR_THREADS = 16
OUTPUT_DIR = Path("outputs")
OUTPUT_DIR.mkdir(exist_ok=True, parents=True)
# ═══════════════════════════════════════════════════════════════
# LOGGING
# ═══════════════════════════════════════════════════════════════
LOG_FMT = "%(asctime)s | %(levelname)-7s | %(message)s"
logging.basicConfig(level=logging.INFO, format=LOG_FMT)
log = logging.getLogger("perchance")
# ═══════════════════════════════════════════════════════════════
# GLOBAL STATE
#
# Asyncio primitives (Lock, Event, Queue) are created inside
# the lifespan handler so they live in uvicorn's event loop.
# ═══════════════════════════════════════════════════════════════
USER_KEY: Optional[str] = None
# -- set in lifespan --
_key_lock: Optional[asyncio.Lock] = None # guard USER_KEY reads/writes
_key_valid: Optional[asyncio.Event] = None # cleared while refreshing
_key_refresh_lock: Optional[asyncio.Lock] = None # one refresh at a time
_key_last_ts: float = 0.0 # last successful refresh
_key_fail_count: int = 0 # consecutive refresh failures
JOB_QUEUE: Optional[asyncio.Queue] = None
TASKS: Dict[str, Dict[str, Any]] = {}
TASK_QUEUES: Dict[str, asyncio.Queue] = {} # SSE event queues
EXECUTOR = ThreadPoolExecutor(max_workers=EXECUTOR_THREADS)
SCRAPER = cloudscraper.create_scraper()
# ═══════════════════════════════════════════════════════════════
# SMALL HELPERS
# ═══════════════════════════════════════════════════════════════
def _safe(s: str) -> str:
"""Sanitise string for filenames."""
ok = set(string.ascii_letters + string.digits + "-_.()")
return "".join(c if c in ok else "_" for c in s)[:120]
def _sid() -> str:
return "".join(random.choices(string.ascii_lowercase + string.digits, k=8))
def _now() -> str:
return datetime.utcnow().isoformat(timespec="milliseconds") + "Z"
def _reqid() -> str:
return f"{time.time():.6f}-{_sid()}"
def _stamp() -> str:
return datetime.utcnow().strftime("%Y%m%dT%H%M%S")
# ═══════════════════════════════════════════════════════════════
# PERCHANCE HTTP CLIENT (blocking – runs in ThreadPoolExecutor)
# ═══════════════════════════════════════════════════════════════
class PerchanceClient:
"""All blocking HTTP work against the Perchance API."""
def __init__(self):
self.base = BASE_URL.rstrip("/")
self.s = SCRAPER
self.h = {
"Accept": "*/*",
"Content-Type": "application/json;charset=UTF-8",
"Origin": IMAGE_GEN_ORIGIN,
"Referer": f"{IMAGE_GEN_ORIGIN}/embed",
"User-Agent": (
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) "
"AppleWebKit/537.36 (KHTML, like Gecko) "
"Chrome/131.0.0.0 Safari/537.36"
),
}
# ---- low-level helpers ----
def get_ad_code(self) -> str:
try:
r = self.s.get(
f"{self.base}{API_ACCESS_CODE}",
timeout=HTTP_TIMEOUT, headers=self.h,
)
r.raise_for_status()
return r.text.strip()
except Exception:
return ""
def _post(self, body: dict, params: dict) -> dict:
try:
r = self.s.post(
f"{self.base}{API_GENERATE}",
json=body, params=params,
timeout=HTTP_TIMEOUT, headers=self.h,
)
r.raise_for_status()
try:
return r.json()
except Exception:
return {"status": "invalid_json", "raw": r.text}
except Exception as exc:
return {"status": "fetch_failure", "error": str(exc)}
def _await_prev(self, key: str):
try:
self.s.get(
f"{self.base}{API_AWAIT}",
params={"userKey": key, "__cacheBust": random.random()},
timeout=20, headers=self.h,
)
except Exception:
pass
# ---- generate one image ----
def generate_one(
self, *,
prompt: str,
negative_prompt: str = "",
seed: int = -1,
resolution: str = "512x768",
guidance_scale: float = 7.0,
channel: str = "ai-text-to-image-generator",
sub_channel: str = "private",
user_key: str = "",
ad_access_code: str = "",
request_id: str = "",
) -> dict:
"""
Returns ONE of:
{"imageId": ..., "seed": ...}
{"inline": ..., "seed": ...}
{"error": "invalid_key"} ← caller must refresh key
{"error": "<other>", ...}
"""
request_id = request_id or _reqid()
params = {
"userKey": user_key,
"requestId": request_id,
"adAccessCode": ad_access_code,
"__cacheBust": random.random(),
}
body = {
"prompt": prompt,
"negativePrompt": negative_prompt,
"seed": seed,
"resolution": resolution,
"guidanceScale": guidance_scale,
"channel": channel,
"subChannel": sub_channel,
"userKey": user_key,
"adAccessCode": ad_access_code,
"requestId": request_id,
}
ad_refreshed = False
for att in range(1, MAX_GEN_RETRIES + 1):
res = self._post(body, params)
st = res.get("status")
# ---- success ----
if st == "success":
iid = res.get("imageId")
urls = res.get("imageDataUrls")
if iid:
log.info("Got imageId: %s", iid)
return {"imageId": iid, "seed": res.get("seed")}
if urls:
return {"inline": urls[0], "seed": res.get("seed")}
log.error("success but empty payload: %s", str(res)[:300])
return {"error": "empty_success", "raw": res}
# ---- invalid key β†’ return immediately (do NOT retry here) ----
if st == "invalid_key":
log.warning("Server says invalid_key")
return {"error": "invalid_key"}
# ---- previous request still running ----
if st == "waiting_for_prev_request_to_finish":
log.info("Waiting for prev request to finish …")
self._await_prev(user_key)
time.sleep(0.3 + random.random() * 0.3)
continue
# ---- ad access code expired ----
if st == "invalid_ad_access_code" and not ad_refreshed:
code = self.get_ad_code()
if code:
ad_access_code = code
params["adAccessCode"] = code
body["adAccessCode"] = code
ad_refreshed = True
log.info("Refreshed ad code β†’ retry")
time.sleep(0.8)
continue
return {"error": "invalid_ad_access_code"}
# ---- transient gen failure ----
if st == "gen_failure" and res.get("type") == 1:
log.warning("gen_failure type 1 β†’ retry after 2.5 s")
time.sleep(2.5)
continue
# ---- network / stale ----
if st in (None, "fetch_failure", "invalid_json", "stale_request"):
log.info("Transient error (status=%s) attempt %d/%d", st, att, MAX_GEN_RETRIES)
time.sleep(1.0)
continue
# ---- anything else ----
log.error("Unhandled status '%s': %s", st, str(res)[:300])
return {"error": f"unhandled_{st}", "raw": res}
return {"error": "max_retries_exceeded"}
# ---- download ----
def download_image(self, image_id: str, prefix: str = "img") -> str:
"""Poll until the image is ready, save to OUTPUT_DIR, return path."""
url = f"{self.base}{API_DOWNLOAD}?imageId={image_id}"
t0 = time.time()
bk = BACKOFF_INIT
while True:
elapsed = time.time() - t0
if elapsed >= MAX_DOWNLOAD_WAIT:
raise TimeoutError(
f"Download timed out ({elapsed:.0f}s) for {image_id}"
)
try:
r = self.s.get(url, timeout=HTTP_TIMEOUT,
headers=self.h, stream=True)
if r.status_code == 200:
ct = r.headers.get("Content-Type", "")
ext = (
".png" if "png" in ct else
".webp" if "webp" in ct else ".jpg"
)
fn = _safe(f"{prefix}_{image_id[:12]}{ext}")
fp = str(OUTPUT_DIR / fn)
with open(fp, "wb") as f:
for chunk in r.iter_content(8192):
if chunk:
f.write(chunk)
log.info("Saved β†’ %s", fp)
return fp
except Exception:
pass
time.sleep(bk)
bk = min(bk * 1.8, 8.0)
CLIENT = PerchanceClient()
# ═══════════════════════════════════════════════════════════════
# ZENDRIVER – browser automation to extract userKey
# ═══════════════════════════════════════════════════════════════
async def _cdp_mouse(tab, typ, x, y, **kw):
await tab.send(
cdp.input_.dispatch_mouse_event(
type_=typ, x=float(x), y=float(y), **kw,
)
)
async def _viewport_center(tab):
try:
v = await tab.evaluate(
"(()=>({w:innerWidth,h:innerHeight}))()",
await_promise=False, return_by_value=True,
)
return (v["w"] / 2.0, v["h"] / 2.0)
except Exception:
return (600.0, 400.0)
async def _ls_get(tab, key):
try:
return await tab.evaluate(
f"localStorage&&localStorage.getItem({json.dumps(key)})",
await_promise=True, return_by_value=True,
)
except Exception:
return None
async def _clicker_loop(tab, stop: asyncio.Event):
"""Simulate steady centre-clicks on *tab* until *stop* is set."""
try:
await tab.evaluate(
"window.focus&&window.focus()",
await_promise=False, return_by_value=False,
)
except Exception:
pass
centre = await _viewport_center(tab)
centre_upd = time.time()
while not stop.is_set():
if time.time() - centre_upd > 2.5:
centre = await _viewport_center(tab)
centre_upd = time.time()
jx = random.uniform(-CLICK_JITTER, CLICK_JITTER)
jy = random.uniform(-CLICK_JITTER, CLICK_JITTER)
cx, cy = centre[0] + jx, centre[1] + jy
try:
await _cdp_mouse(tab, "mouseMoved", cx, cy, pointer_type="mouse")
await asyncio.sleep(random.uniform(0.02, 0.08))
await _cdp_mouse(
tab, "mousePressed", cx, cy,
button=cdp.input_.MouseButton.LEFT,
click_count=1, buttons=1,
)
await asyncio.sleep(random.uniform(0.03, 0.12))
await _cdp_mouse(
tab, "mouseReleased", cx, cy,
button=cdp.input_.MouseButton.LEFT,
click_count=1, buttons=0,
)
except Exception:
pass
# interruptible sleep
try:
await asyncio.wait_for(
stop.wait(),
timeout=CLICK_INTERVAL * random.uniform(0.85, 1.15),
)
break
except asyncio.TimeoutError:
pass
async def _poll_for_key(tab, stop: asyncio.Event, max_sec: int):
"""Poll localStorage every 250 ms for a userKey entry."""
t0 = time.time()
while not stop.is_set() and (time.time() - t0) < max_sec:
val = await _ls_get(tab, f"{KEY_PREFIX}-0")
if val:
return val
try:
keys = await tab.evaluate(
"Object.keys(localStorage||{}).filter(k=>k.includes('userKey'))",
await_promise=False, return_by_value=True,
)
for k in (keys or []):
v = await _ls_get(tab, k)
if v:
return v
except Exception:
pass
await asyncio.sleep(0.25)
return None
async def fetch_key_via_browser(
timeout: int = ZD_TIMEOUT,
headless: bool = ZD_HEADLESS,
) -> Optional[str]:
"""
Launch Chrome β†’ navigate to Perchance β†’ click to trigger
ad/verification β†’ read userKey from localStorage β†’ close browser.
Returns the key string or None.
"""
log.info(
"Launching browser for userKey (timeout=%ds, headless=%s)",
timeout, headless,
)
try:
browser = await zd.start(headless=headless)
except Exception as exc:
log.exception("Browser start failed: %s", exc)
return None
stop = asyncio.Event()
result = None
try:
page_tab = await browser.get(TARGET_URL)
log.info("Opened %s", TARGET_URL)
await asyncio.sleep(2.0)
origin_tab = await browser.get(IMAGE_GEN_ORIGIN, new_tab=True)
log.info("Opened %s", IMAGE_GEN_ORIGIN)
await asyncio.sleep(1.0)
await page_tab.bring_to_front()
await asyncio.sleep(0.5)
clicker = asyncio.create_task(_clicker_loop(page_tab, stop))
poller = asyncio.create_task(_poll_for_key(origin_tab, stop, timeout))
try:
done, _ = await asyncio.wait({poller}, timeout=timeout)
if poller in done:
result = poller.result()
finally:
stop.set()
if not clicker.done():
clicker.cancel()
try:
await clicker
except asyncio.CancelledError:
pass
for t in (origin_tab, page_tab):
try:
await t.close()
except Exception:
pass
finally:
try:
await browser.stop()
except Exception:
pass
if result:
log.info("Fetched userKey (len=%d)", len(result))
else:
log.warning("Could not fetch userKey within %ds", timeout)
return result
# ═══════════════════════════════════════════════════════════════
# KEY MANAGEMENT – coordinated refresh across workers
# ═══════════════════════════════════════════════════════════════
async def _broadcast(event: dict):
"""Push an event into every active task's SSE queue."""
for tid, q in TASK_QUEUES.items():
task = TASKS.get(tid)
if task and task["status"] in ("queued", "running"):
try:
q.put_nowait(event)
except asyncio.QueueFull:
pass
async def refresh_user_key() -> Optional[str]:
"""
Coordinate a single key refresh. If another coroutine is already
refreshing, we simply wait for it to finish and return the new key.
Returns the new key string, or None on failure.
"""
global USER_KEY, _key_last_ts, _key_fail_count
async with _key_refresh_lock:
# ── double-check: maybe another coroutine just refreshed ──
age = time.time() - _key_last_ts
if age < KEY_REFRESH_COOLDOWN and USER_KEY:
log.info(
"Key was refreshed %.1fs ago β†’ reusing existing key", age,
)
return USER_KEY
# ── too many consecutive failures? ──
if _key_fail_count >= MAX_REFRESH_FAILURES:
log.error(
"Key refresh disabled: %d consecutive failures. "
"Set key manually via POST /set_user_key",
_key_fail_count,
)
await _broadcast({
"type": "key_refresh_failed",
"time": _now(),
"message": (
f"Auto-refresh disabled after {_key_fail_count} failures. "
"Please set userKey manually via /set_user_key"
),
})
return None
# ── signal "key is being refreshed" ──
_key_valid.clear()
log.info("Starting userKey refresh via browser …")
await _broadcast({
"type": "key_refreshing",
"time": _now(),
"message": "UserKey expired β€” refreshing via browser automation …",
})
try:
new_key = await fetch_key_via_browser(
timeout=ZD_TIMEOUT, headless=ZD_HEADLESS,
)
if new_key:
async with _key_lock:
USER_KEY = new_key
_key_last_ts = time.time()
_key_fail_count = 0
log.info("UserKey refreshed OK (len=%d)", len(new_key))
await _broadcast({
"type": "key_refreshed",
"time": _now(),
"message": "UserKey refreshed – resuming generation.",
})
return new_key
# fetch returned None
_key_fail_count += 1
log.error(
"Key refresh returned nothing (failure #%d/%d)",
_key_fail_count, MAX_REFRESH_FAILURES,
)
await _broadcast({
"type": "key_refresh_failed",
"time": _now(),
"message": (
f"Key refresh failed (attempt {_key_fail_count}"
f"/{MAX_REFRESH_FAILURES})"
),
})
return None
except Exception as exc:
_key_fail_count += 1
log.exception(
"Key refresh error (failure #%d/%d): %s",
_key_fail_count, MAX_REFRESH_FAILURES, exc,
)
await _broadcast({
"type": "key_refresh_failed",
"time": _now(),
"message": f"Key refresh error: {exc}",
})
return None
finally:
# ALWAYS unblock waiters, even on failure
_key_valid.set()
# ═══════════════════════════════════════════════════════════════
# TASK MODEL
# ═══════════════════════════════════════════════════════════════
def create_task(
prompts: List[str],
count: int,
resolution: str,
guidance: float,
negative: str,
sub_channel: str,
) -> dict:
tid = str(uuid.uuid4())
task = {
"id": tid,
"prompts": prompts,
"count": count,
"resolution": resolution,
"guidance": guidance,
"negative": negative,
"sub_channel": sub_channel,
"created_at": _now(),
"status": "queued", # queued β†’ running β†’ done / failed
"total_images": len(prompts) * count,
"completed": 0,
"results": [],
"error": None,
}
TASKS[tid] = task
TASK_QUEUES[tid] = asyncio.Queue()
return task
# ═══════════════════════════════════════════════════════════════
# WORKER β€” image generation + key-refresh retry loop
# ═══════════════════════════════════════════════════════════════
async def _save_inline(data_url: str, prompt: str) -> str:
"""Decode base-64 data URL β†’ file. Returns path."""
loop = asyncio.get_running_loop()
header, b64 = (data_url.split(",", 1) + [""])[:2] if "," in data_url else ("", data_url)
ext = ".png" if "png" in header else ".jpg"
fn = _safe(f"{prompt[:30]}_{_stamp()}_{_sid()}{ext}")
fp = OUTPUT_DIR / fn
raw = base64.b64decode(b64)
await loop.run_in_executor(EXECUTOR, fp.write_bytes, raw)
log.info("Saved inline β†’ %s", fp)
return str(fp)
async def _download(image_id: str, prompt: str) -> str:
"""Download via PerchanceClient (blocking, in executor)."""
loop = asyncio.get_running_loop()
prefix = f"{_safe(prompt[:30])}_{_stamp()}_{_sid()}"
return await loop.run_in_executor(
EXECUTOR,
partial(CLIENT.download_image, image_id, prefix),
)
async def _generate_single(
prompt: str,
task: dict,
idx: int,
queue: asyncio.Queue,
ad_code: str,
) -> Optional[str]:
"""
Generate + save one image.
On 'invalid_key', triggers a coordinated key refresh and retries
up to MAX_KEY_RETRIES times. Returns the saved filepath or None.
"""
loop = asyncio.get_running_loop()
tid = task["id"]
for key_try in range(1, MAX_KEY_RETRIES + 1):
# ── wait if a refresh is in progress ──
await _key_valid.wait()
# ── read current key ──
async with _key_lock:
active_key = USER_KEY
if not active_key:
await queue.put({
"type": "error",
"time": _now(),
"task_id": tid,
"message": "No userKey available. Set via /set_user_key",
})
return None
# ── blocking generation in thread-pool ──
result = await loop.run_in_executor(
EXECUTOR,
partial(
CLIENT.generate_one,
prompt=prompt,
negative_prompt=task["negative"],
seed=-1,
resolution=task["resolution"],
guidance_scale=task["guidance"],
channel="ai-text-to-image-generator",
sub_channel=task["sub_channel"],
user_key=active_key,
ad_access_code=ad_code,
request_id=_reqid(),
),
)
# ── invalid_key β†’ refresh + retry ──
if result.get("error") == "invalid_key":
log.warning(
"invalid_key for task %s (key_try %d/%d) β†’ refreshing",
tid, key_try, MAX_KEY_RETRIES,
)
await queue.put({
"type": "key_invalid",
"time": _now(),
"task_id": tid,
"attempt": key_try,
"max_attempts": MAX_KEY_RETRIES,
"message": "UserKey invalid β€” refreshing …",
})
new_key = await refresh_user_key()
if new_key:
# also refresh ad code with fresh key
ad_code = await loop.run_in_executor(
EXECUTOR, CLIENT.get_ad_code,
)
continue # ← retry generation
else:
await queue.put({
"type": "error",
"time": _now(),
"task_id": tid,
"message": "Could not refresh userKey β€” aborting image",
})
return None
# ── other errors ──
if result.get("error"):
log.warning(
"Gen error task=%s prompt='%.40s': %s",
tid, prompt, result,
)
await queue.put({
"type": "gen_error",
"time": _now(),
"task_id": tid,
"prompt": prompt,
"index": idx,
"error": result,
})
return None
# ── success β†’ save ──
try:
if result.get("inline"):
fp = await _save_inline(result["inline"], prompt)
elif result.get("imageId"):
fp = await _download(result["imageId"], prompt)
else:
log.error("Unexpected result: %s", result)
return None
seed = result.get("seed")
task["completed"] += 1
task["results"].append({
"prompt": prompt,
"index": idx,
"path": fp,
"seed": seed,
})
await queue.put({
"type": "image_ready",
"time": _now(),
"task_id": tid,
"prompt": prompt,
"index": idx,
"path": fp,
"seed": seed,
"completed": task["completed"],
"total": task["total_images"],
})
return fp
except Exception as exc:
log.exception("Save/download error task=%s: %s", tid, exc)
await queue.put({
"type": "download_error",
"time": _now(),
"task_id": tid,
"prompt": prompt,
"index": idx,
"error": str(exc),
})
return None
# exhausted key retries
log.error("Exhausted key retries for task %s prompt='%.40s'", tid, prompt)
return None
async def worker_loop(worker_id: int, semaphore: asyncio.Semaphore):
"""Long-running coroutine: pull jobs β†’ generate images."""
log.info("Worker %d started", worker_id)
loop = asyncio.get_running_loop()
while True:
job = await JOB_QUEUE.get()
# shutdown sentinel
if job is None:
log.info("Worker %d shutting down", worker_id)
JOB_QUEUE.task_done()
break
task = job["task"]
tid = task["id"]
queue = TASK_QUEUES.get(tid)
log.info(
"Worker %d β†’ task %s (%d images)",
worker_id, tid, task["total_images"],
)
task["status"] = "running"
if queue:
await queue.put({
"type": "started",
"time": _now(),
"task_id": tid,
"total_images": task["total_images"],
})
# fetch ad code once per task
ad_code = await loop.run_in_executor(EXECUTOR, CLIENT.get_ad_code)
# heartbeat coroutine
async def _heartbeat():
while task["status"] == "running":
await asyncio.sleep(5.0)
if queue and task["status"] == "running":
try:
queue.put_nowait({
"type": "heartbeat",
"time": _now(),
"task_id": tid,
"completed": task["completed"],
"total": task["total_images"],
})
except asyncio.QueueFull:
pass
hb = asyncio.create_task(_heartbeat())
try:
for prompt in task["prompts"]:
for i in range(task["count"]):
async with semaphore:
await _generate_single(
prompt, task, i, queue, ad_code,
)
if task["status"] == "failed":
break
if task["status"] == "failed":
break
# decide final status
if task["status"] != "failed":
if task["completed"] == 0 and task["total_images"] > 0:
task["status"] = "failed"
task["error"] = "No images generated successfully"
else:
task["status"] = "done"
if queue:
await queue.put({
"type": task["status"], # "done" or "failed"
"time": _now(),
"task_id": tid,
"completed": task["completed"],
"total": task["total_images"],
"error": task.get("error"),
})
except Exception as exc:
log.exception("Worker %d task %s crashed: %s", worker_id, tid, exc)
task["status"] = "failed"
task["error"] = str(exc)
if queue:
await queue.put({
"type": "failed",
"time": _now(),
"task_id": tid,
"error": str(exc),
})
finally:
hb.cancel()
try:
await hb
except asyncio.CancelledError:
pass
if queue:
await queue.put({"type": "eof", "time": _now(), "task_id": tid})
JOB_QUEUE.task_done()
log.info(
"Worker %d task %s finished (%s, %d/%d)",
worker_id, tid, task["status"],
task["completed"], task["total_images"],
)
# ═══════════════════════════════════════════════════════════════
# FASTAPI – lifespan + app + endpoints
# ═══════════════════════════════════════════════════════════════
@asynccontextmanager
async def lifespan(app: FastAPI):
global USER_KEY, _key_lock, _key_valid, _key_refresh_lock
global _key_last_ts, _key_fail_count, JOB_QUEUE
# ── create asyncio primitives in uvicorn's loop ──
_key_lock = asyncio.Lock()
_key_valid = asyncio.Event()
_key_valid.set() # assume usable initially
_key_refresh_lock = asyncio.Lock()
JOB_QUEUE = asyncio.Queue(maxsize=MAX_QUEUE_SIZE)
# ── initial key fetch (skip if already set from __main__) ──
if USER_KEY:
log.info("Using pre-fetched userKey (len=%d)", len(USER_KEY))
_key_last_ts = time.time()
else:
skip = os.environ.get("NO_INITIAL_FETCH", "") in ("1", "true", "True")
if not skip:
try:
key = await fetch_key_via_browser(
timeout=ZD_TIMEOUT, headless=ZD_HEADLESS,
)
if key:
USER_KEY = key
_key_last_ts = time.time()
log.info("Fetched userKey at startup (len=%d)", len(key))
else:
log.warning(
"Startup key fetch failed. "
"Use /set_user_key or /fetch_user_key."
)
except Exception as exc:
log.exception("Startup key fetch error: %s", exc)
else:
log.info("NO_INITIAL_FETCH=1 β†’ skipping browser key fetch")
# ── launch workers ──
sem = asyncio.Semaphore(WORKER_COUNT)
workers = [
asyncio.create_task(worker_loop(i + 1, sem))
for i in range(WORKER_COUNT)
]
log.info("Launched %d workers", WORKER_COUNT)
# ---------- server is running ----------
yield
# ---------- shutdown begins ------------
log.info("Shutdown: sending stop sentinels to workers …")
for _ in range(WORKER_COUNT):
await JOB_QUEUE.put(None)
await asyncio.gather(*workers, return_exceptions=True)
try:
SCRAPER.close()
except Exception:
pass
EXECUTOR.shutdown(wait=True)
log.info("Shutdown complete")
# ── app ──
app = FastAPI(
title="Perchance Image Generation Server v2",
lifespan=lifespan,
)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# ───────────── endpoints ─────────────
@app.get("/health")
async def health():
async with _key_lock:
has_key = USER_KEY is not None
return {
"status": "ok",
"has_user_key": has_key,
"queue_size": JOB_QUEUE.qsize() if JOB_QUEUE else 0,
"active_tasks": sum(
1 for t in TASKS.values() if t["status"] in ("queued", "running")
),
}
@app.get("/user_key")
async def user_key_info():
async with _key_lock:
has = USER_KEY is not None
ln = len(USER_KEY) if has else 0
return {"has_user_key": has, "key_length": ln}
@app.post("/set_user_key")
async def set_user_key(payload: Dict[str, str]):
global USER_KEY, _key_last_ts, _key_fail_count
key = payload.get("userKey", "").strip()
if not key:
raise HTTPException(400, "userKey required")
async with _key_lock:
USER_KEY = key
_key_last_ts = time.time()
_key_fail_count = 0
_key_valid.set() # unblock any waiting workers
log.info("userKey set via API (len=%d)", len(key))
return {"status": "ok", "key_length": len(key)}
@app.post("/fetch_user_key")
async def fetch_user_key_endpoint():
"""Trigger a background browser-based key fetch."""
global _key_fail_count
async def _bg():
global _key_fail_count
_key_fail_count = 0 # reset so refresh is allowed
await refresh_user_key()
asyncio.create_task(_bg())
return {"status": "started", "note": "Browser key fetch running in background"}
@app.post("/generate")
async def submit_job(payload: Dict[str, Any]):
"""
POST /generate
Body:
{
"prompts": ["a cat in space", "sunset over mountains"],
"count": 2,
"resolution": "512x768",
"guidance": 7.0,
"negative": "",
"subChannel": "private"
}
Returns:
{ "task_id": "...", "stream_url": "/stream/...", "queue_position": N }
"""
prompts = payload.get("prompts") or payload.get("prompt") or []
if isinstance(prompts, str):
prompts = [prompts]
if not isinstance(prompts, list) or not prompts:
raise HTTPException(400, "prompts must be a non-empty list")
count = max(1, int(payload.get("count", 1)))
resolution = payload.get("resolution", "512x768")
guidance = float(payload.get("guidance", 7.0))
negative = payload.get("negative", "") or ""
sub_channel = payload.get("subChannel", "private")
task = create_task(prompts, count, resolution, guidance, negative, sub_channel)
try:
await JOB_QUEUE.put({"task": task})
except asyncio.QueueFull:
raise HTTPException(503, "Server queue full β€” try again later")
position = JOB_QUEUE.qsize()
q = TASK_QUEUES.get(task["id"])
if q:
await q.put({
"type": "queued",
"time": _now(),
"task_id": task["id"],
"queue_position": position,
"total_images": task["total_images"],
})
return {
"task_id": task["id"],
"stream_url": f"/stream/{task['id']}",
"queue_position": position,
}
@app.get("/stream/{task_id}")
async def stream_task(request: Request, task_id: str):
"""
SSE stream. Event types:
meta Β· queued Β· started Β· heartbeat Β· image_ready
key_invalid Β· key_refreshing Β· key_refreshed Β· key_refresh_failed
gen_error Β· download_error Β· done Β· failed Β· eof
"""
if task_id not in TASKS:
raise HTTPException(404, "unknown task id")
task = TASKS[task_id]
queue = TASK_QUEUES[task_id]
async def event_gen():
# ── initial snapshot ──
yield {
"event": "meta",
"data": json.dumps({
"task_id": task_id,
"status": task["status"],
"total_images": task["total_images"],
"created_at": task["created_at"],
}),
}
# ── if already finished, replay results + EOF ──
if task["status"] in ("done", "failed"):
for r in task["results"]:
yield {
"event": "image_ready",
"data": json.dumps({
"task_id": task_id,
"prompt": r["prompt"],
"index": r["index"],
"path": r["path"],
"seed": r["seed"],
"completed": task["completed"],
"total": task["total_images"],
}),
}
yield {
"event": task["status"],
"data": json.dumps({
"task_id": task_id,
"completed": task["completed"],
"total": task["total_images"],
"error": task.get("error"),
}),
}
yield {
"event": "eof",
"data": json.dumps({"task_id": task_id}),
}
return
# ── live stream ──
while True:
try:
ev = await asyncio.wait_for(queue.get(), timeout=30.0)
except asyncio.TimeoutError:
# keep-alive ping
yield {
"event": "ping",
"data": json.dumps({"time": _now()}),
}
if await request.is_disconnected():
log.info("SSE client disconnected (task %s)", task_id)
break
continue
yield {
"event": ev.get("type", "event"),
"data": json.dumps(ev),
}
if ev.get("type") == "eof":
break
return EventSourceResponse(event_gen())
@app.get("/status/{task_id}")
async def get_status(task_id: str):
task = TASKS.get(task_id)
if not task:
raise HTTPException(404, "unknown task id")
return {"task": task}
@app.get("/outputs/{filename}")
async def get_output(filename: str):
fp = OUTPUT_DIR / filename
if not fp.exists():
raise HTTPException(404, "file not found")
return FileResponse(fp, media_type="application/octet-stream", filename=filename)
# ═══════════════════════════════════════════════════════════════
# MAIN
# ═══════════════════════════════════════════════════════════════
def _run_uvicorn():
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000, log_level="info")
if __name__ == "__main__":
skip = os.environ.get("NO_INITIAL_FETCH", "") in ("1", "true", "True")
if not skip:
log.info("Pre-startup: fetching userKey via browser (timeout %ds) …", ZD_TIMEOUT)
try:
found = asyncio.run(
fetch_key_via_browser(timeout=ZD_TIMEOUT, headless=ZD_HEADLESS)
)
if found:
USER_KEY = found
log.info("Pre-startup fetch OK (len=%d)", len(found))
else:
log.warning(
"Pre-startup fetch returned nothing. "
"Server will start; set key via /set_user_key."
)
except Exception as exc:
log.exception("Pre-startup fetch error: %s", exc)
else:
log.info("NO_INITIAL_FETCH=1 β†’ skipping pre-startup browser fetch")
_run_uvicorn()