model-inspect-tool / model_loader.py
Malum0x's picture
Initial Commit: hf space
d638adc
import torch
from safetensors.torch import load_file
import pandas as pd
import json
import os
def load_model_summary(path: str):
state_dict = load_file(path)
print(f'loaded {len(state_dict)} tensors from {path}')
total_params = sum(tensor.numel() for tensor in state_dict.values())
print(f"total parameters: {total_params} ")
print("sample layers: ")
for key in list(state_dict.keys())[:5]:
print(f"key {state_dict[key].shape}")
for key in state_dict:
if "embedding" in key:
print(f"[Embedding] {key}: {state_dict[key].shape}")
return {
"total_params": total_params,
"num_tensors": len(state_dict),
"example_keys": list(state_dict.keys())[:5]
}
def get_top_layers(state_dict: dict, total_params: int, top_k: int = 5) -> list:
sorted_layers = sorted(state_dict.items(), key=lambda x: x[1].numel(), reverse=True)
top = []
for name, tensor in sorted_layers[:top_k]:
percent = (tensor.numel() / total_params) * 100
top.append({
"name": name,
"shape": list(tensor.shape),
"params": tensor.numel(),
"percent": round(percent, 2)
})
return top
def export_top_layers_to_csv(top_layers: list, output_path: str):
df = pd.DataFrame(top_layers)
df.to_csv(output_path, index=False)
print(f"Top layers exported to {output_path}")
def load_config(path: str) -> dict:
if not os.path.exists(path):
print(f"Config file not found at {path}")
return {}
try:
with open(path, "r") as f:
config = json.load(f)
print(f"Loaded config from {path}")
return config
except json.JSONDecodeError:
print("Failed to parse config.json - invalid JSON format")
return {}
def print_config(config: dict):
if not config:
print("⚠️ No config data to display.")
return
print("\n📘 Config hyperparameters:\n")
max_key_len = max(len(str(k)) for k in config.keys())
for k, v in config.items():
print(f"{k.ljust(max_key_len)} : {v}")
if __name__ == "__main__":
path = "./model/model.safetensors"
load_model_summary(path)
info=load_model_summary(path)
state_dict = load_file(path)
top_layers = get_top_layers(state_dict, total_params=info["total_params"])
print("top 5: ")
for layer in top_layers:
print(f"{layer['name']:60}, {layer['shape']}, ({layer['params']:,} params), {layer['percent']}%)")