qstn_gui_test / src /gui_elements /session_cache.py
AhmedASalem's picture
Improve login UI design - cleaner, centered layout
157e0f8
import streamlit as st
import pickle
import os
import shutil
import secrets
import urllib.parse
import requests
import re
from pathlib import Path
from datetime import datetime
# OAuth helper functions
def get_space_url() -> str:
"""Get the base URL for the Space."""
space_base_url = os.getenv("SPACE_BASE_URL")
if space_base_url:
return space_base_url.rstrip("/")
# Default to hf.space URL
space_host = os.getenv("SPACE_HOST", "")
if space_host:
return f"https://{space_host}".rstrip("/")
return ""
def oidc_config():
"""Get OpenID Connect configuration."""
provider = os.environ.get("OPENID_PROVIDER_URL", "https://huggingface.co").rstrip("/")
try:
return requests.get(f"{provider}/.well-known/openid-configuration", timeout=10).json()
except Exception as e:
st.error(f"Failed to fetch OIDC config: {e}")
return {}
def login_button():
"""Display a login button that redirects to HF OAuth."""
if not os.environ.get("OAUTH_CLIENT_ID"):
st.error("OAuth is not configured. Please enable hf_oauth: true in README.md")
return
try:
cfg = oidc_config()
if not cfg:
st.error("Could not fetch OpenID configuration. Please try again later.")
return
auth_endpoint = cfg.get("authorization_endpoint")
if not auth_endpoint:
st.error("Could not get authorization endpoint from OpenID configuration.")
return
state = secrets.token_urlsafe(24)
st.session_state["oauth_state"] = state
redirect_uri = get_space_url() + "/"
if not redirect_uri:
st.warning("Could not determine Space URL. OAuth may not work correctly.")
redirect_uri = "https://" + os.getenv("SPACE_HOST", "") + "/"
params = {
"client_id": os.environ["OAUTH_CLIENT_ID"],
"redirect_uri": redirect_uri,
"response_type": "code",
"scope": os.getenv("OAUTH_SCOPES", "openid profile"),
"state": state,
}
url = auth_endpoint + "?" + urllib.parse.urlencode(params)
# Make button larger and more prominent
st.link_button("🔐 Sign in with Hugging Face", url, use_container_width=True, type="primary")
except Exception as e:
st.error(f"Error setting up OAuth login: {e}")
st.info("Please check the Space logs for more details.")
def handle_oauth_callback():
"""Handle OAuth callback and exchange code for token."""
qp = st.query_params
if "code" not in qp:
return
# CSRF protection
returned_state = qp.get("state")
expected_state = st.session_state.get("oauth_state")
if expected_state and returned_state != expected_state:
st.error("OAuth state mismatch. Please try signing in again.")
return
cfg = oidc_config()
token_endpoint = cfg.get("token_endpoint")
userinfo_endpoint = cfg.get("userinfo_endpoint")
if not token_endpoint:
st.error("Could not get token endpoint")
return
redirect_uri = get_space_url() + "/"
data = {
"grant_type": "authorization_code",
"code": qp["code"],
"redirect_uri": redirect_uri,
"client_id": os.environ["OAUTH_CLIENT_ID"],
"client_secret": os.environ["OAUTH_CLIENT_SECRET"],
}
try:
token_response = requests.post(token_endpoint, data=data, timeout=10)
token_response.raise_for_status()
token = token_response.json()
if "access_token" not in token:
st.error(f"Token exchange failed: {token}")
return
# Fetch user profile
profile = {}
if userinfo_endpoint:
profile_response = requests.get(
userinfo_endpoint,
headers={"Authorization": f"Bearer {token['access_token']}"},
timeout=10,
)
profile_response.raise_for_status()
profile = profile_response.json()
elif "id_token" in token:
# Decode ID token if userinfo not available
import base64
import json
try:
# JWT is base64url encoded (3 parts separated by .)
id_token_parts = token["id_token"].split(".")
if len(id_token_parts) >= 2:
# Decode payload (second part)
payload = id_token_parts[1]
# Add padding if needed
payload += "=" * (4 - len(payload) % 4)
decoded = base64.urlsafe_b64decode(payload)
profile = json.loads(decoded)
except Exception:
pass
st.session_state["hf_profile"] = profile
st.session_state["hf_access_token"] = token["access_token"]
# Clean the URL (remove code/state)
st.query_params.clear()
st.rerun()
except Exception as e:
st.error(f"OAuth callback failed: {e}")
def get_user_id() -> str | None:
"""Get the current user's ID from OAuth profile."""
profile = st.session_state.get("hf_profile") or {}
# Common OpenID fields: preferred_username, name, sub
return profile.get("preferred_username") or profile.get("name") or profile.get("sub")
def safe_user_id(uid: str) -> str:
"""Make user ID filesystem-safe."""
return re.sub(r"[^a-zA-Z0-9_.-]", "_", uid)[:80]
def is_huggingface_space() -> bool:
"""
Detect if running on Hugging Face Spaces.
Checks for HF-specific environment variables.
"""
# Hugging Face Spaces set these environment variables
return bool(os.environ.get("SPACE_ID") or os.environ.get("SPACE_HOST"))
def is_user_logged_in() -> bool:
"""
Check if user is logged in via OAuth.
"""
return get_user_id() is not None
def require_hf_login() -> tuple[bool, str]:
"""
Check if login is required and if user is logged in.
Returns (is_required, error_message)
- If on Hugging Face: login is required
- If not on Hugging Face: login is not required
"""
if is_huggingface_space():
if not is_user_logged_in():
return True, (
"🔒 **Authentication Required**\n\n"
"This application requires you to be logged in to Hugging Face to use it. "
"Please log in using the button in the top right corner of this page."
)
return False, ""
def _get_local_user_identifier_file() -> Path:
"""Get the path to the file storing the local user identifier."""
return Path(".session_cache") / ".local_user_id"
def get_user_identifier() -> str:
"""
Get a unique identifier for the current user.
- On Hugging Face Spaces with OAuth: uses OAuth username
- Locally: uses a persistent file-based identifier that survives refreshes
"""
# Check for OAuth user first
user_id = get_user_id()
if user_id:
return safe_user_id(user_id)
# For local development: use a persistent file-based identifier
# This persists across page refreshes and app restarts
identifier_file = _get_local_user_identifier_file()
# Try to load from file first
if identifier_file.exists():
try:
with open(identifier_file, 'r') as f:
stored_id = f.read().strip()
if stored_id:
st.session_state.user_identifier = stored_id
return stored_id
except Exception:
pass
# File doesn't exist or couldn't be read - create new identifier
# Generate a stable identifier for this local session
new_id = f"local_{secrets.token_hex(8)}"
# Save to file for persistence
try:
identifier_file.parent.mkdir(parents=True, exist_ok=True)
with open(identifier_file, 'w') as f:
f.write(new_id)
except Exception:
pass
st.session_state.user_identifier = new_id
return new_id
def get_user_cache_dir() -> Path:
"""
Get a user-specific cache directory.
Uses /data if persistent storage is enabled, otherwise .session_cache
"""
# Use persistent storage if available, otherwise use .session_cache
base_dir = Path("/data") if Path("/data").exists() else Path(".session_cache")
user_cache_base = base_dir / "user_cache"
# Get user ID
user_id = get_user_id()
if user_id:
safe_uid = safe_user_id(user_id)
user_cache_dir = user_cache_base / safe_uid
user_cache_dir.mkdir(parents=True, exist_ok=True)
return user_cache_dir
# Locally: use shared cache directory
cache_dir = base_dir / "session_cache"
cache_dir.mkdir(parents=True, exist_ok=True)
return cache_dir
# Directory to store session cache files (now user-specific)
def get_cache_dir() -> Path:
"""Get the cache directory for the current user."""
return get_user_cache_dir()
def get_session_file_path(session_id: str = "default") -> Path:
"""Get the file path for a session cache (user-specific)."""
return get_cache_dir() / f"session_{session_id}.pkl"
def generate_session_id() -> str:
"""Generate a unique session ID based on timestamp."""
return datetime.now().strftime("%Y%m%d_%H%M%S")
def get_current_session_id() -> str:
"""Get the current active session ID, creating one if it doesn't exist."""
if "current_session_id" not in st.session_state:
st.session_state.current_session_id = generate_session_id()
return st.session_state.current_session_id
def save_session_state(session_id: str = None, session_name: str = None) -> bool:
"""
Save current session state to disk.
If session_id is None, uses the current active session ID.
If session_name is None, uses the current session name or generates a default.
Returns True if successful, False otherwise.
"""
try:
# Use provided session_id or get current active one
if session_id is None:
session_id = get_current_session_id()
# Use provided session_name or get current one
if session_name is None:
session_name = st.session_state.get("current_session_name", f"Session {session_id}")
# Store the session name in session state
st.session_state.current_session_name = session_name
# Prepare session data
session_data = _prepare_session_data(session_id, session_name)
# Save using pickle (for complex objects like LLMPrompt)
cache_file = get_session_file_path(session_id)
with open(cache_file, 'wb') as f:
pickle.dump(session_data, f)
return True
except Exception as e:
st.error(f"Error saving session: {e}")
return False
def load_session_state(session_id: str = "default") -> bool:
"""
Load session state from disk.
Sets the loaded session as the current active session.
Returns True if successful, False otherwise.
"""
try:
cache_file = get_session_file_path(session_id)
if not cache_file.exists():
return False
with open(cache_file, 'rb') as f:
session_data = pickle.load(f)
# Set this as the current active session
st.session_state.current_session_id = session_id
st.session_state.current_session_name = session_data.get("_session_name", f"Session {session_id}")
# Restore session state (skip internal metadata)
for key, value in session_data.items():
if not key.startswith("_"):
st.session_state[key] = value
return True
except Exception as e:
st.error(f"Error loading session: {e}")
return False
def list_available_sessions() -> list[dict]:
"""List all available saved sessions for the current user."""
sessions = []
cache_dir = get_cache_dir() # Now user-specific
for cache_file in cache_dir.glob("session_*.pkl"):
try:
with open(cache_file, 'rb') as f:
session_data = pickle.load(f)
session_id = session_data.get("_session_id", cache_file.stem.replace("session_", ""))
session_name = session_data.get("_session_name", f"Session {session_id}")
timestamp = session_data.get("_timestamp", "Unknown")
sessions.append({
"id": session_id,
"name": session_name,
"timestamp": timestamp,
"file": cache_file
})
except Exception:
continue
return sorted(sessions, key=lambda x: x["timestamp"], reverse=True)
def delete_session(session_id: str = "default") -> bool:
"""Delete a saved session (user-specific)."""
try:
cache_file = get_session_file_path(session_id) # Already user-specific
if cache_file.exists():
cache_file.unlink()
return True
return False
except Exception as e:
st.error(f"Error deleting session: {e}")
return False
def clear_all_sessions():
"""Clear all saved sessions for the current user."""
cache_dir = get_cache_dir()
for cache_file in cache_dir.glob("session_*.pkl"):
try:
cache_file.unlink()
except Exception:
continue
def rename_session(session_id: str, new_name: str) -> bool:
"""Rename a session by updating its metadata."""
try:
cache_file = get_session_file_path(session_id)
if not cache_file.exists():
return False
# Load existing session
with open(cache_file, 'rb') as f:
session_data = pickle.load(f)
# Update name
session_data["_session_name"] = new_name
# Save back
with open(cache_file, 'wb') as f:
pickle.dump(session_data, f)
# Update current session name if this is the active session
if st.session_state.get("current_session_id") == session_id:
st.session_state.current_session_name = new_name
return True
except Exception as e:
st.error(f"Error renaming session: {e}")
return False
def _prepare_session_data(session_id: str, session_name: str) -> dict:
"""Helper function to prepare session data for saving."""
session_data = {}
# Save dataframes
if "df_questionnaire" in st.session_state:
session_data["df_questionnaire"] = st.session_state.df_questionnaire
if "df_population" in st.session_state:
session_data["df_population"] = st.session_state.df_population
# Save questionnaires (LLMPrompt objects - need pickle)
if "questionnaires" in st.session_state:
session_data["questionnaires"] = st.session_state.questionnaires
# Save inference configs
if "client_config" in st.session_state:
session_data["client_config"] = st.session_state.client_config
if "inference_config" in st.session_state:
session_data["inference_config"] = st.session_state.inference_config
# Save survey options
if "survey_options" in st.session_state:
session_data["survey_options"] = st.session_state.survey_options
# Save other important state
important_keys = [
"model_name", "temperature", "max_tokens", "top_p", "seed",
"api_key", "base_url", "organization", "project",
"advanced_client_params_str", "advanced_inference_params_str",
"timeout", "max_retries"
]
for key in important_keys:
if key in st.session_state:
session_data[key] = st.session_state[key]
# Save timestamp
session_data["_timestamp"] = datetime.now().isoformat()
session_data["_session_id"] = session_id
session_data["_session_name"] = session_name
return session_data
def save_session_state_to_path(session_id: str = None, session_name: str = None, save_path: Path = None) -> bool:
"""
Save current session state to a specific path.
If session_id is None, uses the current active session ID.
If session_name is None, uses the current session name or generates a default.
If save_path is None, uses the default cache directory.
Returns True if successful, False otherwise.
"""
try:
# Use provided session_id or get current active one
if session_id is None:
session_id = get_current_session_id()
# Use provided session_name or get current one
if session_name is None:
session_name = st.session_state.get("current_session_name", f"Session {session_id}")
# Store the session name in session state
st.session_state.current_session_name = session_name
# Prepare session data
session_data = _prepare_session_data(session_id, session_name)
# Determine save path
if save_path is None:
save_path = get_session_file_path(session_id)
else:
# Ensure the directory exists
save_path.parent.mkdir(parents=True, exist_ok=True)
# If it's a directory, append the filename
if save_path.is_dir():
save_path = save_path / f"session_{session_id}.pkl"
# Ensure it has .pkl extension
elif not save_path.suffix == ".pkl":
save_path = save_path.with_suffix(".pkl")
# Save using pickle
with open(save_path, 'wb') as f:
pickle.dump(session_data, f)
return True
except Exception as e:
st.error(f"Error saving session: {e}")
return False
def load_session_state_from_path(file_path: Path) -> bool:
"""
Load session state from a specific file path.
Sets the loaded session as the current active session.
Returns True if successful, False otherwise.
"""
try:
if not file_path.exists():
return False
with open(file_path, 'rb') as f:
session_data = pickle.load(f)
# Get session ID from metadata or generate one
session_id = session_data.get("_session_id", generate_session_id())
# Set this as the current active session
st.session_state.current_session_id = session_id
st.session_state.current_session_name = session_data.get("_session_name", f"Session {session_id}")
# Restore session state (skip internal metadata)
for key, value in session_data.items():
if not key.startswith("_"):
st.session_state[key] = value
# Optionally copy to default cache directory for easy access
default_path = get_session_file_path(session_id)
if file_path != default_path:
try:
shutil.copy2(file_path, default_path)
except Exception:
pass # If copy fails, that's okay - session is still loaded
return True
except Exception as e:
st.error(f"Error loading session: {e}")
return False