Spaces:
Sleeping
Sleeping
File size: 6,429 Bytes
d638adc b7610c9 617e5e5 d638adc b7610c9 617e5e5 b7610c9 617e5e5 b7610c9 617e5e5 b7610c9 617e5e5 b7610c9 617e5e5 b7610c9 617e5e5 b7610c9 617e5e5 b7610c9 617e5e5 b7610c9 617e5e5 b7610c9 617e5e5 b7610c9 617e5e5 b7610c9 617e5e5 b7610c9 617e5e5 b7610c9 617e5e5 b7610c9 617e5e5 b7610c9 617e5e5 b7610c9 617e5e5 b7610c9 617e5e5 b7610c9 617e5e5 b7610c9 617e5e5 d638adc 617e5e5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 | import gradio as gr
from safetensors.torch import load_file
from model_loader import get_top_layers, load_model_summary, load_config
import tempfile
import os
import requests
import json
import logging
import traceback
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
def inspect_model(model_id, config_file=None):
logger.info(f"Processing model ID: {model_id}")
if not model_id or '/' not in model_id:
return "Please provide a valid model ID in the format username/modelname", "No config loaded."
username, modelname = model_id.split('/', 1)
logger.info(f"Username: {username}, Model name: {modelname}")
model_summary = "Processing..."
config_str = "No config loaded."
try:
model_filename = "model.safetensors"
if "/" in modelname:
parts = modelname.split("/")
modelname = parts[0]
if len(parts) > 1 and parts[1].strip():
model_filename = parts[1]
model_url = f"https://huggingface.co/{username}/{modelname}/resolve/main/{model_filename}"
logger.info(f"Attempting to download model from: {model_url}")
response = requests.get(model_url, stream=True)
response.raise_for_status()
total_size = int(response.headers.get('content-length', 0))
logger.info(f"Model file size: {total_size/1024/1024:.2f} MB")
with tempfile.NamedTemporaryFile(delete=False, suffix=".safetensors") as tmp:
if total_size > 0:
downloaded = 0
for chunk in response.iter_content(chunk_size=8192):
if chunk:
tmp.write(chunk)
downloaded += len(chunk)
if downloaded % (100 * 1024 * 1024) == 0:
logger.info(f"Downloaded {downloaded/1024/1024:.2f} MB / {total_size/1024/1024:.2f} MB")
else:
tmp.write(response.content)
model_path = tmp.name
logger.info(f"Model downloaded to temporary file: {model_path}")
logger.info("Loading model summary...")
summary = load_model_summary(model_path)
logger.info(f"Loading state dictionary... (This may take time for large models)")
state_dict = load_file(model_path)
logger.info("Analyzing top layers...")
top_layers = get_top_layers(state_dict, summary["total_params"])
top_layers_str = "\n".join([
f"{layer['name']}: shape={layer['shape']}, params={layer['params']:,} ({layer['percent']}%)"
for layer in top_layers
])
config_data = {}
if config_file is not None:
logger.info("Processing uploaded config file")
with tempfile.NamedTemporaryFile(delete=False, suffix=".json") as tmp_cfg:
tmp_cfg.write(config_file.read())
config_path = tmp_cfg.name
logger.info(f"Loading config from uploaded file: {config_path}")
config_data = load_config(config_path)
os.unlink(config_path)
else:
config_url = f"https://huggingface.co/{username}/{modelname}/resolve/main/config.json"
logger.info(f"Attempting to download config from: {config_url}")
try:
config_response = requests.get(config_url)
config_response.raise_for_status()
config_data = json.loads(config_response.content)
logger.info("Config file downloaded and parsed successfully")
except Exception as e:
logger.warning(f"Could not download or parse config file: {str(e)}")
config_str = "\n".join([f"{k}: {v}" for k, v in config_data.items()]) if config_data else "No config loaded."
# Clean up temporary file
logger.info(f"Cleaning up temporary file: {model_path}")
os.unlink(model_path)
model_summary = (
f" Total tensors: {summary['num_tensors']}\n"
f" Total parameters: {summary['total_params']:,}\n\n"
f" Top Layers:\n{top_layers_str}"
)
logger.info("Model inspection completed successfully")
return model_summary, config_str
except Exception as e:
error_msg = f"Error: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
logger.error(error_msg)
return error_msg, "No config loaded."
with gr.Blocks(title="Model Inspector") as demo:
gr.Markdown("# Model Inspector")
gr.Markdown("Enter a HuggingFace model ID in the format username/modelname to analyze its structure, parameter count, and configuration.")
gr.Markdown("You can specify a custom safetensors file by using username/modelname/filename.safetensors")
with gr.Row():
with gr.Column():
model_id = gr.Textbox(
label="Model ID from HuggingFace",
placeholder="username/modelname",
lines=1
)
config_file = gr.File(
label="Upload config.json (optional)",
type="binary"
)
submit_btn = gr.Button("Analyze Model", variant="primary")
status = gr.Markdown("Ready. Enter a model ID and click 'Analyze Model'")
with gr.Column():
model_summary = gr.Textbox(label="Model Summary", lines=15)
config_output = gr.Textbox(label="Config", lines=10)
def update_status(text):
return text
def on_submit(model_id, config_file):
status_update = update_status("Processing... This may take some time for large models.")
yield status_update, None, None
try:
summary, config = inspect_model(model_id, config_file)
status_update = update_status("Analysis complete!")
yield status_update, summary, config
except Exception as e:
error_msg = f"Error during analysis: {str(e)}"
status_update = update_status(f"❌ {error_msg}")
yield status_update, error_msg, "No config loaded."
submit_btn.click(
fn=on_submit,
inputs=[model_id, config_file],
outputs=[status, model_summary, config_output],
show_progress=True
)
demo.launch() |