File size: 7,700 Bytes
247642a
 
 
5ed1762
 
 
 
 
 
b3cf03c
5ed1762
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b3cf03c
ecacd2c
5ed1762
 
 
 
 
 
 
 
 
 
 
 
23238b5
5ed1762
 
 
 
 
 
 
 
23238b5
5ed1762
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23238b5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5ed1762
 
c62aef1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5ed1762
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a9253e5
 
 
 
 
 
 
 
 
5ed1762
 
a9253e5
5ed1762
 
 
 
a9253e5
 
 
 
5ed1762
 
 
a9253e5
 
 
 
 
 
 
 
 
 
ecacd2c
a9253e5
 
5ed1762
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
# Copyright (C) 2026 Hengzhe Zhao. All rights reserved.
# Licensed under dual license: AGPL-3.0 (open-source) or commercial. See LICENSE.

"""Concurrent user queue for Prefero on shared hosting (e.g. HF Spaces).

Uses a module-level dict (shared across all Streamlit sessions within the
same process) to track active users.  Thread-safe via a lock.

Toggle with PREFERO_QUEUE_ENABLED env var ("true" to enable).
Max concurrent users controlled by PREFERO_MAX_CONCURRENT (default 2).
"""

from __future__ import annotations

import os
import threading
import time
import uuid

import streamlit as st

# ---------------------------------------------------------------------------
# Configuration
# ---------------------------------------------------------------------------

_MAX_CONCURRENT = int(os.environ.get("PREFERO_MAX_CONCURRENT", "2"))
_SESSION_TIMEOUT = 1800  # 30 minutes of inactivity โ†’ evicted


def _queue_enabled() -> bool:
    return os.environ.get("PREFERO_QUEUE_ENABLED", "").lower() == "true"


# ---------------------------------------------------------------------------
# Shared state (module-level, shared across all sessions in one process)
# ---------------------------------------------------------------------------

_lock = threading.Lock()
_active_sessions: dict[str, float] = {}  # session_id โ†’ last_heartbeat
_session_usernames: dict[str, str] = {}  # session_id โ†’ username


def _cleanup_stale() -> None:
    """Remove sessions that haven't sent a heartbeat recently."""
    now = time.time()
    stale = [sid for sid, ts in _active_sessions.items() if now - ts > _SESSION_TIMEOUT]
    for sid in stale:
        del _active_sessions[sid]
        _session_usernames.pop(sid, None)


def _ensure_session_id() -> str:
    """Get or create a unique session identifier."""
    if "_queue_session_id" not in st.session_state:
        st.session_state["_queue_session_id"] = str(uuid.uuid4())
    return st.session_state["_queue_session_id"]


# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------


def try_enter() -> bool:
    """Try to claim a slot.  Returns True if the user is admitted."""
    sid = _ensure_session_id()
    with _lock:
        _cleanup_stale()
        if sid in _active_sessions:
            _active_sessions[sid] = time.time()
            return True
        if len(_active_sessions) < _MAX_CONCURRENT:
            _active_sessions[sid] = time.time()
            return True
        return False


def heartbeat() -> None:
    """Refresh the current session's timestamp (call on every page load)."""
    sid = _ensure_session_id()
    with _lock:
        if sid in _active_sessions:
            _active_sessions[sid] = time.time()


def leave() -> None:
    """Release the current session's slot."""
    sid = _ensure_session_id()
    with _lock:
        _active_sessions.pop(sid, None)
        _session_usernames.pop(sid, None)


def register_username(username: str) -> None:
    """Associate the current session with a username."""
    sid = _ensure_session_id()
    with _lock:
        _session_usernames[sid] = username


def is_username_active(username: str) -> bool:
    """Check if a username is logged in on another active session."""
    sid = _ensure_session_id()
    with _lock:
        _cleanup_stale()
        for other_sid, uname in _session_usernames.items():
            if uname == username and other_sid != sid:
                return True
        return False


def force_evict_username(username: str) -> None:
    """Evict all other sessions using this username so the caller can log in."""
    sid = _ensure_session_id()
    with _lock:
        _cleanup_stale()
        to_remove = [
            other_sid
            for other_sid, uname in _session_usernames.items()
            if uname == username and other_sid != sid
        ]
        for other_sid in to_remove:
            _active_sessions.pop(other_sid, None)
            _session_usernames.pop(other_sid, None)


def active_count() -> int:
    """How many sessions are currently active."""
    with _lock:
        _cleanup_stale()
        return len(_active_sessions)


def spots_available() -> int:
    """How many open slots remain."""
    with _lock:
        _cleanup_stale()
        return max(0, _MAX_CONCURRENT - len(_active_sessions))


def is_session_active() -> bool:
    """Check whether the current session still holds a slot.

    Returns False if the session was evicted due to inactivity (stale
    heartbeat).  Callers should clear auth state and redirect to login.
    """
    sid = _ensure_session_id()
    with _lock:
        _cleanup_stale()
        return sid in _active_sessions


# ---------------------------------------------------------------------------
# Waiting-room gate (Streamlit UI)
# ---------------------------------------------------------------------------

_SLOWBRO_IMG = (
    "https://raw.githubusercontent.com/PokeAPI/sprites/master/sprites"
    "/pokemon/other/official-artwork/80.png"
)


def queue_gate() -> bool:
    """Show waiting room if the server is full.  Returns True if admitted.

    When queue is disabled, always returns True.
    """
    if not _queue_enabled():
        return True

    # Always send a heartbeat so active sessions stay fresh
    heartbeat()

    if try_enter():
        return True

    # โ”€โ”€ Waiting room UI โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
    from waiting_facts import WAITING_FACTS
    from utils import language_banner
    import random

    # Scrolling multilingual banner
    language_banner()

    _SLOWPOKE_IMG = (
        "https://raw.githubusercontent.com/PokeAPI/sprites/master/sprites"
        "/pokemon/other/official-artwork/79.png"
    )

    n_active = active_count()
    spots_left = spots_available()

    # โ”€โ”€ Slowpoke waiting illustration โ”€โ”€
    st.markdown(
        "<div style='text-align:center; margin-top:20px;'>"
        f"<img src='{_SLOWPOKE_IMG}' width='120' />"
        "</div>",
        unsafe_allow_html=True,
    )
    st.markdown(
        "<h3 style='text-align:center;'>Slowbro is busy crunching numbers...</h3>"
        "<p style='text-align:center; color:gray;'>"
        "All seats are taken! But don't worry โ€” Slowpoke is keeping "
        "your spot warm. You'll get in as soon as someone finishes.</p>",
        unsafe_allow_html=True,
    )

    # โ”€โ”€ Queue status โ”€โ”€
    q1, q2 = st.columns(2)
    with q1:
        st.metric("Active users", f"{n_active} / {_MAX_CONCURRENT}")
    with q2:
        st.metric("Seats available", str(spots_left))

    # โ”€โ”€ Session policy note โ”€โ”€
    st.warning(
        "**How the queue works:** Each user gets a seat for as long as "
        "they're active. Sessions expire after **30 minutes** of inactivity "
        "to keep things moving โ€” but if you're running a model, your seat "
        "is safe until estimation completes."
    )

    # โ”€โ”€ Rolling cultural facts โ”€โ”€
    st.markdown("---")
    st.markdown(
        "<p style='text-align:center; font-weight:600; margin-bottom:4px;'>"
        "While you wait โ€” queuing around the world</p>",
        unsafe_allow_html=True,
    )

    rng = random.Random(int(time.time()) // 8)  # change every 8 seconds
    indices = list(range(len(WAITING_FACTS)))
    rng.shuffle(indices)
    display_facts = [WAITING_FACTS[i] for i in indices[:3]]

    for fact in display_facts:
        st.info(fact)

    # auto-refresh every 5 seconds
    time.sleep(5)
    st.rerun()

    return False  # unreachable, but for type-checker