Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from safetensors.torch import load_file | |
| from model_loader import get_top_layers, load_model_summary, load_config | |
| import tempfile | |
| import os | |
| import requests | |
| import json | |
| import logging | |
| import traceback | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| def inspect_model(model_id, config_file=None): | |
| logger.info(f"Processing model ID: {model_id}") | |
| if not model_id or '/' not in model_id: | |
| return "Please provide a valid model ID in the format username/modelname", "No config loaded." | |
| username, modelname = model_id.split('/', 1) | |
| logger.info(f"Username: {username}, Model name: {modelname}") | |
| model_summary = "Processing..." | |
| config_str = "No config loaded." | |
| try: | |
| model_filename = "model.safetensors" | |
| if "/" in modelname: | |
| parts = modelname.split("/") | |
| modelname = parts[0] | |
| if len(parts) > 1 and parts[1].strip(): | |
| model_filename = parts[1] | |
| model_url = f"https://huggingface.co/{username}/{modelname}/resolve/main/{model_filename}" | |
| logger.info(f"Attempting to download model from: {model_url}") | |
| response = requests.get(model_url, stream=True) | |
| response.raise_for_status() | |
| total_size = int(response.headers.get('content-length', 0)) | |
| logger.info(f"Model file size: {total_size/1024/1024:.2f} MB") | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".safetensors") as tmp: | |
| if total_size > 0: | |
| downloaded = 0 | |
| for chunk in response.iter_content(chunk_size=8192): | |
| if chunk: | |
| tmp.write(chunk) | |
| downloaded += len(chunk) | |
| if downloaded % (100 * 1024 * 1024) == 0: | |
| logger.info(f"Downloaded {downloaded/1024/1024:.2f} MB / {total_size/1024/1024:.2f} MB") | |
| else: | |
| tmp.write(response.content) | |
| model_path = tmp.name | |
| logger.info(f"Model downloaded to temporary file: {model_path}") | |
| logger.info("Loading model summary...") | |
| summary = load_model_summary(model_path) | |
| logger.info(f"Loading state dictionary... (This may take time for large models)") | |
| state_dict = load_file(model_path) | |
| logger.info("Analyzing top layers...") | |
| top_layers = get_top_layers(state_dict, summary["total_params"]) | |
| top_layers_str = "\n".join([ | |
| f"{layer['name']}: shape={layer['shape']}, params={layer['params']:,} ({layer['percent']}%)" | |
| for layer in top_layers | |
| ]) | |
| config_data = {} | |
| if config_file is not None: | |
| logger.info("Processing uploaded config file") | |
| with tempfile.NamedTemporaryFile(delete=False, suffix=".json") as tmp_cfg: | |
| tmp_cfg.write(config_file.read()) | |
| config_path = tmp_cfg.name | |
| logger.info(f"Loading config from uploaded file: {config_path}") | |
| config_data = load_config(config_path) | |
| os.unlink(config_path) | |
| else: | |
| config_url = f"https://huggingface.co/{username}/{modelname}/resolve/main/config.json" | |
| logger.info(f"Attempting to download config from: {config_url}") | |
| try: | |
| config_response = requests.get(config_url) | |
| config_response.raise_for_status() | |
| config_data = json.loads(config_response.content) | |
| logger.info("Config file downloaded and parsed successfully") | |
| except Exception as e: | |
| logger.warning(f"Could not download or parse config file: {str(e)}") | |
| config_str = "\n".join([f"{k}: {v}" for k, v in config_data.items()]) if config_data else "No config loaded." | |
| # Clean up temporary file | |
| logger.info(f"Cleaning up temporary file: {model_path}") | |
| os.unlink(model_path) | |
| model_summary = ( | |
| f" Total tensors: {summary['num_tensors']}\n" | |
| f" Total parameters: {summary['total_params']:,}\n\n" | |
| f" Top Layers:\n{top_layers_str}" | |
| ) | |
| logger.info("Model inspection completed successfully") | |
| return model_summary, config_str | |
| except Exception as e: | |
| error_msg = f"Error: {str(e)}\n\nTraceback:\n{traceback.format_exc()}" | |
| logger.error(error_msg) | |
| return error_msg, "No config loaded." | |
| with gr.Blocks(title="Model Inspector") as demo: | |
| gr.Markdown("# Model Inspector") | |
| gr.Markdown("Enter a HuggingFace model ID in the format username/modelname to analyze its structure, parameter count, and configuration.") | |
| gr.Markdown("You can specify a custom safetensors file by using username/modelname/filename.safetensors") | |
| with gr.Row(): | |
| with gr.Column(): | |
| model_id = gr.Textbox( | |
| label="Model ID from HuggingFace", | |
| placeholder="username/modelname", | |
| lines=1 | |
| ) | |
| config_file = gr.File( | |
| label="Upload config.json (optional)", | |
| type="binary" | |
| ) | |
| submit_btn = gr.Button("Analyze Model", variant="primary") | |
| status = gr.Markdown("Ready. Enter a model ID and click 'Analyze Model'") | |
| with gr.Column(): | |
| model_summary = gr.Textbox(label="Model Summary", lines=15) | |
| config_output = gr.Textbox(label="Config", lines=10) | |
| def update_status(text): | |
| return text | |
| def on_submit(model_id, config_file): | |
| status_update = update_status("Processing... This may take some time for large models.") | |
| yield status_update, None, None | |
| try: | |
| summary, config = inspect_model(model_id, config_file) | |
| status_update = update_status("Analysis complete!") | |
| yield status_update, summary, config | |
| except Exception as e: | |
| error_msg = f"Error during analysis: {str(e)}" | |
| status_update = update_status(f"❌ {error_msg}") | |
| yield status_update, error_msg, "No config loaded." | |
| submit_btn.click( | |
| fn=on_submit, | |
| inputs=[model_id, config_file], | |
| outputs=[status, model_summary, config_output], | |
| show_progress=True | |
| ) | |
| demo.launch() |