File size: 5,269 Bytes
b3a3b15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import os
import re
import torch
from collections import OrderedDict

def create_mlp_router_state_dict(router_files_dir):
    """
    Loads all mlp_router weight files from the specified directory and creates a router_state_dict
    with keys formatted as 'transformer.layers.{layer_num}.mlp_router.{param_name}'.

    Args:
        router_files_dir (str): Path to the directory containing mlp_router_*.pt files.

    Returns:
        OrderedDict: A state dictionary suitable for loading into a transformer model.
    """
    # Regular expression to extract layer number from filename
    router_file_pattern = re.compile(r'mlp_router_(\d+)-[\d.]+-[\d.]+-[\d.]+\.pt$')

    router_state_dict = OrderedDict()

    # List all files in the directory
    try:
        all_files = os.listdir(router_files_dir)
    except FileNotFoundError:
        print(f"Error: Directory '{router_files_dir}' does not exist.")
        return None

    # Filter files matching the pattern
    router_files = [f for f in all_files if router_file_pattern.match(f)]

    if not router_files:
        print(f"No router files found in directory '{router_files_dir}'.")
        return None

    for file_name in sorted(router_files, key=lambda x: int(router_file_pattern.match(x).group(1))):
        match = router_file_pattern.match(file_name)
        if not match:
            print(f"Skipping file '{file_name}' as it does not match the pattern.")
            continue

        layer_num = int(match.group(1))
        file_path = os.path.join(router_files_dir, file_name)

        try:
            # Load the router's state dict
            router_weights = torch.load(file_path, map_location='cpu')
            if not isinstance(router_weights, dict):
                print(f"Warning: The file '{file_path}' does not contain a state dictionary. Skipping.")
                continue
        except Exception as e:
            print(f"Error loading '{file_path}': {e}")
            continue

        # Iterate through each parameter in the router's state dict
        for param_name, param_tensor in router_weights.items():
            # Construct the new key
            new_key = f"transformer.layers.{layer_num}.mlp_router.{param_name}"
            router_state_dict[new_key] = param_tensor

        # print(f"Loaded router for layer {layer_num} from '{file_name}'.")

    print(f"Total routers loaded: {len(router_state_dict) // 2}")  # Assuming 4 params per router (weight & bias for 2 layers)
    return router_state_dict


def create_attn_router_state_dict(router_files_dir):
    """
    Loads all attn_router weight files from the specified directory and creates a router_state_dict
    with keys formatted as 'transformer.layers.{layer_num}.mha_router.{param_name}'.

    Args:
        router_files_dir (str): Path to the directory containing attn_router_*.pt files.

    Returns:
        OrderedDict: A state dictionary suitable for loading into a transformer model.
    """
    # Regular expression to extract layer number from filename
    # Pattern: attn_router_{layer_num}-{value1}-{value2}.pt
    router_file_pattern = re.compile(r'attn_router_(\d+)-[\d.]+-[\d.]+\.pt$')

    router_state_dict = OrderedDict()

    # List all files in the directory
    try:
        all_files = os.listdir(router_files_dir)
    except FileNotFoundError:
        print(f"Error: Directory '{router_files_dir}' does not exist.")
        return None

    # Filter files matching the pattern
    router_files = [f for f in all_files if router_file_pattern.match(f)]

    if not router_files:
        print(f"No attn_router files found in directory '{router_files_dir}'.")
        return None

    # To handle potential duplicates, keep track of loaded layer numbers
    loaded_layers = set()

    for file_name in sorted(router_files, key=lambda x: int(router_file_pattern.match(x).group(1))):
        match = router_file_pattern.match(file_name)
        if not match:
            print(f"Skipping file '{file_name}' as it does not match the pattern.")
            continue

        layer_num = int(match.group(1))
        if layer_num in loaded_layers:
            print(f"Warning: Multiple router files found for layer {layer_num}. Skipping '{file_name}'.")
            continue  # Skip duplicate layers

        file_path = os.path.join(router_files_dir, file_name)

        try:
            # Load the router's state dict
            router_weights = torch.load(file_path, map_location='cpu')
            if not isinstance(router_weights, dict):
                print(f"Warning: The file '{file_path}' does not contain a state dictionary. Skipping.")
                continue
        except Exception as e:
            print(f"Error loading '{file_path}': {e}")
            continue

        # Iterate through each parameter in the router's state dict
        for param_name, param_tensor in router_weights.items():
            # Construct the new key
            new_key = f"transformer.layers.{layer_num}.mha_router.{param_name}"
            router_state_dict[new_key] = param_tensor

        loaded_layers.add(layer_num)
        # print(f"Loaded MHA router for layer {layer_num} from '{file_name}'.")

    print(f"Total MHA routers loaded: {len(loaded_layers)}")
    return router_state_dict