Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import pickle | |
| import os | |
| import shutil | |
| import secrets | |
| import urllib.parse | |
| import requests | |
| import re | |
| from pathlib import Path | |
| from datetime import datetime | |
| # OAuth helper functions | |
| def get_space_url() -> str: | |
| """Get the base URL for the Space.""" | |
| space_base_url = os.getenv("SPACE_BASE_URL") | |
| if space_base_url: | |
| return space_base_url.rstrip("/") | |
| # Default to hf.space URL | |
| space_host = os.getenv("SPACE_HOST", "") | |
| if space_host: | |
| return f"https://{space_host}".rstrip("/") | |
| return "" | |
| def oidc_config(): | |
| """Get OpenID Connect configuration.""" | |
| provider = os.environ.get("OPENID_PROVIDER_URL", "https://huggingface.co").rstrip("/") | |
| try: | |
| return requests.get(f"{provider}/.well-known/openid-configuration", timeout=10).json() | |
| except Exception as e: | |
| st.error(f"Failed to fetch OIDC config: {e}") | |
| return {} | |
| def login_button(): | |
| """Display a login button that redirects to HF OAuth.""" | |
| if not os.environ.get("OAUTH_CLIENT_ID"): | |
| st.error("OAuth is not configured. Please enable hf_oauth: true in README.md") | |
| return | |
| try: | |
| cfg = oidc_config() | |
| if not cfg: | |
| st.error("Could not fetch OpenID configuration. Please try again later.") | |
| return | |
| auth_endpoint = cfg.get("authorization_endpoint") | |
| if not auth_endpoint: | |
| st.error("Could not get authorization endpoint from OpenID configuration.") | |
| return | |
| state = secrets.token_urlsafe(24) | |
| st.session_state["oauth_state"] = state | |
| redirect_uri = get_space_url() + "/" | |
| if not redirect_uri: | |
| st.warning("Could not determine Space URL. OAuth may not work correctly.") | |
| redirect_uri = "https://" + os.getenv("SPACE_HOST", "") + "/" | |
| params = { | |
| "client_id": os.environ["OAUTH_CLIENT_ID"], | |
| "redirect_uri": redirect_uri, | |
| "response_type": "code", | |
| "scope": os.getenv("OAUTH_SCOPES", "openid profile"), | |
| "state": state, | |
| } | |
| url = auth_endpoint + "?" + urllib.parse.urlencode(params) | |
| # Make button larger and more prominent | |
| st.link_button("🔐 Sign in with Hugging Face", url, use_container_width=True, type="primary") | |
| except Exception as e: | |
| st.error(f"Error setting up OAuth login: {e}") | |
| st.info("Please check the Space logs for more details.") | |
| def handle_oauth_callback(): | |
| """Handle OAuth callback and exchange code for token.""" | |
| qp = st.query_params | |
| if "code" not in qp: | |
| return | |
| # CSRF protection | |
| returned_state = qp.get("state") | |
| expected_state = st.session_state.get("oauth_state") | |
| if expected_state and returned_state != expected_state: | |
| st.error("OAuth state mismatch. Please try signing in again.") | |
| return | |
| cfg = oidc_config() | |
| token_endpoint = cfg.get("token_endpoint") | |
| userinfo_endpoint = cfg.get("userinfo_endpoint") | |
| if not token_endpoint: | |
| st.error("Could not get token endpoint") | |
| return | |
| redirect_uri = get_space_url() + "/" | |
| data = { | |
| "grant_type": "authorization_code", | |
| "code": qp["code"], | |
| "redirect_uri": redirect_uri, | |
| "client_id": os.environ["OAUTH_CLIENT_ID"], | |
| "client_secret": os.environ["OAUTH_CLIENT_SECRET"], | |
| } | |
| try: | |
| token_response = requests.post(token_endpoint, data=data, timeout=10) | |
| token_response.raise_for_status() | |
| token = token_response.json() | |
| if "access_token" not in token: | |
| st.error(f"Token exchange failed: {token}") | |
| return | |
| # Fetch user profile | |
| profile = {} | |
| if userinfo_endpoint: | |
| profile_response = requests.get( | |
| userinfo_endpoint, | |
| headers={"Authorization": f"Bearer {token['access_token']}"}, | |
| timeout=10, | |
| ) | |
| profile_response.raise_for_status() | |
| profile = profile_response.json() | |
| elif "id_token" in token: | |
| # Decode ID token if userinfo not available | |
| import base64 | |
| import json | |
| try: | |
| # JWT is base64url encoded (3 parts separated by .) | |
| id_token_parts = token["id_token"].split(".") | |
| if len(id_token_parts) >= 2: | |
| # Decode payload (second part) | |
| payload = id_token_parts[1] | |
| # Add padding if needed | |
| payload += "=" * (4 - len(payload) % 4) | |
| decoded = base64.urlsafe_b64decode(payload) | |
| profile = json.loads(decoded) | |
| except Exception: | |
| pass | |
| st.session_state["hf_profile"] = profile | |
| st.session_state["hf_access_token"] = token["access_token"] | |
| # Clean the URL (remove code/state) | |
| st.query_params.clear() | |
| st.rerun() | |
| except Exception as e: | |
| st.error(f"OAuth callback failed: {e}") | |
| def get_user_id() -> str | None: | |
| """Get the current user's ID from OAuth profile.""" | |
| profile = st.session_state.get("hf_profile") or {} | |
| # Common OpenID fields: preferred_username, name, sub | |
| return profile.get("preferred_username") or profile.get("name") or profile.get("sub") | |
| def safe_user_id(uid: str) -> str: | |
| """Make user ID filesystem-safe.""" | |
| return re.sub(r"[^a-zA-Z0-9_.-]", "_", uid)[:80] | |
| def is_huggingface_space() -> bool: | |
| """ | |
| Detect if running on Hugging Face Spaces. | |
| Checks for HF-specific environment variables. | |
| """ | |
| # Hugging Face Spaces set these environment variables | |
| return bool(os.environ.get("SPACE_ID") or os.environ.get("SPACE_HOST")) | |
| def is_user_logged_in() -> bool: | |
| """ | |
| Check if user is logged in via OAuth. | |
| """ | |
| return get_user_id() is not None | |
| def require_hf_login() -> tuple[bool, str]: | |
| """ | |
| Check if login is required and if user is logged in. | |
| Returns (is_required, error_message) | |
| - If on Hugging Face: login is required | |
| - If not on Hugging Face: login is not required | |
| """ | |
| if is_huggingface_space(): | |
| if not is_user_logged_in(): | |
| return True, ( | |
| "🔒 **Authentication Required**\n\n" | |
| "This application requires you to be logged in to Hugging Face to use it. " | |
| "Please log in using the button in the top right corner of this page." | |
| ) | |
| return False, "" | |
| def _get_local_user_identifier_file() -> Path: | |
| """Get the path to the file storing the local user identifier.""" | |
| return Path(".session_cache") / ".local_user_id" | |
| def get_user_identifier() -> str: | |
| """ | |
| Get a unique identifier for the current user. | |
| - On Hugging Face Spaces with OAuth: uses OAuth username | |
| - Locally: uses a persistent file-based identifier that survives refreshes | |
| """ | |
| # Check for OAuth user first | |
| user_id = get_user_id() | |
| if user_id: | |
| return safe_user_id(user_id) | |
| # For local development: use a persistent file-based identifier | |
| # This persists across page refreshes and app restarts | |
| identifier_file = _get_local_user_identifier_file() | |
| # Try to load from file first | |
| if identifier_file.exists(): | |
| try: | |
| with open(identifier_file, 'r') as f: | |
| stored_id = f.read().strip() | |
| if stored_id: | |
| st.session_state.user_identifier = stored_id | |
| return stored_id | |
| except Exception: | |
| pass | |
| # File doesn't exist or couldn't be read - create new identifier | |
| # Generate a stable identifier for this local session | |
| new_id = f"local_{secrets.token_hex(8)}" | |
| # Save to file for persistence | |
| try: | |
| identifier_file.parent.mkdir(parents=True, exist_ok=True) | |
| with open(identifier_file, 'w') as f: | |
| f.write(new_id) | |
| except Exception: | |
| pass | |
| st.session_state.user_identifier = new_id | |
| return new_id | |
| def get_user_cache_dir() -> Path: | |
| """ | |
| Get a user-specific cache directory. | |
| Uses /data if persistent storage is enabled, otherwise .session_cache | |
| """ | |
| # Use persistent storage if available, otherwise use .session_cache | |
| base_dir = Path("/data") if Path("/data").exists() else Path(".session_cache") | |
| user_cache_base = base_dir / "user_cache" | |
| # Get user ID | |
| user_id = get_user_id() | |
| if user_id: | |
| safe_uid = safe_user_id(user_id) | |
| user_cache_dir = user_cache_base / safe_uid | |
| user_cache_dir.mkdir(parents=True, exist_ok=True) | |
| return user_cache_dir | |
| # Locally: use shared cache directory | |
| cache_dir = base_dir / "session_cache" | |
| cache_dir.mkdir(parents=True, exist_ok=True) | |
| return cache_dir | |
| # Directory to store session cache files (now user-specific) | |
| def get_cache_dir() -> Path: | |
| """Get the cache directory for the current user.""" | |
| return get_user_cache_dir() | |
| def get_session_file_path(session_id: str = "default") -> Path: | |
| """Get the file path for a session cache (user-specific).""" | |
| return get_cache_dir() / f"session_{session_id}.pkl" | |
| def generate_session_id() -> str: | |
| """Generate a unique session ID based on timestamp.""" | |
| return datetime.now().strftime("%Y%m%d_%H%M%S") | |
| def get_current_session_id() -> str: | |
| """Get the current active session ID, creating one if it doesn't exist.""" | |
| if "current_session_id" not in st.session_state: | |
| st.session_state.current_session_id = generate_session_id() | |
| return st.session_state.current_session_id | |
| def save_session_state(session_id: str = None, session_name: str = None) -> bool: | |
| """ | |
| Save current session state to disk. | |
| If session_id is None, uses the current active session ID. | |
| If session_name is None, uses the current session name or generates a default. | |
| Returns True if successful, False otherwise. | |
| """ | |
| try: | |
| # Use provided session_id or get current active one | |
| if session_id is None: | |
| session_id = get_current_session_id() | |
| # Use provided session_name or get current one | |
| if session_name is None: | |
| session_name = st.session_state.get("current_session_name", f"Session {session_id}") | |
| # Store the session name in session state | |
| st.session_state.current_session_name = session_name | |
| # Prepare session data | |
| session_data = _prepare_session_data(session_id, session_name) | |
| # Save using pickle (for complex objects like LLMPrompt) | |
| cache_file = get_session_file_path(session_id) | |
| with open(cache_file, 'wb') as f: | |
| pickle.dump(session_data, f) | |
| return True | |
| except Exception as e: | |
| st.error(f"Error saving session: {e}") | |
| return False | |
| def load_session_state(session_id: str = "default") -> bool: | |
| """ | |
| Load session state from disk. | |
| Sets the loaded session as the current active session. | |
| Returns True if successful, False otherwise. | |
| """ | |
| try: | |
| cache_file = get_session_file_path(session_id) | |
| if not cache_file.exists(): | |
| return False | |
| with open(cache_file, 'rb') as f: | |
| session_data = pickle.load(f) | |
| # Set this as the current active session | |
| st.session_state.current_session_id = session_id | |
| st.session_state.current_session_name = session_data.get("_session_name", f"Session {session_id}") | |
| # Restore session state (skip internal metadata) | |
| for key, value in session_data.items(): | |
| if not key.startswith("_"): | |
| st.session_state[key] = value | |
| return True | |
| except Exception as e: | |
| st.error(f"Error loading session: {e}") | |
| return False | |
| def list_available_sessions() -> list[dict]: | |
| """List all available saved sessions for the current user.""" | |
| sessions = [] | |
| cache_dir = get_cache_dir() # Now user-specific | |
| for cache_file in cache_dir.glob("session_*.pkl"): | |
| try: | |
| with open(cache_file, 'rb') as f: | |
| session_data = pickle.load(f) | |
| session_id = session_data.get("_session_id", cache_file.stem.replace("session_", "")) | |
| session_name = session_data.get("_session_name", f"Session {session_id}") | |
| timestamp = session_data.get("_timestamp", "Unknown") | |
| sessions.append({ | |
| "id": session_id, | |
| "name": session_name, | |
| "timestamp": timestamp, | |
| "file": cache_file | |
| }) | |
| except Exception: | |
| continue | |
| return sorted(sessions, key=lambda x: x["timestamp"], reverse=True) | |
| def delete_session(session_id: str = "default") -> bool: | |
| """Delete a saved session (user-specific).""" | |
| try: | |
| cache_file = get_session_file_path(session_id) # Already user-specific | |
| if cache_file.exists(): | |
| cache_file.unlink() | |
| return True | |
| return False | |
| except Exception as e: | |
| st.error(f"Error deleting session: {e}") | |
| return False | |
| def clear_all_sessions(): | |
| """Clear all saved sessions for the current user.""" | |
| cache_dir = get_cache_dir() | |
| for cache_file in cache_dir.glob("session_*.pkl"): | |
| try: | |
| cache_file.unlink() | |
| except Exception: | |
| continue | |
| def rename_session(session_id: str, new_name: str) -> bool: | |
| """Rename a session by updating its metadata.""" | |
| try: | |
| cache_file = get_session_file_path(session_id) | |
| if not cache_file.exists(): | |
| return False | |
| # Load existing session | |
| with open(cache_file, 'rb') as f: | |
| session_data = pickle.load(f) | |
| # Update name | |
| session_data["_session_name"] = new_name | |
| # Save back | |
| with open(cache_file, 'wb') as f: | |
| pickle.dump(session_data, f) | |
| # Update current session name if this is the active session | |
| if st.session_state.get("current_session_id") == session_id: | |
| st.session_state.current_session_name = new_name | |
| return True | |
| except Exception as e: | |
| st.error(f"Error renaming session: {e}") | |
| return False | |
| def _prepare_session_data(session_id: str, session_name: str) -> dict: | |
| """Helper function to prepare session data for saving.""" | |
| session_data = {} | |
| # Save dataframes | |
| if "df_questionnaire" in st.session_state: | |
| session_data["df_questionnaire"] = st.session_state.df_questionnaire | |
| if "df_population" in st.session_state: | |
| session_data["df_population"] = st.session_state.df_population | |
| # Save questionnaires (LLMPrompt objects - need pickle) | |
| if "questionnaires" in st.session_state: | |
| session_data["questionnaires"] = st.session_state.questionnaires | |
| # Save inference configs | |
| if "client_config" in st.session_state: | |
| session_data["client_config"] = st.session_state.client_config | |
| if "inference_config" in st.session_state: | |
| session_data["inference_config"] = st.session_state.inference_config | |
| # Save survey options | |
| if "survey_options" in st.session_state: | |
| session_data["survey_options"] = st.session_state.survey_options | |
| # Save other important state | |
| important_keys = [ | |
| "model_name", "temperature", "max_tokens", "top_p", "seed", | |
| "api_key", "base_url", "organization", "project", | |
| "advanced_client_params_str", "advanced_inference_params_str", | |
| "timeout", "max_retries" | |
| ] | |
| for key in important_keys: | |
| if key in st.session_state: | |
| session_data[key] = st.session_state[key] | |
| # Save timestamp | |
| session_data["_timestamp"] = datetime.now().isoformat() | |
| session_data["_session_id"] = session_id | |
| session_data["_session_name"] = session_name | |
| return session_data | |
| def save_session_state_to_path(session_id: str = None, session_name: str = None, save_path: Path = None) -> bool: | |
| """ | |
| Save current session state to a specific path. | |
| If session_id is None, uses the current active session ID. | |
| If session_name is None, uses the current session name or generates a default. | |
| If save_path is None, uses the default cache directory. | |
| Returns True if successful, False otherwise. | |
| """ | |
| try: | |
| # Use provided session_id or get current active one | |
| if session_id is None: | |
| session_id = get_current_session_id() | |
| # Use provided session_name or get current one | |
| if session_name is None: | |
| session_name = st.session_state.get("current_session_name", f"Session {session_id}") | |
| # Store the session name in session state | |
| st.session_state.current_session_name = session_name | |
| # Prepare session data | |
| session_data = _prepare_session_data(session_id, session_name) | |
| # Determine save path | |
| if save_path is None: | |
| save_path = get_session_file_path(session_id) | |
| else: | |
| # Ensure the directory exists | |
| save_path.parent.mkdir(parents=True, exist_ok=True) | |
| # If it's a directory, append the filename | |
| if save_path.is_dir(): | |
| save_path = save_path / f"session_{session_id}.pkl" | |
| # Ensure it has .pkl extension | |
| elif not save_path.suffix == ".pkl": | |
| save_path = save_path.with_suffix(".pkl") | |
| # Save using pickle | |
| with open(save_path, 'wb') as f: | |
| pickle.dump(session_data, f) | |
| return True | |
| except Exception as e: | |
| st.error(f"Error saving session: {e}") | |
| return False | |
| def load_session_state_from_path(file_path: Path) -> bool: | |
| """ | |
| Load session state from a specific file path. | |
| Sets the loaded session as the current active session. | |
| Returns True if successful, False otherwise. | |
| """ | |
| try: | |
| if not file_path.exists(): | |
| return False | |
| with open(file_path, 'rb') as f: | |
| session_data = pickle.load(f) | |
| # Get session ID from metadata or generate one | |
| session_id = session_data.get("_session_id", generate_session_id()) | |
| # Set this as the current active session | |
| st.session_state.current_session_id = session_id | |
| st.session_state.current_session_name = session_data.get("_session_name", f"Session {session_id}") | |
| # Restore session state (skip internal metadata) | |
| for key, value in session_data.items(): | |
| if not key.startswith("_"): | |
| st.session_state[key] = value | |
| # Optionally copy to default cache directory for easy access | |
| default_path = get_session_file_path(session_id) | |
| if file_path != default_path: | |
| try: | |
| shutil.copy2(file_path, default_path) | |
| except Exception: | |
| pass # If copy fails, that's okay - session is still loaded | |
| return True | |
| except Exception as e: | |
| st.error(f"Error loading session: {e}") | |
| return False | |