yujiepan's picture
add alert
c5f9cd8
import json
import tempfile
import os
import glob
import shutil
import io
import time
import threading
import sys
import gradio as gr
import torch
from huggingface_hub import hf_hub_download, scan_cache_dir, whoami
from safetensors import safe_open
# Default token from HF_TOKEN environment variable (for HuggingFace Spaces)
DEFAULT_HF_TOKEN = os.environ.get("HF_TOKEN")
def hf_login(token: str, session_token: str):
"""Login to Hugging Face with provided token (per-user session)."""
if not token:
return "โŒ Please provide a token", "Not logged in", session_token
try:
user_info = whoami(token=token)
username = user_info.get('name', 'Unknown')
return f"โœ… Successfully logged in as: {username}", f"โœ… Logged in as {username}", token
except Exception as e:
return f"โŒ Login failed: {str(e)}", "โŒ Not logged in", session_token
def hf_logout(session_token: str):
"""Logout from Hugging Face (clear session token)."""
return "โœ… Successfully logged out", "Not logged in", None
def check_hf_status(session_token: str):
"""Check current HF login status for this session."""
# Check session token first, then fall back to default token
token = session_token or DEFAULT_HF_TOKEN
if not token:
return "โ„น๏ธ Not logged in", "Not logged in", session_token
try:
user_info = whoami(token=token)
username = user_info.get('name', 'Unknown')
source = "(session)" if session_token else "(default HF_TOKEN)"
return f"โœ… Currently logged in as: {username} {source}", f"โœ… Logged in as {username}", session_token
except Exception:
return "โ„น๏ธ Not logged in", "Not logged in", session_token
def get_param(model_id: str, param_key: str, log_buffer: io.StringIO, progress: gr.Progress, token: str = None):
"""
Download and return a specific parameter tensor from a Hugging Face model.
"""
# Use session token or fall back to default token
auth_token = token or DEFAULT_HF_TOKEN
# Redirect stderr to log buffer for real-time tqdm updates
original_stderr = sys.stderr
sys.stderr = log_buffer
try:
# Try to download the index file (for sharded models)
try:
log_buffer.write(f"๐Ÿ“ฅ Downloading index file for {model_id}...\n")
progress(0.1, desc="Downloading index...")
index_path = hf_hub_download(
model_id, "model.safetensors.index.json", token=auth_token)
log_buffer.write(f"โœ“ Index file found: {index_path}\n")
with open(index_path, "r", encoding="utf-8") as f:
index = json.load(f)
weight_map = index["weight_map"]
if param_key not in weight_map:
raise KeyError(
f"Parameter '{param_key}' not found in model. Available keys: {list(weight_map.keys())[:10]}..."
)
shard_file = weight_map[param_key]
log_buffer.write(f"โœ“ Parameter found in shard: {shard_file}\n")
except Exception as e:
if "404" in str(e) or "not found" in str(e).lower():
log_buffer.write("โ„น๏ธ No index file, trying single model file...\n")
shard_file = "model.safetensors"
else:
raise
log_buffer.write(f"๐Ÿ“ฅ Downloading shard: {shard_file}...\n")
progress(0.3, desc=f"Downloading {shard_file}...")
shard_path = hf_hub_download(model_id, shard_file, token=auth_token)
log_buffer.write(f"\nโœ“ Shard downloaded: {shard_path}\n")
progress(0.7, desc="Loading tensor...")
log_buffer.write(f"๐Ÿ” Loading tensor '{param_key}'...\n")
with safe_open(shard_path, framework="pt") as f:
tensor = f.get_tensor(param_key)
log_buffer.write(f"โœ“ Tensor loaded successfully\n")
progress(0.9, desc="Finalizing...")
return tensor
finally:
# Restore original stderr
sys.stderr = original_stderr
def get_available_keys(model_id: str, token: str = None):
"""Get all available parameter keys from a model."""
# Use session token or fall back to default token
auth_token = token or DEFAULT_HF_TOKEN
try:
index_path = hf_hub_download(model_id, "model.safetensors.index.json", token=auth_token)
with open(index_path, "r", encoding="utf-8") as f:
index = json.load(f)
return sorted(index["weight_map"].keys())
except Exception:
# Try single file
try:
shard_path = hf_hub_download(model_id, "model.safetensors", token=auth_token)
with safe_open(shard_path, framework="pt") as f:
return sorted(f.keys())
except Exception as e:
return []
def format_tensor_info(tensor: torch.Tensor) -> str:
"""Format tensor information for display."""
info = []
info.append(f"**Shape:** {list(tensor.shape)}")
info.append(f"**Dtype:** {tensor.dtype}")
info.append(f"**Device:** {tensor.device}")
info.append(f"**Numel:** {tensor.numel():,}")
# Handle special dtypes that don't support statistical operations
try:
# Convert FP8 and other special dtypes to float32 for stats
if str(tensor.dtype) in ['torch.float8_e4m3fn', 'torch.float8_e5m2']:
stats_tensor = tensor.to(torch.float32)
else:
stats_tensor = tensor
info.append(f"**Min:** {stats_tensor.min().item():.6f}")
info.append(f"**Max:** {stats_tensor.max().item():.6f}")
info.append(f"**Mean:** {stats_tensor.float().mean().item():.6f}")
info.append(f"**Std:** {stats_tensor.float().std().item():.6f}")
except Exception as e:
info.append(f"**Stats:** Unable to compute (dtype not supported)")
return "<br>".join(info)
def fetch_param(model_id: str, param_key: str, session_token: str, progress=gr.Progress()):
"""Fetch parameter and return formatted info and tensor preview."""
log_buffer = io.StringIO()
last_log_value = ""
if not model_id or not param_key:
yield "Please provide both model ID and parameter key.", "", None, "โŒ Missing required inputs"
return
try:
log_buffer.write(f"๐Ÿš€ Starting download for {model_id}\n")
log_buffer.write(f"๐ŸŽฏ Target parameter: {param_key}\n\n")
progress(0, desc="Initializing...")
yield "", "", None, log_buffer.getvalue()
time.sleep(0.5)
# Start download in background thread
download_complete = threading.Event()
download_error = [None] # Use list to store exception from thread
result_tensor = [None] # Use list to store result from thread
def download_thread():
try:
result_tensor[0] = get_param(model_id, param_key, log_buffer, progress, session_token)
except Exception as e:
download_error[0] = e
finally:
download_complete.set()
thread = threading.Thread(target=download_thread, daemon=True)
thread.start()
# Poll log buffer every 1 second while download is running
while not download_complete.is_set():
current_log = log_buffer.getvalue()
if current_log != last_log_value:
yield "", "", None, current_log
last_log_value = current_log
time.sleep(1)
# Final log update after download completes
current_log = log_buffer.getvalue()
if current_log != last_log_value:
yield "", "", None, current_log
last_log_value = current_log
# Check for errors
if download_error[0]:
raise download_error[0]
tensor = result_tensor[0]
info = format_tensor_info(tensor)
# Create tensor preview (first few elements)
log_buffer.write(f"\n๐Ÿ“Š Creating preview...\n")
yield "", "", None, log_buffer.getvalue()
flat = tensor.flatten()
preview_size = min(100, flat.numel())
# Convert to float32 for FP8 types for display
if str(tensor.dtype) in ['torch.float8_e4m3fn', 'torch.float8_e5m2']:
preview = flat[:preview_size].to(torch.float32).tolist()
else:
preview = flat[:preview_size].tolist()
# Format preview in multiple lines (10 values per line)
# Adapt to different data types
preview_lines = []
for i in range(0, len(preview), 10):
line_values = preview[i:i+10]
if tensor.dtype in [torch.float32, torch.float64, torch.float16, torch.bfloat16] or str(tensor.dtype) in ['torch.float8_e4m3fn', 'torch.float8_e5m2']:
preview_lines.append(", ".join(f"{v:.6f}" for v in line_values))
elif tensor.dtype in [torch.int8, torch.int16, torch.int32, torch.int64, torch.uint8]:
preview_lines.append(", ".join(f"{v}" for v in line_values))
elif tensor.dtype == torch.bool:
preview_lines.append(", ".join(f"{v}" for v in line_values))
else:
preview_lines.append(", ".join(str(v) for v in line_values))
preview_str = f"**First {preview_size} values:**\n```\n" + \
"\n".join(preview_lines) + "\n```"
# if flat.numel() > preview_size:
# preview_str += f"\n\n... and {flat.numel() - preview_size:,} more values"
# Save tensor for download
log_buffer.write(f"๐Ÿ’พ Saving tensor for download...\n")
yield info, preview_str, None, log_buffer.getvalue()
temp_dir = tempfile.gettempdir()
safe_param_key = param_key.replace("/", "_").replace(".", "_")
download_path = os.path.join(temp_dir, f"{safe_param_key}.pt")
torch.save(tensor, download_path)
log_buffer.write(f"โœ“ Saved to: {download_path}\n")
progress(1.0, desc="Complete!")
log_buffer.write(f"\nโœ… All operations completed successfully!\n")
yield info, preview_str, download_path, log_buffer.getvalue()
except Exception as e:
log_buffer.write(f"\nโŒ Error: {str(e)}\n")
yield f"**Error:** {str(e)}", "", None, log_buffer.getvalue()
def list_keys(model_id: str, session_token: str):
"""List all available keys for a model."""
if not model_id:
return "Please provide a model ID."
try:
keys = get_available_keys(model_id, session_token)
if not keys:
return "No keys found or failed to load model."
return "\n".join(keys)
except Exception as e:
return f"**Error:** {str(e)}"
def clear_temp_files():
"""Clear all .pt files from temp directory."""
try:
temp_dir = tempfile.gettempdir()
pt_files = glob.glob(os.path.join(temp_dir, "*.pt"))
count = len(pt_files)
deleted_files = []
for file in pt_files:
try:
os.remove(file)
deleted_files.append(os.path.basename(file))
except Exception:
pass
if deleted_files:
files_list = "\n".join(deleted_files)
return f"โœ… Cleared {count} temporary file(s):\n\n{files_list}"
else:
return "โœ… No temporary files to clear"
except Exception as e:
return f"โŒ Error: {str(e)}"
def clear_hf_cache():
"""Clear Hugging Face cache directory."""
try:
cache_info = scan_cache_dir()
total_size = cache_info.size_on_disk
total_repos = len(cache_info.repos)
if total_repos == 0:
return "โœ… Hugging Face cache is already empty"
# Get cache directory and clear it
cache_dir = os.path.expanduser("~/.cache/huggingface/hub")
if os.path.exists(cache_dir):
shutil.rmtree(cache_dir)
os.makedirs(cache_dir)
size_mb = total_size / (1024 * 1024)
return f"โœ… Cleared Hugging Face cache: {total_repos} repo(s), {size_mb:.2f} MB freed"
else:
return "โœ… Hugging Face cache directory not found"
except Exception as e:
return f"โŒ Error: {str(e)}"
def get_cache_info():
"""Get size information about caches."""
try:
# Temp files
temp_dir = tempfile.gettempdir()
pt_files = glob.glob(os.path.join(temp_dir, "*.pt"))
temp_size = sum(os.path.getsize(f)
for f in pt_files if os.path.exists(f))
temp_size_mb = temp_size / (1024 * 1024)
info = f"๐Ÿ“Š Cache Info:\n\n"
info += f"โ•โ•โ• Temp .pt files: {len(pt_files)} file(s), {temp_size_mb:.2f} MB โ•โ•โ•\n"
if pt_files:
for file in pt_files:
size = os.path.getsize(file) / (1024 * 1024)
filename = os.path.basename(file)
info += f" โ€ข {filename} ({size:.2f} MB)\n"
else:
info += " (empty)\n"
# HF cache
info += f"\nโ•โ•โ• Hugging Face Cache โ•โ•โ•\n"
try:
cache_info = scan_cache_dir()
hf_size_mb = cache_info.size_on_disk / (1024 * 1024)
hf_repos = len(cache_info.repos)
info += f"Total: {hf_repos} repo(s), {hf_size_mb:.2f} MB\n\n"
if hf_repos > 0:
for repo in cache_info.repos:
repo_size = repo.size_on_disk / (1024 * 1024)
info += f" ๐Ÿ“ฆ {repo.repo_id}\n"
info += f" Size: {repo_size:.2f} MB, Revisions: {len(repo.revisions)}\n"
info += f" Last accessed: {repo.last_accessed}\n"
else:
info += " (empty)\n"
except Exception as e:
info += f" Error reading HF cache: {str(e)}\n"
info += f"\nโ•โ•โ• Total: {temp_size_mb + (hf_size_mb if 'hf_size_mb' in locals() else 0):.2f} MB โ•โ•โ•"
return info
except Exception as e:
return f"โŒ Error: {str(e)}"
# Create Gradio interface
custom_css = """
* {
font-family: Consolas, Monaco, 'Courier New', monospace !important;
}
.compact-row {
gap: 0.5rem !important;
}
.tensor-preview pre {
font-size: 0.75rem !important;
line-height: 1.0 !important;
}
.compact-file {
max-height: 80px !important;
}
.compact-file > div {
min-height: 60px !important;
}
"""
with gr.Blocks(title="Hugging Face Model Weight Inspector") as demo:
gr.Markdown("# ๐Ÿ” Hugging Face Model Weight Inspector")
# Session state for per-user token
session_token = gr.State(None)
# HF Login section
with gr.Accordion("๐Ÿ” Hugging Face Login (Per-User Session) [โš ๏ธโš ๏ธโš ๏ธWIP, Do not useโš ๏ธโš ๏ธโš ๏ธ]", open=False):
gr.Markdown("""
**Note:** This Space uses the default `HF_TOKEN` secret for all users if no session token is provided.
Login below with your own token for per-user authentication (affects only your session).
""")
with gr.Row():
with gr.Column(scale=3):
hf_token_input = gr.Textbox(
label="HF Token",
placeholder="hf_...",
type="password",
)
with gr.Column(scale=2):
initial_status = "โœ… Using default HF_TOKEN" if DEFAULT_HF_TOKEN else "Not logged in"
hf_status = gr.Textbox(
label="Status",
value=initial_status,
interactive=False,
)
with gr.Row():
login_btn = gr.Button("๐Ÿ”‘ Login", variant="primary", scale=1)
logout_btn = gr.Button("๐Ÿšช Logout", variant="secondary", scale=1)
check_status_btn = gr.Button("โ„น๏ธ Check Status", variant="secondary", scale=1)
login_output = gr.Textbox(label="Login Status", interactive=False, lines=2)
with gr.Row():
with gr.Column(scale=1):
model_id_input = gr.Textbox(
label="Model ID",
placeholder="e.g., meta-llama/Llama-2-7b-hf",
value="Qwen/Qwen3-Coder-Next-FP8",
)
param_key_input = gr.Textbox(
label="Parameter Key",
placeholder="e.g., model.norm.weight",
value="model.norm.weight",
)
with gr.Row():
list_keys_btn = gr.Button(
"๐Ÿ“‹ List Keys", variant="secondary", scale=1)
fetch_btn = gr.Button("๐Ÿ”Ž Fetch", variant="primary", scale=1)
with gr.Column(scale=1):
keys_output = gr.Textbox(
label="Available Parameter Keys",
lines=5,
max_lines=8,
)
with gr.Tabs():
with gr.Tab("Results"):
with gr.Row():
with gr.Column(scale=3):
preview_output = gr.Markdown(label="Tensor Preview", elem_classes="tensor-preview")
with gr.Column(scale=1):
info_output = gr.Markdown(label="Tensor Info")
download_output = gr.File(label="Download Tensor (.pt file)", elem_classes="compact-file")
log_output = gr.Textbox(
label="๐Ÿ“‹ Download Log", lines=1, interactive=False)
with gr.Tab("Cache Management"):
with gr.Row():
get_info_btn = gr.Button(
"๐Ÿ“Š Get Cache Info", variant="secondary", scale=1)
clear_temp_btn = gr.Button(
"๐Ÿ—‘๏ธ Clear Temp Folder", variant="secondary", scale=1)
clear_hf_btn = gr.Button(
"๐Ÿ—‘๏ธ Clear HF Cache", variant="secondary", scale=1)
clear_status = gr.Textbox(
label="Status", interactive=False, lines=6)
# Event handlers
login_btn.click(
fn=hf_login,
inputs=[hf_token_input, session_token],
outputs=[login_output, hf_status, session_token],
)
logout_btn.click(
fn=hf_logout,
inputs=[session_token],
outputs=[login_output, hf_status, session_token],
)
check_status_btn.click(
fn=check_hf_status,
inputs=[session_token],
outputs=[login_output, hf_status, session_token],
)
list_keys_btn.click(
fn=list_keys,
inputs=[model_id_input, session_token],
outputs=[keys_output],
)
fetch_btn.click(
fn=fetch_param,
inputs=[model_id_input, param_key_input, session_token],
outputs=[info_output, preview_output, download_output, log_output],
)
clear_temp_btn.click(
fn=clear_temp_files,
inputs=[],
outputs=[clear_status],
)
clear_hf_btn.click(
fn=clear_hf_cache,
inputs=[],
outputs=[clear_status],
)
get_info_btn.click(
fn=get_cache_info,
inputs=[],
outputs=[clear_status],
)
# Auto-check status on load
demo.load(
fn=check_hf_status,
inputs=[session_token],
outputs=[login_output, hf_status, session_token],
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", css=custom_css)