Spaces:
Runtime error
Runtime error
| import json | |
| import tempfile | |
| import os | |
| import glob | |
| import shutil | |
| import io | |
| import time | |
| import threading | |
| import sys | |
| import gradio as gr | |
| import torch | |
| from huggingface_hub import hf_hub_download, scan_cache_dir, whoami | |
| from safetensors import safe_open | |
| # Default token from HF_TOKEN environment variable (for HuggingFace Spaces) | |
| DEFAULT_HF_TOKEN = os.environ.get("HF_TOKEN") | |
| def hf_login(token: str, session_token: str): | |
| """Login to Hugging Face with provided token (per-user session).""" | |
| if not token: | |
| return "โ Please provide a token", "Not logged in", session_token | |
| try: | |
| user_info = whoami(token=token) | |
| username = user_info.get('name', 'Unknown') | |
| return f"โ Successfully logged in as: {username}", f"โ Logged in as {username}", token | |
| except Exception as e: | |
| return f"โ Login failed: {str(e)}", "โ Not logged in", session_token | |
| def hf_logout(session_token: str): | |
| """Logout from Hugging Face (clear session token).""" | |
| return "โ Successfully logged out", "Not logged in", None | |
| def check_hf_status(session_token: str): | |
| """Check current HF login status for this session.""" | |
| # Check session token first, then fall back to default token | |
| token = session_token or DEFAULT_HF_TOKEN | |
| if not token: | |
| return "โน๏ธ Not logged in", "Not logged in", session_token | |
| try: | |
| user_info = whoami(token=token) | |
| username = user_info.get('name', 'Unknown') | |
| source = "(session)" if session_token else "(default HF_TOKEN)" | |
| return f"โ Currently logged in as: {username} {source}", f"โ Logged in as {username}", session_token | |
| except Exception: | |
| return "โน๏ธ Not logged in", "Not logged in", session_token | |
| def get_param(model_id: str, param_key: str, log_buffer: io.StringIO, progress: gr.Progress, token: str = None): | |
| """ | |
| Download and return a specific parameter tensor from a Hugging Face model. | |
| """ | |
| # Use session token or fall back to default token | |
| auth_token = token or DEFAULT_HF_TOKEN | |
| # Redirect stderr to log buffer for real-time tqdm updates | |
| original_stderr = sys.stderr | |
| sys.stderr = log_buffer | |
| try: | |
| # Try to download the index file (for sharded models) | |
| try: | |
| log_buffer.write(f"๐ฅ Downloading index file for {model_id}...\n") | |
| progress(0.1, desc="Downloading index...") | |
| index_path = hf_hub_download( | |
| model_id, "model.safetensors.index.json", token=auth_token) | |
| log_buffer.write(f"โ Index file found: {index_path}\n") | |
| with open(index_path, "r", encoding="utf-8") as f: | |
| index = json.load(f) | |
| weight_map = index["weight_map"] | |
| if param_key not in weight_map: | |
| raise KeyError( | |
| f"Parameter '{param_key}' not found in model. Available keys: {list(weight_map.keys())[:10]}..." | |
| ) | |
| shard_file = weight_map[param_key] | |
| log_buffer.write(f"โ Parameter found in shard: {shard_file}\n") | |
| except Exception as e: | |
| if "404" in str(e) or "not found" in str(e).lower(): | |
| log_buffer.write("โน๏ธ No index file, trying single model file...\n") | |
| shard_file = "model.safetensors" | |
| else: | |
| raise | |
| log_buffer.write(f"๐ฅ Downloading shard: {shard_file}...\n") | |
| progress(0.3, desc=f"Downloading {shard_file}...") | |
| shard_path = hf_hub_download(model_id, shard_file, token=auth_token) | |
| log_buffer.write(f"\nโ Shard downloaded: {shard_path}\n") | |
| progress(0.7, desc="Loading tensor...") | |
| log_buffer.write(f"๐ Loading tensor '{param_key}'...\n") | |
| with safe_open(shard_path, framework="pt") as f: | |
| tensor = f.get_tensor(param_key) | |
| log_buffer.write(f"โ Tensor loaded successfully\n") | |
| progress(0.9, desc="Finalizing...") | |
| return tensor | |
| finally: | |
| # Restore original stderr | |
| sys.stderr = original_stderr | |
| def get_available_keys(model_id: str, token: str = None): | |
| """Get all available parameter keys from a model.""" | |
| # Use session token or fall back to default token | |
| auth_token = token or DEFAULT_HF_TOKEN | |
| try: | |
| index_path = hf_hub_download(model_id, "model.safetensors.index.json", token=auth_token) | |
| with open(index_path, "r", encoding="utf-8") as f: | |
| index = json.load(f) | |
| return sorted(index["weight_map"].keys()) | |
| except Exception: | |
| # Try single file | |
| try: | |
| shard_path = hf_hub_download(model_id, "model.safetensors", token=auth_token) | |
| with safe_open(shard_path, framework="pt") as f: | |
| return sorted(f.keys()) | |
| except Exception as e: | |
| return [] | |
| def format_tensor_info(tensor: torch.Tensor) -> str: | |
| """Format tensor information for display.""" | |
| info = [] | |
| info.append(f"**Shape:** {list(tensor.shape)}") | |
| info.append(f"**Dtype:** {tensor.dtype}") | |
| info.append(f"**Device:** {tensor.device}") | |
| info.append(f"**Numel:** {tensor.numel():,}") | |
| # Handle special dtypes that don't support statistical operations | |
| try: | |
| # Convert FP8 and other special dtypes to float32 for stats | |
| if str(tensor.dtype) in ['torch.float8_e4m3fn', 'torch.float8_e5m2']: | |
| stats_tensor = tensor.to(torch.float32) | |
| else: | |
| stats_tensor = tensor | |
| info.append(f"**Min:** {stats_tensor.min().item():.6f}") | |
| info.append(f"**Max:** {stats_tensor.max().item():.6f}") | |
| info.append(f"**Mean:** {stats_tensor.float().mean().item():.6f}") | |
| info.append(f"**Std:** {stats_tensor.float().std().item():.6f}") | |
| except Exception as e: | |
| info.append(f"**Stats:** Unable to compute (dtype not supported)") | |
| return "<br>".join(info) | |
| def fetch_param(model_id: str, param_key: str, session_token: str, progress=gr.Progress()): | |
| """Fetch parameter and return formatted info and tensor preview.""" | |
| log_buffer = io.StringIO() | |
| last_log_value = "" | |
| if not model_id or not param_key: | |
| yield "Please provide both model ID and parameter key.", "", None, "โ Missing required inputs" | |
| return | |
| try: | |
| log_buffer.write(f"๐ Starting download for {model_id}\n") | |
| log_buffer.write(f"๐ฏ Target parameter: {param_key}\n\n") | |
| progress(0, desc="Initializing...") | |
| yield "", "", None, log_buffer.getvalue() | |
| time.sleep(0.5) | |
| # Start download in background thread | |
| download_complete = threading.Event() | |
| download_error = [None] # Use list to store exception from thread | |
| result_tensor = [None] # Use list to store result from thread | |
| def download_thread(): | |
| try: | |
| result_tensor[0] = get_param(model_id, param_key, log_buffer, progress, session_token) | |
| except Exception as e: | |
| download_error[0] = e | |
| finally: | |
| download_complete.set() | |
| thread = threading.Thread(target=download_thread, daemon=True) | |
| thread.start() | |
| # Poll log buffer every 1 second while download is running | |
| while not download_complete.is_set(): | |
| current_log = log_buffer.getvalue() | |
| if current_log != last_log_value: | |
| yield "", "", None, current_log | |
| last_log_value = current_log | |
| time.sleep(1) | |
| # Final log update after download completes | |
| current_log = log_buffer.getvalue() | |
| if current_log != last_log_value: | |
| yield "", "", None, current_log | |
| last_log_value = current_log | |
| # Check for errors | |
| if download_error[0]: | |
| raise download_error[0] | |
| tensor = result_tensor[0] | |
| info = format_tensor_info(tensor) | |
| # Create tensor preview (first few elements) | |
| log_buffer.write(f"\n๐ Creating preview...\n") | |
| yield "", "", None, log_buffer.getvalue() | |
| flat = tensor.flatten() | |
| preview_size = min(100, flat.numel()) | |
| # Convert to float32 for FP8 types for display | |
| if str(tensor.dtype) in ['torch.float8_e4m3fn', 'torch.float8_e5m2']: | |
| preview = flat[:preview_size].to(torch.float32).tolist() | |
| else: | |
| preview = flat[:preview_size].tolist() | |
| # Format preview in multiple lines (10 values per line) | |
| # Adapt to different data types | |
| preview_lines = [] | |
| for i in range(0, len(preview), 10): | |
| line_values = preview[i:i+10] | |
| if tensor.dtype in [torch.float32, torch.float64, torch.float16, torch.bfloat16] or str(tensor.dtype) in ['torch.float8_e4m3fn', 'torch.float8_e5m2']: | |
| preview_lines.append(", ".join(f"{v:.6f}" for v in line_values)) | |
| elif tensor.dtype in [torch.int8, torch.int16, torch.int32, torch.int64, torch.uint8]: | |
| preview_lines.append(", ".join(f"{v}" for v in line_values)) | |
| elif tensor.dtype == torch.bool: | |
| preview_lines.append(", ".join(f"{v}" for v in line_values)) | |
| else: | |
| preview_lines.append(", ".join(str(v) for v in line_values)) | |
| preview_str = f"**First {preview_size} values:**\n```\n" + \ | |
| "\n".join(preview_lines) + "\n```" | |
| # if flat.numel() > preview_size: | |
| # preview_str += f"\n\n... and {flat.numel() - preview_size:,} more values" | |
| # Save tensor for download | |
| log_buffer.write(f"๐พ Saving tensor for download...\n") | |
| yield info, preview_str, None, log_buffer.getvalue() | |
| temp_dir = tempfile.gettempdir() | |
| safe_param_key = param_key.replace("/", "_").replace(".", "_") | |
| download_path = os.path.join(temp_dir, f"{safe_param_key}.pt") | |
| torch.save(tensor, download_path) | |
| log_buffer.write(f"โ Saved to: {download_path}\n") | |
| progress(1.0, desc="Complete!") | |
| log_buffer.write(f"\nโ All operations completed successfully!\n") | |
| yield info, preview_str, download_path, log_buffer.getvalue() | |
| except Exception as e: | |
| log_buffer.write(f"\nโ Error: {str(e)}\n") | |
| yield f"**Error:** {str(e)}", "", None, log_buffer.getvalue() | |
| def list_keys(model_id: str, session_token: str): | |
| """List all available keys for a model.""" | |
| if not model_id: | |
| return "Please provide a model ID." | |
| try: | |
| keys = get_available_keys(model_id, session_token) | |
| if not keys: | |
| return "No keys found or failed to load model." | |
| return "\n".join(keys) | |
| except Exception as e: | |
| return f"**Error:** {str(e)}" | |
| def clear_temp_files(): | |
| """Clear all .pt files from temp directory.""" | |
| try: | |
| temp_dir = tempfile.gettempdir() | |
| pt_files = glob.glob(os.path.join(temp_dir, "*.pt")) | |
| count = len(pt_files) | |
| deleted_files = [] | |
| for file in pt_files: | |
| try: | |
| os.remove(file) | |
| deleted_files.append(os.path.basename(file)) | |
| except Exception: | |
| pass | |
| if deleted_files: | |
| files_list = "\n".join(deleted_files) | |
| return f"โ Cleared {count} temporary file(s):\n\n{files_list}" | |
| else: | |
| return "โ No temporary files to clear" | |
| except Exception as e: | |
| return f"โ Error: {str(e)}" | |
| def clear_hf_cache(): | |
| """Clear Hugging Face cache directory.""" | |
| try: | |
| cache_info = scan_cache_dir() | |
| total_size = cache_info.size_on_disk | |
| total_repos = len(cache_info.repos) | |
| if total_repos == 0: | |
| return "โ Hugging Face cache is already empty" | |
| # Get cache directory and clear it | |
| cache_dir = os.path.expanduser("~/.cache/huggingface/hub") | |
| if os.path.exists(cache_dir): | |
| shutil.rmtree(cache_dir) | |
| os.makedirs(cache_dir) | |
| size_mb = total_size / (1024 * 1024) | |
| return f"โ Cleared Hugging Face cache: {total_repos} repo(s), {size_mb:.2f} MB freed" | |
| else: | |
| return "โ Hugging Face cache directory not found" | |
| except Exception as e: | |
| return f"โ Error: {str(e)}" | |
| def get_cache_info(): | |
| """Get size information about caches.""" | |
| try: | |
| # Temp files | |
| temp_dir = tempfile.gettempdir() | |
| pt_files = glob.glob(os.path.join(temp_dir, "*.pt")) | |
| temp_size = sum(os.path.getsize(f) | |
| for f in pt_files if os.path.exists(f)) | |
| temp_size_mb = temp_size / (1024 * 1024) | |
| info = f"๐ Cache Info:\n\n" | |
| info += f"โโโ Temp .pt files: {len(pt_files)} file(s), {temp_size_mb:.2f} MB โโโ\n" | |
| if pt_files: | |
| for file in pt_files: | |
| size = os.path.getsize(file) / (1024 * 1024) | |
| filename = os.path.basename(file) | |
| info += f" โข {filename} ({size:.2f} MB)\n" | |
| else: | |
| info += " (empty)\n" | |
| # HF cache | |
| info += f"\nโโโ Hugging Face Cache โโโ\n" | |
| try: | |
| cache_info = scan_cache_dir() | |
| hf_size_mb = cache_info.size_on_disk / (1024 * 1024) | |
| hf_repos = len(cache_info.repos) | |
| info += f"Total: {hf_repos} repo(s), {hf_size_mb:.2f} MB\n\n" | |
| if hf_repos > 0: | |
| for repo in cache_info.repos: | |
| repo_size = repo.size_on_disk / (1024 * 1024) | |
| info += f" ๐ฆ {repo.repo_id}\n" | |
| info += f" Size: {repo_size:.2f} MB, Revisions: {len(repo.revisions)}\n" | |
| info += f" Last accessed: {repo.last_accessed}\n" | |
| else: | |
| info += " (empty)\n" | |
| except Exception as e: | |
| info += f" Error reading HF cache: {str(e)}\n" | |
| info += f"\nโโโ Total: {temp_size_mb + (hf_size_mb if 'hf_size_mb' in locals() else 0):.2f} MB โโโ" | |
| return info | |
| except Exception as e: | |
| return f"โ Error: {str(e)}" | |
| # Create Gradio interface | |
| custom_css = """ | |
| * { | |
| font-family: Consolas, Monaco, 'Courier New', monospace !important; | |
| } | |
| .compact-row { | |
| gap: 0.5rem !important; | |
| } | |
| .tensor-preview pre { | |
| font-size: 0.75rem !important; | |
| line-height: 1.0 !important; | |
| } | |
| .compact-file { | |
| max-height: 80px !important; | |
| } | |
| .compact-file > div { | |
| min-height: 60px !important; | |
| } | |
| """ | |
| with gr.Blocks(title="Hugging Face Model Weight Inspector") as demo: | |
| gr.Markdown("# ๐ Hugging Face Model Weight Inspector") | |
| # Session state for per-user token | |
| session_token = gr.State(None) | |
| # HF Login section | |
| with gr.Accordion("๐ Hugging Face Login (Per-User Session) [โ ๏ธโ ๏ธโ ๏ธWIP, Do not useโ ๏ธโ ๏ธโ ๏ธ]", open=False): | |
| gr.Markdown(""" | |
| **Note:** This Space uses the default `HF_TOKEN` secret for all users if no session token is provided. | |
| Login below with your own token for per-user authentication (affects only your session). | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| hf_token_input = gr.Textbox( | |
| label="HF Token", | |
| placeholder="hf_...", | |
| type="password", | |
| ) | |
| with gr.Column(scale=2): | |
| initial_status = "โ Using default HF_TOKEN" if DEFAULT_HF_TOKEN else "Not logged in" | |
| hf_status = gr.Textbox( | |
| label="Status", | |
| value=initial_status, | |
| interactive=False, | |
| ) | |
| with gr.Row(): | |
| login_btn = gr.Button("๐ Login", variant="primary", scale=1) | |
| logout_btn = gr.Button("๐ช Logout", variant="secondary", scale=1) | |
| check_status_btn = gr.Button("โน๏ธ Check Status", variant="secondary", scale=1) | |
| login_output = gr.Textbox(label="Login Status", interactive=False, lines=2) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| model_id_input = gr.Textbox( | |
| label="Model ID", | |
| placeholder="e.g., meta-llama/Llama-2-7b-hf", | |
| value="Qwen/Qwen3-Coder-Next-FP8", | |
| ) | |
| param_key_input = gr.Textbox( | |
| label="Parameter Key", | |
| placeholder="e.g., model.norm.weight", | |
| value="model.norm.weight", | |
| ) | |
| with gr.Row(): | |
| list_keys_btn = gr.Button( | |
| "๐ List Keys", variant="secondary", scale=1) | |
| fetch_btn = gr.Button("๐ Fetch", variant="primary", scale=1) | |
| with gr.Column(scale=1): | |
| keys_output = gr.Textbox( | |
| label="Available Parameter Keys", | |
| lines=5, | |
| max_lines=8, | |
| ) | |
| with gr.Tabs(): | |
| with gr.Tab("Results"): | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| preview_output = gr.Markdown(label="Tensor Preview", elem_classes="tensor-preview") | |
| with gr.Column(scale=1): | |
| info_output = gr.Markdown(label="Tensor Info") | |
| download_output = gr.File(label="Download Tensor (.pt file)", elem_classes="compact-file") | |
| log_output = gr.Textbox( | |
| label="๐ Download Log", lines=1, interactive=False) | |
| with gr.Tab("Cache Management"): | |
| with gr.Row(): | |
| get_info_btn = gr.Button( | |
| "๐ Get Cache Info", variant="secondary", scale=1) | |
| clear_temp_btn = gr.Button( | |
| "๐๏ธ Clear Temp Folder", variant="secondary", scale=1) | |
| clear_hf_btn = gr.Button( | |
| "๐๏ธ Clear HF Cache", variant="secondary", scale=1) | |
| clear_status = gr.Textbox( | |
| label="Status", interactive=False, lines=6) | |
| # Event handlers | |
| login_btn.click( | |
| fn=hf_login, | |
| inputs=[hf_token_input, session_token], | |
| outputs=[login_output, hf_status, session_token], | |
| ) | |
| logout_btn.click( | |
| fn=hf_logout, | |
| inputs=[session_token], | |
| outputs=[login_output, hf_status, session_token], | |
| ) | |
| check_status_btn.click( | |
| fn=check_hf_status, | |
| inputs=[session_token], | |
| outputs=[login_output, hf_status, session_token], | |
| ) | |
| list_keys_btn.click( | |
| fn=list_keys, | |
| inputs=[model_id_input, session_token], | |
| outputs=[keys_output], | |
| ) | |
| fetch_btn.click( | |
| fn=fetch_param, | |
| inputs=[model_id_input, param_key_input, session_token], | |
| outputs=[info_output, preview_output, download_output, log_output], | |
| ) | |
| clear_temp_btn.click( | |
| fn=clear_temp_files, | |
| inputs=[], | |
| outputs=[clear_status], | |
| ) | |
| clear_hf_btn.click( | |
| fn=clear_hf_cache, | |
| inputs=[], | |
| outputs=[clear_status], | |
| ) | |
| get_info_btn.click( | |
| fn=get_cache_info, | |
| inputs=[], | |
| outputs=[clear_status], | |
| ) | |
| # Auto-check status on load | |
| demo.load( | |
| fn=check_hf_status, | |
| inputs=[session_token], | |
| outputs=[login_output, hf_status, session_token], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", css=custom_css) | |