| import streamlit as st |
| import pickle |
| import os |
| import shutil |
| import secrets |
| import urllib.parse |
| import requests |
| import re |
| from pathlib import Path |
| from datetime import datetime |
|
|
| |
| def get_space_url() -> str: |
| """Get the base URL for the Space.""" |
| space_base_url = os.getenv("SPACE_BASE_URL") |
| if space_base_url: |
| return space_base_url.rstrip("/") |
| |
| space_host = os.getenv("SPACE_HOST", "") |
| if space_host: |
| return f"https://{space_host}".rstrip("/") |
| return "" |
|
|
| def oidc_config(): |
| """Get OpenID Connect configuration.""" |
| provider = os.environ.get("OPENID_PROVIDER_URL", "https://huggingface.co").rstrip("/") |
| try: |
| return requests.get(f"{provider}/.well-known/openid-configuration", timeout=10).json() |
| except Exception as e: |
| st.error(f"Failed to fetch OIDC config: {e}") |
| return {} |
|
|
| def login_button(): |
| """Display a login button that redirects to HF OAuth.""" |
| if not os.environ.get("OAUTH_CLIENT_ID"): |
| st.error("OAuth is not configured. Please enable hf_oauth: true in README.md") |
| return |
| |
| try: |
| cfg = oidc_config() |
| if not cfg: |
| st.error("Could not fetch OpenID configuration. Please try again later.") |
| return |
| |
| auth_endpoint = cfg.get("authorization_endpoint") |
| if not auth_endpoint: |
| st.error("Could not get authorization endpoint from OpenID configuration.") |
| return |
|
|
| state = secrets.token_urlsafe(24) |
| st.session_state["oauth_state"] = state |
|
|
| redirect_uri = get_space_url() + "/" |
| if not redirect_uri: |
| st.warning("Could not determine Space URL. OAuth may not work correctly.") |
| redirect_uri = "https://" + os.getenv("SPACE_HOST", "") + "/" |
| |
| params = { |
| "client_id": os.environ["OAUTH_CLIENT_ID"], |
| "redirect_uri": redirect_uri, |
| "response_type": "code", |
| "scope": os.getenv("OAUTH_SCOPES", "openid profile"), |
| "state": state, |
| } |
| url = auth_endpoint + "?" + urllib.parse.urlencode(params) |
| |
| st.link_button("🔐 Sign in with Hugging Face", url, use_container_width=True, type="primary") |
| except Exception as e: |
| st.error(f"Error setting up OAuth login: {e}") |
| st.info("Please check the Space logs for more details.") |
|
|
| def handle_oauth_callback(): |
| """Handle OAuth callback and exchange code for token.""" |
| qp = st.query_params |
| if "code" not in qp: |
| return |
|
|
| |
| returned_state = qp.get("state") |
| expected_state = st.session_state.get("oauth_state") |
| if expected_state and returned_state != expected_state: |
| st.error("OAuth state mismatch. Please try signing in again.") |
| return |
|
|
| cfg = oidc_config() |
| token_endpoint = cfg.get("token_endpoint") |
| userinfo_endpoint = cfg.get("userinfo_endpoint") |
|
|
| if not token_endpoint: |
| st.error("Could not get token endpoint") |
| return |
|
|
| redirect_uri = get_space_url() + "/" |
|
|
| data = { |
| "grant_type": "authorization_code", |
| "code": qp["code"], |
| "redirect_uri": redirect_uri, |
| "client_id": os.environ["OAUTH_CLIENT_ID"], |
| "client_secret": os.environ["OAUTH_CLIENT_SECRET"], |
| } |
| |
| try: |
| token_response = requests.post(token_endpoint, data=data, timeout=10) |
| token_response.raise_for_status() |
| token = token_response.json() |
| |
| if "access_token" not in token: |
| st.error(f"Token exchange failed: {token}") |
| return |
|
|
| |
| profile = {} |
| if userinfo_endpoint: |
| profile_response = requests.get( |
| userinfo_endpoint, |
| headers={"Authorization": f"Bearer {token['access_token']}"}, |
| timeout=10, |
| ) |
| profile_response.raise_for_status() |
| profile = profile_response.json() |
| elif "id_token" in token: |
| |
| import base64 |
| import json |
| try: |
| |
| id_token_parts = token["id_token"].split(".") |
| if len(id_token_parts) >= 2: |
| |
| payload = id_token_parts[1] |
| |
| payload += "=" * (4 - len(payload) % 4) |
| decoded = base64.urlsafe_b64decode(payload) |
| profile = json.loads(decoded) |
| except Exception: |
| pass |
|
|
| st.session_state["hf_profile"] = profile |
| st.session_state["hf_access_token"] = token["access_token"] |
|
|
| |
| st.query_params.clear() |
| st.rerun() |
| except Exception as e: |
| st.error(f"OAuth callback failed: {e}") |
|
|
| def get_user_id() -> str | None: |
| """Get the current user's ID from OAuth profile.""" |
| profile = st.session_state.get("hf_profile") or {} |
| |
| return profile.get("preferred_username") or profile.get("name") or profile.get("sub") |
|
|
| def safe_user_id(uid: str) -> str: |
| """Make user ID filesystem-safe.""" |
| return re.sub(r"[^a-zA-Z0-9_.-]", "_", uid)[:80] |
|
|
| def is_huggingface_space() -> bool: |
| """ |
| Detect if running on Hugging Face Spaces. |
| Checks for HF-specific environment variables. |
| """ |
| |
| return bool(os.environ.get("SPACE_ID") or os.environ.get("SPACE_HOST")) |
|
|
| def is_user_logged_in() -> bool: |
| """ |
| Check if user is logged in via OAuth. |
| """ |
| return get_user_id() is not None |
|
|
| def require_hf_login() -> tuple[bool, str]: |
| """ |
| Check if login is required and if user is logged in. |
| Returns (is_required, error_message) |
| - If on Hugging Face: login is required |
| - If not on Hugging Face: login is not required |
| """ |
| if is_huggingface_space(): |
| if not is_user_logged_in(): |
| return True, ( |
| "🔒 **Authentication Required**\n\n" |
| "This application requires you to be logged in to Hugging Face to use it. " |
| "Please log in using the button in the top right corner of this page." |
| ) |
| return False, "" |
|
|
| def _get_local_user_identifier_file() -> Path: |
| """Get the path to the file storing the local user identifier.""" |
| return Path(".session_cache") / ".local_user_id" |
|
|
| def get_user_identifier() -> str: |
| """ |
| Get a unique identifier for the current user. |
| - On Hugging Face Spaces with OAuth: uses OAuth username |
| - Locally: uses a persistent file-based identifier that survives refreshes |
| """ |
| |
| user_id = get_user_id() |
| if user_id: |
| return safe_user_id(user_id) |
| |
| |
| |
| identifier_file = _get_local_user_identifier_file() |
| |
| |
| if identifier_file.exists(): |
| try: |
| with open(identifier_file, 'r') as f: |
| stored_id = f.read().strip() |
| if stored_id: |
| st.session_state.user_identifier = stored_id |
| return stored_id |
| except Exception: |
| pass |
| |
| |
| |
| new_id = f"local_{secrets.token_hex(8)}" |
| |
| |
| try: |
| identifier_file.parent.mkdir(parents=True, exist_ok=True) |
| with open(identifier_file, 'w') as f: |
| f.write(new_id) |
| except Exception: |
| pass |
| |
| st.session_state.user_identifier = new_id |
| return new_id |
|
|
| def get_user_cache_dir() -> Path: |
| """ |
| Get a user-specific cache directory. |
| Uses /data if persistent storage is enabled, otherwise .session_cache |
| """ |
| |
| base_dir = Path("/data") if Path("/data").exists() else Path(".session_cache") |
| user_cache_base = base_dir / "user_cache" |
| |
| |
| user_id = get_user_id() |
| if user_id: |
| safe_uid = safe_user_id(user_id) |
| user_cache_dir = user_cache_base / safe_uid |
| user_cache_dir.mkdir(parents=True, exist_ok=True) |
| return user_cache_dir |
| |
| |
| cache_dir = base_dir / "session_cache" |
| cache_dir.mkdir(parents=True, exist_ok=True) |
| return cache_dir |
|
|
| |
| def get_cache_dir() -> Path: |
| """Get the cache directory for the current user.""" |
| return get_user_cache_dir() |
|
|
| def get_session_file_path(session_id: str = "default") -> Path: |
| """Get the file path for a session cache (user-specific).""" |
| return get_cache_dir() / f"session_{session_id}.pkl" |
|
|
| def generate_session_id() -> str: |
| """Generate a unique session ID based on timestamp.""" |
| return datetime.now().strftime("%Y%m%d_%H%M%S") |
|
|
| def get_current_session_id() -> str: |
| """Get the current active session ID, creating one if it doesn't exist.""" |
| if "current_session_id" not in st.session_state: |
| st.session_state.current_session_id = generate_session_id() |
| return st.session_state.current_session_id |
|
|
| def save_session_state(session_id: str = None, session_name: str = None) -> bool: |
| """ |
| Save current session state to disk. |
| If session_id is None, uses the current active session ID. |
| If session_name is None, uses the current session name or generates a default. |
| Returns True if successful, False otherwise. |
| """ |
| try: |
| |
| if session_id is None: |
| session_id = get_current_session_id() |
| |
| |
| if session_name is None: |
| session_name = st.session_state.get("current_session_name", f"Session {session_id}") |
| |
| |
| st.session_state.current_session_name = session_name |
| |
| |
| session_data = _prepare_session_data(session_id, session_name) |
| |
| |
| cache_file = get_session_file_path(session_id) |
| with open(cache_file, 'wb') as f: |
| pickle.dump(session_data, f) |
| |
| return True |
| except Exception as e: |
| st.error(f"Error saving session: {e}") |
| return False |
|
|
| def load_session_state(session_id: str = "default") -> bool: |
| """ |
| Load session state from disk. |
| Sets the loaded session as the current active session. |
| Returns True if successful, False otherwise. |
| """ |
| try: |
| cache_file = get_session_file_path(session_id) |
| if not cache_file.exists(): |
| return False |
| |
| with open(cache_file, 'rb') as f: |
| session_data = pickle.load(f) |
| |
| |
| st.session_state.current_session_id = session_id |
| st.session_state.current_session_name = session_data.get("_session_name", f"Session {session_id}") |
| |
| |
| for key, value in session_data.items(): |
| if not key.startswith("_"): |
| st.session_state[key] = value |
|
|
| |
| st.session_state.pop("prompt_config_placeholders_initialized", None) |
| |
| return True |
| except Exception as e: |
| st.error(f"Error loading session: {e}") |
| return False |
|
|
| def list_available_sessions() -> list[dict]: |
| """List all available saved sessions for the current user.""" |
| sessions = [] |
| cache_dir = get_cache_dir() |
| for cache_file in cache_dir.glob("session_*.pkl"): |
| try: |
| with open(cache_file, 'rb') as f: |
| session_data = pickle.load(f) |
| session_id = session_data.get("_session_id", cache_file.stem.replace("session_", "")) |
| session_name = session_data.get("_session_name", f"Session {session_id}") |
| timestamp = session_data.get("_timestamp", "Unknown") |
| sessions.append({ |
| "id": session_id, |
| "name": session_name, |
| "timestamp": timestamp, |
| "file": cache_file |
| }) |
| except Exception: |
| continue |
| return sorted(sessions, key=lambda x: x["timestamp"], reverse=True) |
|
|
| def delete_session(session_id: str = "default") -> bool: |
| """Delete a saved session (user-specific).""" |
| try: |
| cache_file = get_session_file_path(session_id) |
| if cache_file.exists(): |
| cache_file.unlink() |
| return True |
| return False |
| except Exception as e: |
| st.error(f"Error deleting session: {e}") |
| return False |
|
|
| def clear_all_sessions(): |
| """Clear all saved sessions for the current user.""" |
| cache_dir = get_cache_dir() |
| for cache_file in cache_dir.glob("session_*.pkl"): |
| try: |
| cache_file.unlink() |
| except Exception: |
| continue |
|
|
| def rename_session(session_id: str, new_name: str) -> bool: |
| """Rename a session by updating its metadata.""" |
| try: |
| cache_file = get_session_file_path(session_id) |
| if not cache_file.exists(): |
| return False |
| |
| |
| with open(cache_file, 'rb') as f: |
| session_data = pickle.load(f) |
| |
| |
| session_data["_session_name"] = new_name |
| |
| |
| with open(cache_file, 'wb') as f: |
| pickle.dump(session_data, f) |
| |
| |
| if st.session_state.get("current_session_id") == session_id: |
| st.session_state.current_session_name = new_name |
| |
| return True |
| except Exception as e: |
| st.error(f"Error renaming session: {e}") |
| return False |
|
|
| def _prepare_session_data(session_id: str, session_name: str) -> dict: |
| """Helper function to prepare session data for saving.""" |
| session_data = {} |
| |
| |
| if "df_questionnaire" in st.session_state: |
| session_data["df_questionnaire"] = st.session_state.df_questionnaire |
| if "df_population" in st.session_state: |
| session_data["df_population"] = st.session_state.df_population |
| |
| |
| if "questionnaires" in st.session_state: |
| session_data["questionnaires"] = st.session_state.questionnaires |
| |
| |
| if "client_config" in st.session_state: |
| session_data["client_config"] = st.session_state.client_config |
| if "inference_config" in st.session_state: |
| session_data["inference_config"] = st.session_state.inference_config |
| |
| |
| if "survey_options" in st.session_state: |
| session_data["survey_options"] = st.session_state.survey_options |
| |
| |
| important_keys = [ |
| "model_name", "temperature", "max_tokens", "top_p", "seed", |
| "api_key", "base_url", "organization", "project", |
| "advanced_client_params_str", "advanced_inference_params_str", |
| "timeout", "max_retries" |
| ] |
| for key in important_keys: |
| if key in st.session_state: |
| session_data[key] = st.session_state[key] |
| |
| |
| session_data["_timestamp"] = datetime.now().isoformat() |
| session_data["_session_id"] = session_id |
| session_data["_session_name"] = session_name |
| |
| return session_data |
|
|
| def save_session_state_to_path(session_id: str = None, session_name: str = None, save_path: Path = None) -> bool: |
| """ |
| Save current session state to a specific path. |
| If session_id is None, uses the current active session ID. |
| If session_name is None, uses the current session name or generates a default. |
| If save_path is None, uses the default cache directory. |
| Returns True if successful, False otherwise. |
| """ |
| try: |
| |
| if session_id is None: |
| session_id = get_current_session_id() |
| |
| |
| if session_name is None: |
| session_name = st.session_state.get("current_session_name", f"Session {session_id}") |
| |
| |
| st.session_state.current_session_name = session_name |
| |
| |
| session_data = _prepare_session_data(session_id, session_name) |
| |
| |
| if save_path is None: |
| save_path = get_session_file_path(session_id) |
| else: |
| |
| save_path.parent.mkdir(parents=True, exist_ok=True) |
| |
| if save_path.is_dir(): |
| save_path = save_path / f"session_{session_id}.pkl" |
| |
| elif not save_path.suffix == ".pkl": |
| save_path = save_path.with_suffix(".pkl") |
| |
| |
| with open(save_path, 'wb') as f: |
| pickle.dump(session_data, f) |
| |
| return True |
| except Exception as e: |
| st.error(f"Error saving session: {e}") |
| return False |
|
|
| def load_session_state_from_path(file_path: Path) -> bool: |
| """ |
| Load session state from a specific file path. |
| Sets the loaded session as the current active session. |
| Returns True if successful, False otherwise. |
| """ |
| try: |
| if not file_path.exists(): |
| return False |
| |
| with open(file_path, 'rb') as f: |
| session_data = pickle.load(f) |
| |
| |
| session_id = session_data.get("_session_id", generate_session_id()) |
| |
| |
| st.session_state.current_session_id = session_id |
| st.session_state.current_session_name = session_data.get("_session_name", f"Session {session_id}") |
| |
| |
| for key, value in session_data.items(): |
| if not key.startswith("_"): |
| st.session_state[key] = value |
| |
| |
| default_path = get_session_file_path(session_id) |
| if file_path != default_path: |
| try: |
| shutil.copy2(file_path, default_path) |
| except Exception: |
| pass |
| |
| return True |
| except Exception as e: |
| st.error(f"Error loading session: {e}") |
| return False |
|
|
|
|