multimodalart HF Staff
fix: explicit / route to stop StaticFiles shadowing /gradio_api/startup-events
14c1e5d verified | import os | |
| os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True") | |
| import spaces # noqa: E402 -- must precede torch / CUDA imports | |
| import torch # noqa: E402 | |
| # PersonaLive ships full-pickle .pth checkpoints; torch>=2.6 defaults weights_only=True. | |
| _orig_torch_load = torch.load | |
| def _patched_torch_load(*args, **kwargs): | |
| kwargs.setdefault("weights_only", False) | |
| return _orig_torch_load(*args, **kwargs) | |
| torch.load = _patched_torch_load | |
| # diffusers 0.27 / transformers 4.36 reference huggingface_hub symbols removed in | |
| # hub>=1.0 (forced upon us by gradio 6). Re-inject them before those libs import. | |
| import huggingface_hub as _hub # noqa: E402 | |
| if not hasattr(_hub, "cached_download"): | |
| _hub.cached_download = _hub.hf_hub_download | |
| if not hasattr(_hub, "HfFolder"): | |
| class _HfFolder: | |
| def get_token(): | |
| return _hub.get_token() | |
| _hub.HfFolder = _HfFolder | |
| import sys # noqa: E402 | |
| import time # noqa: E402 | |
| from types import SimpleNamespace # noqa: E402 | |
| import cv2 # noqa: E402 | |
| import numpy as np # noqa: E402 | |
| import gradio as gr # noqa: E402 | |
| from PIL import Image # noqa: E402 | |
| sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) | |
| # ---- weights (downloaded once at startup, no GPU needed) ---- | |
| from tools.download_weights import prepare_base_model, prepare_vae, prepare_personalive # noqa: E402 | |
| prepare_base_model() | |
| prepare_vae() | |
| prepare_personalive() | |
| CONFIG_PATH = "./configs/prompts/personalive_online.yaml" | |
| ARGS = SimpleNamespace(config_path=CONFIG_PATH, acceleration="none") | |
| CHUNK = 4 # temporal_window_size: frames consumed/produced per diffusion call | |
| _model = None | |
| def get_model(): | |
| """Lazy, in-worker init. This is a held-session model, so the init cost is | |
| paid once per cold GPU worker and amortized across the whole session - | |
| module-scope init would gain nothing here and risks main-process CUDA init.""" | |
| global _model | |
| if _model is None: | |
| from src.wrapper import PersonaLive | |
| t0 = time.perf_counter() | |
| _model = PersonaLive(ARGS, device="cuda") | |
| print(f"[model] loaded in {time.perf_counter() - t0:.1f}s", flush=True) | |
| return _model | |
| AOTI_REPO = "multimodalart/PersonaLive-aoti" | |
| _aoti_lazy = None | |
| _aoti_bank_map = None | |
| def _load_aoti(): | |
| """Download + open the AoTI kernel package once. The .pt2 holds only | |
| hardware-specific kernels (no weights); constants are supplied per-fuse.""" | |
| global _aoti_lazy, _aoti_bank_map | |
| if _aoti_lazy is None: | |
| import json | |
| from pathlib import Path | |
| from huggingface_hub import snapshot_download | |
| from spaces.zero.torch.aoti import LazyAOTIModel | |
| repo = snapshot_download(AOTI_REPO, allow_patterns="package/*") | |
| pkg = Path(repo) / "package" | |
| meta = json.loads((pkg / "bank_constants.json").read_text()) | |
| _aoti_bank_map = meta["bank_constants"] | |
| _aoti_lazy = LazyAOTIModel(str(pkg / "submodules" / "denoising_unet" / "package.pt2")) | |
| print(f"[aoti] loaded kernels, {len(_aoti_bank_map)} bank constants", flush=True) | |
| return _aoti_lazy, _aoti_bank_map | |
| def patch_unet_aoti(model): | |
| """Patch denoising_unet.forward with the AoTI kernel. Must run AFTER fuse_reference | |
| (reference banks populated) and with keyframes disabled (model.num_khf=3) so the | |
| bank set matches the compiled graph. Reference banks (the 16 lifted constants) are | |
| supplied live from each block's bank[0] -> portrait-agnostic kernel.""" | |
| lazy, bank_map = _load_aoti() | |
| unet = model.denoising_unet | |
| weights = {} | |
| for n, p in unet.named_parameters(remove_duplicate=False): | |
| weights[n] = p | |
| for n, b in unet.named_buffers(remove_duplicate=False): | |
| weights[n] = b | |
| # The compiled model references the reference banks as flat constants | |
| # `_tensor_constant{i}` (NOT the export-time `lifted_tensor_N` FQN). Supply both | |
| # names; LazyAOTIModel filters to the set the kernel actually wants. Order the | |
| # blocks by the lifted index so `_tensor_constant{i}` gets block i's bank. | |
| ordered_blocks = [bank_map[k] for k in sorted( | |
| bank_map, key=lambda k: int(k.rsplit("lifted_tensor_", 1)[1]))] | |
| for fqn, block_path in bank_map.items(): | |
| weights[fqn] = unet.get_submodule(block_path).bank[0] | |
| for i, block_path in enumerate(ordered_blocks): | |
| weights[f"_tensor_constant{i}"] = unet.get_submodule(block_path).bank[0] | |
| unet.forward = lazy.with_weights(weights) | |
| def _frame_to_input(rgb: np.ndarray) -> torch.Tensor: | |
| """RGB uint8 HxWx3 -> (1,3,512,512) float in [-1,1] on cuda. Mirrors the | |
| transform the original websocket server applied to driving frames.""" | |
| rgb = cv2.resize(rgb, (512, 512), interpolation=cv2.INTER_AREA) | |
| t = torch.from_numpy(rgb).to("cuda").float() / 255.0 | |
| t = t * 2.0 - 1.0 | |
| return t.permute(2, 0, 1).unsqueeze(0) | |
| def selftest(num_frames: int): | |
| """Validate the model end-to-end on the bundled demo assets: fuse the demo | |
| reference portrait, then animate it with the demo driving video. Returns a | |
| few output frames plus timing so we can size GPU duration / measure fps.""" | |
| model = get_model() | |
| model.reset() | |
| ref = Image.open("demo/ref_image.png").convert("RGB") | |
| t0 = time.perf_counter() | |
| model.fuse_reference(ref) | |
| fuse_s = time.perf_counter() - t0 | |
| cap = cv2.VideoCapture("demo/driving_video.mp4") | |
| frames = [] | |
| while len(frames) < num_frames: | |
| ok, bgr = cap.read() | |
| if not ok: | |
| break | |
| frames.append(cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)) | |
| cap.release() | |
| outputs = [] | |
| gen_t0 = time.perf_counter() | |
| n_chunks = len(frames) // CHUNK | |
| for c in range(n_chunks): | |
| batch = [_frame_to_input(frames[c * CHUNK + i]) for i in range(CHUNK)] | |
| video = model.process_input(torch.cat(batch, dim=0)) # (CHUNK,h,w,c) in [0,1] | |
| for img in video: | |
| outputs.append(Image.fromarray((img * 255.0).astype(np.uint8))) | |
| gen_s = time.perf_counter() - gen_t0 | |
| produced = len(outputs) | |
| fps = produced / gen_s if gen_s > 0 else 0.0 | |
| report = ( | |
| f"fuse: {fuse_s:.1f}s | generated {produced} frames in {gen_s:.1f}s " | |
| f"({fps:.1f} fps) | per-chunk {gen_s / max(n_chunks,1):.2f}s" | |
| ) | |
| print("[selftest]", report, flush=True) | |
| return outputs[:8], report | |
| def _describe(x): | |
| if torch.is_tensor(x): | |
| return f"T{tuple(x.shape)}:{x.dtype}" | |
| if isinstance(x, (list, tuple)): | |
| return [_describe(i) for i in x] | |
| if isinstance(x, dict): | |
| return {k: _describe(v) for k, v in x.items()} | |
| return repr(x) | |
| def export_probe(): | |
| """Feasibility probe: capture the real denoising_unet inputs during a live | |
| chunk, then attempt torch.export.export on it. denoising_unet is the per-chunk | |
| bottleneck and the AoTI target; its hacked attention reads reference banks | |
| (self.bank/self.kv_bank) so we need to know empirically whether it traces.""" | |
| import traceback | |
| model = get_model() | |
| model.reset() | |
| model.fuse_reference(Image.open("demo/ref_image.png").convert("RGB")) | |
| cap = cv2.VideoCapture("demo/driving_video.mp4") | |
| frames = [] | |
| while len(frames) < CHUNK: | |
| ok, bgr = cap.read() | |
| if not ok: | |
| break | |
| frames.append(cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)) | |
| cap.release() | |
| batch = torch.cat([_frame_to_input(frames[i]) for i in range(CHUNK)], dim=0) | |
| unet = model.denoising_unet | |
| with spaces.aoti_capture(unet) as call: | |
| model.process_input(batch) | |
| args, kwargs = call.args, call.kwargs | |
| log = [ | |
| f"torch {torch.__version__}", | |
| f"captured args: {_describe(list(args))}", | |
| f"captured kwargs: {_describe(kwargs)}", | |
| ] | |
| try: | |
| ep = torch.export.export(unet, args, kwargs) | |
| n_nodes = len(list(ep.graph.nodes)) | |
| log.append(f"EXPORT OK — graph nodes: {n_nodes}") | |
| # Lifted constants: non-buffer tensor attributes (the reference banks live | |
| # here). For one .pt2 to serve any portrait these must be swappable at load. | |
| consts = dict(getattr(ep, "constants", {}) or {}) | |
| lifted = {k: v for k, v in consts.items() | |
| if torch.is_tensor(v) and k.rsplit(".", 1)[-1].startswith("lifted_tensor")} | |
| log.append(f"#constants: {len(consts)} | #lifted_tensor: {len(lifted)}") | |
| # For each lifted constant, derive its block path and compare against that | |
| # block's live self.bank — this validates the loader's FQN->bank mapping. | |
| match = mismatch = 0 | |
| for fqn, ct in list(lifted.items()): | |
| block_path = fqn.rsplit(".", 1)[0] | |
| try: | |
| blk = unet.get_submodule(block_path) | |
| except Exception as e: | |
| log.append(f" {fqn}: get_submodule FAIL {e}") | |
| mismatch += 1 | |
| continue | |
| bank = getattr(blk, "bank", None) | |
| kvb = getattr(blk, "kv_bank", None) | |
| blen = len(bank) if isinstance(bank, list) else ("None" if bank is None else "?") | |
| bshape = tuple(bank[0].shape) if isinstance(bank, list) and bank else None | |
| ok = bshape == tuple(ct.shape) | |
| match += int(ok) | |
| mismatch += int(not ok) | |
| if len(lifted) <= 4 or mismatch <= 3 or match <= 2: | |
| log.append(f" {fqn} const{tuple(ct.shape)} | bank len={blen} " | |
| f"bank0={bshape} kv_bank={'None' if kvb is None else tuple(kvb.shape)} " | |
| f"{'OK' if ok else 'MISMATCH'}") | |
| log.append(f"bank-shape match: {match}/{len(lifted)}") | |
| except Exception: | |
| log.append("EXPORT FAILED:\n" + traceback.format_exc()[-3500:]) | |
| report = "\n".join(log) | |
| print("[export_probe]\n" + report, flush=True) | |
| return report | |
| def _read_frames(n): | |
| cap = cv2.VideoCapture("demo/driving_video.mp4") | |
| frames = [] | |
| while len(frames) < n: | |
| ok, bgr = cap.read() | |
| if not ok: | |
| break | |
| frames.append(cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)) | |
| cap.release() | |
| return frames | |
| def aoti_test(num_frames: int): | |
| """Verify the AoTI kernel: (1) numerical match vs eager on one captured chunk, | |
| (2) end-to-end animated frames + fps with AoTI patched in.""" | |
| model = get_model() | |
| model.reset() | |
| model.fuse_reference(Image.open("demo/ref_image.png").convert("RGB")) | |
| model.num_khf = 3 # disable keyframes (match compile-time bank set) | |
| frames = _read_frames(max(num_frames, CHUNK)) | |
| unet = model.denoising_unet | |
| # Once patch_unet_aoti overwrites unet.forward (instance attr), the eager path is | |
| # only reachable via the class method. Bind it so each portrait's eager reference | |
| # is the real network, not a stale AoTI wrapper from a previous portrait. | |
| eager_forward = type(unet).forward | |
| def check(tag, ref_img): | |
| model.reset() | |
| model.fuse_reference(ref_img) | |
| model.num_khf = 3 | |
| batch = torch.cat([_frame_to_input(frames[i]) for i in range(CHUNK)], dim=0) | |
| with spaces.aoti_capture(unet) as cap: | |
| model.process_input(batch) | |
| a, k = cap.args, cap.kwargs | |
| with torch.no_grad(): | |
| eager = eager_forward(unet, *a, **k)[0].float() | |
| patch_unet_aoti(model) # supplies THIS portrait's banks live | |
| with torch.no_grad(): | |
| aoti = unet(*a, **k)[0].float() | |
| d = (eager - aoti).abs() | |
| rel = (d.max() / (eager.abs().max() + 1e-6)).item() | |
| return (f"[{tag}] max|eager-aoti|={d.max().item():.4e} mean={d.mean().item():.4e} " | |
| f"rel={rel:.4e} | eager range [{eager.min():.3f},{eager.max():.3f}]") | |
| # demo-ref = the portrait the kernel was COMPILED against (baked-on-disk constants | |
| # happen to equal these banks). alt-ref = a different portrait; a small diff there | |
| # proves the 16 banks are supplied live per-fuse, not baked into the kernel. | |
| log = [check("demo-ref", Image.open("demo/ref_image.png").convert("RGB"))] | |
| log.append(check("alt-ref ", Image.fromarray(frames[len(frames) // 2]))) | |
| # End-to-end AoTI generation + fps (fresh session, AoTI stays patched). | |
| model.reset() | |
| model.fuse_reference(Image.open("demo/ref_image.png").convert("RGB")) | |
| model.num_khf = 3 | |
| patch_unet_aoti(model) # rebuild bank map for the fresh fuse | |
| outputs = [] | |
| n_chunks = len(frames) // CHUNK | |
| t0 = time.perf_counter() | |
| for c in range(n_chunks): | |
| batch = torch.cat([_frame_to_input(frames[c * CHUNK + i]) for i in range(CHUNK)], dim=0) | |
| video = model.process_input(batch) | |
| for img in video: | |
| outputs.append(Image.fromarray((img * 255.0).astype(np.uint8))) | |
| gen_s = time.perf_counter() - t0 | |
| fps = len(outputs) / gen_s if gen_s > 0 else 0.0 | |
| log.append(f"AoTI e2e: {len(outputs)} frames in {gen_s:.1f}s ({fps:.1f} fps) " | |
| f"| per-chunk {gen_s/max(n_chunks,1):.2f}s") | |
| report = "\n".join(log) | |
| print("[aoti_test]\n" + report, flush=True) | |
| return outputs[:8], report | |
| def aoti_fqns(): | |
| """Diagnostic: what constant FQNs does the compiled package actually demand, | |
| and do they intersect the names I supply (params/buffers/lifted/_tensor_constant)? | |
| Also dump the baked-on-disk constant values' fingerprint to see if a no-op | |
| load_constants would silently fall back to compile-time (demo) banks.""" | |
| import json | |
| from pathlib import Path | |
| from huggingface_hub import snapshot_download | |
| model = get_model() | |
| model.reset() | |
| model.fuse_reference(Image.open("demo/ref_image.png").convert("RGB")) | |
| model.num_khf = 3 | |
| unet = model.denoising_unet | |
| repo = snapshot_download(AOTI_REPO, allow_patterns="package/*") | |
| pkg = Path(repo) / "package" | |
| bank_map = json.loads((pkg / "bank_constants.json").read_text())["bank_constants"] | |
| pt2 = str(pkg / "submodules" / "denoising_unet" / "package.pt2") | |
| cm = torch._inductor.aoti_load_package(pt2) | |
| fqns = list(cm.get_constant_fqns()) | |
| supplied = set() | |
| for n, _ in unet.named_parameters(remove_duplicate=False): | |
| supplied.add(n) | |
| for n, _ in unet.named_buffers(remove_duplicate=False): | |
| supplied.add(n) | |
| supplied |= set(bank_map.keys()) | |
| supplied |= {f"_tensor_constant{i}" for i in range(len(bank_map))} | |
| missing = [f for f in fqns if f not in supplied] | |
| log = [ | |
| f"#constant_fqns={len(fqns)} | #supplied_names={len(supplied)} | #missing={len(missing)}", | |
| f"sample fqns: {fqns[:8]}", | |
| f"tail fqns: {fqns[-8:]}", | |
| f"#_tensor_constant in fqns: {sum(1 for f in fqns if f.startswith('_tensor_constant'))}", | |
| f"#lifted_tensor in fqns: {sum(1 for f in fqns if 'lifted_tensor' in f)}", | |
| f"missing (first 20): {missing[:20]}", | |
| ] | |
| report = "\n".join(log) | |
| print("[aoti_fqns]\n" + report, flush=True) | |
| return report | |
| import multiprocessing as _mp # noqa: E402 | |
| import threading as _threading # noqa: E402 | |
| # Fork-context queues created at import (in the main process). The spaces lib forks its | |
| # GPU worker with multiprocessing.get_context('fork'), so a worker forked AFTER these | |
| # exist inherits the live pipe fds -> the parent (WS handler) can feed a held GPU worker. | |
| _QCTX = _mp.get_context("fork") | |
| PROBE_IN = _QCTX.Queue() | |
| PROBE_OUT = _QCTX.Queue() | |
| def _queue_worker(): | |
| """Held GPU worker: block waiting for an item the PARENT pushes AFTER we've forked. | |
| If this receives it, module-global fork-queues bridge the ZeroGPU fork boundary - | |
| which is the whole basis for feeding live WS frames into a held @spaces.GPU session.""" | |
| t0 = time.time() | |
| got = PROBE_IN.get(timeout=25) | |
| PROBE_OUT.put(f"worker pid={os.getpid()} got={got!r} after {time.time()-t0:.1f}s") | |
| return "ok" | |
| def queue_probe(): | |
| """Parent side: fork the GPU worker (in a thread so it blocks on get), then push a | |
| live item and confirm the worker received it across the fork boundary.""" | |
| th = _threading.Thread(target=_queue_worker, daemon=True) | |
| th.start() | |
| time.sleep(4) # let the worker fork and reach PROBE_IN.get() | |
| PROBE_IN.put(f"live-{int(time.time())}-from-pid-{os.getpid()}") | |
| th.join(timeout=30) | |
| try: | |
| res = PROBE_OUT.get(timeout=8) | |
| except Exception as e: | |
| res = f"NO OUTPUT FROM WORKER: {type(e).__name__} {e}" | |
| report = f"parent pid={os.getpid()}\n{res}" | |
| print("[queue_probe]\n" + report, flush=True) | |
| return report | |
| # ---- Held-session machinery (the real WS backend runs on these) ---- | |
| # Parent (WS handler) <-> held GPU worker, bridged by fork-inherited queues. | |
| SESS_IN = _QCTX.Queue() # parent->worker: driving RGB frames (np.uint8 HxWx3) | |
| SESS_OUT = _QCTX.Queue() # worker->parent: ("ready"|"frame"|"error", payload) | |
| SESS_REF = _QCTX.Queue() # parent->worker: reference PIL.Image to fuse | |
| def session_worker(): | |
| """Held GPU session: the original generate_process loop, AoTI-accelerated, fed live | |
| via fork-queues instead of a raw mp.Process. Inits the model, blocks for a reference, | |
| fuses + AoTI-patches, then consumes driving frames in CHUNK batches and emits frames. | |
| Holds the GPU only as long as the session runs (early-returns on the stop sentinel).""" | |
| import traceback | |
| try: | |
| model = get_model() | |
| model.reset() | |
| ref = SESS_REF.get() # block until parent supplies a reference | |
| model.fuse_reference(ref) | |
| model.num_khf = 3 # pin keyframes off (AoTI bank set is fixed) | |
| patch_unet_aoti(model) | |
| SESS_OUT.put(("ready", os.getpid())) | |
| buf = [] | |
| while True: | |
| item = SESS_IN.get() | |
| if item is None: # stop sentinel | |
| break | |
| buf.append(item) | |
| if len(buf) >= CHUNK: | |
| chunk, buf = buf[:CHUNK], buf[CHUNK:] | |
| batch = torch.cat([_frame_to_input(f) for f in chunk], dim=0) | |
| video = model.process_input(batch) | |
| for img in video: | |
| SESS_OUT.put(("frame", (img * 255.0).astype(np.uint8))) | |
| except Exception: | |
| SESS_OUT.put(("error", traceback.format_exc()[-2000:])) | |
| finally: | |
| SESS_OUT.put(("done", None)) | |
| def gpu_fuse(): | |
| """Cross-call state experiment, step 1: init model + fuse a reference inside one | |
| @spaces.GPU call, then RETURN (releasing the GPU). Records the worker pid and the | |
| model object id so step 2 can tell whether the worker (and its CUDA state) persists.""" | |
| global _model | |
| model = get_model() | |
| model.reset() | |
| model.fuse_reference(Image.open("demo/ref_image.png").convert("RGB")) | |
| model.num_khf = 3 | |
| report = (f"pid={os.getpid()} id(_model)={id(_model)} " | |
| f"ref_latents={hasattr(model, 'ref_image_latents')} " | |
| f"motion_bank={'set' if getattr(model, 'motion_bank', None) is not None else 'None'}") | |
| print("[gpu_fuse]", report, flush=True) | |
| return report | |
| def gpu_step(): | |
| """Cross-call state experiment, step 2: in a SEPARATE @spaces.GPU call, try to run | |
| process_input WITHOUT re-fusing. If the worker persisted, _model is already fused | |
| and this produces a valid frame -> per-chunk @spaces.GPU architecture is viable | |
| (state survives across calls). If the worker is fresh, _model is None / unfused and | |
| this reveals that we need a held session instead.""" | |
| global _model | |
| was_none = _model is None | |
| log = [f"pid={os.getpid()} _model_was_none={was_none}"] | |
| try: | |
| model = get_model() | |
| log.append(f"id(_model)={id(_model)} ref_latents={hasattr(model, 'ref_image_latents')}") | |
| frames = _read_frames(CHUNK) | |
| batch = torch.cat([_frame_to_input(frames[i]) for i in range(CHUNK)], dim=0) | |
| video = model.process_input(batch) | |
| log.append(f"process_ok frames={len(video)} mean={float(np.mean(video)):.4f} " | |
| f"range=[{float(np.min(video)):.3f},{float(np.max(video)):.3f}]") | |
| except Exception as e: | |
| import traceback | |
| log.append("PROCESS FAILED:\n" + traceback.format_exc()[-1500:]) | |
| report = "\n".join(log) | |
| print("[gpu_step]\n" + report, flush=True) | |
| return report | |
| # ==================== Real-time WebSocket backend ==================== | |
| # Keeps PersonaLive's own Svelte frontend + WS transport. The held GPU session | |
| # (session_worker) runs in a parent-side thread fed by the fork-queues above; the | |
| # visitor's X-IP-Token (minted by the huggingface.co parent via postMessage and | |
| # forwarded over the WS) is injected into LocalContext.request so spaces.schedule | |
| # bills the right HF account instead of falling back to the Space's IP quota. | |
| import asyncio # noqa: E402 | |
| import json # noqa: E402 | |
| import queue as _queue # noqa: E402 | |
| from io import BytesIO # noqa: E402 | |
| from fastapi import File, UploadFile, WebSocket, WebSocketDisconnect # noqa: E402 | |
| from fastapi.responses import FileResponse, HTMLResponse, JSONResponse # noqa: E402 | |
| from fastapi.staticfiles import StaticFiles # noqa: E402 | |
| from gradio.context import LocalContext # noqa: E402 | |
| class _IPHeaders(dict): | |
| """spaces._get_headers requires request.headers to have BOTH __dict__ (a plain | |
| dict lacks it) AND .get -> a dict SUBCLASS satisfies both. See | |
| reference_zerogpu_xiptoken_handshake.""" | |
| def _ip_request(token): | |
| if not token: | |
| return None | |
| req = SimpleNamespace() | |
| req.headers = _IPHeaders() | |
| req.headers["x-ip-token"] = token | |
| return req | |
| def _drain(q): | |
| try: | |
| while True: | |
| q.get_nowait() | |
| except Exception: | |
| pass | |
| _PENDING_REF = None # set by /api/upload_reference_image, fused at WS open | |
| _session_busy = _threading.Lock() | |
| def _run_session(token): | |
| """Background thread: set the GPU-quota context for THIS thread (contextvars don't | |
| propagate across threads), then enter the held @spaces.GPU session_worker.""" | |
| req = _ip_request(token) | |
| if req is not None: | |
| LocalContext.request.set(req) | |
| session_worker() | |
| # gr.Server (gradio 6) inherits from FastAPI, so .get/.post/.websocket/.mount work | |
| # directly while keeping Gradio's queue + ZeroGPU scheduling. demo=app below; the HF | |
| # runtime launches it (no separate uvicorn -> avoids the 7860 double-bind). | |
| app = gr.Server(title="PersonaLive") | |
| async def settings(): | |
| # Minimal shape the Svelte frontend reads (image input mode, no extra params). | |
| return JSONResponse({ | |
| "info": {"properties": {"input_mode": {"default": "image"}}}, | |
| "input_params": {"properties": {}}, | |
| "max_queue_size": 1, | |
| "page_content": "", | |
| }) | |
| async def queue_size(): | |
| return JSONResponse({"queue_size": 1 if _session_busy.locked() else 0}) | |
| async def upload_reference_image(ref_image: UploadFile = File(...)): | |
| global _PENDING_REF | |
| data = await ref_image.read() | |
| _PENDING_REF = Image.open(BytesIO(data)).convert("RGB") | |
| return JSONResponse({"status": "ok"}) | |
| async def reset(): | |
| global _PENDING_REF | |
| _PENDING_REF = None | |
| return JSONResponse({"status": "ok"}) | |
| async def ws_endpoint(websocket: WebSocket, user_id: str): | |
| await websocket.accept() | |
| if not _session_busy.acquire(blocking=False): | |
| await websocket.send_text(json.dumps({"status": "timeout"})) | |
| await websocket.close() | |
| return | |
| loop = asyncio.get_event_loop() | |
| send_task = None | |
| try: | |
| # 1. first text frame carries the visitor's X-IP-Token (for GPU quota). | |
| token = None | |
| try: | |
| first = await asyncio.wait_for(websocket.receive_text(), timeout=15) | |
| token = json.loads(first).get("x_ip_token") | |
| except Exception: | |
| pass | |
| ref = _PENDING_REF | |
| if ref is None: | |
| await websocket.send_text(json.dumps( | |
| {"status": "error", "message": "Upload a reference portrait first."})) | |
| return | |
| # 2. start the held GPU session and wait for it to fuse + be ready. | |
| _drain(SESS_IN); _drain(SESS_OUT); _drain(SESS_REF) | |
| SESS_REF.put(ref) | |
| _threading.Thread(target=_run_session, args=(token,), daemon=True).start() | |
| # Cold start loads the model in-worker (~21s). Send heartbeats while we | |
| # wait so the client's WS keepalive ping doesn't time out on the silence. | |
| ready = None | |
| while ready is None: | |
| try: | |
| ready = await loop.run_in_executor( | |
| None, lambda: SESS_OUT.get(timeout=2)) | |
| except _queue.Empty: | |
| await websocket.send_text(json.dumps({"status": "loading"})) | |
| if ready[0] != "ready": | |
| await websocket.send_text(json.dumps( | |
| {"status": "error", "message": f"session start failed: {ready}"})) | |
| return | |
| await websocket.send_text(json.dumps({"status": "connected"})) | |
| await websocket.send_text(json.dumps({"status": "send_frame"})) | |
| # 3. worker -> client: stream generated frames as JPEG. | |
| async def sender(): | |
| while True: | |
| kind, payload = await loop.run_in_executor(None, SESS_OUT.get) | |
| if kind == "frame": | |
| bgr = cv2.cvtColor(payload, cv2.COLOR_RGB2BGR) | |
| ok, jpg = cv2.imencode(".jpg", bgr) | |
| if ok: | |
| await websocket.send_bytes(jpg.tobytes()) | |
| elif kind == "error": | |
| await websocket.send_text(json.dumps( | |
| {"status": "error", "message": str(payload)[:500]})) | |
| break | |
| elif kind == "done": | |
| break | |
| send_task = asyncio.create_task(sender()) | |
| # 4. client -> worker: decode driving JPEGs, prefer-latest (drop backlog so the | |
| # ~20fps client doesn't outrun the ~6.8fps model and pile up latency). | |
| while True: | |
| msg = await websocket.receive() | |
| if msg.get("type") == "websocket.disconnect": | |
| break | |
| if msg.get("bytes"): | |
| arr = np.frombuffer(msg["bytes"], np.uint8) | |
| bgr = cv2.imdecode(arr, cv2.IMREAD_COLOR) | |
| if bgr is not None: | |
| if SESS_IN.qsize() <= CHUNK * 2: | |
| SESS_IN.put(cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB)) | |
| elif msg.get("text"): | |
| try: | |
| if json.loads(msg["text"]).get("status") == "resume": | |
| await websocket.send_text(json.dumps({"status": "send_frame"})) | |
| except Exception: | |
| pass | |
| except WebSocketDisconnect: | |
| pass | |
| except Exception as e: | |
| print(f"[ws] error: {e}", flush=True) | |
| finally: | |
| SESS_IN.put(None) # stop sentinel -> worker exits -> releases GPU slot | |
| if send_task is not None: | |
| send_task.cancel() | |
| _session_busy.release() | |
| async def health(): | |
| # Quick liveness probe (no GPU): confirms the server + routes are up. | |
| return JSONResponse({"status": "ok", "session_busy": _session_busy.locked()}) | |
| async def selftest_route(num_frames: int = 16): | |
| # Debug-only: run the eager self-test (no X-IP-Token -> creator IP quota) so the | |
| # model path can be validated via curl without the frontend. | |
| loop = asyncio.get_event_loop() | |
| _, report = await loop.run_in_executor(None, selftest, num_frames) | |
| return JSONResponse({"report": report}) | |
| # Serve PersonaLive's Svelte build. gr.Server overrides a mount("/") with its own | |
| # page route, so the root is an explicit handler (per the gr.Server blueprint) and | |
| # the build's asset dirs are mounted at their referenced sub-paths. | |
| _FRONTEND_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "frontend_public") | |
| if os.path.isdir(_FRONTEND_DIR): | |
| for _sub in ("_app", "presets"): | |
| _p = os.path.join(_FRONTEND_DIR, _sub) | |
| if os.path.isdir(_p): | |
| app.mount(f"/{_sub}", StaticFiles(directory=_p), name=f"frontend_{_sub}") | |
| async def _favicon(): | |
| return FileResponse(os.path.join(_FRONTEND_DIR, "favicon.png")) | |
| async def _homepage(): | |
| with open(os.path.join(_FRONTEND_DIR, "index.html"), encoding="utf-8") as f: | |
| return f.read() | |
| demo = app # HF runtime launches `demo` | |
| if __name__ == "__main__": | |
| # Don't force ssr_mode: with SSR off, gr.Server registers its own GET "/" | |
| # page route, which collides with the explicit homepage route above and | |
| # aborts gradio's route setup (startup-events 404). Let gradio auto-detect. | |
| app.launch(server_name="0.0.0.0", server_port=7860, show_error=True) | |
| # deploy: explicit / route, no mount("/") shadowing /gradio_api | |