Add per-user model saving, forum persistence, sidebar/header sticky, session timeout 30min
Browse files- model_store.py: serialize/deserialize EstimationResult & LatentClassResult to HF Dataset (per-user, max 10)
- Model page: Save to Profile button + Saved Models section with Load/Delete
- utils.py: auto-load saved models on login, show in sidebar with delete, sticky sidebar & header CSS
- community_db.py: persist forum posts/replies to HF Dataset (no longer ephemeral)
- session_queue.py: extend session timeout from 2min to 30min
- utils.py: language banner tooltip explaining multilingual translations
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
- app/community_db.py +128 -3
- app/model_store.py +483 -0
- app/pages/2_⚙️_Model.py +65 -0
- app/session_queue.py +2 -2
- app/utils.py +113 -3
app/community_db.py
CHANGED
|
@@ -3,9 +3,9 @@
|
|
| 3 |
Stores users (username, email, join date) and posts (author, title, body,
|
| 4 |
replies, timestamps). Designed for single-instance deployment (HF Spaces).
|
| 5 |
|
| 6 |
-
User accounts
|
| 7 |
-
(Wil2200/prefero-data) so they survive container
|
| 8 |
-
|
| 9 |
"""
|
| 10 |
|
| 11 |
from __future__ import annotations
|
|
@@ -302,6 +302,126 @@ def _sync_users_from_hf() -> None:
|
|
| 302 |
conn.commit()
|
| 303 |
logger.info("Loaded %d users from HF dataset into SQLite", len(users))
|
| 304 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 305 |
# ---------------------------------------------------------------------------
|
| 306 |
# Database path — persistent on HF Spaces at /data or fallback to app dir
|
| 307 |
# ---------------------------------------------------------------------------
|
|
@@ -386,6 +506,7 @@ def init_db() -> None:
|
|
| 386 |
# Load persisted data from HF on first startup
|
| 387 |
_sync_users_from_hf()
|
| 388 |
_sync_activity_from_hf()
|
|
|
|
| 389 |
|
| 390 |
|
| 391 |
# ---------------------------------------------------------------------------
|
|
@@ -571,6 +692,8 @@ def create_post(author_id: int, title: str, body: str) -> Post:
|
|
| 571 |
)
|
| 572 |
conn.commit()
|
| 573 |
user = get_user_by_id(author_id)
|
|
|
|
|
|
|
| 574 |
return Post(
|
| 575 |
id=cur.lastrowid,
|
| 576 |
author_id=author_id,
|
|
@@ -665,6 +788,8 @@ def create_reply(post_id: int, author_id: int, body: str) -> Reply:
|
|
| 665 |
)
|
| 666 |
conn.commit()
|
| 667 |
user = get_user_by_id(author_id)
|
|
|
|
|
|
|
| 668 |
return Reply(
|
| 669 |
id=cur.lastrowid,
|
| 670 |
post_id=post_id,
|
|
|
|
| 3 |
Stores users (username, email, join date) and posts (author, title, body,
|
| 4 |
replies, timestamps). Designed for single-instance deployment (HF Spaces).
|
| 5 |
|
| 6 |
+
User accounts, activity logs, and forum posts/replies are persisted to a
|
| 7 |
+
private HF Dataset repo (Wil2200/prefero-data) so they survive container
|
| 8 |
+
restarts.
|
| 9 |
"""
|
| 10 |
|
| 11 |
from __future__ import annotations
|
|
|
|
| 302 |
conn.commit()
|
| 303 |
logger.info("Loaded %d users from HF dataset into SQLite", len(users))
|
| 304 |
|
| 305 |
+
|
| 306 |
+
# ---------------------------------------------------------------------------
|
| 307 |
+
# HF Dataset persistence for forum posts/replies
|
| 308 |
+
# ---------------------------------------------------------------------------
|
| 309 |
+
|
| 310 |
+
_posts_synced = False # only sync once per process
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
def _save_posts_to_hf() -> None:
|
| 314 |
+
"""Persist all posts and replies to the HF dataset repo."""
|
| 315 |
+
token = _hf_token()
|
| 316 |
+
if not token:
|
| 317 |
+
return
|
| 318 |
+
try:
|
| 319 |
+
from huggingface_hub import HfApi
|
| 320 |
+
import tempfile
|
| 321 |
+
|
| 322 |
+
conn = _get_conn()
|
| 323 |
+
post_rows = conn.execute(
|
| 324 |
+
"SELECT id, author_id, title, body, created_at, updated_at "
|
| 325 |
+
"FROM posts ORDER BY id"
|
| 326 |
+
).fetchall()
|
| 327 |
+
posts = [dict(r) for r in post_rows]
|
| 328 |
+
|
| 329 |
+
reply_rows = conn.execute(
|
| 330 |
+
"SELECT id, post_id, author_id, body, created_at "
|
| 331 |
+
"FROM replies ORDER BY id"
|
| 332 |
+
).fetchall()
|
| 333 |
+
replies = [dict(r) for r in reply_rows]
|
| 334 |
+
|
| 335 |
+
data = {"posts": posts, "replies": replies}
|
| 336 |
+
tmp = os.path.join(tempfile.gettempdir(), "prefero_forum_posts.json")
|
| 337 |
+
with open(tmp, "w") as f:
|
| 338 |
+
json.dump(data, f, indent=2)
|
| 339 |
+
|
| 340 |
+
api = HfApi(token=token)
|
| 341 |
+
api.upload_file(
|
| 342 |
+
path_or_fileobj=tmp, path_in_repo="forum_posts.json",
|
| 343 |
+
repo_id=_HF_DATASET_REPO, repo_type="dataset",
|
| 344 |
+
)
|
| 345 |
+
logger.info("Synced %d posts and %d replies to HF dataset", len(posts), len(replies))
|
| 346 |
+
except Exception as exc:
|
| 347 |
+
logger.warning("Failed to save posts to HF: %s", exc)
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
def _load_posts_from_hf() -> dict:
|
| 351 |
+
"""Load posts and replies from HF dataset repo. Returns {} on failure."""
|
| 352 |
+
token = _hf_token()
|
| 353 |
+
if not token:
|
| 354 |
+
logger.debug("No HF token — skipping posts sync")
|
| 355 |
+
return {}
|
| 356 |
+
try:
|
| 357 |
+
from huggingface_hub import hf_hub_download
|
| 358 |
+
path = hf_hub_download(
|
| 359 |
+
repo_id=_HF_DATASET_REPO, filename="forum_posts.json",
|
| 360 |
+
repo_type="dataset", token=token,
|
| 361 |
+
)
|
| 362 |
+
with open(path) as f:
|
| 363 |
+
data = json.load(f)
|
| 364 |
+
return data
|
| 365 |
+
except Exception as exc:
|
| 366 |
+
logger.debug("Failed to load posts from HF: %s", exc)
|
| 367 |
+
return {}
|
| 368 |
+
|
| 369 |
+
|
| 370 |
+
def _sync_posts_from_hf() -> None:
|
| 371 |
+
"""Restore forum posts and replies from HF dataset on startup (once)."""
|
| 372 |
+
global _posts_synced
|
| 373 |
+
if _posts_synced:
|
| 374 |
+
return
|
| 375 |
+
_posts_synced = True
|
| 376 |
+
|
| 377 |
+
data = _load_posts_from_hf()
|
| 378 |
+
if not data:
|
| 379 |
+
return
|
| 380 |
+
|
| 381 |
+
conn = _get_conn()
|
| 382 |
+
posts = data.get("posts", [])
|
| 383 |
+
replies = data.get("replies", [])
|
| 384 |
+
|
| 385 |
+
restored_posts = 0
|
| 386 |
+
for p in posts:
|
| 387 |
+
try:
|
| 388 |
+
existing = conn.execute(
|
| 389 |
+
"SELECT 1 FROM posts WHERE id = ?", (p["id"],)
|
| 390 |
+
).fetchone()
|
| 391 |
+
if existing:
|
| 392 |
+
continue
|
| 393 |
+
conn.execute(
|
| 394 |
+
"INSERT INTO posts (id, author_id, title, body, created_at, updated_at) "
|
| 395 |
+
"VALUES (?, ?, ?, ?, ?, ?)",
|
| 396 |
+
(p["id"], p["author_id"], p["title"], p["body"],
|
| 397 |
+
p["created_at"], p["updated_at"]),
|
| 398 |
+
)
|
| 399 |
+
restored_posts += 1
|
| 400 |
+
except Exception:
|
| 401 |
+
pass
|
| 402 |
+
|
| 403 |
+
restored_replies = 0
|
| 404 |
+
for r in replies:
|
| 405 |
+
try:
|
| 406 |
+
existing = conn.execute(
|
| 407 |
+
"SELECT 1 FROM replies WHERE id = ?", (r["id"],)
|
| 408 |
+
).fetchone()
|
| 409 |
+
if existing:
|
| 410 |
+
continue
|
| 411 |
+
conn.execute(
|
| 412 |
+
"INSERT INTO replies (id, post_id, author_id, body, created_at) "
|
| 413 |
+
"VALUES (?, ?, ?, ?, ?)",
|
| 414 |
+
(r["id"], r["post_id"], r["author_id"], r["body"],
|
| 415 |
+
r["created_at"]),
|
| 416 |
+
)
|
| 417 |
+
restored_replies += 1
|
| 418 |
+
except Exception:
|
| 419 |
+
pass
|
| 420 |
+
|
| 421 |
+
conn.commit()
|
| 422 |
+
logger.info("Restored %d posts and %d replies from HF dataset", restored_posts, restored_replies)
|
| 423 |
+
|
| 424 |
+
|
| 425 |
# ---------------------------------------------------------------------------
|
| 426 |
# Database path — persistent on HF Spaces at /data or fallback to app dir
|
| 427 |
# ---------------------------------------------------------------------------
|
|
|
|
| 506 |
# Load persisted data from HF on first startup
|
| 507 |
_sync_users_from_hf()
|
| 508 |
_sync_activity_from_hf()
|
| 509 |
+
_sync_posts_from_hf()
|
| 510 |
|
| 511 |
|
| 512 |
# ---------------------------------------------------------------------------
|
|
|
|
| 692 |
)
|
| 693 |
conn.commit()
|
| 694 |
user = get_user_by_id(author_id)
|
| 695 |
+
# Persist to HF dataset repo (non-blocking)
|
| 696 |
+
threading.Thread(target=_save_posts_to_hf, daemon=True).start()
|
| 697 |
return Post(
|
| 698 |
id=cur.lastrowid,
|
| 699 |
author_id=author_id,
|
|
|
|
| 788 |
)
|
| 789 |
conn.commit()
|
| 790 |
user = get_user_by_id(author_id)
|
| 791 |
+
# Persist to HF dataset repo (non-blocking)
|
| 792 |
+
threading.Thread(target=_save_posts_to_hf, daemon=True).start()
|
| 793 |
return Reply(
|
| 794 |
id=cur.lastrowid,
|
| 795 |
post_id=post_id,
|
app/model_store.py
ADDED
|
@@ -0,0 +1,483 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Per-user model persistence via HF Dataset.
|
| 2 |
+
|
| 3 |
+
Each user's saved models are stored as ``models/{username}.json`` in the
|
| 4 |
+
private HF Dataset repo ``Wil2200/prefero-data``. Models survive container
|
| 5 |
+
restarts **and** redeployments.
|
| 6 |
+
|
| 7 |
+
Public API
|
| 8 |
+
----------
|
| 9 |
+
- save_model(username, model_entry) -> bool
|
| 10 |
+
- load_models(username) -> list[dict]
|
| 11 |
+
- delete_saved_model(username, index) -> bool
|
| 12 |
+
- serialize_model_entry(entry) -> dict
|
| 13 |
+
- deserialize_model_entry(data) -> dict
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
from __future__ import annotations
|
| 17 |
+
|
| 18 |
+
import json
|
| 19 |
+
import logging
|
| 20 |
+
import math
|
| 21 |
+
import os
|
| 22 |
+
import tempfile
|
| 23 |
+
import threading
|
| 24 |
+
from datetime import datetime, timezone
|
| 25 |
+
|
| 26 |
+
import numpy as np
|
| 27 |
+
import pandas as pd
|
| 28 |
+
|
| 29 |
+
logger = logging.getLogger(__name__)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def _sanitize_float(v):
|
| 33 |
+
"""Convert NaN / Inf to None so JSON stays spec-compliant."""
|
| 34 |
+
if isinstance(v, float) and (math.isnan(v) or math.isinf(v)):
|
| 35 |
+
return None
|
| 36 |
+
return v
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def _sanitize_list(lst):
|
| 40 |
+
"""Recursively sanitize a (possibly nested) list of floats."""
|
| 41 |
+
if lst is None:
|
| 42 |
+
return None
|
| 43 |
+
out = []
|
| 44 |
+
for item in lst:
|
| 45 |
+
if isinstance(item, list):
|
| 46 |
+
out.append(_sanitize_list(item))
|
| 47 |
+
elif isinstance(item, float):
|
| 48 |
+
out.append(_sanitize_float(item))
|
| 49 |
+
else:
|
| 50 |
+
out.append(item)
|
| 51 |
+
return out
|
| 52 |
+
|
| 53 |
+
_HF_DATASET_REPO = "Wil2200/prefero-data"
|
| 54 |
+
_MAX_MODELS_PER_USER = 10
|
| 55 |
+
_lock = threading.Lock()
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
# ---------------------------------------------------------------------------
|
| 59 |
+
# HF helpers (mirrors community_db.py pattern)
|
| 60 |
+
# ---------------------------------------------------------------------------
|
| 61 |
+
|
| 62 |
+
def _hf_token() -> str | None:
|
| 63 |
+
return os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN")
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def _load_user_models_from_hf(username: str) -> list[dict]:
|
| 67 |
+
"""Download ``models/{username}.json`` from the HF Dataset repo."""
|
| 68 |
+
token = _hf_token()
|
| 69 |
+
if not token:
|
| 70 |
+
logger.debug("No HF token -- skipping model load for %s", username)
|
| 71 |
+
return []
|
| 72 |
+
try:
|
| 73 |
+
from huggingface_hub import hf_hub_download
|
| 74 |
+
path = hf_hub_download(
|
| 75 |
+
repo_id=_HF_DATASET_REPO,
|
| 76 |
+
filename=f"models/{username}.json",
|
| 77 |
+
repo_type="dataset",
|
| 78 |
+
token=token,
|
| 79 |
+
)
|
| 80 |
+
with open(path) as f:
|
| 81 |
+
data = json.load(f)
|
| 82 |
+
return data.get("models", [])
|
| 83 |
+
except Exception as exc:
|
| 84 |
+
logger.debug("Failed to load models for %s from HF: %s", username, exc)
|
| 85 |
+
return []
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def _save_user_models_to_hf(username: str, models: list[dict]) -> None:
|
| 89 |
+
"""Upload ``models/{username}.json`` to the HF Dataset repo."""
|
| 90 |
+
token = _hf_token()
|
| 91 |
+
if not token:
|
| 92 |
+
return
|
| 93 |
+
try:
|
| 94 |
+
from huggingface_hub import HfApi
|
| 95 |
+
|
| 96 |
+
data = {"models": models}
|
| 97 |
+
tmp = os.path.join(tempfile.gettempdir(), f"prefero_models_{username}.json")
|
| 98 |
+
with open(tmp, "w") as f:
|
| 99 |
+
json.dump(data, f)
|
| 100 |
+
|
| 101 |
+
api = HfApi(token=token)
|
| 102 |
+
api.upload_file(
|
| 103 |
+
path_or_fileobj=tmp,
|
| 104 |
+
path_in_repo=f"models/{username}.json",
|
| 105 |
+
repo_id=_HF_DATASET_REPO,
|
| 106 |
+
repo_type="dataset",
|
| 107 |
+
)
|
| 108 |
+
logger.info("Saved %d models for user %s to HF", len(models), username)
|
| 109 |
+
except Exception as exc:
|
| 110 |
+
logger.warning("Failed to save models for %s to HF: %s", username, exc)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
# ---------------------------------------------------------------------------
|
| 114 |
+
# Serialization helpers
|
| 115 |
+
# ---------------------------------------------------------------------------
|
| 116 |
+
|
| 117 |
+
def _serialize_dataframe(df: pd.DataFrame | None) -> dict | None:
|
| 118 |
+
if df is None:
|
| 119 |
+
return None
|
| 120 |
+
data = {}
|
| 121 |
+
for col in df.columns:
|
| 122 |
+
vals = df[col].tolist()
|
| 123 |
+
data[col] = [_sanitize_float(v) if isinstance(v, float) else v for v in vals]
|
| 124 |
+
return {
|
| 125 |
+
"columns": list(df.columns),
|
| 126 |
+
"data": data,
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def _deserialize_dataframe(d: dict | None) -> pd.DataFrame | None:
|
| 131 |
+
if d is None:
|
| 132 |
+
return None
|
| 133 |
+
return pd.DataFrame(d["data"], columns=d["columns"])
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def _serialize_ndarray(arr: np.ndarray | None) -> list | None:
|
| 137 |
+
if arr is None:
|
| 138 |
+
return None
|
| 139 |
+
return _sanitize_list(arr.tolist())
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def _deserialize_ndarray(lst: list | None) -> np.ndarray | None:
|
| 143 |
+
if lst is None:
|
| 144 |
+
return None
|
| 145 |
+
return np.array(lst)
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
def _serialize_variable_spec(vs) -> dict:
|
| 149 |
+
return {"name": vs.name, "column": vs.column, "distribution": vs.distribution}
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def _serialize_dummy_coding(dc) -> dict:
|
| 153 |
+
ref = dc.ref_level
|
| 154 |
+
# ref_level can be int/str/float -- store as-is (JSON-safe for primitives)
|
| 155 |
+
return {"column": dc.column, "ref_level": ref}
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def _serialize_interaction_term(it) -> dict:
|
| 159 |
+
return {"columns": list(it.columns)}
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def _serialize_model_spec(spec) -> dict | None:
|
| 163 |
+
if spec is None:
|
| 164 |
+
return None
|
| 165 |
+
from dce_analyzer.config import ModelSpec
|
| 166 |
+
return {
|
| 167 |
+
"id_col": spec.id_col,
|
| 168 |
+
"task_col": spec.task_col,
|
| 169 |
+
"alt_col": spec.alt_col,
|
| 170 |
+
"choice_col": spec.choice_col,
|
| 171 |
+
"variables": [_serialize_variable_spec(v) for v in spec.variables],
|
| 172 |
+
"n_draws": spec.n_draws,
|
| 173 |
+
"n_classes": getattr(spec, "n_classes", 2),
|
| 174 |
+
"membership_cols": getattr(spec, "membership_cols", None),
|
| 175 |
+
}
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def _serialize_full_model_spec(spec) -> dict | None:
|
| 179 |
+
if spec is None:
|
| 180 |
+
return None
|
| 181 |
+
return {
|
| 182 |
+
"id_col": spec.id_col,
|
| 183 |
+
"task_col": spec.task_col,
|
| 184 |
+
"alt_col": spec.alt_col,
|
| 185 |
+
"choice_col": spec.choice_col,
|
| 186 |
+
"variables": [_serialize_variable_spec(v) for v in spec.variables],
|
| 187 |
+
"model_type": spec.model_type,
|
| 188 |
+
"dummy_codings": [_serialize_dummy_coding(dc) for dc in spec.dummy_codings],
|
| 189 |
+
"interactions": [_serialize_interaction_term(it) for it in spec.interactions],
|
| 190 |
+
"correlated": spec.correlated,
|
| 191 |
+
"correlation_groups": spec.correlation_groups,
|
| 192 |
+
"bws_worst_col": spec.bws_worst_col,
|
| 193 |
+
"estimate_lambda_w": spec.estimate_lambda_w,
|
| 194 |
+
"gmnl_variant": spec.gmnl_variant,
|
| 195 |
+
"n_classes": spec.n_classes,
|
| 196 |
+
"membership_cols": spec.membership_cols,
|
| 197 |
+
"lc_method": spec.lc_method,
|
| 198 |
+
"n_draws": spec.n_draws,
|
| 199 |
+
"maxiter": spec.maxiter,
|
| 200 |
+
"seed": spec.seed,
|
| 201 |
+
"n_starts": spec.n_starts,
|
| 202 |
+
"custom_start": spec.custom_start,
|
| 203 |
+
}
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
def _deserialize_variable_spec(d: dict):
|
| 207 |
+
from dce_analyzer.config import VariableSpec
|
| 208 |
+
return VariableSpec(name=d["name"], column=d["column"], distribution=d.get("distribution", "fixed"))
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
def _deserialize_model_spec(d: dict | None):
|
| 212 |
+
if d is None:
|
| 213 |
+
return None
|
| 214 |
+
from dce_analyzer.config import ModelSpec
|
| 215 |
+
return ModelSpec(
|
| 216 |
+
id_col=d["id_col"],
|
| 217 |
+
task_col=d["task_col"],
|
| 218 |
+
alt_col=d["alt_col"],
|
| 219 |
+
choice_col=d["choice_col"],
|
| 220 |
+
variables=[_deserialize_variable_spec(v) for v in d["variables"]],
|
| 221 |
+
n_draws=d.get("n_draws", 200),
|
| 222 |
+
n_classes=d.get("n_classes", 2),
|
| 223 |
+
membership_cols=d.get("membership_cols"),
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
def _deserialize_full_model_spec(d: dict | None):
|
| 228 |
+
if d is None:
|
| 229 |
+
return None
|
| 230 |
+
from dce_analyzer.config import FullModelSpec, DummyCoding, InteractionTerm
|
| 231 |
+
return FullModelSpec(
|
| 232 |
+
id_col=d["id_col"],
|
| 233 |
+
task_col=d["task_col"],
|
| 234 |
+
alt_col=d["alt_col"],
|
| 235 |
+
choice_col=d["choice_col"],
|
| 236 |
+
variables=[_deserialize_variable_spec(v) for v in d["variables"]],
|
| 237 |
+
model_type=d.get("model_type", "mixed"),
|
| 238 |
+
dummy_codings=[DummyCoding(column=dc["column"], ref_level=dc["ref_level"]) for dc in d.get("dummy_codings", [])],
|
| 239 |
+
interactions=[InteractionTerm(columns=tuple(it["columns"])) for it in d.get("interactions", [])],
|
| 240 |
+
correlated=d.get("correlated", False),
|
| 241 |
+
correlation_groups=d.get("correlation_groups"),
|
| 242 |
+
bws_worst_col=d.get("bws_worst_col"),
|
| 243 |
+
estimate_lambda_w=d.get("estimate_lambda_w", True),
|
| 244 |
+
gmnl_variant=d.get("gmnl_variant", "general"),
|
| 245 |
+
n_classes=d.get("n_classes", 2),
|
| 246 |
+
membership_cols=d.get("membership_cols"),
|
| 247 |
+
lc_method=d.get("lc_method", "em"),
|
| 248 |
+
n_draws=d.get("n_draws", 200),
|
| 249 |
+
maxiter=d.get("maxiter", 300),
|
| 250 |
+
seed=d.get("seed", 123),
|
| 251 |
+
n_starts=d.get("n_starts", 10),
|
| 252 |
+
custom_start=d.get("custom_start"),
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
def _serialize_estimation_result(est) -> dict:
|
| 257 |
+
"""Serialize an EstimationResult to a JSON-safe dict."""
|
| 258 |
+
return {
|
| 259 |
+
"_type": "EstimationResult",
|
| 260 |
+
"success": est.success,
|
| 261 |
+
"message": est.message,
|
| 262 |
+
"log_likelihood": _sanitize_float(est.log_likelihood),
|
| 263 |
+
"aic": _sanitize_float(est.aic),
|
| 264 |
+
"bic": _sanitize_float(est.bic),
|
| 265 |
+
"n_parameters": est.n_parameters,
|
| 266 |
+
"n_observations": est.n_observations,
|
| 267 |
+
"n_individuals": est.n_individuals,
|
| 268 |
+
"optimizer_iterations": est.optimizer_iterations,
|
| 269 |
+
"runtime_seconds": _sanitize_float(est.runtime_seconds),
|
| 270 |
+
"estimates": _serialize_dataframe(est.estimates),
|
| 271 |
+
"vcov_matrix": _serialize_ndarray(est.vcov_matrix),
|
| 272 |
+
"covariance_matrix": _serialize_ndarray(est.covariance_matrix),
|
| 273 |
+
"correlation_matrix": _serialize_ndarray(est.correlation_matrix),
|
| 274 |
+
"random_param_names": est.random_param_names,
|
| 275 |
+
"covariance_se": _serialize_ndarray(est.covariance_se),
|
| 276 |
+
"correlation_se": _serialize_ndarray(est.correlation_se),
|
| 277 |
+
"correlation_test": _serialize_dataframe(est.correlation_test),
|
| 278 |
+
"raw_theta": _serialize_ndarray(est.raw_theta),
|
| 279 |
+
}
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
def _serialize_latent_class_result(est) -> dict:
|
| 283 |
+
"""Serialize a LatentClassResult to a JSON-safe dict.
|
| 284 |
+
|
| 285 |
+
``posterior_probs`` is skipped (too large) and restored as an empty
|
| 286 |
+
DataFrame on deserialization.
|
| 287 |
+
"""
|
| 288 |
+
return {
|
| 289 |
+
"_type": "LatentClassResult",
|
| 290 |
+
"success": est.success,
|
| 291 |
+
"message": est.message,
|
| 292 |
+
"log_likelihood": _sanitize_float(est.log_likelihood),
|
| 293 |
+
"aic": _sanitize_float(est.aic),
|
| 294 |
+
"bic": _sanitize_float(est.bic),
|
| 295 |
+
"n_parameters": est.n_parameters,
|
| 296 |
+
"n_observations": est.n_observations,
|
| 297 |
+
"n_individuals": est.n_individuals,
|
| 298 |
+
"optimizer_iterations": est.optimizer_iterations,
|
| 299 |
+
"runtime_seconds": _sanitize_float(est.runtime_seconds),
|
| 300 |
+
"estimates": _serialize_dataframe(est.estimates),
|
| 301 |
+
"n_classes": est.n_classes,
|
| 302 |
+
"class_probabilities": _sanitize_list(list(est.class_probabilities)),
|
| 303 |
+
"class_estimates": _serialize_dataframe(est.class_estimates),
|
| 304 |
+
# posterior_probs intentionally skipped
|
| 305 |
+
"vcov_matrix": _serialize_ndarray(est.vcov_matrix),
|
| 306 |
+
"membership_estimates": _serialize_dataframe(est.membership_estimates),
|
| 307 |
+
"n_starts_attempted": est.n_starts_attempted,
|
| 308 |
+
"n_starts_succeeded": est.n_starts_succeeded,
|
| 309 |
+
"all_start_lls": _sanitize_list(list(est.all_start_lls)),
|
| 310 |
+
"best_start_index": est.best_start_index,
|
| 311 |
+
"optimizer_method": est.optimizer_method,
|
| 312 |
+
"em_iterations": est.em_iterations,
|
| 313 |
+
"em_ll_history": _sanitize_list(list(est.em_ll_history)),
|
| 314 |
+
"em_converged": est.em_converged,
|
| 315 |
+
"raw_theta": _serialize_ndarray(est.raw_theta),
|
| 316 |
+
}
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
def _deserialize_estimation_result(d: dict):
|
| 320 |
+
from dce_analyzer.model import EstimationResult
|
| 321 |
+
return EstimationResult(
|
| 322 |
+
success=d["success"],
|
| 323 |
+
message=d["message"],
|
| 324 |
+
log_likelihood=d["log_likelihood"],
|
| 325 |
+
aic=d["aic"],
|
| 326 |
+
bic=d["bic"],
|
| 327 |
+
n_parameters=d["n_parameters"],
|
| 328 |
+
n_observations=d["n_observations"],
|
| 329 |
+
n_individuals=d["n_individuals"],
|
| 330 |
+
optimizer_iterations=d["optimizer_iterations"],
|
| 331 |
+
runtime_seconds=d["runtime_seconds"],
|
| 332 |
+
estimates=_deserialize_dataframe(d["estimates"]),
|
| 333 |
+
vcov_matrix=_deserialize_ndarray(d.get("vcov_matrix")),
|
| 334 |
+
covariance_matrix=_deserialize_ndarray(d.get("covariance_matrix")),
|
| 335 |
+
correlation_matrix=_deserialize_ndarray(d.get("correlation_matrix")),
|
| 336 |
+
random_param_names=d.get("random_param_names"),
|
| 337 |
+
covariance_se=_deserialize_ndarray(d.get("covariance_se")),
|
| 338 |
+
correlation_se=_deserialize_ndarray(d.get("correlation_se")),
|
| 339 |
+
correlation_test=_deserialize_dataframe(d.get("correlation_test")),
|
| 340 |
+
raw_theta=_deserialize_ndarray(d.get("raw_theta")),
|
| 341 |
+
)
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
def _deserialize_latent_class_result(d: dict):
|
| 345 |
+
from dce_analyzer.latent_class import LatentClassResult
|
| 346 |
+
return LatentClassResult(
|
| 347 |
+
success=d["success"],
|
| 348 |
+
message=d["message"],
|
| 349 |
+
log_likelihood=d["log_likelihood"],
|
| 350 |
+
aic=d["aic"],
|
| 351 |
+
bic=d["bic"],
|
| 352 |
+
n_parameters=d["n_parameters"],
|
| 353 |
+
n_observations=d["n_observations"],
|
| 354 |
+
n_individuals=d["n_individuals"],
|
| 355 |
+
optimizer_iterations=d["optimizer_iterations"],
|
| 356 |
+
runtime_seconds=d["runtime_seconds"],
|
| 357 |
+
estimates=_deserialize_dataframe(d["estimates"]),
|
| 358 |
+
n_classes=d["n_classes"],
|
| 359 |
+
class_probabilities=d["class_probabilities"],
|
| 360 |
+
class_estimates=_deserialize_dataframe(d["class_estimates"]),
|
| 361 |
+
posterior_probs=pd.DataFrame(), # skipped on serialize
|
| 362 |
+
vcov_matrix=_deserialize_ndarray(d.get("vcov_matrix")),
|
| 363 |
+
membership_estimates=_deserialize_dataframe(d.get("membership_estimates")),
|
| 364 |
+
n_starts_attempted=d.get("n_starts_attempted", 0),
|
| 365 |
+
n_starts_succeeded=d.get("n_starts_succeeded", 0),
|
| 366 |
+
all_start_lls=d.get("all_start_lls", []),
|
| 367 |
+
best_start_index=d.get("best_start_index", -1),
|
| 368 |
+
optimizer_method=d.get("optimizer_method", "L-BFGS-B"),
|
| 369 |
+
em_iterations=d.get("em_iterations", 0),
|
| 370 |
+
em_ll_history=d.get("em_ll_history", []),
|
| 371 |
+
em_converged=d.get("em_converged", False),
|
| 372 |
+
raw_theta=_deserialize_ndarray(d.get("raw_theta")),
|
| 373 |
+
)
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
def _serialize_estimation(est) -> dict:
|
| 377 |
+
"""Route to the correct serializer based on type."""
|
| 378 |
+
from dce_analyzer.latent_class import LatentClassResult
|
| 379 |
+
if isinstance(est, LatentClassResult):
|
| 380 |
+
return _serialize_latent_class_result(est)
|
| 381 |
+
return _serialize_estimation_result(est)
|
| 382 |
+
|
| 383 |
+
|
| 384 |
+
def _deserialize_estimation(d: dict):
|
| 385 |
+
"""Route to the correct deserializer based on ``_type`` tag."""
|
| 386 |
+
if d.get("_type") == "LatentClassResult":
|
| 387 |
+
return _deserialize_latent_class_result(d)
|
| 388 |
+
return _deserialize_estimation_result(d)
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
# ---------------------------------------------------------------------------
|
| 392 |
+
# Public API
|
| 393 |
+
# ---------------------------------------------------------------------------
|
| 394 |
+
|
| 395 |
+
def serialize_model_entry(entry: dict) -> dict:
|
| 396 |
+
"""Convert a session_state model_history entry to a JSON-serializable dict.
|
| 397 |
+
|
| 398 |
+
Expected keys in *entry*: label, model_type, spec, full_spec, estimation.
|
| 399 |
+
"""
|
| 400 |
+
est = entry["estimation"]
|
| 401 |
+
return {
|
| 402 |
+
"label": entry.get("label", "unnamed"),
|
| 403 |
+
"model_type": entry.get("model_type", "mixed"),
|
| 404 |
+
"saved_at": datetime.now(timezone.utc).isoformat(),
|
| 405 |
+
"stats": {
|
| 406 |
+
"log_likelihood": _sanitize_float(est.log_likelihood),
|
| 407 |
+
"aic": _sanitize_float(est.aic),
|
| 408 |
+
"bic": _sanitize_float(est.bic),
|
| 409 |
+
"n_parameters": est.n_parameters,
|
| 410 |
+
"n_observations": est.n_observations,
|
| 411 |
+
"n_individuals": est.n_individuals,
|
| 412 |
+
"runtime_seconds": _sanitize_float(est.runtime_seconds),
|
| 413 |
+
},
|
| 414 |
+
"spec": _serialize_model_spec(entry.get("spec")),
|
| 415 |
+
"full_spec": _serialize_full_model_spec(entry.get("full_spec")),
|
| 416 |
+
"estimation_data": _serialize_estimation(est),
|
| 417 |
+
}
|
| 418 |
+
|
| 419 |
+
|
| 420 |
+
def deserialize_model_entry(data: dict) -> dict:
|
| 421 |
+
"""Reconstruct a model_history-compatible dict from stored JSON.
|
| 422 |
+
|
| 423 |
+
Returns a dict with keys: label, model_type, spec, full_spec, estimation.
|
| 424 |
+
"""
|
| 425 |
+
return {
|
| 426 |
+
"label": data.get("label", "unnamed"),
|
| 427 |
+
"model_type": data.get("model_type", "mixed"),
|
| 428 |
+
"saved_at": data.get("saved_at"),
|
| 429 |
+
"spec": _deserialize_model_spec(data.get("spec")),
|
| 430 |
+
"full_spec": _deserialize_full_model_spec(data.get("full_spec")),
|
| 431 |
+
"estimation": _deserialize_estimation(data["estimation_data"]),
|
| 432 |
+
}
|
| 433 |
+
|
| 434 |
+
|
| 435 |
+
def save_model(username: str, model_entry: dict) -> bool:
|
| 436 |
+
"""Save a model to the user's profile.
|
| 437 |
+
|
| 438 |
+
Enforces a maximum of ``_MAX_MODELS_PER_USER`` models.
|
| 439 |
+
Returns ``True`` on success.
|
| 440 |
+
"""
|
| 441 |
+
with _lock:
|
| 442 |
+
models = _load_user_models_from_hf(username)
|
| 443 |
+
if len(models) >= _MAX_MODELS_PER_USER:
|
| 444 |
+
logger.warning(
|
| 445 |
+
"User %s already has %d models (limit %d)",
|
| 446 |
+
username, len(models), _MAX_MODELS_PER_USER,
|
| 447 |
+
)
|
| 448 |
+
return False
|
| 449 |
+
serialized = serialize_model_entry(model_entry)
|
| 450 |
+
models.append(serialized)
|
| 451 |
+
_save_user_models_to_hf(username, models)
|
| 452 |
+
return True
|
| 453 |
+
|
| 454 |
+
|
| 455 |
+
def load_models(username: str) -> list[dict]:
|
| 456 |
+
"""Load all saved models for a user.
|
| 457 |
+
|
| 458 |
+
Returns a list of deserialized model_history-compatible dicts.
|
| 459 |
+
"""
|
| 460 |
+
with _lock:
|
| 461 |
+
raw_models = _load_user_models_from_hf(username)
|
| 462 |
+
result = []
|
| 463 |
+
for m in raw_models:
|
| 464 |
+
try:
|
| 465 |
+
result.append(deserialize_model_entry(m))
|
| 466 |
+
except Exception as exc:
|
| 467 |
+
logger.warning("Skipping corrupt model for %s: %s", username, exc)
|
| 468 |
+
return result
|
| 469 |
+
|
| 470 |
+
|
| 471 |
+
def delete_saved_model(username: str, index: int) -> bool:
|
| 472 |
+
"""Delete a saved model by its index (0-based).
|
| 473 |
+
|
| 474 |
+
Returns ``True`` on success.
|
| 475 |
+
"""
|
| 476 |
+
with _lock:
|
| 477 |
+
models = _load_user_models_from_hf(username)
|
| 478 |
+
if index < 0 or index >= len(models):
|
| 479 |
+
logger.warning("Invalid index %d for user %s (has %d models)", index, username, len(models))
|
| 480 |
+
return False
|
| 481 |
+
models.pop(index)
|
| 482 |
+
_save_user_models_to_hf(username, models)
|
| 483 |
+
return True
|
app/pages/2_⚙️_Model.py
CHANGED
|
@@ -1348,6 +1348,22 @@ if st.button("Run Estimation", type="primary", use_container_width=True):
|
|
| 1348 |
"estimation": estimation,
|
| 1349 |
})
|
| 1350 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1351 |
# Also store LC-specific result
|
| 1352 |
if model_type == "latent_class":
|
| 1353 |
st.session_state.lc_result = {
|
|
@@ -1495,4 +1511,53 @@ if st.session_state.model_history:
|
|
| 1495 |
st.session_state.model_results = None
|
| 1496 |
st.rerun()
|
| 1497 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1498 |
slowbro_next_step("pages/2_⚙️_Model.py")
|
|
|
|
| 1348 |
"estimation": estimation,
|
| 1349 |
})
|
| 1350 |
|
| 1351 |
+
# Save to profile button
|
| 1352 |
+
_username = st.session_state.get("username", "")
|
| 1353 |
+
if _username:
|
| 1354 |
+
_n_saved = len(st.session_state.get("saved_models", []))
|
| 1355 |
+
if _n_saved >= 10:
|
| 1356 |
+
st.warning("You've reached the limit of 10 saved models. Delete one to save more.")
|
| 1357 |
+
else:
|
| 1358 |
+
if st.button("Save to Profile", key=f"_save_profile_{run_label}", type="secondary"):
|
| 1359 |
+
from model_store import save_model, load_models
|
| 1360 |
+
_entry = st.session_state.model_history[-1]
|
| 1361 |
+
if save_model(_username, _entry):
|
| 1362 |
+
st.session_state.saved_models = load_models(_username)
|
| 1363 |
+
st.success(f"Model '{run_label}' saved to your profile!")
|
| 1364 |
+
else:
|
| 1365 |
+
st.error("Could not save -- you may have reached the 10-model limit.")
|
| 1366 |
+
|
| 1367 |
# Also store LC-specific result
|
| 1368 |
if model_type == "latent_class":
|
| 1369 |
st.session_state.lc_result = {
|
|
|
|
| 1511 |
st.session_state.model_results = None
|
| 1512 |
st.rerun()
|
| 1513 |
|
| 1514 |
+
# ── Saved models from profile ──────────────────────────────────
|
| 1515 |
+
_saved = st.session_state.get("saved_models", [])
|
| 1516 |
+
if _saved:
|
| 1517 |
+
st.divider()
|
| 1518 |
+
st.subheader("Saved Models (Profile)")
|
| 1519 |
+
st.caption("These models are saved to your profile and persist across sessions.")
|
| 1520 |
+
_profile_delete_idx: int | None = None
|
| 1521 |
+
_profile_load_idx: int | None = None
|
| 1522 |
+
for i, sm in enumerate(_saved):
|
| 1523 |
+
est = sm.get("estimation")
|
| 1524 |
+
_info_c, _load_c, _del_c = st.columns([5, 1, 1])
|
| 1525 |
+
with _info_c:
|
| 1526 |
+
_saved_at = sm.get("saved_at", "")
|
| 1527 |
+
if _saved_at:
|
| 1528 |
+
_saved_at = _saved_at[:10] # just date
|
| 1529 |
+
st.markdown(
|
| 1530 |
+
f"**{i+1}. {sm.get('label', 'model')}** ({sm.get('model_type', '?')}) "
|
| 1531 |
+
f"-- LL: {est.log_likelihood:.3f}, AIC: {est.aic:.2f}, BIC: {est.bic:.2f} "
|
| 1532 |
+
f"<span style='color:gray;font-size:0.8em;'>({_saved_at})</span>",
|
| 1533 |
+
unsafe_allow_html=True,
|
| 1534 |
+
)
|
| 1535 |
+
with _load_c:
|
| 1536 |
+
if st.button("Load", key=f"_load_saved_{i}"):
|
| 1537 |
+
_profile_load_idx = i
|
| 1538 |
+
with _del_c:
|
| 1539 |
+
if st.button("Delete", key=f"_del_saved_{i}"):
|
| 1540 |
+
_profile_delete_idx = i
|
| 1541 |
+
|
| 1542 |
+
if _profile_load_idx is not None:
|
| 1543 |
+
_loaded = _saved[_profile_load_idx]
|
| 1544 |
+
st.session_state.model_results = {
|
| 1545 |
+
"spec": _loaded.get("spec"),
|
| 1546 |
+
"full_spec": _loaded.get("full_spec"),
|
| 1547 |
+
"model_type": _loaded.get("model_type"),
|
| 1548 |
+
"estimation": _loaded["estimation"],
|
| 1549 |
+
"label": _loaded.get("label"),
|
| 1550 |
+
"expanded_df": None,
|
| 1551 |
+
}
|
| 1552 |
+
if _loaded not in st.session_state.model_history:
|
| 1553 |
+
st.session_state.model_history.append(_loaded)
|
| 1554 |
+
st.rerun()
|
| 1555 |
+
|
| 1556 |
+
if _profile_delete_idx is not None:
|
| 1557 |
+
_username = st.session_state.get("username", "")
|
| 1558 |
+
from model_store import delete_saved_model, load_models
|
| 1559 |
+
if delete_saved_model(_username, _profile_delete_idx):
|
| 1560 |
+
st.session_state.saved_models = load_models(_username)
|
| 1561 |
+
st.rerun()
|
| 1562 |
+
|
| 1563 |
slowbro_next_step("pages/2_⚙️_Model.py")
|
app/session_queue.py
CHANGED
|
@@ -21,7 +21,7 @@ import streamlit as st
|
|
| 21 |
# ---------------------------------------------------------------------------
|
| 22 |
|
| 23 |
_MAX_CONCURRENT = int(os.environ.get("PREFERO_MAX_CONCURRENT", "5"))
|
| 24 |
-
_SESSION_TIMEOUT =
|
| 25 |
|
| 26 |
|
| 27 |
def _queue_enabled() -> bool:
|
|
@@ -212,7 +212,7 @@ def queue_gate() -> bool:
|
|
| 212 |
# ── Session policy note ──
|
| 213 |
st.warning(
|
| 214 |
"**How the queue works:** Each user gets a seat for as long as "
|
| 215 |
-
"they're active. Sessions expire after **
|
| 216 |
"to keep things moving — but if you're running a model, your seat "
|
| 217 |
"is safe until estimation completes."
|
| 218 |
)
|
|
|
|
| 21 |
# ---------------------------------------------------------------------------
|
| 22 |
|
| 23 |
_MAX_CONCURRENT = int(os.environ.get("PREFERO_MAX_CONCURRENT", "5"))
|
| 24 |
+
_SESSION_TIMEOUT = 1800 # 30 minutes of inactivity → evicted
|
| 25 |
|
| 26 |
|
| 27 |
def _queue_enabled() -> bool:
|
|
|
|
| 212 |
# ── Session policy note ──
|
| 213 |
st.warning(
|
| 214 |
"**How the queue works:** Each user gets a seat for as long as "
|
| 215 |
+
"they're active. Sessions expire after **30 minutes** of inactivity "
|
| 216 |
"to keep things moving — but if you're running a model, your seat "
|
| 217 |
"is safe until estimation completes."
|
| 218 |
)
|
app/utils.py
CHANGED
|
@@ -26,6 +26,7 @@ _SESSION_DEFAULTS: dict[str, object] = {
|
|
| 26 |
"lc_result": None,
|
| 27 |
"lc_bic_comparison": None,
|
| 28 |
"lc_best_q": None,
|
|
|
|
| 29 |
"authenticated": False,
|
| 30 |
"auth_email": "",
|
| 31 |
"username": "",
|
|
@@ -139,6 +140,12 @@ def init_session_state() -> None:
|
|
| 139 |
st.stop()
|
| 140 |
require_queue_slot()
|
| 141 |
queue_heartbeat()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 142 |
st.session_state["_queue_admitted"] = True
|
| 143 |
_inject_activity_heartbeat()
|
| 144 |
import inspect
|
|
@@ -194,6 +201,26 @@ def slowbro_status() -> None:
|
|
| 194 |
[data-testid="stStatusWidget"] {{
|
| 195 |
display: none !important;
|
| 196 |
}}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 197 |
{_hide_admin_css}
|
| 198 |
/* Slowbro pill — injected into Streamlit's fixed header via ::after */
|
| 199 |
[data-testid="stHeader"]::after {{
|
|
@@ -237,6 +264,28 @@ def sidebar_branding() -> None:
|
|
| 237 |
else:
|
| 238 |
st.sidebar.info("No data loaded yet.")
|
| 239 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 240 |
# Developer mode prompt (admin users only)
|
| 241 |
if _is_admin_user() and not st.session_state.get("_dev_mode_active"):
|
| 242 |
import os
|
|
@@ -290,11 +339,14 @@ def language_banner() -> None:
|
|
| 290 |
0% { transform: translateX(0%); }
|
| 291 |
100% { transform: translateX(-50%); }
|
| 292 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
| 293 |
.scroll-banner {
|
| 294 |
overflow: hidden;
|
| 295 |
white-space: nowrap;
|
| 296 |
padding: 12px 0;
|
| 297 |
-
margin-bottom: 8px;
|
| 298 |
border-top: 1px solid rgba(128,128,128,0.2);
|
| 299 |
border-bottom: 1px solid rgba(128,128,128,0.2);
|
| 300 |
}
|
|
@@ -309,6 +361,52 @@ def language_banner() -> None:
|
|
| 309 |
padding: 0 16px;
|
| 310 |
}
|
| 311 |
.scroll-inner .zh { font-weight: 700; opacity: 1.0; font-size: 1.2rem; }
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 312 |
</style>
|
| 313 |
""",
|
| 314 |
unsafe_allow_html=True,
|
|
@@ -322,8 +420,20 @@ def language_banner() -> None:
|
|
| 322 |
|
| 323 |
st.markdown(
|
| 324 |
f"""
|
| 325 |
-
<div class="scroll-banner">
|
| 326 |
-
<div class="scroll-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 327 |
</div>
|
| 328 |
""",
|
| 329 |
unsafe_allow_html=True,
|
|
|
|
| 26 |
"lc_result": None,
|
| 27 |
"lc_bic_comparison": None,
|
| 28 |
"lc_best_q": None,
|
| 29 |
+
"saved_models": [],
|
| 30 |
"authenticated": False,
|
| 31 |
"auth_email": "",
|
| 32 |
"username": "",
|
|
|
|
| 140 |
st.stop()
|
| 141 |
require_queue_slot()
|
| 142 |
queue_heartbeat()
|
| 143 |
+
if not st.session_state.get("_saved_models_loaded"):
|
| 144 |
+
username = st.session_state.get("username", "")
|
| 145 |
+
if username:
|
| 146 |
+
from model_store import load_models
|
| 147 |
+
st.session_state.saved_models = load_models(username)
|
| 148 |
+
st.session_state._saved_models_loaded = True
|
| 149 |
st.session_state["_queue_admitted"] = True
|
| 150 |
_inject_activity_heartbeat()
|
| 151 |
import inspect
|
|
|
|
| 201 |
[data-testid="stStatusWidget"] {{
|
| 202 |
display: none !important;
|
| 203 |
}}
|
| 204 |
+
/* Keep header fixed at top when scrolling */
|
| 205 |
+
[data-testid="stHeader"] {{
|
| 206 |
+
position: fixed !important;
|
| 207 |
+
top: 0 !important;
|
| 208 |
+
left: 0 !important;
|
| 209 |
+
right: 0 !important;
|
| 210 |
+
z-index: 1000 !important;
|
| 211 |
+
}}
|
| 212 |
+
/* Keep sidebar fixed when scrolling */
|
| 213 |
+
section[data-testid="stSidebar"] {{
|
| 214 |
+
position: fixed !important;
|
| 215 |
+
height: 100vh !important;
|
| 216 |
+
top: 0 !important;
|
| 217 |
+
left: 0 !important;
|
| 218 |
+
z-index: 999 !important;
|
| 219 |
+
}}
|
| 220 |
+
section[data-testid="stSidebar"] > div {{
|
| 221 |
+
height: 100vh !important;
|
| 222 |
+
overflow-y: auto !important;
|
| 223 |
+
}}
|
| 224 |
{_hide_admin_css}
|
| 225 |
/* Slowbro pill — injected into Streamlit's fixed header via ::after */
|
| 226 |
[data-testid="stHeader"]::after {{
|
|
|
|
| 264 |
else:
|
| 265 |
st.sidebar.info("No data loaded yet.")
|
| 266 |
|
| 267 |
+
# ── Saved models in sidebar ──
|
| 268 |
+
_saved = st.session_state.get("saved_models", [])
|
| 269 |
+
if _saved:
|
| 270 |
+
with st.sidebar.expander(f"Saved Models ({len(_saved)}/10)", expanded=False):
|
| 271 |
+
_sb_delete_idx: int | None = None
|
| 272 |
+
for i, sm in enumerate(_saved):
|
| 273 |
+
est = sm.get("estimation")
|
| 274 |
+
_ll = f"{est.log_likelihood:.1f}" if est else "?"
|
| 275 |
+
_col_info, _col_del = st.columns([5, 1])
|
| 276 |
+
with _col_info:
|
| 277 |
+
st.caption(f"**{sm.get('label', 'model')}** ({sm.get('model_type', '?')}) LL:{_ll}")
|
| 278 |
+
with _col_del:
|
| 279 |
+
if st.button("✕", key=f"_sb_del_saved_{i}", help="Delete"):
|
| 280 |
+
_sb_delete_idx = i
|
| 281 |
+
if _sb_delete_idx is not None:
|
| 282 |
+
_uname = st.session_state.get("username", "")
|
| 283 |
+
if _uname:
|
| 284 |
+
from model_store import delete_saved_model, load_models
|
| 285 |
+
if delete_saved_model(_uname, _sb_delete_idx):
|
| 286 |
+
st.session_state.saved_models = load_models(_uname)
|
| 287 |
+
st.rerun()
|
| 288 |
+
|
| 289 |
# Developer mode prompt (admin users only)
|
| 290 |
if _is_admin_user() and not st.session_state.get("_dev_mode_active"):
|
| 291 |
import os
|
|
|
|
| 339 |
0% { transform: translateX(0%); }
|
| 340 |
100% { transform: translateX(-50%); }
|
| 341 |
}
|
| 342 |
+
.scroll-banner-wrap {
|
| 343 |
+
position: relative;
|
| 344 |
+
margin-bottom: 8px;
|
| 345 |
+
}
|
| 346 |
.scroll-banner {
|
| 347 |
overflow: hidden;
|
| 348 |
white-space: nowrap;
|
| 349 |
padding: 12px 0;
|
|
|
|
| 350 |
border-top: 1px solid rgba(128,128,128,0.2);
|
| 351 |
border-bottom: 1px solid rgba(128,128,128,0.2);
|
| 352 |
}
|
|
|
|
| 361 |
padding: 0 16px;
|
| 362 |
}
|
| 363 |
.scroll-inner .zh { font-weight: 700; opacity: 1.0; font-size: 1.2rem; }
|
| 364 |
+
.banner-help {
|
| 365 |
+
position: absolute;
|
| 366 |
+
right: 6px;
|
| 367 |
+
top: 50%;
|
| 368 |
+
transform: translateY(-50%);
|
| 369 |
+
z-index: 10;
|
| 370 |
+
}
|
| 371 |
+
.banner-help-icon {
|
| 372 |
+
display: flex;
|
| 373 |
+
align-items: center;
|
| 374 |
+
justify-content: center;
|
| 375 |
+
width: 20px;
|
| 376 |
+
height: 20px;
|
| 377 |
+
border-radius: 50%;
|
| 378 |
+
background: rgba(128,128,128,0.18);
|
| 379 |
+
color: rgba(128,128,128,0.7);
|
| 380 |
+
font-size: 0.75rem;
|
| 381 |
+
font-weight: 600;
|
| 382 |
+
cursor: default;
|
| 383 |
+
user-select: none;
|
| 384 |
+
line-height: 1;
|
| 385 |
+
}
|
| 386 |
+
.banner-help-icon:hover {
|
| 387 |
+
background: rgba(128,128,128,0.30);
|
| 388 |
+
color: rgba(128,128,128,0.95);
|
| 389 |
+
}
|
| 390 |
+
.banner-help-tooltip {
|
| 391 |
+
display: none;
|
| 392 |
+
position: absolute;
|
| 393 |
+
right: 0;
|
| 394 |
+
top: 28px;
|
| 395 |
+
width: 260px;
|
| 396 |
+
padding: 12px 14px;
|
| 397 |
+
background: var(--background-color, #fff);
|
| 398 |
+
border: 1px solid rgba(128,128,128,0.2);
|
| 399 |
+
border-radius: 8px;
|
| 400 |
+
box-shadow: 0 4px 16px rgba(0,0,0,0.10);
|
| 401 |
+
white-space: normal;
|
| 402 |
+
font-size: 0.82rem;
|
| 403 |
+
line-height: 1.5;
|
| 404 |
+
color: var(--text-color, #444);
|
| 405 |
+
z-index: 100;
|
| 406 |
+
}
|
| 407 |
+
.banner-help:hover .banner-help-tooltip {
|
| 408 |
+
display: block;
|
| 409 |
+
}
|
| 410 |
</style>
|
| 411 |
""",
|
| 412 |
unsafe_allow_html=True,
|
|
|
|
| 420 |
|
| 421 |
st.markdown(
|
| 422 |
f"""
|
| 423 |
+
<div class="scroll-banner-wrap">
|
| 424 |
+
<div class="scroll-banner">
|
| 425 |
+
<div class="scroll-inner">{doubled}</div>
|
| 426 |
+
</div>
|
| 427 |
+
<div class="banner-help">
|
| 428 |
+
<div class="banner-help-icon">?</div>
|
| 429 |
+
<div class="banner-help-tooltip">
|
| 430 |
+
The name <b>Prefero</b> is shown in many languages
|
| 431 |
+
to celebrate the wonderful diversity of our users.
|
| 432 |
+
Translations may not be perfectly accurate —
|
| 433 |
+
if you spot an error, we would love to hear from you
|
| 434 |
+
on the <b>Community</b> page!
|
| 435 |
+
</div>
|
| 436 |
+
</div>
|
| 437 |
</div>
|
| 438 |
""",
|
| 439 |
unsafe_allow_html=True,
|