import json import tempfile import os import glob import shutil import io import time import threading import sys import gradio as gr import torch from huggingface_hub import hf_hub_download, scan_cache_dir, whoami from safetensors import safe_open # Default token from HF_TOKEN environment variable (for HuggingFace Spaces) DEFAULT_HF_TOKEN = os.environ.get("HF_TOKEN") def hf_login(token: str, session_token: str): """Login to Hugging Face with provided token (per-user session).""" if not token: return "❌ Please provide a token", "Not logged in", session_token try: user_info = whoami(token=token) username = user_info.get('name', 'Unknown') return f"✅ Successfully logged in as: {username}", f"✅ Logged in as {username}", token except Exception as e: return f"❌ Login failed: {str(e)}", "❌ Not logged in", session_token def hf_logout(session_token: str): """Logout from Hugging Face (clear session token).""" return "✅ Successfully logged out", "Not logged in", None def check_hf_status(session_token: str): """Check current HF login status for this session.""" # Check session token first, then fall back to default token token = session_token or DEFAULT_HF_TOKEN if not token: return "ℹ️ Not logged in", "Not logged in", session_token try: user_info = whoami(token=token) username = user_info.get('name', 'Unknown') source = "(session)" if session_token else "(default HF_TOKEN)" return f"✅ Currently logged in as: {username} {source}", f"✅ Logged in as {username}", session_token except Exception: return "ℹ️ Not logged in", "Not logged in", session_token def get_param(model_id: str, param_key: str, log_buffer: io.StringIO, progress: gr.Progress, token: str = None): """ Download and return a specific parameter tensor from a Hugging Face model. """ # Use session token or fall back to default token auth_token = token or DEFAULT_HF_TOKEN # Redirect stderr to log buffer for real-time tqdm updates original_stderr = sys.stderr sys.stderr = log_buffer try: # Try to download the index file (for sharded models) try: log_buffer.write(f"📥 Downloading index file for {model_id}...\n") progress(0.1, desc="Downloading index...") index_path = hf_hub_download( model_id, "model.safetensors.index.json", token=auth_token) log_buffer.write(f"✓ Index file found: {index_path}\n") with open(index_path, "r", encoding="utf-8") as f: index = json.load(f) weight_map = index["weight_map"] if param_key not in weight_map: raise KeyError( f"Parameter '{param_key}' not found in model. Available keys: {list(weight_map.keys())[:10]}..." ) shard_file = weight_map[param_key] log_buffer.write(f"✓ Parameter found in shard: {shard_file}\n") except Exception as e: if "404" in str(e) or "not found" in str(e).lower(): log_buffer.write("ℹ️ No index file, trying single model file...\n") shard_file = "model.safetensors" else: raise log_buffer.write(f"📥 Downloading shard: {shard_file}...\n") progress(0.3, desc=f"Downloading {shard_file}...") shard_path = hf_hub_download(model_id, shard_file, token=auth_token) log_buffer.write(f"\n✓ Shard downloaded: {shard_path}\n") progress(0.7, desc="Loading tensor...") log_buffer.write(f"🔍 Loading tensor '{param_key}'...\n") with safe_open(shard_path, framework="pt") as f: tensor = f.get_tensor(param_key) log_buffer.write(f"✓ Tensor loaded successfully\n") progress(0.9, desc="Finalizing...") return tensor finally: # Restore original stderr sys.stderr = original_stderr def get_available_keys(model_id: str, token: str = None): """Get all available parameter keys from a model.""" # Use session token or fall back to default token auth_token = token or DEFAULT_HF_TOKEN try: index_path = hf_hub_download(model_id, "model.safetensors.index.json", token=auth_token) with open(index_path, "r", encoding="utf-8") as f: index = json.load(f) return sorted(index["weight_map"].keys()) except Exception: # Try single file try: shard_path = hf_hub_download(model_id, "model.safetensors", token=auth_token) with safe_open(shard_path, framework="pt") as f: return sorted(f.keys()) except Exception as e: return [] def format_tensor_info(tensor: torch.Tensor) -> str: """Format tensor information for display.""" info = [] info.append(f"**Shape:** {list(tensor.shape)}") info.append(f"**Dtype:** {tensor.dtype}") info.append(f"**Device:** {tensor.device}") info.append(f"**Numel:** {tensor.numel():,}") # Handle special dtypes that don't support statistical operations try: # Convert FP8 and other special dtypes to float32 for stats if str(tensor.dtype) in ['torch.float8_e4m3fn', 'torch.float8_e5m2']: stats_tensor = tensor.to(torch.float32) else: stats_tensor = tensor info.append(f"**Min:** {stats_tensor.min().item():.6f}") info.append(f"**Max:** {stats_tensor.max().item():.6f}") info.append(f"**Mean:** {stats_tensor.float().mean().item():.6f}") info.append(f"**Std:** {stats_tensor.float().std().item():.6f}") except Exception as e: info.append(f"**Stats:** Unable to compute (dtype not supported)") return "
".join(info) def fetch_param(model_id: str, param_key: str, session_token: str, progress=gr.Progress()): """Fetch parameter and return formatted info and tensor preview.""" log_buffer = io.StringIO() last_log_value = "" if not model_id or not param_key: yield "Please provide both model ID and parameter key.", "", None, "❌ Missing required inputs" return try: log_buffer.write(f"🚀 Starting download for {model_id}\n") log_buffer.write(f"🎯 Target parameter: {param_key}\n\n") progress(0, desc="Initializing...") yield "", "", None, log_buffer.getvalue() time.sleep(0.5) # Start download in background thread download_complete = threading.Event() download_error = [None] # Use list to store exception from thread result_tensor = [None] # Use list to store result from thread def download_thread(): try: result_tensor[0] = get_param(model_id, param_key, log_buffer, progress, session_token) except Exception as e: download_error[0] = e finally: download_complete.set() thread = threading.Thread(target=download_thread, daemon=True) thread.start() # Poll log buffer every 1 second while download is running while not download_complete.is_set(): current_log = log_buffer.getvalue() if current_log != last_log_value: yield "", "", None, current_log last_log_value = current_log time.sleep(1) # Final log update after download completes current_log = log_buffer.getvalue() if current_log != last_log_value: yield "", "", None, current_log last_log_value = current_log # Check for errors if download_error[0]: raise download_error[0] tensor = result_tensor[0] info = format_tensor_info(tensor) # Create tensor preview (first few elements) log_buffer.write(f"\n📊 Creating preview...\n") yield "", "", None, log_buffer.getvalue() flat = tensor.flatten() preview_size = min(100, flat.numel()) # Convert to float32 for FP8 types for display if str(tensor.dtype) in ['torch.float8_e4m3fn', 'torch.float8_e5m2']: preview = flat[:preview_size].to(torch.float32).tolist() else: preview = flat[:preview_size].tolist() # Format preview in multiple lines (10 values per line) # Adapt to different data types preview_lines = [] for i in range(0, len(preview), 10): line_values = preview[i:i+10] if tensor.dtype in [torch.float32, torch.float64, torch.float16, torch.bfloat16] or str(tensor.dtype) in ['torch.float8_e4m3fn', 'torch.float8_e5m2']: preview_lines.append(", ".join(f"{v:.6f}" for v in line_values)) elif tensor.dtype in [torch.int8, torch.int16, torch.int32, torch.int64, torch.uint8]: preview_lines.append(", ".join(f"{v}" for v in line_values)) elif tensor.dtype == torch.bool: preview_lines.append(", ".join(f"{v}" for v in line_values)) else: preview_lines.append(", ".join(str(v) for v in line_values)) preview_str = f"**First {preview_size} values:**\n```\n" + \ "\n".join(preview_lines) + "\n```" # if flat.numel() > preview_size: # preview_str += f"\n\n... and {flat.numel() - preview_size:,} more values" # Save tensor for download log_buffer.write(f"💾 Saving tensor for download...\n") yield info, preview_str, None, log_buffer.getvalue() temp_dir = tempfile.gettempdir() safe_param_key = param_key.replace("/", "_").replace(".", "_") download_path = os.path.join(temp_dir, f"{safe_param_key}.pt") torch.save(tensor, download_path) log_buffer.write(f"✓ Saved to: {download_path}\n") progress(1.0, desc="Complete!") log_buffer.write(f"\n✅ All operations completed successfully!\n") yield info, preview_str, download_path, log_buffer.getvalue() except Exception as e: log_buffer.write(f"\n❌ Error: {str(e)}\n") yield f"**Error:** {str(e)}", "", None, log_buffer.getvalue() def list_keys(model_id: str, session_token: str): """List all available keys for a model.""" if not model_id: return "Please provide a model ID." try: keys = get_available_keys(model_id, session_token) if not keys: return "No keys found or failed to load model." return "\n".join(keys) except Exception as e: return f"**Error:** {str(e)}" def clear_temp_files(): """Clear all .pt files from temp directory.""" try: temp_dir = tempfile.gettempdir() pt_files = glob.glob(os.path.join(temp_dir, "*.pt")) count = len(pt_files) deleted_files = [] for file in pt_files: try: os.remove(file) deleted_files.append(os.path.basename(file)) except Exception: pass if deleted_files: files_list = "\n".join(deleted_files) return f"✅ Cleared {count} temporary file(s):\n\n{files_list}" else: return "✅ No temporary files to clear" except Exception as e: return f"❌ Error: {str(e)}" def clear_hf_cache(): """Clear Hugging Face cache directory.""" try: cache_info = scan_cache_dir() total_size = cache_info.size_on_disk total_repos = len(cache_info.repos) if total_repos == 0: return "✅ Hugging Face cache is already empty" # Get cache directory and clear it cache_dir = os.path.expanduser("~/.cache/huggingface/hub") if os.path.exists(cache_dir): shutil.rmtree(cache_dir) os.makedirs(cache_dir) size_mb = total_size / (1024 * 1024) return f"✅ Cleared Hugging Face cache: {total_repos} repo(s), {size_mb:.2f} MB freed" else: return "✅ Hugging Face cache directory not found" except Exception as e: return f"❌ Error: {str(e)}" def get_cache_info(): """Get size information about caches.""" try: # Temp files temp_dir = tempfile.gettempdir() pt_files = glob.glob(os.path.join(temp_dir, "*.pt")) temp_size = sum(os.path.getsize(f) for f in pt_files if os.path.exists(f)) temp_size_mb = temp_size / (1024 * 1024) info = f"📊 Cache Info:\n\n" info += f"═══ Temp .pt files: {len(pt_files)} file(s), {temp_size_mb:.2f} MB ═══\n" if pt_files: for file in pt_files: size = os.path.getsize(file) / (1024 * 1024) filename = os.path.basename(file) info += f" • {filename} ({size:.2f} MB)\n" else: info += " (empty)\n" # HF cache info += f"\n═══ Hugging Face Cache ═══\n" try: cache_info = scan_cache_dir() hf_size_mb = cache_info.size_on_disk / (1024 * 1024) hf_repos = len(cache_info.repos) info += f"Total: {hf_repos} repo(s), {hf_size_mb:.2f} MB\n\n" if hf_repos > 0: for repo in cache_info.repos: repo_size = repo.size_on_disk / (1024 * 1024) info += f" 📦 {repo.repo_id}\n" info += f" Size: {repo_size:.2f} MB, Revisions: {len(repo.revisions)}\n" info += f" Last accessed: {repo.last_accessed}\n" else: info += " (empty)\n" except Exception as e: info += f" Error reading HF cache: {str(e)}\n" info += f"\n═══ Total: {temp_size_mb + (hf_size_mb if 'hf_size_mb' in locals() else 0):.2f} MB ═══" return info except Exception as e: return f"❌ Error: {str(e)}" # Create Gradio interface custom_css = """ * { font-family: Consolas, Monaco, 'Courier New', monospace !important; } .compact-row { gap: 0.5rem !important; } .tensor-preview pre { font-size: 0.75rem !important; line-height: 1.0 !important; } .compact-file { max-height: 80px !important; } .compact-file > div { min-height: 60px !important; } """ with gr.Blocks(title="Hugging Face Model Weight Inspector") as demo: gr.Markdown("# 🔍 Hugging Face Model Weight Inspector") # Session state for per-user token session_token = gr.State(None) # HF Login section with gr.Accordion("🔐 Hugging Face Login (Per-User Session) [⚠️⚠️⚠️WIP, Do not use⚠️⚠️⚠️]", open=False): gr.Markdown(""" **Note:** This Space uses the default `HF_TOKEN` secret for all users if no session token is provided. Login below with your own token for per-user authentication (affects only your session). """) with gr.Row(): with gr.Column(scale=3): hf_token_input = gr.Textbox( label="HF Token", placeholder="hf_...", type="password", ) with gr.Column(scale=2): initial_status = "✅ Using default HF_TOKEN" if DEFAULT_HF_TOKEN else "Not logged in" hf_status = gr.Textbox( label="Status", value=initial_status, interactive=False, ) with gr.Row(): login_btn = gr.Button("🔑 Login", variant="primary", scale=1) logout_btn = gr.Button("🚪 Logout", variant="secondary", scale=1) check_status_btn = gr.Button("ℹ️ Check Status", variant="secondary", scale=1) login_output = gr.Textbox(label="Login Status", interactive=False, lines=2) with gr.Row(): with gr.Column(scale=1): model_id_input = gr.Textbox( label="Model ID", placeholder="e.g., meta-llama/Llama-2-7b-hf", value="Qwen/Qwen3-Coder-Next-FP8", ) param_key_input = gr.Textbox( label="Parameter Key", placeholder="e.g., model.norm.weight", value="model.norm.weight", ) with gr.Row(): list_keys_btn = gr.Button( "📋 List Keys", variant="secondary", scale=1) fetch_btn = gr.Button("🔎 Fetch", variant="primary", scale=1) with gr.Column(scale=1): keys_output = gr.Textbox( label="Available Parameter Keys", lines=5, max_lines=8, ) with gr.Tabs(): with gr.Tab("Results"): with gr.Row(): with gr.Column(scale=3): preview_output = gr.Markdown(label="Tensor Preview", elem_classes="tensor-preview") with gr.Column(scale=1): info_output = gr.Markdown(label="Tensor Info") download_output = gr.File(label="Download Tensor (.pt file)", elem_classes="compact-file") log_output = gr.Textbox( label="📋 Download Log", lines=1, interactive=False) with gr.Tab("Cache Management"): with gr.Row(): get_info_btn = gr.Button( "📊 Get Cache Info", variant="secondary", scale=1) clear_temp_btn = gr.Button( "🗑️ Clear Temp Folder", variant="secondary", scale=1) clear_hf_btn = gr.Button( "🗑️ Clear HF Cache", variant="secondary", scale=1) clear_status = gr.Textbox( label="Status", interactive=False, lines=6) # Event handlers login_btn.click( fn=hf_login, inputs=[hf_token_input, session_token], outputs=[login_output, hf_status, session_token], ) logout_btn.click( fn=hf_logout, inputs=[session_token], outputs=[login_output, hf_status, session_token], ) check_status_btn.click( fn=check_hf_status, inputs=[session_token], outputs=[login_output, hf_status, session_token], ) list_keys_btn.click( fn=list_keys, inputs=[model_id_input, session_token], outputs=[keys_output], ) fetch_btn.click( fn=fetch_param, inputs=[model_id_input, param_key_input, session_token], outputs=[info_output, preview_output, download_output, log_output], ) clear_temp_btn.click( fn=clear_temp_files, inputs=[], outputs=[clear_status], ) clear_hf_btn.click( fn=clear_hf_cache, inputs=[], outputs=[clear_status], ) get_info_btn.click( fn=get_cache_info, inputs=[], outputs=[clear_status], ) # Auto-check status on load demo.load( fn=check_hf_status, inputs=[session_token], outputs=[login_output, hf_status, session_token], ) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", css=custom_css)