File size: 4,079 Bytes
e39ff3a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import sys
from pathlib import Path

# Add the current directory to the python path to import model.py
sys.path.append(str(Path.cwd()))

from model import load_model
from mlx.utils import tree_flatten


def run_diagnostic_checks():
    """
    Performs the verification checks outlined in the review.
    """
    print("--- Running Diagnostic Checks ---")

    # 1. Load model and check for errors
    try:
        model = load_model(".")
        print("Successfully loaded model definition.")
    except Exception as e:
        print(f"Error loading model: {e}")
        return

    # 2. Print total parameter count
    try:
        params = model.parameters()
        num_params = sum(p.size for _, p in tree_flatten(params))
        print(f"Total number of parameters: {num_params / 1e6:.2f}M")
    except Exception as e:
        print(f"Error calculating parameters: {e}")

    # 3. Verify MLP weight shapes
    print("--- Verifying MLP Weight Shapes ---")
    try:
        first_block = model.layers[0]
        args = model.args
        print(f"use_dual_mlp detected: {args.use_dual_mlp}")

        if args.use_dual_mlp:
            g_up_shape = first_block.feed_forward.g_up.weight.shape
            p_up_shape = first_block.feed_forward.p_up.weight.shape
            print(f"Gated MLP branch (g_up) weight shape: {g_up_shape}")
            print(f"Plain MLP branch (p_up) weight shape: {p_up_shape}")
            assert g_up_shape == (args.intermediate_size, args.hidden_size)
            assert p_up_shape == (args.intermediate_size_mlp, args.hidden_size)
            print("DualMLP weight shapes are correct.")
        else:
            gate_proj_shape = first_block.feed_forward.gate_proj.weight.shape
            up_proj_shape = first_block.feed_forward.up_proj.weight.shape
            print(f"SwiGLUMLP gate_proj weight shape: {gate_proj_shape}")
            print(f"SwiGLUMLP up_proj weight shape: {up_proj_shape}")
            assert gate_proj_shape == (args.intermediate_size_mlp, args.hidden_size)
            assert up_proj_shape == (args.intermediate_size_mlp, args.hidden_size)
            print("SwiGLUMLP weight shapes are correct.")

    except AttributeError as e:
        print(
            f"Error accessing MLP weights. It seems the structure is not as expected: {e}"
        )
    except AssertionError:
        print("Error: MLP weight shapes do not match the configuration.")
    except Exception as e:
        print(f"An unexpected error occurred while verifying shapes: {e}")

    # 4. Verify Embedding shape
    print("--- Verifying Embedding Shape ---")
    try:
        embedding_shape = model.tok_embeddings.weight.shape
        print(f"Embedding weight shape: {embedding_shape}")

        args = model.args
        print(f"Expected embedding shape: ({args.vocab_size}, {args.hidden_size})")

        assert embedding_shape == (args.vocab_size, args.hidden_size)
        print("Embedding shape is correct.")
    except Exception as e:
        print(f"An unexpected error occurred while verifying embedding shape: {e}")

    print("--- Sanity Checking Loaded Weights ---")
    try:
        # Check expected attribute exists based on architecture
        if model.args.use_dual_mlp:
            _ = model.layers[0].feed_forward.g_gate.weight
            _ = model.layers[0].feed_forward.g_up.weight
            _ = model.layers[0].feed_forward.g_down.weight
            _ = model.layers[0].feed_forward.p_up.weight
            _ = model.layers[0].feed_forward.p_down.weight
            print("Found dual-branch MLP weights in the model.")
        else:
            _ = model.layers[0].feed_forward.gate_proj.weight
            _ = model.layers[0].feed_forward.up_proj.weight
            _ = model.layers[0].feed_forward.down_proj.weight
            print("Found SwiGLU MLP weights in the model.")
        print("Weight presence sanity check passed.")
    except Exception as e:
        print(f"An error occurred during sanity check: {e}")

    print("--- Diagnostic Checks Complete ---")


if __name__ == "__main__":
    run_diagnostic_checks()