| | |
| | """ |
| | Perchance Image-Generation Server v2.0 |
| | |
| | This variant adds optional pyvirtualdisplay support so the server can be |
| | hosted on headless environments (Hugging Face Spaces etc.) while keeping |
| | all original behaviour unchanged. |
| | |
| | Behaviour: |
| | - If ZD_HEADLESS is True, zendriver will run headless as before. |
| | - If ZD_HEADLESS is False and USE_VIRTUAL_DISPLAY is True and a DISPLAY |
| | is not present, we attempt to start a pyvirtualdisplay.Display (Xvfb) |
| | automatically before launching browsers. If pyvirtualdisplay is not |
| | installed or starting Xvfb fails, we log a warning and continue. |
| | - If USE_VIRTUAL_DISPLAY is False we will NOT attempt to start a virtual |
| | display β you must provide a DISPLAY yourself (or set ZD_HEADLESS=True) |
| | if running on a headless host. |
| | |
| | To run on Hugging Face Spaces, add `pyvirtualdisplay` to requirements.txt |
| | and ensure `xvfb` is available in the runtime (HF Spaces typically provide it). |
| | |
| | All original defaults and constants are preserved from the original file. |
| | """ |
| |
|
| | import asyncio |
| | import base64 |
| | import json |
| | import logging |
| | import os |
| | import random |
| | import string |
| | import time |
| | import uuid |
| | from concurrent.futures import ThreadPoolExecutor |
| | from contextlib import asynccontextmanager |
| | from datetime import datetime |
| | from functools import partial |
| | from pathlib import Path |
| | from typing import Any, Dict, List, Optional |
| |
|
| | import cloudscraper |
| | from fastapi import FastAPI, HTTPException, Request |
| | from fastapi.middleware.cors import CORSMiddleware |
| | from fastapi.responses import FileResponse |
| | from sse_starlette.sse import EventSourceResponse |
| | import zendriver as zd |
| | from zendriver import cdp |
| |
|
| | |
| | try: |
| | from pyvirtualdisplay import Display |
| | _HAS_PYVIRTUALDISPLAY = True |
| | except Exception: |
| | Display = None |
| | _HAS_PYVIRTUALDISPLAY = False |
| |
|
| | |
| | |
| | |
| |
|
| | |
| | BASE_URL = "https://image-generation.perchance.org" |
| | API_GENERATE = "/api/generate" |
| | API_DOWNLOAD = "/api/downloadTemporaryImage" |
| | API_AWAIT = "/api/awaitExistingGenerationRequest" |
| | API_ACCESS_CODE = "/api/getAccessCodeForAdPoweredStuff" |
| |
|
| | |
| | TARGET_URL = "https://perchance.org/ai-text-to-image-generator" |
| | IMAGE_GEN_ORIGIN = "https://image-generation.perchance.org" |
| | ZD_TIMEOUT = 90 |
| | ZD_HEADLESS = False |
| | CLICK_INTERVAL = 0.35 |
| | CLICK_JITTER = 8.0 |
| | KEY_PREFIX = "userKey" |
| |
|
| | |
| | |
| | |
| | |
| | |
| | USE_VIRTUAL_DISPLAY = True |
| |
|
| | |
| | HTTP_TIMEOUT = 30 |
| | MAX_DOWNLOAD_WAIT = 180 |
| | BACKOFF_INIT = 0.7 |
| | MAX_GEN_RETRIES = 6 |
| |
|
| | |
| | MAX_KEY_RETRIES = 3 |
| | KEY_REFRESH_COOLDOWN = 30 |
| | MAX_REFRESH_FAILURES = 5 |
| |
|
| | |
| | WORKER_COUNT = 3 |
| | MAX_QUEUE_SIZE = 1000 |
| | EXECUTOR_THREADS = 16 |
| | OUTPUT_DIR = Path("outputs") |
| | OUTPUT_DIR.mkdir(exist_ok=True, parents=True) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | LOG_FMT = "%(asctime)s | %(levelname)-7s | %(message)s" |
| | logging.basicConfig(level=logging.INFO, format=LOG_FMT) |
| | log = logging.getLogger("perchance") |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | USER_KEY: Optional[str] = None |
| |
|
| | |
| | _key_lock: Optional[asyncio.Lock] = None |
| | _key_valid: Optional[asyncio.Event] = None |
| | _key_refresh_lock: Optional[asyncio.Lock] = None |
| | _key_last_ts: float = 0.0 |
| | _key_fail_count: int = 0 |
| |
|
| | JOB_QUEUE: Optional[asyncio.Queue] = None |
| |
|
| | TASKS: Dict[str, Dict[str, Any]] = {} |
| | TASK_QUEUES: Dict[str, asyncio.Queue] = {} |
| |
|
| | EXECUTOR = ThreadPoolExecutor(max_workers=EXECUTOR_THREADS) |
| | SCRAPER = cloudscraper.create_scraper() |
| |
|
| | |
| | VDISPLAY: Optional[Display] = None |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def _safe(s: str) -> str: |
| | """Sanitise string for filenames.""" |
| | ok = set(string.ascii_letters + string.digits + "-_.()") |
| | return "".join(c if c in ok else "_" for c in s)[:120] |
| |
|
| |
|
| | def _sid() -> str: |
| | return "".join(random.choices(string.ascii_lowercase + string.digits, k=8)) |
| |
|
| |
|
| | def _now() -> str: |
| | return datetime.utcnow().isoformat(timespec="milliseconds") + "Z" |
| |
|
| |
|
| | def _reqid() -> str: |
| | return f"{time.time():.6f}-{_sid()}" |
| |
|
| |
|
| | def _stamp() -> str: |
| | return datetime.utcnow().strftime("%Y%m%dT%H%M%S") |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def _start_virtual_display_if_needed(headless: bool): |
| | """ |
| | Start pyvirtualdisplay.Display() if we're running non-headless in an |
| | environment without DISPLAY and USE_VIRTUAL_DISPLAY is True. |
| | This function is synchronous and safe to be run in a thread executor. |
| | """ |
| | global VDISPLAY |
| |
|
| | if headless: |
| | log.info("ZD_HEADLESS=True β not starting virtual display") |
| | return |
| |
|
| | if not USE_VIRTUAL_DISPLAY: |
| | log.info("USE_VIRTUAL_DISPLAY=False β not starting virtual display; expecting manual DISPLAY or headless mode.") |
| | return |
| |
|
| | if os.environ.get("DISPLAY"): |
| | log.info("DISPLAY already set: %s", os.environ.get("DISPLAY")) |
| | return |
| |
|
| | if not _HAS_PYVIRTUALDISPLAY or Display is None: |
| | log.warning( |
| | "pyvirtualdisplay not installed β cannot create virtual DISPLAY. " |
| | "Install pyvirtualdisplay in your environment to enable Xvfb.") |
| | return |
| |
|
| | try: |
| | VDISPLAY = Display(visible=0, size=(1280, 720)) |
| | VDISPLAY.start() |
| | |
| | log.info("Started virtual display via pyvirtualdisplay (DISPLAY=%s)", os.environ.get("DISPLAY")) |
| | except Exception as exc: |
| | VDISPLAY = None |
| | log.exception("Failed to start virtual display: %s", exc) |
| |
|
| |
|
| | def _stop_virtual_display_if_needed(): |
| | global VDISPLAY |
| | if VDISPLAY is None: |
| | return |
| | try: |
| | VDISPLAY.stop() |
| | log.info("Stopped virtual display") |
| | except Exception: |
| | log.exception("Error while stopping virtual display") |
| | finally: |
| | VDISPLAY = None |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | class PerchanceClient: |
| | """All blocking HTTP work against the Perchance API.""" |
| |
|
| | def __init__(self): |
| | self.base = BASE_URL.rstrip("/") |
| | self.s = SCRAPER |
| | self.h = { |
| | "Accept": "*/*", |
| | "Content-Type": "application/json;charset=UTF-8", |
| | "Origin": IMAGE_GEN_ORIGIN, |
| | "Referer": f"{IMAGE_GEN_ORIGIN}/embed", |
| | "User-Agent": ( |
| | "Mozilla/5.0 (Windows NT 10.0; Win64; x64) " |
| | "AppleWebKit/537.36 (KHTML, like Gecko) " |
| | "Chrome/131.0.0.0 Safari/537.36" |
| | ), |
| | } |
| |
|
| | |
| |
|
| | def get_ad_code(self) -> str: |
| | try: |
| | r = self.s.get( |
| | f"{self.base}{API_ACCESS_CODE}", |
| | timeout=HTTP_TIMEOUT, headers=self.h, |
| | ) |
| | r.raise_for_status() |
| | return r.text.strip() |
| | except Exception: |
| | return "" |
| |
|
| | def _post(self, body: dict, params: dict) -> dict: |
| | try: |
| | r = self.s.post( |
| | f"{self.base}{API_GENERATE}", |
| | json=body, params=params, |
| | timeout=HTTP_TIMEOUT, headers=self.h, |
| | ) |
| | r.raise_for_status() |
| | try: |
| | return r.json() |
| | except Exception: |
| | return {"status": "invalid_json", "raw": r.text} |
| | except Exception as exc: |
| | return {"status": "fetch_failure", "error": str(exc)} |
| |
|
| | def _await_prev(self, key: str): |
| | try: |
| | self.s.get( |
| | f"{self.base}{API_AWAIT}", |
| | params={"userKey": key, "__cacheBust": random.random()}, |
| | timeout=20, headers=self.h, |
| | ) |
| | except Exception: |
| | pass |
| |
|
| | |
| |
|
| | def generate_one( |
| | self, *, |
| | prompt: str, |
| | negative_prompt: str = "", |
| | seed: int = -1, |
| | resolution: str = "512x768", |
| | guidance_scale: float = 7.0, |
| | channel: str = "ai-text-to-image-generator", |
| | sub_channel: str = "private", |
| | user_key: str = "", |
| | ad_access_code: str = "", |
| | request_id: str = "", |
| | ) -> dict: |
| | """ |
| | Returns ONE of: |
| | {"imageId": ..., "seed": ...} |
| | {"inline": ..., "seed": ...} |
| | {"error": "invalid_key"} β caller must refresh key |
| | {"error": "<other>", ...} |
| | """ |
| | request_id = request_id or _reqid() |
| | params = { |
| | "userKey": user_key, |
| | "requestId": request_id, |
| | "adAccessCode": ad_access_code, |
| | "__cacheBust": random.random(), |
| | } |
| | body = { |
| | "prompt": prompt, |
| | "negativePrompt": negative_prompt, |
| | "seed": seed, |
| | "resolution": resolution, |
| | "guidanceScale": guidance_scale, |
| | "channel": channel, |
| | "subChannel": sub_channel, |
| | "userKey": user_key, |
| | "adAccessCode": ad_access_code, |
| | "requestId": request_id, |
| | } |
| |
|
| | ad_refreshed = False |
| |
|
| | for att in range(1, MAX_GEN_RETRIES + 1): |
| | res = self._post(body, params) |
| | st = res.get("status") |
| |
|
| | |
| | if st == "success": |
| | iid = res.get("imageId") |
| | urls = res.get("imageDataUrls") |
| | if iid: |
| | log.info("Got imageId: %s", iid) |
| | return {"imageId": iid, "seed": res.get("seed")} |
| | if urls: |
| | return {"inline": urls[0], "seed": res.get("seed")} |
| | log.error("success but empty payload: %s", str(res)[:300]) |
| | return {"error": "empty_success", "raw": res} |
| |
|
| | |
| | if st == "invalid_key": |
| | log.warning("Server says invalid_key") |
| | return {"error": "invalid_key"} |
| |
|
| | |
| | if st == "waiting_for_prev_request_to_finish": |
| | log.info("Waiting for prev request to finish β¦") |
| | self._await_prev(user_key) |
| | time.sleep(0.3 + random.random() * 0.3) |
| | continue |
| |
|
| | |
| | if st == "invalid_ad_access_code" and not ad_refreshed: |
| | code = self.get_ad_code() |
| | if code: |
| | ad_access_code = code |
| | params["adAccessCode"] = code |
| | body["adAccessCode"] = code |
| | ad_refreshed = True |
| | log.info("Refreshed ad code β retry") |
| | time.sleep(0.8) |
| | continue |
| | return {"error": "invalid_ad_access_code"} |
| |
|
| | |
| | if st == "gen_failure" and res.get("type") == 1: |
| | log.warning("gen_failure type 1 β retry after 2.5 s") |
| | time.sleep(2.5) |
| | continue |
| |
|
| | |
| | if st in (None, "fetch_failure", "invalid_json", "stale_request"): |
| | log.info("Transient error (status=%s) attempt %d/%d", st, att, MAX_GEN_RETRIES) |
| | time.sleep(1.0) |
| | continue |
| |
|
| | |
| | log.error("Unhandled status '%s': %s", st, str(res)[:300]) |
| | return {"error": f"unhandled_{st}", "raw": res} |
| |
|
| | return {"error": "max_retries_exceeded"} |
| |
|
| | |
| |
|
| | def download_image(self, image_id: str, prefix: str = "img") -> str: |
| | """Poll until the image is ready, save to OUTPUT_DIR, return path.""" |
| | url = f"{self.base}{API_DOWNLOAD}?imageId={image_id}" |
| | t0 = time.time() |
| | bk = BACKOFF_INIT |
| |
|
| | while True: |
| | elapsed = time.time() - t0 |
| | if elapsed >= MAX_DOWNLOAD_WAIT: |
| | raise TimeoutError( |
| | f"Download timed out ({elapsed:.0f}s) for {image_id}" |
| | ) |
| | try: |
| | r = self.s.get(url, timeout=HTTP_TIMEOUT, |
| | headers=self.h, stream=True) |
| | if r.status_code == 200: |
| | ct = r.headers.get("Content-Type", "") |
| | ext = ( |
| | ".png" if "png" in ct else |
| | ".webp" if "webp" in ct else ".jpg" |
| | ) |
| | fn = _safe(f"{prefix}_{image_id[:12]}{ext}") |
| | fp = str(OUTPUT_DIR / fn) |
| | with open(fp, "wb") as f: |
| | for chunk in r.iter_content(8192): |
| | if chunk: |
| | f.write(chunk) |
| | log.info("Saved β %s", fp) |
| | return fp |
| | except Exception: |
| | pass |
| |
|
| | time.sleep(bk) |
| | bk = min(bk * 1.8, 8.0) |
| |
|
| |
|
| | CLIENT = PerchanceClient() |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | async def _cdp_mouse(tab, typ, x, y, **kw): |
| | await tab.send( |
| | cdp.input_.dispatch_mouse_event( |
| | type_=typ, x=float(x), y=float(y), **kw, |
| | ) |
| | ) |
| |
|
| |
|
| | async def _viewport_center(tab): |
| | try: |
| | v = await tab.evaluate( |
| | "(()=>({w:innerWidth,h:innerHeight}))()", |
| | await_promise=False, return_by_value=True, |
| | ) |
| | return (v["w"] / 2.0, v["h"] / 2.0) |
| | except Exception: |
| | return (600.0, 400.0) |
| |
|
| |
|
| | async def _ls_get(tab, key): |
| | try: |
| | return await tab.evaluate( |
| | f"localStorage&&localStorage.getItem({json.dumps(key)})", |
| | await_promise=True, return_by_value=True, |
| | ) |
| | except Exception: |
| | return None |
| |
|
| |
|
| | async def _clicker_loop(tab, stop: asyncio.Event): |
| | """Simulate steady centre-clicks on *tab* until *stop* is set.""" |
| | try: |
| | await tab.evaluate( |
| | "window.focus&&window.focus()", |
| | await_promise=False, return_by_value=False, |
| | ) |
| | except Exception: |
| | pass |
| |
|
| | centre = await _viewport_center(tab) |
| | centre_upd = time.time() |
| |
|
| | while not stop.is_set(): |
| | if time.time() - centre_upd > 2.5: |
| | centre = await _viewport_center(tab) |
| | centre_upd = time.time() |
| |
|
| | jx = random.uniform(-CLICK_JITTER, CLICK_JITTER) |
| | jy = random.uniform(-CLICK_JITTER, CLICK_JITTER) |
| | cx, cy = centre[0] + jx, centre[1] + jy |
| |
|
| | try: |
| | await _cdp_mouse(tab, "mouseMoved", cx, cy, pointer_type="mouse") |
| | await asyncio.sleep(random.uniform(0.02, 0.08)) |
| | await _cdp_mouse( |
| | tab, "mousePressed", cx, cy, |
| | button=cdp.input_.MouseButton.LEFT, |
| | click_count=1, buttons=1, |
| | ) |
| | await asyncio.sleep(random.uniform(0.03, 0.12)) |
| | await _cdp_mouse( |
| | tab, "mouseReleased", cx, cy, |
| | button=cdp.input_.MouseButton.LEFT, |
| | click_count=1, buttons=0, |
| | ) |
| | except Exception: |
| | pass |
| |
|
| | |
| | try: |
| | await asyncio.wait_for( |
| | stop.wait(), |
| | timeout=CLICK_INTERVAL * random.uniform(0.85, 1.15), |
| | ) |
| | break |
| | except asyncio.TimeoutError: |
| | pass |
| |
|
| |
|
| | async def _poll_for_key(tab, stop: asyncio.Event, max_sec: int): |
| | """Poll localStorage every 250 ms for a userKey entry.""" |
| | t0 = time.time() |
| | while not stop.is_set() and (time.time() - t0) < max_sec: |
| | val = await _ls_get(tab, f"{KEY_PREFIX}-0") |
| | if val: |
| | return val |
| | try: |
| | keys = await tab.evaluate( |
| | "Object.keys(localStorage||{}).filter(k=>k.includes('userKey'))", |
| | await_promise=False, return_by_value=True, |
| | ) |
| | for k in (keys or []): |
| | v = await _ls_get(tab, k) |
| | if v: |
| | return v |
| | except Exception: |
| | pass |
| | await asyncio.sleep(0.25) |
| | return None |
| |
|
| |
|
| | async def fetch_key_via_browser( |
| | timeout: int = ZD_TIMEOUT, |
| | headless: bool = ZD_HEADLESS, |
| | ) -> Optional[str]: |
| | """ |
| | Launch Chrome β navigate to Perchance β click to trigger |
| | ad/verification β read userKey from localStorage β close browser. |
| | Returns the key string or None. |
| | """ |
| | log.info( |
| | "Launching browser for userKey (timeout=%ds, headless=%s)", |
| | timeout, headless, |
| | ) |
| |
|
| | |
| | |
| | |
| | loop = asyncio.get_running_loop() |
| | try: |
| | await loop.run_in_executor(None, partial(_start_virtual_display_if_needed, headless)) |
| | except Exception: |
| | log.exception("Error while attempting to start virtual display") |
| |
|
| | try: |
| | browser = await zd.start(headless=headless) |
| | except Exception as exc: |
| | log.exception("Browser start failed: %s", exc) |
| | return None |
| |
|
| | stop = asyncio.Event() |
| | result = None |
| |
|
| | try: |
| | page_tab = await browser.get(TARGET_URL) |
| | log.info("Opened %s", TARGET_URL) |
| | await asyncio.sleep(2.0) |
| |
|
| | origin_tab = await browser.get(IMAGE_GEN_ORIGIN, new_tab=True) |
| | log.info("Opened %s", IMAGE_GEN_ORIGIN) |
| | await asyncio.sleep(1.0) |
| |
|
| | await page_tab.bring_to_front() |
| | await asyncio.sleep(0.5) |
| |
|
| | clicker = asyncio.create_task(_clicker_loop(page_tab, stop)) |
| | poller = asyncio.create_task(_poll_for_key(origin_tab, stop, timeout)) |
| |
|
| | try: |
| | done, _ = await asyncio.wait({poller}, timeout=timeout) |
| | if poller in done: |
| | result = poller.result() |
| | finally: |
| | stop.set() |
| | if not clicker.done(): |
| | clicker.cancel() |
| | try: |
| | await clicker |
| | except asyncio.CancelledError: |
| | pass |
| |
|
| | for t in (origin_tab, page_tab): |
| | try: |
| | await t.close() |
| | except Exception: |
| | pass |
| | finally: |
| | try: |
| | await browser.stop() |
| | except Exception: |
| | pass |
| |
|
| | if result: |
| | log.info("Fetched userKey (len=%d)", len(result)) |
| | else: |
| | log.warning("Could not fetch userKey within %ds", timeout) |
| | return result |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | async def _broadcast(event: dict): |
| | """Push an event into every active task's SSE queue.""" |
| | for tid, q in TASK_QUEUES.items(): |
| | task = TASKS.get(tid) |
| | if task and task["status"] in ("queued", "running"): |
| | try: |
| | q.put_nowait(event) |
| | except asyncio.QueueFull: |
| | pass |
| |
|
| |
|
| | async def refresh_user_key() -> Optional[str]: |
| | """ |
| | Coordinate a single key refresh. If another coroutine is already |
| | refreshing, we simply wait for it to finish and return the new key. |
| | |
| | Returns the new key string, or None on failure. |
| | """ |
| | global USER_KEY, _key_last_ts, _key_fail_count |
| |
|
| | async with _key_refresh_lock: |
| | |
| | age = time.time() - _key_last_ts |
| | if age < KEY_REFRESH_COOLDOWN and USER_KEY: |
| | log.info( |
| | "Key was refreshed %.1fs ago β reusing existing key", age, |
| | ) |
| | return USER_KEY |
| |
|
| | |
| | if _key_fail_count >= MAX_REFRESH_FAILURES: |
| | log.error( |
| | "Key refresh disabled: %d consecutive failures. " |
| | "Set key manually via POST /set_user_key", |
| | _key_fail_count, |
| | ) |
| | await _broadcast({ |
| | "type": "key_refresh_failed", |
| | "time": _now(), |
| | "message": ( |
| | f"Auto-refresh disabled after {_key_fail_count} failures. " |
| | "Please set userKey manually via /set_user_key" |
| | ), |
| | }) |
| | return None |
| |
|
| | |
| | _key_valid.clear() |
| | log.info("Starting userKey refresh via browser β¦") |
| |
|
| | await _broadcast({ |
| | "type": "key_refreshing", |
| | "time": _now(), |
| | "message": "UserKey expired β refreshing via browser automation β¦", |
| | }) |
| |
|
| | try: |
| | new_key = await fetch_key_via_browser( |
| | timeout=ZD_TIMEOUT, headless=ZD_HEADLESS, |
| | ) |
| |
|
| | if new_key: |
| | async with _key_lock: |
| | USER_KEY = new_key |
| | _key_last_ts = time.time() |
| | _key_fail_count = 0 |
| |
|
| | log.info("UserKey refreshed OK (len=%d)", len(new_key)) |
| | await _broadcast({ |
| | "type": "key_refreshed", |
| | "time": _now(), |
| | "message": "UserKey refreshed β resuming generation.", |
| | }) |
| | return new_key |
| |
|
| | |
| | _key_fail_count += 1 |
| | log.error( |
| | "Key refresh returned nothing (failure #%d/%d)", |
| | _key_fail_count, MAX_REFRESH_FAILURES, |
| | ) |
| | await _broadcast({ |
| | "type": "key_refresh_failed", |
| | "time": _now(), |
| | "message": ( |
| | f"Key refresh failed (attempt {_key_fail_count}" |
| | f"/{MAX_REFRESH_FAILURES})" |
| | ), |
| | }) |
| | return None |
| |
|
| | except Exception as exc: |
| | _key_fail_count += 1 |
| | log.exception( |
| | "Key refresh error (failure #%d/%d): %s", |
| | _key_fail_count, MAX_REFRESH_FAILURES, exc, |
| | ) |
| | await _broadcast({ |
| | "type": "key_refresh_failed", |
| | "time": _now(), |
| | "message": f"Key refresh error: {exc}", |
| | }) |
| | return None |
| |
|
| | finally: |
| | |
| | _key_valid.set() |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | def create_task( |
| | prompts: List[str], |
| | count: int, |
| | resolution: str, |
| | guidance: float, |
| | negative: str, |
| | sub_channel: str, |
| | ) -> dict: |
| | tid = str(uuid.uuid4()) |
| | task = { |
| | "id": tid, |
| | "prompts": prompts, |
| | "count": count, |
| | "resolution": resolution, |
| | "guidance": guidance, |
| | "negative": negative, |
| | "sub_channel": sub_channel, |
| | "created_at": _now(), |
| | "status": "queued", |
| | "total_images": len(prompts) * count, |
| | "completed": 0, |
| | "results": [], |
| | "error": None, |
| | } |
| | TASKS[tid] = task |
| | TASK_QUEUES[tid] = asyncio.Queue() |
| | return task |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | async def _save_inline(data_url: str, prompt: str) -> str: |
| | """Decode base-64 data URL β file. Returns path.""" |
| | loop = asyncio.get_running_loop() |
| | header, b64 = (data_url.split(",", 1) + [""])[:2] if "," in data_url else ("", data_url) |
| | ext = ".png" if "png" in header else ".jpg" |
| | fn = _safe(f"{prompt[:30]}_{_stamp()}_{_sid()}{ext}") |
| | fp = OUTPUT_DIR / fn |
| | raw = base64.b64decode(b64) |
| | await loop.run_in_executor(EXECUTOR, fp.write_bytes, raw) |
| | log.info("Saved inline β %s", fp) |
| | return str(fp) |
| |
|
| |
|
| | async def _download(image_id: str, prompt: str) -> str: |
| | """Download via PerchanceClient (blocking, in executor).""" |
| | loop = asyncio.get_running_loop() |
| | prefix = f"{_safe(prompt[:30])}_{_stamp()}_{_sid()}" |
| | return await loop.run_in_executor( |
| | EXECUTOR, |
| | partial(CLIENT.download_image, image_id, prefix), |
| | ) |
| |
|
| |
|
| | async def _generate_single( |
| | prompt: str, |
| | task: dict, |
| | idx: int, |
| | queue: asyncio.Queue, |
| | ad_code: str, |
| | ) -> Optional[str]: |
| | """ |
| | Generate + save one image. |
| | |
| | On 'invalid_key', triggers a coordinated key refresh and retries |
| | up to MAX_KEY_RETRIES times. Returns the saved filepath or None. |
| | """ |
| | loop = asyncio.get_running_loop() |
| | tid = task["id"] |
| |
|
| | for key_try in range(1, MAX_KEY_RETRIES + 1): |
| |
|
| | |
| | await _key_valid.wait() |
| |
|
| | |
| | async with _key_lock: |
| | active_key = USER_KEY |
| |
|
| | if not active_key: |
| | await queue.put({ |
| | "type": "error", |
| | "time": _now(), |
| | "task_id": tid, |
| | "message": "No userKey available. Set via /set_user_key", |
| | }) |
| | return None |
| |
|
| | |
| | result = await loop.run_in_executor( |
| | EXECUTOR, |
| | partial( |
| | CLIENT.generate_one, |
| | prompt=prompt, |
| | negative_prompt=task["negative"], |
| | seed=-1, |
| | resolution=task["resolution"], |
| | guidance_scale=task["guidance"], |
| | channel="ai-text-to-image-generator", |
| | sub_channel=task["sub_channel"], |
| | user_key=active_key, |
| | ad_access_code=ad_code, |
| | request_id=_reqid(), |
| | ), |
| | ) |
| |
|
| | |
| | if result.get("error") == "invalid_key": |
| | log.warning( |
| | "invalid_key for task %s (key_try %d/%d) β refreshing", |
| | tid, key_try, MAX_KEY_RETRIES, |
| | ) |
| | await queue.put({ |
| | "type": "key_invalid", |
| | "time": _now(), |
| | "task_id": tid, |
| | "attempt": key_try, |
| | "max_attempts": MAX_KEY_RETRIES, |
| | "message": "UserKey invalid β refreshing β¦", |
| | }) |
| |
|
| | new_key = await refresh_user_key() |
| | if new_key: |
| | |
| | ad_code = await loop.run_in_executor( |
| | EXECUTOR, CLIENT.get_ad_code, |
| | ) |
| | continue |
| | else: |
| | await queue.put({ |
| | "type": "error", |
| | "time": _now(), |
| | "task_id": tid, |
| | "message": "Could not refresh userKey β aborting image", |
| | }) |
| | return None |
| |
|
| | |
| | if result.get("error"): |
| | log.warning( |
| | "Gen error task=%s prompt='%.40s': %s", |
| | tid, prompt, result, |
| | ) |
| | await queue.put({ |
| | "type": "gen_error", |
| | "time": _now(), |
| | "task_id": tid, |
| | "prompt": prompt, |
| | "index": idx, |
| | "error": result, |
| | }) |
| | return None |
| |
|
| | |
| | try: |
| | if result.get("inline"): |
| | fp = await _save_inline(result["inline"], prompt) |
| | elif result.get("imageId"): |
| | fp = await _download(result["imageId"], prompt) |
| | else: |
| | log.error("Unexpected result: %s", result) |
| | return None |
| |
|
| | seed = result.get("seed") |
| | task["completed"] += 1 |
| | task["results"].append({ |
| | "prompt": prompt, |
| | "index": idx, |
| | "path": fp, |
| | "seed": seed, |
| | }) |
| | await queue.put({ |
| | "type": "image_ready", |
| | "time": _now(), |
| | "task_id": tid, |
| | "prompt": prompt, |
| | "index": idx, |
| | "path": fp, |
| | "seed": seed, |
| | "completed": task["completed"], |
| | "total": task["total_images"], |
| | }) |
| | return fp |
| |
|
| | except Exception as exc: |
| | log.exception("Save/download error task=%s: %s", tid, exc) |
| | await queue.put({ |
| | "type": "download_error", |
| | "time": _now(), |
| | "task_id": tid, |
| | "prompt": prompt, |
| | "index": idx, |
| | "error": str(exc), |
| | }) |
| | return None |
| |
|
| | |
| | log.error("Exhausted key retries for task %s prompt='%.40s'", tid, prompt) |
| | return None |
| |
|
| |
|
| | async def worker_loop(worker_id: int, semaphore: asyncio.Semaphore): |
| | """Long-running coroutine: pull jobs β generate images.""" |
| | log.info("Worker %d started", worker_id) |
| | loop = asyncio.get_running_loop() |
| |
|
| | while True: |
| | job = await JOB_QUEUE.get() |
| |
|
| | |
| | if job is None: |
| | log.info("Worker %d shutting down", worker_id) |
| | JOB_QUEUE.task_done() |
| | break |
| |
|
| | task = job["task"] |
| | tid = task["id"] |
| | queue = TASK_QUEUES.get(tid) |
| |
|
| | log.info( |
| | "Worker %d β task %s (%d images)", |
| | worker_id, tid, task["total_images"], |
| | ) |
| | task["status"] = "running" |
| | if queue: |
| | await queue.put({ |
| | "type": "started", |
| | "time": _now(), |
| | "task_id": tid, |
| | "total_images": task["total_images"], |
| | }) |
| |
|
| | |
| | ad_code = await loop.run_in_executor(EXECUTOR, CLIENT.get_ad_code) |
| |
|
| | |
| | async def _heartbeat(): |
| | while task["status"] == "running": |
| | await asyncio.sleep(5.0) |
| | if queue and task["status"] == "running": |
| | try: |
| | queue.put_nowait({ |
| | "type": "heartbeat", |
| | "time": _now(), |
| | "task_id": tid, |
| | "completed": task["completed"], |
| | "total": task["total_images"], |
| | }) |
| | except asyncio.QueueFull: |
| | pass |
| |
|
| | hb = asyncio.create_task(_heartbeat()) |
| |
|
| | try: |
| | for prompt in task["prompts"]: |
| | for i in range(task["count"]): |
| | async with semaphore: |
| | await _generate_single( |
| | prompt, task, i, queue, ad_code, |
| | ) |
| | if task["status"] == "failed": |
| | break |
| | if task["status"] == "failed": |
| | break |
| |
|
| | |
| | if task["status"] != "failed": |
| | if task["completed"] == 0 and task["total_images"] > 0: |
| | task["status"] = "failed" |
| | task["error"] = "No images generated successfully" |
| | else: |
| | task["status"] = "done" |
| |
|
| | if queue: |
| | await queue.put({ |
| | "type": task["status"], |
| | "time": _now(), |
| | "task_id": tid, |
| | "completed": task["completed"], |
| | "total": task["total_images"], |
| | "error": task.get("error"), |
| | }) |
| |
|
| | except Exception as exc: |
| | log.exception("Worker %d task %s crashed: %s", worker_id, tid, exc) |
| | task["status"] = "failed" |
| | task["error"] = str(exc) |
| | if queue: |
| | await queue.put({ |
| | "type": "failed", |
| | "time": _now(), |
| | "task_id": tid, |
| | "error": str(exc), |
| | }) |
| |
|
| | finally: |
| | hb.cancel() |
| | try: |
| | await hb |
| | except asyncio.CancelledError: |
| | pass |
| |
|
| | if queue: |
| | await queue.put({"type": "eof", "time": _now(), "task_id": tid}) |
| |
|
| | JOB_QUEUE.task_done() |
| | log.info( |
| | "Worker %d task %s finished (%s, %d/%d)", |
| | worker_id, tid, task["status"], |
| | task["completed"], task["total_images"], |
| | ) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | @asynccontextmanager |
| | async def lifespan(app: FastAPI): |
| | global USER_KEY, _key_lock, _key_valid, _key_refresh_lock |
| | global _key_last_ts, _key_fail_count, JOB_QUEUE |
| |
|
| | |
| | _key_lock = asyncio.Lock() |
| | _key_valid = asyncio.Event() |
| | _key_valid.set() |
| | _key_refresh_lock = asyncio.Lock() |
| | JOB_QUEUE = asyncio.Queue(maxsize=MAX_QUEUE_SIZE) |
| |
|
| | |
| | loop = asyncio.get_running_loop() |
| | try: |
| | await loop.run_in_executor(None, partial(_start_virtual_display_if_needed, ZD_HEADLESS)) |
| | except Exception: |
| | log.exception("Failed to ensure virtual display at startup") |
| |
|
| | |
| | if USER_KEY: |
| | log.info("Using pre-fetched userKey (len=%d)", len(USER_KEY)) |
| | _key_last_ts = time.time() |
| | else: |
| | skip = os.environ.get("NO_INITIAL_FETCH", "") in ("1", "true", "True") |
| | if not skip: |
| | try: |
| | key = await fetch_key_via_browser( |
| | timeout=ZD_TIMEOUT, headless=ZD_HEADLESS, |
| | ) |
| | if key: |
| | USER_KEY = key |
| | _key_last_ts = time.time() |
| | log.info("Fetched userKey at startup (len=%d)", len(key)) |
| | else: |
| | log.warning( |
| | "Startup key fetch failed. " |
| | "Use /set_user_key or /fetch_user_key." |
| | ) |
| | except Exception as exc: |
| | log.exception("Startup key fetch error: %s", exc) |
| | else: |
| | log.info("NO_INITIAL_FETCH=1 β skipping browser key fetch") |
| |
|
| | |
| | sem = asyncio.Semaphore(WORKER_COUNT) |
| | workers = [ |
| | asyncio.create_task(worker_loop(i + 1, sem)) |
| | for i in range(WORKER_COUNT) |
| | ] |
| | log.info("Launched %d workers", WORKER_COUNT) |
| |
|
| | |
| | yield |
| | |
| |
|
| | log.info("Shutdown: sending stop sentinels to workers β¦") |
| | for _ in range(WORKER_COUNT): |
| | await JOB_QUEUE.put(None) |
| | await asyncio.gather(*workers, return_exceptions=True) |
| |
|
| | try: |
| | SCRAPER.close() |
| | except Exception: |
| | pass |
| | EXECUTOR.shutdown(wait=True) |
| |
|
| | |
| | try: |
| | await loop.run_in_executor(None, _stop_virtual_display_if_needed) |
| | except Exception: |
| | log.exception("Failed to stop virtual display cleanly") |
| |
|
| | log.info("Shutdown complete") |
| |
|
| |
|
| | |
| | app = FastAPI( |
| | title="Perchance Image Generation Server v2 (pyvirtualdisplay)", |
| | lifespan=lifespan, |
| | ) |
| | app.add_middleware( |
| | CORSMiddleware, |
| | allow_origins=["*"], |
| | allow_methods=["*"], |
| | allow_headers=["*"], |
| | ) |
| |
|
| |
|
| | |
| |
|
| | @app.get("/health") |
| | async def health(): |
| | async with _key_lock: |
| | has_key = USER_KEY is not None |
| | return { |
| | "status": "ok", |
| | "has_user_key": has_key, |
| | "queue_size": JOB_QUEUE.qsize() if JOB_QUEUE else 0, |
| | "active_tasks": sum( |
| | 1 for t in TASKS.values() if t["status"] in ("queued", "running") |
| | ), |
| | } |
| |
|
| |
|
| | @app.get("/user_key") |
| | async def user_key_info(): |
| | async with _key_lock: |
| | has = USER_KEY is not None |
| | ln = len(USER_KEY) if has else 0 |
| | return {"has_user_key": has, "key_length": ln} |
| |
|
| |
|
| | @app.post("/set_user_key") |
| | async def set_user_key(payload: Dict[str, str]): |
| | global USER_KEY, _key_last_ts, _key_fail_count |
| | key = payload.get("userKey", "").strip() |
| | if not key: |
| | raise HTTPException(400, "userKey required") |
| | async with _key_lock: |
| | USER_KEY = key |
| | _key_last_ts = time.time() |
| | _key_fail_count = 0 |
| | _key_valid.set() |
| | log.info("userKey set via API (len=%d)", len(key)) |
| | return {"status": "ok", "key_length": len(key)} |
| |
|
| |
|
| | @app.post("/fetch_user_key") |
| | async def fetch_user_key_endpoint(): |
| | """Trigger a background browser-based key fetch.""" |
| | global _key_fail_count |
| |
|
| | async def _bg(): |
| | global _key_fail_count |
| | _key_fail_count = 0 |
| | await refresh_user_key() |
| |
|
| | asyncio.create_task(_bg()) |
| | return {"status": "started", "note": "Browser key fetch running in background"} |
| |
|
| |
|
| | @app.post("/generate") |
| | async def submit_job(payload: Dict[str, Any]): |
| | """ |
| | POST /generate |
| | Body: |
| | { |
| | "prompts": ["a cat in space", "sunset over mountains"], |
| | "count": 2, |
| | "resolution": "512x768", |
| | "guidance": 7.0, |
| | "negative": "", |
| | "subChannel": "private" |
| | } |
| | Returns: |
| | { "task_id": "...", "stream_url": "/stream/...", "queue_position": N } |
| | """ |
| | prompts = payload.get("prompts") or payload.get("prompt") or [] |
| | if isinstance(prompts, str): |
| | prompts = [prompts] |
| | if not isinstance(prompts, list) or not prompts: |
| | raise HTTPException(400, "prompts must be a non-empty list") |
| |
|
| | count = max(1, int(payload.get("count", 1))) |
| | resolution = payload.get("resolution", "512x768") |
| | guidance = float(payload.get("guidance", 7.0)) |
| | negative = payload.get("negative", "") or "" |
| | sub_channel = payload.get("subChannel", "private") |
| |
|
| | task = create_task(prompts, count, resolution, guidance, negative, sub_channel) |
| |
|
| | try: |
| | await JOB_QUEUE.put({"task": task}) |
| | except asyncio.QueueFull: |
| | raise HTTPException(503, "Server queue full β try again later") |
| |
|
| | position = JOB_QUEUE.qsize() |
| | q = TASK_QUEUES.get(task["id"]) |
| | if q: |
| | await q.put({ |
| | "type": "queued", |
| | "time": _now(), |
| | "task_id": task["id"], |
| | "queue_position": position, |
| | "total_images": task["total_images"], |
| | }) |
| |
|
| | return { |
| | "task_id": task["id"], |
| | "stream_url": f"/stream/{task['id']}", |
| | "queue_position": position, |
| | } |
| |
|
| |
|
| | @app.get("/stream/{task_id}") |
| | async def stream_task(request: Request, task_id: str): |
| | """ |
| | SSE stream. Event types: |
| | meta Β· queued Β· started Β· heartbeat Β· image_ready |
| | key_invalid Β· key_refreshing Β· key_refreshed Β· key_refresh_failed |
| | gen_error Β· download_error Β· done Β· failed Β· eof |
| | """ |
| | if task_id not in TASKS: |
| | raise HTTPException(404, "unknown task id") |
| |
|
| | task = TASKS[task_id] |
| | queue = TASK_QUEUES[task_id] |
| |
|
| | async def event_gen(): |
| | |
| | yield { |
| | "event": "meta", |
| | "data": json.dumps({ |
| | "task_id": task_id, |
| | "status": task["status"], |
| | "total_images": task["total_images"], |
| | "created_at": task["created_at"], |
| | }), |
| | } |
| |
|
| | |
| | if task["status"] in ("done", "failed"): |
| | for r in task["results"]: |
| | yield { |
| | "event": "image_ready", |
| | "data": json.dumps({ |
| | "task_id": task_id, |
| | "prompt": r["prompt"], |
| | "index": r["index"], |
| | "path": r["path"], |
| | "seed": r["seed"], |
| | "completed": task["completed"], |
| | "total": task["total_images"], |
| | }), |
| | } |
| | yield { |
| | "event": task["status"], |
| | "data": json.dumps({ |
| | "task_id": task_id, |
| | "completed": task["completed"], |
| | "total": task["total_images"], |
| | "error": task.get("error"), |
| | }), |
| | } |
| | yield { |
| | "event": "eof", |
| | "data": json.dumps({"task_id": task_id}), |
| | } |
| | return |
| |
|
| | |
| | while True: |
| | try: |
| | ev = await asyncio.wait_for(queue.get(), timeout=30.0) |
| | except asyncio.TimeoutError: |
| | |
| | yield { |
| | "event": "ping", |
| | "data": json.dumps({"time": _now()}), |
| | } |
| | if await request.is_disconnected(): |
| | log.info("SSE client disconnected (task %s)", task_id) |
| | break |
| | continue |
| |
|
| | yield { |
| | "event": ev.get("type", "event"), |
| | "data": json.dumps(ev), |
| | } |
| | if ev.get("type") == "eof": |
| | break |
| |
|
| | return EventSourceResponse(event_gen()) |
| |
|
| |
|
| | @app.get("/status/{task_id}") |
| | async def get_status(task_id: str): |
| | task = TASKS.get(task_id) |
| | if not task: |
| | raise HTTPException(404, "unknown task id") |
| | return {"task": task} |
| |
|
| |
|
| | @app.get("/outputs/{filename}") |
| | async def get_output(filename: str): |
| | fp = OUTPUT_DIR / filename |
| | if not fp.exists(): |
| | raise HTTPException(404, "file not found") |
| | return FileResponse(fp, media_type="application/octet-stream", filename=filename) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | if __name__ == "__main__": |
| | import uvicorn |
| |
|
| | |
| | ZD_HEADLESS = os.environ.get("ZD_HEADLESS", str(ZD_HEADLESS)) in ("1", "true", "True") |
| | USE_VIRTUAL_DISPLAY = os.environ.get("USE_VIRTUAL_DISPLAY", str(USE_VIRTUAL_DISPLAY)) in ("1", "true", "True") |
| |
|
| | |
| | try: |
| | _start_virtual_display_if_needed(ZD_HEADLESS) |
| | except Exception: |
| | log.exception("Failed to ensure virtual display in __main__") |
| |
|
| | uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("PORT", 7860))) |
| |
|