File size: 19,223 Bytes
bd33bd7
 
 
 
 
cf2a978
 
 
bd33bd7
 
 
cf2a978
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8e17e8f
 
 
 
 
 
 
 
 
 
cf2a978
8e17e8f
 
cf2a978
8e17e8f
 
 
 
 
 
 
 
 
 
 
 
 
157e0f8
 
8e17e8f
 
 
cf2a978
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bd33bd7
 
 
 
 
 
 
 
 
 
cf2a978
bd33bd7
cf2a978
bd33bd7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cf2a978
bd33bd7
 
cf2a978
 
 
 
bd33bd7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cf2a978
bd33bd7
cf2a978
 
 
 
 
 
 
 
 
bd33bd7
 
 
cf2a978
 
bd33bd7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
import streamlit as st
import pickle
import os
import shutil
import secrets
import urllib.parse
import requests
import re
from pathlib import Path
from datetime import datetime

# OAuth helper functions
def get_space_url() -> str:
    """Get the base URL for the Space."""
    space_base_url = os.getenv("SPACE_BASE_URL")
    if space_base_url:
        return space_base_url.rstrip("/")
    # Default to hf.space URL
    space_host = os.getenv("SPACE_HOST", "")
    if space_host:
        return f"https://{space_host}".rstrip("/")
    return ""

def oidc_config():
    """Get OpenID Connect configuration."""
    provider = os.environ.get("OPENID_PROVIDER_URL", "https://huggingface.co").rstrip("/")
    try:
        return requests.get(f"{provider}/.well-known/openid-configuration", timeout=10).json()
    except Exception as e:
        st.error(f"Failed to fetch OIDC config: {e}")
        return {}

def login_button():
    """Display a login button that redirects to HF OAuth."""
    if not os.environ.get("OAUTH_CLIENT_ID"):
        st.error("OAuth is not configured. Please enable hf_oauth: true in README.md")
        return
    
    try:
        cfg = oidc_config()
        if not cfg:
            st.error("Could not fetch OpenID configuration. Please try again later.")
            return
            
        auth_endpoint = cfg.get("authorization_endpoint")
        if not auth_endpoint:
            st.error("Could not get authorization endpoint from OpenID configuration.")
            return

        state = secrets.token_urlsafe(24)
        st.session_state["oauth_state"] = state

        redirect_uri = get_space_url() + "/"
        if not redirect_uri:
            st.warning("Could not determine Space URL. OAuth may not work correctly.")
            redirect_uri = "https://" + os.getenv("SPACE_HOST", "") + "/"
        
        params = {
            "client_id": os.environ["OAUTH_CLIENT_ID"],
            "redirect_uri": redirect_uri,
            "response_type": "code",
            "scope": os.getenv("OAUTH_SCOPES", "openid profile"),
            "state": state,
        }
        url = auth_endpoint + "?" + urllib.parse.urlencode(params)
        # Make button larger and more prominent
        st.link_button("🔐 Sign in with Hugging Face", url, use_container_width=True, type="primary")
    except Exception as e:
        st.error(f"Error setting up OAuth login: {e}")
        st.info("Please check the Space logs for more details.")

def handle_oauth_callback():
    """Handle OAuth callback and exchange code for token."""
    qp = st.query_params
    if "code" not in qp:
        return

    # CSRF protection
    returned_state = qp.get("state")
    expected_state = st.session_state.get("oauth_state")
    if expected_state and returned_state != expected_state:
        st.error("OAuth state mismatch. Please try signing in again.")
        return

    cfg = oidc_config()
    token_endpoint = cfg.get("token_endpoint")
    userinfo_endpoint = cfg.get("userinfo_endpoint")

    if not token_endpoint:
        st.error("Could not get token endpoint")
        return

    redirect_uri = get_space_url() + "/"

    data = {
        "grant_type": "authorization_code",
        "code": qp["code"],
        "redirect_uri": redirect_uri,
        "client_id": os.environ["OAUTH_CLIENT_ID"],
        "client_secret": os.environ["OAUTH_CLIENT_SECRET"],
    }
    
    try:
        token_response = requests.post(token_endpoint, data=data, timeout=10)
        token_response.raise_for_status()
        token = token_response.json()
        
        if "access_token" not in token:
            st.error(f"Token exchange failed: {token}")
            return

        # Fetch user profile
        profile = {}
        if userinfo_endpoint:
            profile_response = requests.get(
                userinfo_endpoint,
                headers={"Authorization": f"Bearer {token['access_token']}"},
                timeout=10,
            )
            profile_response.raise_for_status()
            profile = profile_response.json()
        elif "id_token" in token:
            # Decode ID token if userinfo not available
            import base64
            import json
            try:
                # JWT is base64url encoded (3 parts separated by .)
                id_token_parts = token["id_token"].split(".")
                if len(id_token_parts) >= 2:
                    # Decode payload (second part)
                    payload = id_token_parts[1]
                    # Add padding if needed
                    payload += "=" * (4 - len(payload) % 4)
                    decoded = base64.urlsafe_b64decode(payload)
                    profile = json.loads(decoded)
            except Exception:
                pass

        st.session_state["hf_profile"] = profile
        st.session_state["hf_access_token"] = token["access_token"]

        # Clean the URL (remove code/state)
        st.query_params.clear()
        st.rerun()
    except Exception as e:
        st.error(f"OAuth callback failed: {e}")

def get_user_id() -> str | None:
    """Get the current user's ID from OAuth profile."""
    profile = st.session_state.get("hf_profile") or {}
    # Common OpenID fields: preferred_username, name, sub
    return profile.get("preferred_username") or profile.get("name") or profile.get("sub")

def safe_user_id(uid: str) -> str:
    """Make user ID filesystem-safe."""
    return re.sub(r"[^a-zA-Z0-9_.-]", "_", uid)[:80]

def is_huggingface_space() -> bool:
    """
    Detect if running on Hugging Face Spaces.
    Checks for HF-specific environment variables.
    """
    # Hugging Face Spaces set these environment variables
    return bool(os.environ.get("SPACE_ID") or os.environ.get("SPACE_HOST"))

def is_user_logged_in() -> bool:
    """
    Check if user is logged in via OAuth.
    """
    return get_user_id() is not None

def require_hf_login() -> tuple[bool, str]:
    """
    Check if login is required and if user is logged in.
    Returns (is_required, error_message)
    - If on Hugging Face: login is required
    - If not on Hugging Face: login is not required
    """
    if is_huggingface_space():
        if not is_user_logged_in():
            return True, (
                "🔒 **Authentication Required**\n\n"
                "This application requires you to be logged in to Hugging Face to use it. "
                "Please log in using the button in the top right corner of this page."
            )
    return False, ""

def _get_local_user_identifier_file() -> Path:
    """Get the path to the file storing the local user identifier."""
    return Path(".session_cache") / ".local_user_id"

def get_user_identifier() -> str:
    """
    Get a unique identifier for the current user.
    - On Hugging Face Spaces with OAuth: uses OAuth username
    - Locally: uses a persistent file-based identifier that survives refreshes
    """
    # Check for OAuth user first
    user_id = get_user_id()
    if user_id:
        return safe_user_id(user_id)
    
    # For local development: use a persistent file-based identifier
    # This persists across page refreshes and app restarts
    identifier_file = _get_local_user_identifier_file()
    
    # Try to load from file first
    if identifier_file.exists():
        try:
            with open(identifier_file, 'r') as f:
                stored_id = f.read().strip()
            if stored_id:
                st.session_state.user_identifier = stored_id
                return stored_id
        except Exception:
            pass
    
    # File doesn't exist or couldn't be read - create new identifier
    # Generate a stable identifier for this local session
    new_id = f"local_{secrets.token_hex(8)}"
    
    # Save to file for persistence
    try:
        identifier_file.parent.mkdir(parents=True, exist_ok=True)
        with open(identifier_file, 'w') as f:
            f.write(new_id)
    except Exception:
        pass
    
    st.session_state.user_identifier = new_id
    return new_id

def get_user_cache_dir() -> Path:
    """
    Get a user-specific cache directory.
    Uses /data if persistent storage is enabled, otherwise .session_cache
    """
    # Use persistent storage if available, otherwise use .session_cache
    base_dir = Path("/data") if Path("/data").exists() else Path(".session_cache")
    user_cache_base = base_dir / "user_cache"
    
    # Get user ID
    user_id = get_user_id()
    if user_id:
        safe_uid = safe_user_id(user_id)
        user_cache_dir = user_cache_base / safe_uid
        user_cache_dir.mkdir(parents=True, exist_ok=True)
        return user_cache_dir
    
    # Locally: use shared cache directory
    cache_dir = base_dir / "session_cache"
    cache_dir.mkdir(parents=True, exist_ok=True)
    return cache_dir

# Directory to store session cache files (now user-specific)
def get_cache_dir() -> Path:
    """Get the cache directory for the current user."""
    return get_user_cache_dir()

def get_session_file_path(session_id: str = "default") -> Path:
    """Get the file path for a session cache (user-specific)."""
    return get_cache_dir() / f"session_{session_id}.pkl"

def generate_session_id() -> str:
    """Generate a unique session ID based on timestamp."""
    return datetime.now().strftime("%Y%m%d_%H%M%S")

def get_current_session_id() -> str:
    """Get the current active session ID, creating one if it doesn't exist."""
    if "current_session_id" not in st.session_state:
        st.session_state.current_session_id = generate_session_id()
    return st.session_state.current_session_id

def save_session_state(session_id: str = None, session_name: str = None) -> bool:
    """
    Save current session state to disk.
    If session_id is None, uses the current active session ID.
    If session_name is None, uses the current session name or generates a default.
    Returns True if successful, False otherwise.
    """
    try:
        # Use provided session_id or get current active one
        if session_id is None:
            session_id = get_current_session_id()
        
        # Use provided session_name or get current one
        if session_name is None:
            session_name = st.session_state.get("current_session_name", f"Session {session_id}")
        
        # Store the session name in session state
        st.session_state.current_session_name = session_name
        
        # Prepare session data
        session_data = _prepare_session_data(session_id, session_name)
        
        # Save using pickle (for complex objects like LLMPrompt)
        cache_file = get_session_file_path(session_id)
        with open(cache_file, 'wb') as f:
            pickle.dump(session_data, f)
        
        return True
    except Exception as e:
        st.error(f"Error saving session: {e}")
        return False

def load_session_state(session_id: str = "default") -> bool:
    """
    Load session state from disk.
    Sets the loaded session as the current active session.
    Returns True if successful, False otherwise.
    """
    try:
        cache_file = get_session_file_path(session_id)
        if not cache_file.exists():
            return False
        
        with open(cache_file, 'rb') as f:
            session_data = pickle.load(f)
        
        # Set this as the current active session
        st.session_state.current_session_id = session_id
        st.session_state.current_session_name = session_data.get("_session_name", f"Session {session_id}")
        
        # Restore session state (skip internal metadata)
        for key, value in session_data.items():
            if not key.startswith("_"):
                st.session_state[key] = value
        
        return True
    except Exception as e:
        st.error(f"Error loading session: {e}")
        return False

def list_available_sessions() -> list[dict]:
    """List all available saved sessions for the current user."""
    sessions = []
    cache_dir = get_cache_dir()  # Now user-specific
    for cache_file in cache_dir.glob("session_*.pkl"):
        try:
            with open(cache_file, 'rb') as f:
                session_data = pickle.load(f)
                session_id = session_data.get("_session_id", cache_file.stem.replace("session_", ""))
                session_name = session_data.get("_session_name", f"Session {session_id}")
                timestamp = session_data.get("_timestamp", "Unknown")
                sessions.append({
                    "id": session_id,
                    "name": session_name,
                    "timestamp": timestamp,
                    "file": cache_file
                })
        except Exception:
            continue
    return sorted(sessions, key=lambda x: x["timestamp"], reverse=True)

def delete_session(session_id: str = "default") -> bool:
    """Delete a saved session (user-specific)."""
    try:
        cache_file = get_session_file_path(session_id)  # Already user-specific
        if cache_file.exists():
            cache_file.unlink()
            return True
        return False
    except Exception as e:
        st.error(f"Error deleting session: {e}")
        return False

def clear_all_sessions():
    """Clear all saved sessions for the current user."""
    cache_dir = get_cache_dir()
    for cache_file in cache_dir.glob("session_*.pkl"):
        try:
            cache_file.unlink()
        except Exception:
            continue

def rename_session(session_id: str, new_name: str) -> bool:
    """Rename a session by updating its metadata."""
    try:
        cache_file = get_session_file_path(session_id)
        if not cache_file.exists():
            return False
        
        # Load existing session
        with open(cache_file, 'rb') as f:
            session_data = pickle.load(f)
        
        # Update name
        session_data["_session_name"] = new_name
        
        # Save back
        with open(cache_file, 'wb') as f:
            pickle.dump(session_data, f)
        
        # Update current session name if this is the active session
        if st.session_state.get("current_session_id") == session_id:
            st.session_state.current_session_name = new_name
        
        return True
    except Exception as e:
        st.error(f"Error renaming session: {e}")
        return False

def _prepare_session_data(session_id: str, session_name: str) -> dict:
    """Helper function to prepare session data for saving."""
    session_data = {}
    
    # Save dataframes
    if "df_questionnaire" in st.session_state:
        session_data["df_questionnaire"] = st.session_state.df_questionnaire
    if "df_population" in st.session_state:
        session_data["df_population"] = st.session_state.df_population
    
    # Save questionnaires (LLMPrompt objects - need pickle)
    if "questionnaires" in st.session_state:
        session_data["questionnaires"] = st.session_state.questionnaires
    
    # Save inference configs
    if "client_config" in st.session_state:
        session_data["client_config"] = st.session_state.client_config
    if "inference_config" in st.session_state:
        session_data["inference_config"] = st.session_state.inference_config
    
    # Save survey options
    if "survey_options" in st.session_state:
        session_data["survey_options"] = st.session_state.survey_options
    
    # Save other important state
    important_keys = [
        "model_name", "temperature", "max_tokens", "top_p", "seed",
        "api_key", "base_url", "organization", "project",
        "advanced_client_params_str", "advanced_inference_params_str",
        "timeout", "max_retries"
    ]
    for key in important_keys:
        if key in st.session_state:
            session_data[key] = st.session_state[key]
    
    # Save timestamp
    session_data["_timestamp"] = datetime.now().isoformat()
    session_data["_session_id"] = session_id
    session_data["_session_name"] = session_name
    
    return session_data

def save_session_state_to_path(session_id: str = None, session_name: str = None, save_path: Path = None) -> bool:
    """
    Save current session state to a specific path.
    If session_id is None, uses the current active session ID.
    If session_name is None, uses the current session name or generates a default.
    If save_path is None, uses the default cache directory.
    Returns True if successful, False otherwise.
    """
    try:
        # Use provided session_id or get current active one
        if session_id is None:
            session_id = get_current_session_id()
        
        # Use provided session_name or get current one
        if session_name is None:
            session_name = st.session_state.get("current_session_name", f"Session {session_id}")
        
        # Store the session name in session state
        st.session_state.current_session_name = session_name
        
        # Prepare session data
        session_data = _prepare_session_data(session_id, session_name)
        
        # Determine save path
        if save_path is None:
            save_path = get_session_file_path(session_id)
        else:
            # Ensure the directory exists
            save_path.parent.mkdir(parents=True, exist_ok=True)
            # If it's a directory, append the filename
            if save_path.is_dir():
                save_path = save_path / f"session_{session_id}.pkl"
            # Ensure it has .pkl extension
            elif not save_path.suffix == ".pkl":
                save_path = save_path.with_suffix(".pkl")
        
        # Save using pickle
        with open(save_path, 'wb') as f:
            pickle.dump(session_data, f)
        
        return True
    except Exception as e:
        st.error(f"Error saving session: {e}")
        return False

def load_session_state_from_path(file_path: Path) -> bool:
    """
    Load session state from a specific file path.
    Sets the loaded session as the current active session.
    Returns True if successful, False otherwise.
    """
    try:
        if not file_path.exists():
            return False
        
        with open(file_path, 'rb') as f:
            session_data = pickle.load(f)
        
        # Get session ID from metadata or generate one
        session_id = session_data.get("_session_id", generate_session_id())
        
        # Set this as the current active session
        st.session_state.current_session_id = session_id
        st.session_state.current_session_name = session_data.get("_session_name", f"Session {session_id}")
        
        # Restore session state (skip internal metadata)
        for key, value in session_data.items():
            if not key.startswith("_"):
                st.session_state[key] = value
        
        # Optionally copy to default cache directory for easy access
        default_path = get_session_file_path(session_id)
        if file_path != default_path:
            try:
                shutil.copy2(file_path, default_path)
            except Exception:
                pass  # If copy fails, that's okay - session is still loaded
        
        return True
    except Exception as e:
        st.error(f"Error loading session: {e}")
        return False