File size: 2,770 Bytes
e39ff3a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import argparse
import json
from pathlib import Path
from safetensors import safe_open


def check_model_shape(model_path: str):
    """Inspects a model's config and weights to determine its MLP structure."""
    model_path = Path(model_path)
    config_path = model_path / "config.json"
    weights_path = model_path / "model.safetensors"

    if not config_path.exists():
        print(f"Error: config.json not found in {model_path}")
        return

    if not weights_path.exists():
        print(f"Error: model.safetensors not found in {model_path}")
        return

    print(f"--- Checking model shape in {model_path} ---")

    # 1. Inspect config.json
    with open(config_path, "r") as f:
        config = json.load(f)

    has_dual_mlp_config = config.get("intermediate_size_mlp", 0) > 0
    print(f"Config has 'intermediate_size_mlp': {has_dual_mlp_config}")

    # 2. Inspect weight keys from model.safetensors
    has_dual_mlp_weights = False
    try:
        with safe_open(weights_path, framework="mlx") as f:
            weight_keys = f.keys()
            # A simple heuristic: check for weight keys that are not part of the standard SwiGLU MLP.
            # This is not foolproof as names can vary, but it's a good indicator.
            for key in weight_keys:
                if (
                    "mlp" in key
                    and "gate_proj" not in key
                    and "up_proj" not in key
                    and "down_proj" not in key
                ):
                    print(f"Found potential dual-branch weight: {key}")
                    has_dual_mlp_weights = True
                    break
    except Exception as e:
        print(f"Could not read weights from model.safetensors: {e}")
        return

    print(f"Found potential dual-branch MLP weights: {has_dual_mlp_weights}")

    # 3. Report conclusion
    print("\n--- Conclusion ---")
    if has_dual_mlp_config and has_dual_mlp_weights:
        print("✅ The model appears to be a DUAL-BRANCH MLP variant.")
    elif has_dual_mlp_config and not has_dual_mlp_weights:
        print(
            "⚠️ The model configuration suggests a dual-branch MLP, but no corresponding weights were found."
        )
        print("   It will likely run as a SINGLE-BRANCH model.")
    else:
        print("✅ The model appears to be a SINGLE-BRANCH MLP variant.")
    print("--------------------\n")


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Check the MLP shape of a model variant."
    )
    parser.add_argument(
        "model_path",
        type=str,
        nargs="?",
        default=".",
        help="Path to the model directory to check.",
    )
    args = parser.parse_args()

    check_model_shape(args.model_path)