File size: 2,770 Bytes
e39ff3a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
import argparse
import json
from pathlib import Path
from safetensors import safe_open
def check_model_shape(model_path: str):
"""Inspects a model's config and weights to determine its MLP structure."""
model_path = Path(model_path)
config_path = model_path / "config.json"
weights_path = model_path / "model.safetensors"
if not config_path.exists():
print(f"Error: config.json not found in {model_path}")
return
if not weights_path.exists():
print(f"Error: model.safetensors not found in {model_path}")
return
print(f"--- Checking model shape in {model_path} ---")
# 1. Inspect config.json
with open(config_path, "r") as f:
config = json.load(f)
has_dual_mlp_config = config.get("intermediate_size_mlp", 0) > 0
print(f"Config has 'intermediate_size_mlp': {has_dual_mlp_config}")
# 2. Inspect weight keys from model.safetensors
has_dual_mlp_weights = False
try:
with safe_open(weights_path, framework="mlx") as f:
weight_keys = f.keys()
# A simple heuristic: check for weight keys that are not part of the standard SwiGLU MLP.
# This is not foolproof as names can vary, but it's a good indicator.
for key in weight_keys:
if (
"mlp" in key
and "gate_proj" not in key
and "up_proj" not in key
and "down_proj" not in key
):
print(f"Found potential dual-branch weight: {key}")
has_dual_mlp_weights = True
break
except Exception as e:
print(f"Could not read weights from model.safetensors: {e}")
return
print(f"Found potential dual-branch MLP weights: {has_dual_mlp_weights}")
# 3. Report conclusion
print("\n--- Conclusion ---")
if has_dual_mlp_config and has_dual_mlp_weights:
print("✅ The model appears to be a DUAL-BRANCH MLP variant.")
elif has_dual_mlp_config and not has_dual_mlp_weights:
print(
"⚠️ The model configuration suggests a dual-branch MLP, but no corresponding weights were found."
)
print(" It will likely run as a SINGLE-BRANCH model.")
else:
print("✅ The model appears to be a SINGLE-BRANCH MLP variant.")
print("--------------------\n")
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Check the MLP shape of a model variant."
)
parser.add_argument(
"model_path",
type=str,
nargs="?",
default=".",
help="Path to the model directory to check.",
)
args = parser.parse_args()
check_model_shape(args.model_path)
|