File size: 6,429 Bytes
d638adc
 
 
 
 
b7610c9
 
617e5e5
 
 
 
 
d638adc
b7610c9
617e5e5
 
b7610c9
 
 
 
617e5e5
b7610c9
617e5e5
 
b7610c9
 
617e5e5
 
 
 
 
 
 
 
 
 
 
b7610c9
 
617e5e5
 
 
b7610c9
617e5e5
 
 
 
 
 
 
 
 
 
 
b7610c9
617e5e5
b7610c9
617e5e5
b7610c9
617e5e5
b7610c9
617e5e5
b7610c9
 
 
 
 
 
 
 
617e5e5
b7610c9
 
617e5e5
 
 
 
 
b7610c9
 
617e5e5
b7610c9
 
 
 
617e5e5
b7610c9
617e5e5
b7610c9
 
 
617e5e5
 
b7610c9
 
617e5e5
b7610c9
 
617e5e5
b7610c9
617e5e5
 
 
b7610c9
 
617e5e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d638adc
617e5e5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import gradio as gr
from safetensors.torch import load_file
from model_loader import get_top_layers, load_model_summary, load_config
import tempfile
import os
import requests
import json
import logging
import traceback

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def inspect_model(model_id, config_file=None):
    logger.info(f"Processing model ID: {model_id}")
    
    if not model_id or '/' not in model_id:
        return "Please provide a valid model ID in the format username/modelname", "No config loaded."
    
    username, modelname = model_id.split('/', 1)
    logger.info(f"Username: {username}, Model name: {modelname}")
    
    model_summary = "Processing..."
    config_str = "No config loaded."
    
    try:
        model_filename = "model.safetensors"
        if "/" in modelname:
            parts = modelname.split("/")
            modelname = parts[0]
            if len(parts) > 1 and parts[1].strip():
                model_filename = parts[1]
        
        model_url = f"https://huggingface.co/{username}/{modelname}/resolve/main/{model_filename}"
        logger.info(f"Attempting to download model from: {model_url}")
        
        response = requests.get(model_url, stream=True)
        response.raise_for_status()  
        
        total_size = int(response.headers.get('content-length', 0))
        logger.info(f"Model file size: {total_size/1024/1024:.2f} MB")
        
        with tempfile.NamedTemporaryFile(delete=False, suffix=".safetensors") as tmp:
            if total_size > 0:
                downloaded = 0
                for chunk in response.iter_content(chunk_size=8192):
                    if chunk:
                        tmp.write(chunk)
                        downloaded += len(chunk)
                        if downloaded % (100 * 1024 * 1024) == 0: 
                            logger.info(f"Downloaded {downloaded/1024/1024:.2f} MB / {total_size/1024/1024:.2f} MB")
            else:
                tmp.write(response.content)
            
            model_path = tmp.name
            logger.info(f"Model downloaded to temporary file: {model_path}")
        
        logger.info("Loading model summary...")
        summary = load_model_summary(model_path)
        logger.info(f"Loading state dictionary... (This may take time for large models)")
        state_dict = load_file(model_path)
        logger.info("Analyzing top layers...")
        top_layers = get_top_layers(state_dict, summary["total_params"])
        top_layers_str = "\n".join([
            f"{layer['name']}: shape={layer['shape']}, params={layer['params']:,} ({layer['percent']}%)"
            for layer in top_layers
        ])
        
        config_data = {}
        if config_file is not None:
            logger.info("Processing uploaded config file")
            with tempfile.NamedTemporaryFile(delete=False, suffix=".json") as tmp_cfg:
                tmp_cfg.write(config_file.read())
                config_path = tmp_cfg.name
                
            logger.info(f"Loading config from uploaded file: {config_path}")
            config_data = load_config(config_path)
            os.unlink(config_path)  
        else:
            config_url = f"https://huggingface.co/{username}/{modelname}/resolve/main/config.json"
            logger.info(f"Attempting to download config from: {config_url}")
            try:
                config_response = requests.get(config_url)
                config_response.raise_for_status()
                config_data = json.loads(config_response.content)
                logger.info("Config file downloaded and parsed successfully")
            except Exception as e:
                logger.warning(f"Could not download or parse config file: {str(e)}")
        
        config_str = "\n".join([f"{k}: {v}" for k, v in config_data.items()]) if config_data else "No config loaded."
        
        # Clean up temporary file
        logger.info(f"Cleaning up temporary file: {model_path}")
        os.unlink(model_path)
        
        model_summary = (
            f" Total tensors: {summary['num_tensors']}\n"
            f" Total parameters: {summary['total_params']:,}\n\n"
            f" Top Layers:\n{top_layers_str}"
        )
        logger.info("Model inspection completed successfully")
        
        return model_summary, config_str
    
    except Exception as e:
        error_msg = f"Error: {str(e)}\n\nTraceback:\n{traceback.format_exc()}"
        logger.error(error_msg)
        return error_msg, "No config loaded."

with gr.Blocks(title="Model Inspector") as demo:
    gr.Markdown("# Model Inspector")
    gr.Markdown("Enter a HuggingFace model ID in the format username/modelname to analyze its structure, parameter count, and configuration.")
    gr.Markdown("You can specify a custom safetensors file by using username/modelname/filename.safetensors")
    
    with gr.Row():
        with gr.Column():
            model_id = gr.Textbox(
                label="Model ID from HuggingFace", 
                placeholder="username/modelname", 
                lines=1
            )
            config_file = gr.File(
                label="Upload config.json (optional)", 
                type="binary"
            )
            submit_btn = gr.Button("Analyze Model", variant="primary")
            status = gr.Markdown("Ready. Enter a model ID and click 'Analyze Model'")
        
        with gr.Column():
            model_summary = gr.Textbox(label="Model Summary", lines=15)
            config_output = gr.Textbox(label="Config", lines=10)
    
    def update_status(text):
        return text
    
    def on_submit(model_id, config_file):
        status_update = update_status("Processing... This may take some time for large models.")
        yield status_update, None, None
        
        try:
            summary, config = inspect_model(model_id, config_file)
            status_update = update_status("Analysis complete!")
            yield status_update, summary, config
        except Exception as e:
            error_msg = f"Error during analysis: {str(e)}"
            status_update = update_status(f"❌ {error_msg}")
            yield status_update, error_msg, "No config loaded."
    
    submit_btn.click(
        fn=on_submit,
        inputs=[model_id, config_file],
        outputs=[status, model_summary, config_output],
        show_progress=True
    )

demo.launch()