yujiepan commited on
Commit
1581622
Β·
1 Parent(s): 65c2b88

Implement per-user session-based authentication

Browse files
Files changed (1) hide show
  1. app.py +63 -52
app.py CHANGED
@@ -10,58 +10,55 @@ import sys
10
 
11
  import gradio as gr
12
  import torch
13
- from huggingface_hub import hf_hub_download, scan_cache_dir, login, whoami
14
  from safetensors import safe_open
15
 
16
- # Auto-login with HF_TOKEN if available (for HuggingFace Spaces)
17
- HF_TOKEN = os.environ.get("HF_TOKEN")
18
- if HF_TOKEN:
19
- try:
20
- login(token=HF_TOKEN, add_to_git_credential=False)
21
- print("βœ… Automatically logged in using HF_TOKEN")
22
- except Exception as e:
23
- print(f"⚠️ Auto-login failed: {str(e)}")
24
 
25
 
26
- def hf_login(token: str):
27
- """Login to Hugging Face with provided token."""
28
  if not token:
29
- return "❌ Please provide a token", "Not logged in"
30
 
31
  try:
32
- login(token=token, add_to_git_credential=False)
33
- user_info = whoami()
34
  username = user_info.get('name', 'Unknown')
35
- return f"βœ… Successfully logged in as: {username}", f"βœ… Logged in as {username}"
36
  except Exception as e:
37
- return f"❌ Login failed: {str(e)}", "❌ Not logged in"
38
 
39
 
40
- def hf_logout():
41
- """Logout from Hugging Face."""
42
- try:
43
- # Clear token by logging in with empty token
44
- from huggingface_hub import logout
45
- logout()
46
- return "βœ… Successfully logged out", "Not logged in"
47
- except Exception as e:
48
- return f"❌ Logout failed: {str(e)}", "Status unknown"
49
 
50
 
51
- def check_hf_status():
52
- """Check current HF login status."""
 
 
 
 
 
 
53
  try:
54
- user_info = whoami()
55
  username = user_info.get('name', 'Unknown')
56
- return f"βœ… Currently logged in as: {username}", f"βœ… Logged in as {username}"
 
57
  except Exception:
58
- return "ℹ️ Not logged in", "Not logged in"
59
 
60
 
61
- def get_param(model_id: str, param_key: str, log_buffer: io.StringIO, progress: gr.Progress):
62
  """
63
  Download and return a specific parameter tensor from a Hugging Face model.
64
  """
 
 
 
65
  # Redirect stderr to log buffer for real-time tqdm updates
66
  original_stderr = sys.stderr
67
  sys.stderr = log_buffer
@@ -73,7 +70,7 @@ def get_param(model_id: str, param_key: str, log_buffer: io.StringIO, progress:
73
  progress(0.1, desc="Downloading index...")
74
 
75
  index_path = hf_hub_download(
76
- model_id, "model.safetensors.index.json")
77
 
78
  log_buffer.write(f"βœ“ Index file found: {index_path}\n")
79
 
@@ -96,7 +93,7 @@ def get_param(model_id: str, param_key: str, log_buffer: io.StringIO, progress:
96
  log_buffer.write(f"πŸ“₯ Downloading shard: {shard_file}...\n")
97
  progress(0.3, desc=f"Downloading {shard_file}...")
98
 
99
- shard_path = hf_hub_download(model_id, shard_file)
100
 
101
  log_buffer.write(f"\nβœ“ Shard downloaded: {shard_path}\n")
102
  progress(0.7, desc="Loading tensor...")
@@ -113,17 +110,20 @@ def get_param(model_id: str, param_key: str, log_buffer: io.StringIO, progress:
113
  sys.stderr = original_stderr
114
 
115
 
116
- def get_available_keys(model_id: str):
117
  """Get all available parameter keys from a model."""
 
 
 
118
  try:
119
- index_path = hf_hub_download(model_id, "model.safetensors.index.json")
120
  with open(index_path, "r", encoding="utf-8") as f:
121
  index = json.load(f)
122
  return sorted(index["weight_map"].keys())
123
  except Exception:
124
  # Try single file
125
  try:
126
- shard_path = hf_hub_download(model_id, "model.safetensors")
127
  with safe_open(shard_path, framework="pt") as f:
128
  return sorted(f.keys())
129
  except Exception as e:
@@ -156,7 +156,7 @@ def format_tensor_info(tensor: torch.Tensor) -> str:
156
  return "<br>".join(info)
157
 
158
 
159
- def fetch_param(model_id: str, param_key: str, progress=gr.Progress()):
160
  """Fetch parameter and return formatted info and tensor preview."""
161
  log_buffer = io.StringIO()
162
  last_log_value = ""
@@ -179,7 +179,7 @@ def fetch_param(model_id: str, param_key: str, progress=gr.Progress()):
179
 
180
  def download_thread():
181
  try:
182
- result_tensor[0] = get_param(model_id, param_key, log_buffer, progress)
183
  except Exception as e:
184
  download_error[0] = e
185
  finally:
@@ -260,13 +260,13 @@ def fetch_param(model_id: str, param_key: str, progress=gr.Progress()):
260
  yield f"**Error:** {str(e)}", "", None, log_buffer.getvalue()
261
 
262
 
263
- def list_keys(model_id: str):
264
  """List all available keys for a model."""
265
  if not model_id:
266
  return "Please provide a model ID."
267
 
268
  try:
269
- keys = get_available_keys(model_id)
270
  if not keys:
271
  return "No keys found or failed to load model."
272
  return "\n".join(keys)
@@ -390,11 +390,14 @@ custom_css = """
390
  with gr.Blocks(title="Hugging Face Model Weight Inspector") as demo:
391
  gr.Markdown("# πŸ” Hugging Face Model Weight Inspector")
392
 
 
 
 
393
  # HF Login section
394
- with gr.Accordion("πŸ” Hugging Face Login (Optional - auto-login with HF_TOKEN)", open=False):
395
  gr.Markdown("""
396
- **Note:** This Space automatically uses the `HF_TOKEN` secret if configured.
397
- Manual login below is only needed if you want to use a different token.
398
  """)
399
  with gr.Row():
400
  with gr.Column(scale=3):
@@ -404,9 +407,10 @@ with gr.Blocks(title="Hugging Face Model Weight Inspector") as demo:
404
  type="password",
405
  )
406
  with gr.Column(scale=2):
 
407
  hf_status = gr.Textbox(
408
  label="Status",
409
- value="Not logged in",
410
  interactive=False,
411
  )
412
  with gr.Row():
@@ -464,31 +468,31 @@ with gr.Blocks(title="Hugging Face Model Weight Inspector") as demo:
464
  # Event handlers
465
  login_btn.click(
466
  fn=hf_login,
467
- inputs=[hf_token_input],
468
- outputs=[login_output, hf_status],
469
  )
470
 
471
  logout_btn.click(
472
  fn=hf_logout,
473
- inputs=[],
474
- outputs=[login_output, hf_status],
475
  )
476
 
477
  check_status_btn.click(
478
  fn=check_hf_status,
479
- inputs=[],
480
- outputs=[login_output, hf_status],
481
  )
482
 
483
  list_keys_btn.click(
484
  fn=list_keys,
485
- inputs=[model_id_input],
486
  outputs=[keys_output],
487
  )
488
 
489
  fetch_btn.click(
490
  fn=fetch_param,
491
- inputs=[model_id_input, param_key_input],
492
  outputs=[info_output, preview_output, download_output, log_output],
493
  )
494
 
@@ -509,6 +513,13 @@ with gr.Blocks(title="Hugging Face Model Weight Inspector") as demo:
509
  inputs=[],
510
  outputs=[clear_status],
511
  )
 
 
 
 
 
 
 
512
 
513
 
514
  if __name__ == "__main__":
 
10
 
11
  import gradio as gr
12
  import torch
13
+ from huggingface_hub import hf_hub_download, scan_cache_dir, whoami
14
  from safetensors import safe_open
15
 
16
+ # Default token from HF_TOKEN environment variable (for HuggingFace Spaces)
17
+ DEFAULT_HF_TOKEN = os.environ.get("HF_TOKEN")
 
 
 
 
 
 
18
 
19
 
20
+ def hf_login(token: str, session_token: str):
21
+ """Login to Hugging Face with provided token (per-user session)."""
22
  if not token:
23
+ return "❌ Please provide a token", "Not logged in", session_token
24
 
25
  try:
26
+ user_info = whoami(token=token)
 
27
  username = user_info.get('name', 'Unknown')
28
+ return f"βœ… Successfully logged in as: {username}", f"βœ… Logged in as {username}", token
29
  except Exception as e:
30
+ return f"❌ Login failed: {str(e)}", "❌ Not logged in", session_token
31
 
32
 
33
+ def hf_logout(session_token: str):
34
+ """Logout from Hugging Face (clear session token)."""
35
+ return "βœ… Successfully logged out", "Not logged in", None
 
 
 
 
 
 
36
 
37
 
38
+ def check_hf_status(session_token: str):
39
+ """Check current HF login status for this session."""
40
+ # Check session token first, then fall back to default token
41
+ token = session_token or DEFAULT_HF_TOKEN
42
+
43
+ if not token:
44
+ return "ℹ️ Not logged in", "Not logged in", session_token
45
+
46
  try:
47
+ user_info = whoami(token=token)
48
  username = user_info.get('name', 'Unknown')
49
+ source = "(session)" if session_token else "(default HF_TOKEN)"
50
+ return f"βœ… Currently logged in as: {username} {source}", f"βœ… Logged in as {username}", session_token
51
  except Exception:
52
+ return "ℹ️ Not logged in", "Not logged in", session_token
53
 
54
 
55
+ def get_param(model_id: str, param_key: str, log_buffer: io.StringIO, progress: gr.Progress, token: str = None):
56
  """
57
  Download and return a specific parameter tensor from a Hugging Face model.
58
  """
59
+ # Use session token or fall back to default token
60
+ auth_token = token or DEFAULT_HF_TOKEN
61
+
62
  # Redirect stderr to log buffer for real-time tqdm updates
63
  original_stderr = sys.stderr
64
  sys.stderr = log_buffer
 
70
  progress(0.1, desc="Downloading index...")
71
 
72
  index_path = hf_hub_download(
73
+ model_id, "model.safetensors.index.json", token=auth_token)
74
 
75
  log_buffer.write(f"βœ“ Index file found: {index_path}\n")
76
 
 
93
  log_buffer.write(f"πŸ“₯ Downloading shard: {shard_file}...\n")
94
  progress(0.3, desc=f"Downloading {shard_file}...")
95
 
96
+ shard_path = hf_hub_download(model_id, shard_file, token=auth_token)
97
 
98
  log_buffer.write(f"\nβœ“ Shard downloaded: {shard_path}\n")
99
  progress(0.7, desc="Loading tensor...")
 
110
  sys.stderr = original_stderr
111
 
112
 
113
+ def get_available_keys(model_id: str, token: str = None):
114
  """Get all available parameter keys from a model."""
115
+ # Use session token or fall back to default token
116
+ auth_token = token or DEFAULT_HF_TOKEN
117
+
118
  try:
119
+ index_path = hf_hub_download(model_id, "model.safetensors.index.json", token=auth_token)
120
  with open(index_path, "r", encoding="utf-8") as f:
121
  index = json.load(f)
122
  return sorted(index["weight_map"].keys())
123
  except Exception:
124
  # Try single file
125
  try:
126
+ shard_path = hf_hub_download(model_id, "model.safetensors", token=auth_token)
127
  with safe_open(shard_path, framework="pt") as f:
128
  return sorted(f.keys())
129
  except Exception as e:
 
156
  return "<br>".join(info)
157
 
158
 
159
+ def fetch_param(model_id: str, param_key: str, session_token: str, progress=gr.Progress()):
160
  """Fetch parameter and return formatted info and tensor preview."""
161
  log_buffer = io.StringIO()
162
  last_log_value = ""
 
179
 
180
  def download_thread():
181
  try:
182
+ result_tensor[0] = get_param(model_id, param_key, log_buffer, progress, session_token)
183
  except Exception as e:
184
  download_error[0] = e
185
  finally:
 
260
  yield f"**Error:** {str(e)}", "", None, log_buffer.getvalue()
261
 
262
 
263
+ def list_keys(model_id: str, session_token: str):
264
  """List all available keys for a model."""
265
  if not model_id:
266
  return "Please provide a model ID."
267
 
268
  try:
269
+ keys = get_available_keys(model_id, session_token)
270
  if not keys:
271
  return "No keys found or failed to load model."
272
  return "\n".join(keys)
 
390
  with gr.Blocks(title="Hugging Face Model Weight Inspector") as demo:
391
  gr.Markdown("# πŸ” Hugging Face Model Weight Inspector")
392
 
393
+ # Session state for per-user token
394
+ session_token = gr.State(None)
395
+
396
  # HF Login section
397
+ with gr.Accordion("πŸ” Hugging Face Login (Per-User Session)", open=False):
398
  gr.Markdown("""
399
+ **Note:** This Space uses the default `HF_TOKEN` secret for all users if no session token is provided.
400
+ Login below with your own token for per-user authentication (affects only your session).
401
  """)
402
  with gr.Row():
403
  with gr.Column(scale=3):
 
407
  type="password",
408
  )
409
  with gr.Column(scale=2):
410
+ initial_status = "βœ… Using default HF_TOKEN" if DEFAULT_HF_TOKEN else "Not logged in"
411
  hf_status = gr.Textbox(
412
  label="Status",
413
+ value=initial_status,
414
  interactive=False,
415
  )
416
  with gr.Row():
 
468
  # Event handlers
469
  login_btn.click(
470
  fn=hf_login,
471
+ inputs=[hf_token_input, session_token],
472
+ outputs=[login_output, hf_status, session_token],
473
  )
474
 
475
  logout_btn.click(
476
  fn=hf_logout,
477
+ inputs=[session_token],
478
+ outputs=[login_output, hf_status, session_token],
479
  )
480
 
481
  check_status_btn.click(
482
  fn=check_hf_status,
483
+ inputs=[session_token],
484
+ outputs=[login_output, hf_status, session_token],
485
  )
486
 
487
  list_keys_btn.click(
488
  fn=list_keys,
489
+ inputs=[model_id_input, session_token],
490
  outputs=[keys_output],
491
  )
492
 
493
  fetch_btn.click(
494
  fn=fetch_param,
495
+ inputs=[model_id_input, param_key_input, session_token],
496
  outputs=[info_output, preview_output, download_output, log_output],
497
  )
498
 
 
513
  inputs=[],
514
  outputs=[clear_status],
515
  )
516
+
517
+ # Auto-check status on load
518
+ demo.load(
519
+ fn=check_hf_status,
520
+ inputs=[session_token],
521
+ outputs=[login_output, hf_status, session_token],
522
+ )
523
 
524
 
525
  if __name__ == "__main__":