PersonaLive / app.py
multimodalart's picture
multimodalart HF Staff
fix: explicit / route to stop StaticFiles shadowing /gradio_api/startup-events
14c1e5d verified
Raw
History Blame Contribute Delete
29.1 kB
import os
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
import spaces # noqa: E402 -- must precede torch / CUDA imports
import torch # noqa: E402
# PersonaLive ships full-pickle .pth checkpoints; torch>=2.6 defaults weights_only=True.
_orig_torch_load = torch.load
def _patched_torch_load(*args, **kwargs):
kwargs.setdefault("weights_only", False)
return _orig_torch_load(*args, **kwargs)
torch.load = _patched_torch_load
# diffusers 0.27 / transformers 4.36 reference huggingface_hub symbols removed in
# hub>=1.0 (forced upon us by gradio 6). Re-inject them before those libs import.
import huggingface_hub as _hub # noqa: E402
if not hasattr(_hub, "cached_download"):
_hub.cached_download = _hub.hf_hub_download
if not hasattr(_hub, "HfFolder"):
class _HfFolder:
@staticmethod
def get_token():
return _hub.get_token()
_hub.HfFolder = _HfFolder
import sys # noqa: E402
import time # noqa: E402
from types import SimpleNamespace # noqa: E402
import cv2 # noqa: E402
import numpy as np # noqa: E402
import gradio as gr # noqa: E402
from PIL import Image # noqa: E402
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
# ---- weights (downloaded once at startup, no GPU needed) ----
from tools.download_weights import prepare_base_model, prepare_vae, prepare_personalive # noqa: E402
prepare_base_model()
prepare_vae()
prepare_personalive()
CONFIG_PATH = "./configs/prompts/personalive_online.yaml"
ARGS = SimpleNamespace(config_path=CONFIG_PATH, acceleration="none")
CHUNK = 4 # temporal_window_size: frames consumed/produced per diffusion call
_model = None
def get_model():
"""Lazy, in-worker init. This is a held-session model, so the init cost is
paid once per cold GPU worker and amortized across the whole session -
module-scope init would gain nothing here and risks main-process CUDA init."""
global _model
if _model is None:
from src.wrapper import PersonaLive
t0 = time.perf_counter()
_model = PersonaLive(ARGS, device="cuda")
print(f"[model] loaded in {time.perf_counter() - t0:.1f}s", flush=True)
return _model
AOTI_REPO = "multimodalart/PersonaLive-aoti"
_aoti_lazy = None
_aoti_bank_map = None
def _load_aoti():
"""Download + open the AoTI kernel package once. The .pt2 holds only
hardware-specific kernels (no weights); constants are supplied per-fuse."""
global _aoti_lazy, _aoti_bank_map
if _aoti_lazy is None:
import json
from pathlib import Path
from huggingface_hub import snapshot_download
from spaces.zero.torch.aoti import LazyAOTIModel
repo = snapshot_download(AOTI_REPO, allow_patterns="package/*")
pkg = Path(repo) / "package"
meta = json.loads((pkg / "bank_constants.json").read_text())
_aoti_bank_map = meta["bank_constants"]
_aoti_lazy = LazyAOTIModel(str(pkg / "submodules" / "denoising_unet" / "package.pt2"))
print(f"[aoti] loaded kernels, {len(_aoti_bank_map)} bank constants", flush=True)
return _aoti_lazy, _aoti_bank_map
def patch_unet_aoti(model):
"""Patch denoising_unet.forward with the AoTI kernel. Must run AFTER fuse_reference
(reference banks populated) and with keyframes disabled (model.num_khf=3) so the
bank set matches the compiled graph. Reference banks (the 16 lifted constants) are
supplied live from each block's bank[0] -> portrait-agnostic kernel."""
lazy, bank_map = _load_aoti()
unet = model.denoising_unet
weights = {}
for n, p in unet.named_parameters(remove_duplicate=False):
weights[n] = p
for n, b in unet.named_buffers(remove_duplicate=False):
weights[n] = b
# The compiled model references the reference banks as flat constants
# `_tensor_constant{i}` (NOT the export-time `lifted_tensor_N` FQN). Supply both
# names; LazyAOTIModel filters to the set the kernel actually wants. Order the
# blocks by the lifted index so `_tensor_constant{i}` gets block i's bank.
ordered_blocks = [bank_map[k] for k in sorted(
bank_map, key=lambda k: int(k.rsplit("lifted_tensor_", 1)[1]))]
for fqn, block_path in bank_map.items():
weights[fqn] = unet.get_submodule(block_path).bank[0]
for i, block_path in enumerate(ordered_blocks):
weights[f"_tensor_constant{i}"] = unet.get_submodule(block_path).bank[0]
unet.forward = lazy.with_weights(weights)
def _frame_to_input(rgb: np.ndarray) -> torch.Tensor:
"""RGB uint8 HxWx3 -> (1,3,512,512) float in [-1,1] on cuda. Mirrors the
transform the original websocket server applied to driving frames."""
rgb = cv2.resize(rgb, (512, 512), interpolation=cv2.INTER_AREA)
t = torch.from_numpy(rgb).to("cuda").float() / 255.0
t = t * 2.0 - 1.0
return t.permute(2, 0, 1).unsqueeze(0)
@spaces.GPU(duration=180)
def selftest(num_frames: int):
"""Validate the model end-to-end on the bundled demo assets: fuse the demo
reference portrait, then animate it with the demo driving video. Returns a
few output frames plus timing so we can size GPU duration / measure fps."""
model = get_model()
model.reset()
ref = Image.open("demo/ref_image.png").convert("RGB")
t0 = time.perf_counter()
model.fuse_reference(ref)
fuse_s = time.perf_counter() - t0
cap = cv2.VideoCapture("demo/driving_video.mp4")
frames = []
while len(frames) < num_frames:
ok, bgr = cap.read()
if not ok:
break
frames.append(cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB))
cap.release()
outputs = []
gen_t0 = time.perf_counter()
n_chunks = len(frames) // CHUNK
for c in range(n_chunks):
batch = [_frame_to_input(frames[c * CHUNK + i]) for i in range(CHUNK)]
video = model.process_input(torch.cat(batch, dim=0)) # (CHUNK,h,w,c) in [0,1]
for img in video:
outputs.append(Image.fromarray((img * 255.0).astype(np.uint8)))
gen_s = time.perf_counter() - gen_t0
produced = len(outputs)
fps = produced / gen_s if gen_s > 0 else 0.0
report = (
f"fuse: {fuse_s:.1f}s | generated {produced} frames in {gen_s:.1f}s "
f"({fps:.1f} fps) | per-chunk {gen_s / max(n_chunks,1):.2f}s"
)
print("[selftest]", report, flush=True)
return outputs[:8], report
def _describe(x):
if torch.is_tensor(x):
return f"T{tuple(x.shape)}:{x.dtype}"
if isinstance(x, (list, tuple)):
return [_describe(i) for i in x]
if isinstance(x, dict):
return {k: _describe(v) for k, v in x.items()}
return repr(x)
@spaces.GPU(duration=300)
def export_probe():
"""Feasibility probe: capture the real denoising_unet inputs during a live
chunk, then attempt torch.export.export on it. denoising_unet is the per-chunk
bottleneck and the AoTI target; its hacked attention reads reference banks
(self.bank/self.kv_bank) so we need to know empirically whether it traces."""
import traceback
model = get_model()
model.reset()
model.fuse_reference(Image.open("demo/ref_image.png").convert("RGB"))
cap = cv2.VideoCapture("demo/driving_video.mp4")
frames = []
while len(frames) < CHUNK:
ok, bgr = cap.read()
if not ok:
break
frames.append(cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB))
cap.release()
batch = torch.cat([_frame_to_input(frames[i]) for i in range(CHUNK)], dim=0)
unet = model.denoising_unet
with spaces.aoti_capture(unet) as call:
model.process_input(batch)
args, kwargs = call.args, call.kwargs
log = [
f"torch {torch.__version__}",
f"captured args: {_describe(list(args))}",
f"captured kwargs: {_describe(kwargs)}",
]
try:
ep = torch.export.export(unet, args, kwargs)
n_nodes = len(list(ep.graph.nodes))
log.append(f"EXPORT OK — graph nodes: {n_nodes}")
# Lifted constants: non-buffer tensor attributes (the reference banks live
# here). For one .pt2 to serve any portrait these must be swappable at load.
consts = dict(getattr(ep, "constants", {}) or {})
lifted = {k: v for k, v in consts.items()
if torch.is_tensor(v) and k.rsplit(".", 1)[-1].startswith("lifted_tensor")}
log.append(f"#constants: {len(consts)} | #lifted_tensor: {len(lifted)}")
# For each lifted constant, derive its block path and compare against that
# block's live self.bank — this validates the loader's FQN->bank mapping.
match = mismatch = 0
for fqn, ct in list(lifted.items()):
block_path = fqn.rsplit(".", 1)[0]
try:
blk = unet.get_submodule(block_path)
except Exception as e:
log.append(f" {fqn}: get_submodule FAIL {e}")
mismatch += 1
continue
bank = getattr(blk, "bank", None)
kvb = getattr(blk, "kv_bank", None)
blen = len(bank) if isinstance(bank, list) else ("None" if bank is None else "?")
bshape = tuple(bank[0].shape) if isinstance(bank, list) and bank else None
ok = bshape == tuple(ct.shape)
match += int(ok)
mismatch += int(not ok)
if len(lifted) <= 4 or mismatch <= 3 or match <= 2:
log.append(f" {fqn} const{tuple(ct.shape)} | bank len={blen} "
f"bank0={bshape} kv_bank={'None' if kvb is None else tuple(kvb.shape)} "
f"{'OK' if ok else 'MISMATCH'}")
log.append(f"bank-shape match: {match}/{len(lifted)}")
except Exception:
log.append("EXPORT FAILED:\n" + traceback.format_exc()[-3500:])
report = "\n".join(log)
print("[export_probe]\n" + report, flush=True)
return report
def _read_frames(n):
cap = cv2.VideoCapture("demo/driving_video.mp4")
frames = []
while len(frames) < n:
ok, bgr = cap.read()
if not ok:
break
frames.append(cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB))
cap.release()
return frames
@spaces.GPU(duration=300)
def aoti_test(num_frames: int):
"""Verify the AoTI kernel: (1) numerical match vs eager on one captured chunk,
(2) end-to-end animated frames + fps with AoTI patched in."""
model = get_model()
model.reset()
model.fuse_reference(Image.open("demo/ref_image.png").convert("RGB"))
model.num_khf = 3 # disable keyframes (match compile-time bank set)
frames = _read_frames(max(num_frames, CHUNK))
unet = model.denoising_unet
# Once patch_unet_aoti overwrites unet.forward (instance attr), the eager path is
# only reachable via the class method. Bind it so each portrait's eager reference
# is the real network, not a stale AoTI wrapper from a previous portrait.
eager_forward = type(unet).forward
def check(tag, ref_img):
model.reset()
model.fuse_reference(ref_img)
model.num_khf = 3
batch = torch.cat([_frame_to_input(frames[i]) for i in range(CHUNK)], dim=0)
with spaces.aoti_capture(unet) as cap:
model.process_input(batch)
a, k = cap.args, cap.kwargs
with torch.no_grad():
eager = eager_forward(unet, *a, **k)[0].float()
patch_unet_aoti(model) # supplies THIS portrait's banks live
with torch.no_grad():
aoti = unet(*a, **k)[0].float()
d = (eager - aoti).abs()
rel = (d.max() / (eager.abs().max() + 1e-6)).item()
return (f"[{tag}] max|eager-aoti|={d.max().item():.4e} mean={d.mean().item():.4e} "
f"rel={rel:.4e} | eager range [{eager.min():.3f},{eager.max():.3f}]")
# demo-ref = the portrait the kernel was COMPILED against (baked-on-disk constants
# happen to equal these banks). alt-ref = a different portrait; a small diff there
# proves the 16 banks are supplied live per-fuse, not baked into the kernel.
log = [check("demo-ref", Image.open("demo/ref_image.png").convert("RGB"))]
log.append(check("alt-ref ", Image.fromarray(frames[len(frames) // 2])))
# End-to-end AoTI generation + fps (fresh session, AoTI stays patched).
model.reset()
model.fuse_reference(Image.open("demo/ref_image.png").convert("RGB"))
model.num_khf = 3
patch_unet_aoti(model) # rebuild bank map for the fresh fuse
outputs = []
n_chunks = len(frames) // CHUNK
t0 = time.perf_counter()
for c in range(n_chunks):
batch = torch.cat([_frame_to_input(frames[c * CHUNK + i]) for i in range(CHUNK)], dim=0)
video = model.process_input(batch)
for img in video:
outputs.append(Image.fromarray((img * 255.0).astype(np.uint8)))
gen_s = time.perf_counter() - t0
fps = len(outputs) / gen_s if gen_s > 0 else 0.0
log.append(f"AoTI e2e: {len(outputs)} frames in {gen_s:.1f}s ({fps:.1f} fps) "
f"| per-chunk {gen_s/max(n_chunks,1):.2f}s")
report = "\n".join(log)
print("[aoti_test]\n" + report, flush=True)
return outputs[:8], report
@spaces.GPU(duration=180)
def aoti_fqns():
"""Diagnostic: what constant FQNs does the compiled package actually demand,
and do they intersect the names I supply (params/buffers/lifted/_tensor_constant)?
Also dump the baked-on-disk constant values' fingerprint to see if a no-op
load_constants would silently fall back to compile-time (demo) banks."""
import json
from pathlib import Path
from huggingface_hub import snapshot_download
model = get_model()
model.reset()
model.fuse_reference(Image.open("demo/ref_image.png").convert("RGB"))
model.num_khf = 3
unet = model.denoising_unet
repo = snapshot_download(AOTI_REPO, allow_patterns="package/*")
pkg = Path(repo) / "package"
bank_map = json.loads((pkg / "bank_constants.json").read_text())["bank_constants"]
pt2 = str(pkg / "submodules" / "denoising_unet" / "package.pt2")
cm = torch._inductor.aoti_load_package(pt2)
fqns = list(cm.get_constant_fqns())
supplied = set()
for n, _ in unet.named_parameters(remove_duplicate=False):
supplied.add(n)
for n, _ in unet.named_buffers(remove_duplicate=False):
supplied.add(n)
supplied |= set(bank_map.keys())
supplied |= {f"_tensor_constant{i}" for i in range(len(bank_map))}
missing = [f for f in fqns if f not in supplied]
log = [
f"#constant_fqns={len(fqns)} | #supplied_names={len(supplied)} | #missing={len(missing)}",
f"sample fqns: {fqns[:8]}",
f"tail fqns: {fqns[-8:]}",
f"#_tensor_constant in fqns: {sum(1 for f in fqns if f.startswith('_tensor_constant'))}",
f"#lifted_tensor in fqns: {sum(1 for f in fqns if 'lifted_tensor' in f)}",
f"missing (first 20): {missing[:20]}",
]
report = "\n".join(log)
print("[aoti_fqns]\n" + report, flush=True)
return report
import multiprocessing as _mp # noqa: E402
import threading as _threading # noqa: E402
# Fork-context queues created at import (in the main process). The spaces lib forks its
# GPU worker with multiprocessing.get_context('fork'), so a worker forked AFTER these
# exist inherits the live pipe fds -> the parent (WS handler) can feed a held GPU worker.
_QCTX = _mp.get_context("fork")
PROBE_IN = _QCTX.Queue()
PROBE_OUT = _QCTX.Queue()
@spaces.GPU(duration=40)
def _queue_worker():
"""Held GPU worker: block waiting for an item the PARENT pushes AFTER we've forked.
If this receives it, module-global fork-queues bridge the ZeroGPU fork boundary -
which is the whole basis for feeding live WS frames into a held @spaces.GPU session."""
t0 = time.time()
got = PROBE_IN.get(timeout=25)
PROBE_OUT.put(f"worker pid={os.getpid()} got={got!r} after {time.time()-t0:.1f}s")
return "ok"
def queue_probe():
"""Parent side: fork the GPU worker (in a thread so it blocks on get), then push a
live item and confirm the worker received it across the fork boundary."""
th = _threading.Thread(target=_queue_worker, daemon=True)
th.start()
time.sleep(4) # let the worker fork and reach PROBE_IN.get()
PROBE_IN.put(f"live-{int(time.time())}-from-pid-{os.getpid()}")
th.join(timeout=30)
try:
res = PROBE_OUT.get(timeout=8)
except Exception as e:
res = f"NO OUTPUT FROM WORKER: {type(e).__name__} {e}"
report = f"parent pid={os.getpid()}\n{res}"
print("[queue_probe]\n" + report, flush=True)
return report
# ---- Held-session machinery (the real WS backend runs on these) ----
# Parent (WS handler) <-> held GPU worker, bridged by fork-inherited queues.
SESS_IN = _QCTX.Queue() # parent->worker: driving RGB frames (np.uint8 HxWx3)
SESS_OUT = _QCTX.Queue() # worker->parent: ("ready"|"frame"|"error", payload)
SESS_REF = _QCTX.Queue() # parent->worker: reference PIL.Image to fuse
@spaces.GPU(duration=120)
def session_worker():
"""Held GPU session: the original generate_process loop, AoTI-accelerated, fed live
via fork-queues instead of a raw mp.Process. Inits the model, blocks for a reference,
fuses + AoTI-patches, then consumes driving frames in CHUNK batches and emits frames.
Holds the GPU only as long as the session runs (early-returns on the stop sentinel)."""
import traceback
try:
model = get_model()
model.reset()
ref = SESS_REF.get() # block until parent supplies a reference
model.fuse_reference(ref)
model.num_khf = 3 # pin keyframes off (AoTI bank set is fixed)
patch_unet_aoti(model)
SESS_OUT.put(("ready", os.getpid()))
buf = []
while True:
item = SESS_IN.get()
if item is None: # stop sentinel
break
buf.append(item)
if len(buf) >= CHUNK:
chunk, buf = buf[:CHUNK], buf[CHUNK:]
batch = torch.cat([_frame_to_input(f) for f in chunk], dim=0)
video = model.process_input(batch)
for img in video:
SESS_OUT.put(("frame", (img * 255.0).astype(np.uint8)))
except Exception:
SESS_OUT.put(("error", traceback.format_exc()[-2000:]))
finally:
SESS_OUT.put(("done", None))
@spaces.GPU(duration=120)
def gpu_fuse():
"""Cross-call state experiment, step 1: init model + fuse a reference inside one
@spaces.GPU call, then RETURN (releasing the GPU). Records the worker pid and the
model object id so step 2 can tell whether the worker (and its CUDA state) persists."""
global _model
model = get_model()
model.reset()
model.fuse_reference(Image.open("demo/ref_image.png").convert("RGB"))
model.num_khf = 3
report = (f"pid={os.getpid()} id(_model)={id(_model)} "
f"ref_latents={hasattr(model, 'ref_image_latents')} "
f"motion_bank={'set' if getattr(model, 'motion_bank', None) is not None else 'None'}")
print("[gpu_fuse]", report, flush=True)
return report
@spaces.GPU(duration=120)
def gpu_step():
"""Cross-call state experiment, step 2: in a SEPARATE @spaces.GPU call, try to run
process_input WITHOUT re-fusing. If the worker persisted, _model is already fused
and this produces a valid frame -> per-chunk @spaces.GPU architecture is viable
(state survives across calls). If the worker is fresh, _model is None / unfused and
this reveals that we need a held session instead."""
global _model
was_none = _model is None
log = [f"pid={os.getpid()} _model_was_none={was_none}"]
try:
model = get_model()
log.append(f"id(_model)={id(_model)} ref_latents={hasattr(model, 'ref_image_latents')}")
frames = _read_frames(CHUNK)
batch = torch.cat([_frame_to_input(frames[i]) for i in range(CHUNK)], dim=0)
video = model.process_input(batch)
log.append(f"process_ok frames={len(video)} mean={float(np.mean(video)):.4f} "
f"range=[{float(np.min(video)):.3f},{float(np.max(video)):.3f}]")
except Exception as e:
import traceback
log.append("PROCESS FAILED:\n" + traceback.format_exc()[-1500:])
report = "\n".join(log)
print("[gpu_step]\n" + report, flush=True)
return report
# ==================== Real-time WebSocket backend ====================
# Keeps PersonaLive's own Svelte frontend + WS transport. The held GPU session
# (session_worker) runs in a parent-side thread fed by the fork-queues above; the
# visitor's X-IP-Token (minted by the huggingface.co parent via postMessage and
# forwarded over the WS) is injected into LocalContext.request so spaces.schedule
# bills the right HF account instead of falling back to the Space's IP quota.
import asyncio # noqa: E402
import json # noqa: E402
import queue as _queue # noqa: E402
from io import BytesIO # noqa: E402
from fastapi import File, UploadFile, WebSocket, WebSocketDisconnect # noqa: E402
from fastapi.responses import FileResponse, HTMLResponse, JSONResponse # noqa: E402
from fastapi.staticfiles import StaticFiles # noqa: E402
from gradio.context import LocalContext # noqa: E402
class _IPHeaders(dict):
"""spaces._get_headers requires request.headers to have BOTH __dict__ (a plain
dict lacks it) AND .get -> a dict SUBCLASS satisfies both. See
reference_zerogpu_xiptoken_handshake."""
def _ip_request(token):
if not token:
return None
req = SimpleNamespace()
req.headers = _IPHeaders()
req.headers["x-ip-token"] = token
return req
def _drain(q):
try:
while True:
q.get_nowait()
except Exception:
pass
_PENDING_REF = None # set by /api/upload_reference_image, fused at WS open
_session_busy = _threading.Lock()
def _run_session(token):
"""Background thread: set the GPU-quota context for THIS thread (contextvars don't
propagate across threads), then enter the held @spaces.GPU session_worker."""
req = _ip_request(token)
if req is not None:
LocalContext.request.set(req)
session_worker()
# gr.Server (gradio 6) inherits from FastAPI, so .get/.post/.websocket/.mount work
# directly while keeping Gradio's queue + ZeroGPU scheduling. demo=app below; the HF
# runtime launches it (no separate uvicorn -> avoids the 7860 double-bind).
app = gr.Server(title="PersonaLive")
@app.get("/api/settings")
async def settings():
# Minimal shape the Svelte frontend reads (image input mode, no extra params).
return JSONResponse({
"info": {"properties": {"input_mode": {"default": "image"}}},
"input_params": {"properties": {}},
"max_queue_size": 1,
"page_content": "",
})
@app.get("/api/queue")
async def queue_size():
return JSONResponse({"queue_size": 1 if _session_busy.locked() else 0})
@app.post("/api/upload_reference_image")
async def upload_reference_image(ref_image: UploadFile = File(...)):
global _PENDING_REF
data = await ref_image.read()
_PENDING_REF = Image.open(BytesIO(data)).convert("RGB")
return JSONResponse({"status": "ok"})
@app.post("/api/reset")
async def reset():
global _PENDING_REF
_PENDING_REF = None
return JSONResponse({"status": "ok"})
@app.websocket("/api/ws/{user_id}")
async def ws_endpoint(websocket: WebSocket, user_id: str):
await websocket.accept()
if not _session_busy.acquire(blocking=False):
await websocket.send_text(json.dumps({"status": "timeout"}))
await websocket.close()
return
loop = asyncio.get_event_loop()
send_task = None
try:
# 1. first text frame carries the visitor's X-IP-Token (for GPU quota).
token = None
try:
first = await asyncio.wait_for(websocket.receive_text(), timeout=15)
token = json.loads(first).get("x_ip_token")
except Exception:
pass
ref = _PENDING_REF
if ref is None:
await websocket.send_text(json.dumps(
{"status": "error", "message": "Upload a reference portrait first."}))
return
# 2. start the held GPU session and wait for it to fuse + be ready.
_drain(SESS_IN); _drain(SESS_OUT); _drain(SESS_REF)
SESS_REF.put(ref)
_threading.Thread(target=_run_session, args=(token,), daemon=True).start()
# Cold start loads the model in-worker (~21s). Send heartbeats while we
# wait so the client's WS keepalive ping doesn't time out on the silence.
ready = None
while ready is None:
try:
ready = await loop.run_in_executor(
None, lambda: SESS_OUT.get(timeout=2))
except _queue.Empty:
await websocket.send_text(json.dumps({"status": "loading"}))
if ready[0] != "ready":
await websocket.send_text(json.dumps(
{"status": "error", "message": f"session start failed: {ready}"}))
return
await websocket.send_text(json.dumps({"status": "connected"}))
await websocket.send_text(json.dumps({"status": "send_frame"}))
# 3. worker -> client: stream generated frames as JPEG.
async def sender():
while True:
kind, payload = await loop.run_in_executor(None, SESS_OUT.get)
if kind == "frame":
bgr = cv2.cvtColor(payload, cv2.COLOR_RGB2BGR)
ok, jpg = cv2.imencode(".jpg", bgr)
if ok:
await websocket.send_bytes(jpg.tobytes())
elif kind == "error":
await websocket.send_text(json.dumps(
{"status": "error", "message": str(payload)[:500]}))
break
elif kind == "done":
break
send_task = asyncio.create_task(sender())
# 4. client -> worker: decode driving JPEGs, prefer-latest (drop backlog so the
# ~20fps client doesn't outrun the ~6.8fps model and pile up latency).
while True:
msg = await websocket.receive()
if msg.get("type") == "websocket.disconnect":
break
if msg.get("bytes"):
arr = np.frombuffer(msg["bytes"], np.uint8)
bgr = cv2.imdecode(arr, cv2.IMREAD_COLOR)
if bgr is not None:
if SESS_IN.qsize() <= CHUNK * 2:
SESS_IN.put(cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB))
elif msg.get("text"):
try:
if json.loads(msg["text"]).get("status") == "resume":
await websocket.send_text(json.dumps({"status": "send_frame"}))
except Exception:
pass
except WebSocketDisconnect:
pass
except Exception as e:
print(f"[ws] error: {e}", flush=True)
finally:
SESS_IN.put(None) # stop sentinel -> worker exits -> releases GPU slot
if send_task is not None:
send_task.cancel()
_session_busy.release()
@app.get("/api/health")
async def health():
# Quick liveness probe (no GPU): confirms the server + routes are up.
return JSONResponse({"status": "ok", "session_busy": _session_busy.locked()})
@app.get("/api/selftest")
async def selftest_route(num_frames: int = 16):
# Debug-only: run the eager self-test (no X-IP-Token -> creator IP quota) so the
# model path can be validated via curl without the frontend.
loop = asyncio.get_event_loop()
_, report = await loop.run_in_executor(None, selftest, num_frames)
return JSONResponse({"report": report})
# Serve PersonaLive's Svelte build. gr.Server overrides a mount("/") with its own
# page route, so the root is an explicit handler (per the gr.Server blueprint) and
# the build's asset dirs are mounted at their referenced sub-paths.
_FRONTEND_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "frontend_public")
if os.path.isdir(_FRONTEND_DIR):
for _sub in ("_app", "presets"):
_p = os.path.join(_FRONTEND_DIR, _sub)
if os.path.isdir(_p):
app.mount(f"/{_sub}", StaticFiles(directory=_p), name=f"frontend_{_sub}")
@app.get("/favicon.png")
async def _favicon():
return FileResponse(os.path.join(_FRONTEND_DIR, "favicon.png"))
@app.get("/", response_class=HTMLResponse)
async def _homepage():
with open(os.path.join(_FRONTEND_DIR, "index.html"), encoding="utf-8") as f:
return f.read()
demo = app # HF runtime launches `demo`
if __name__ == "__main__":
# Don't force ssr_mode: with SSR off, gr.Server registers its own GET "/"
# page route, which collides with the explicit homepage route above and
# aborts gradio's route setup (startup-events 404). Let gradio auto-detect.
app.launch(server_name="0.0.0.0", server_port=7860, show_error=True)
# deploy: explicit / route, no mount("/") shadowing /gradio_api