File size: 7,700 Bytes
247642a 5ed1762 b3cf03c 5ed1762 b3cf03c ecacd2c 5ed1762 23238b5 5ed1762 23238b5 5ed1762 23238b5 5ed1762 c62aef1 5ed1762 a9253e5 5ed1762 a9253e5 5ed1762 a9253e5 5ed1762 a9253e5 ecacd2c a9253e5 5ed1762 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 | # Copyright (C) 2026 Hengzhe Zhao. All rights reserved.
# Licensed under dual license: AGPL-3.0 (open-source) or commercial. See LICENSE.
"""Concurrent user queue for Prefero on shared hosting (e.g. HF Spaces).
Uses a module-level dict (shared across all Streamlit sessions within the
same process) to track active users. Thread-safe via a lock.
Toggle with PREFERO_QUEUE_ENABLED env var ("true" to enable).
Max concurrent users controlled by PREFERO_MAX_CONCURRENT (default 2).
"""
from __future__ import annotations
import os
import threading
import time
import uuid
import streamlit as st
# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------
_MAX_CONCURRENT = int(os.environ.get("PREFERO_MAX_CONCURRENT", "2"))
_SESSION_TIMEOUT = 1800 # 30 minutes of inactivity โ evicted
def _queue_enabled() -> bool:
return os.environ.get("PREFERO_QUEUE_ENABLED", "").lower() == "true"
# ---------------------------------------------------------------------------
# Shared state (module-level, shared across all sessions in one process)
# ---------------------------------------------------------------------------
_lock = threading.Lock()
_active_sessions: dict[str, float] = {} # session_id โ last_heartbeat
_session_usernames: dict[str, str] = {} # session_id โ username
def _cleanup_stale() -> None:
"""Remove sessions that haven't sent a heartbeat recently."""
now = time.time()
stale = [sid for sid, ts in _active_sessions.items() if now - ts > _SESSION_TIMEOUT]
for sid in stale:
del _active_sessions[sid]
_session_usernames.pop(sid, None)
def _ensure_session_id() -> str:
"""Get or create a unique session identifier."""
if "_queue_session_id" not in st.session_state:
st.session_state["_queue_session_id"] = str(uuid.uuid4())
return st.session_state["_queue_session_id"]
# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
def try_enter() -> bool:
"""Try to claim a slot. Returns True if the user is admitted."""
sid = _ensure_session_id()
with _lock:
_cleanup_stale()
if sid in _active_sessions:
_active_sessions[sid] = time.time()
return True
if len(_active_sessions) < _MAX_CONCURRENT:
_active_sessions[sid] = time.time()
return True
return False
def heartbeat() -> None:
"""Refresh the current session's timestamp (call on every page load)."""
sid = _ensure_session_id()
with _lock:
if sid in _active_sessions:
_active_sessions[sid] = time.time()
def leave() -> None:
"""Release the current session's slot."""
sid = _ensure_session_id()
with _lock:
_active_sessions.pop(sid, None)
_session_usernames.pop(sid, None)
def register_username(username: str) -> None:
"""Associate the current session with a username."""
sid = _ensure_session_id()
with _lock:
_session_usernames[sid] = username
def is_username_active(username: str) -> bool:
"""Check if a username is logged in on another active session."""
sid = _ensure_session_id()
with _lock:
_cleanup_stale()
for other_sid, uname in _session_usernames.items():
if uname == username and other_sid != sid:
return True
return False
def force_evict_username(username: str) -> None:
"""Evict all other sessions using this username so the caller can log in."""
sid = _ensure_session_id()
with _lock:
_cleanup_stale()
to_remove = [
other_sid
for other_sid, uname in _session_usernames.items()
if uname == username and other_sid != sid
]
for other_sid in to_remove:
_active_sessions.pop(other_sid, None)
_session_usernames.pop(other_sid, None)
def active_count() -> int:
"""How many sessions are currently active."""
with _lock:
_cleanup_stale()
return len(_active_sessions)
def spots_available() -> int:
"""How many open slots remain."""
with _lock:
_cleanup_stale()
return max(0, _MAX_CONCURRENT - len(_active_sessions))
def is_session_active() -> bool:
"""Check whether the current session still holds a slot.
Returns False if the session was evicted due to inactivity (stale
heartbeat). Callers should clear auth state and redirect to login.
"""
sid = _ensure_session_id()
with _lock:
_cleanup_stale()
return sid in _active_sessions
# ---------------------------------------------------------------------------
# Waiting-room gate (Streamlit UI)
# ---------------------------------------------------------------------------
_SLOWBRO_IMG = (
"https://raw.githubusercontent.com/PokeAPI/sprites/master/sprites"
"/pokemon/other/official-artwork/80.png"
)
def queue_gate() -> bool:
"""Show waiting room if the server is full. Returns True if admitted.
When queue is disabled, always returns True.
"""
if not _queue_enabled():
return True
# Always send a heartbeat so active sessions stay fresh
heartbeat()
if try_enter():
return True
# โโ Waiting room UI โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ
from waiting_facts import WAITING_FACTS
from utils import language_banner
import random
# Scrolling multilingual banner
language_banner()
_SLOWPOKE_IMG = (
"https://raw.githubusercontent.com/PokeAPI/sprites/master/sprites"
"/pokemon/other/official-artwork/79.png"
)
n_active = active_count()
spots_left = spots_available()
# โโ Slowpoke waiting illustration โโ
st.markdown(
"<div style='text-align:center; margin-top:20px;'>"
f"<img src='{_SLOWPOKE_IMG}' width='120' />"
"</div>",
unsafe_allow_html=True,
)
st.markdown(
"<h3 style='text-align:center;'>Slowbro is busy crunching numbers...</h3>"
"<p style='text-align:center; color:gray;'>"
"All seats are taken! But don't worry โ Slowpoke is keeping "
"your spot warm. You'll get in as soon as someone finishes.</p>",
unsafe_allow_html=True,
)
# โโ Queue status โโ
q1, q2 = st.columns(2)
with q1:
st.metric("Active users", f"{n_active} / {_MAX_CONCURRENT}")
with q2:
st.metric("Seats available", str(spots_left))
# โโ Session policy note โโ
st.warning(
"**How the queue works:** Each user gets a seat for as long as "
"they're active. Sessions expire after **30 minutes** of inactivity "
"to keep things moving โ but if you're running a model, your seat "
"is safe until estimation completes."
)
# โโ Rolling cultural facts โโ
st.markdown("---")
st.markdown(
"<p style='text-align:center; font-weight:600; margin-bottom:4px;'>"
"While you wait โ queuing around the world</p>",
unsafe_allow_html=True,
)
rng = random.Random(int(time.time()) // 8) # change every 8 seconds
indices = list(range(len(WAITING_FACTS)))
rng.shuffle(indices)
display_facts = [WAITING_FACTS[i] for i in indices[:3]]
for fact in display_facts:
st.info(fact)
# auto-refresh every 5 seconds
time.sleep(5)
st.rerun()
return False # unreachable, but for type-checker
|