File size: 5,269 Bytes
b3a3b15 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
import os
import re
import torch
from collections import OrderedDict
def create_mlp_router_state_dict(router_files_dir):
"""
Loads all mlp_router weight files from the specified directory and creates a router_state_dict
with keys formatted as 'transformer.layers.{layer_num}.mlp_router.{param_name}'.
Args:
router_files_dir (str): Path to the directory containing mlp_router_*.pt files.
Returns:
OrderedDict: A state dictionary suitable for loading into a transformer model.
"""
# Regular expression to extract layer number from filename
router_file_pattern = re.compile(r'mlp_router_(\d+)-[\d.]+-[\d.]+-[\d.]+\.pt$')
router_state_dict = OrderedDict()
# List all files in the directory
try:
all_files = os.listdir(router_files_dir)
except FileNotFoundError:
print(f"Error: Directory '{router_files_dir}' does not exist.")
return None
# Filter files matching the pattern
router_files = [f for f in all_files if router_file_pattern.match(f)]
if not router_files:
print(f"No router files found in directory '{router_files_dir}'.")
return None
for file_name in sorted(router_files, key=lambda x: int(router_file_pattern.match(x).group(1))):
match = router_file_pattern.match(file_name)
if not match:
print(f"Skipping file '{file_name}' as it does not match the pattern.")
continue
layer_num = int(match.group(1))
file_path = os.path.join(router_files_dir, file_name)
try:
# Load the router's state dict
router_weights = torch.load(file_path, map_location='cpu')
if not isinstance(router_weights, dict):
print(f"Warning: The file '{file_path}' does not contain a state dictionary. Skipping.")
continue
except Exception as e:
print(f"Error loading '{file_path}': {e}")
continue
# Iterate through each parameter in the router's state dict
for param_name, param_tensor in router_weights.items():
# Construct the new key
new_key = f"transformer.layers.{layer_num}.mlp_router.{param_name}"
router_state_dict[new_key] = param_tensor
# print(f"Loaded router for layer {layer_num} from '{file_name}'.")
print(f"Total routers loaded: {len(router_state_dict) // 2}") # Assuming 4 params per router (weight & bias for 2 layers)
return router_state_dict
def create_attn_router_state_dict(router_files_dir):
"""
Loads all attn_router weight files from the specified directory and creates a router_state_dict
with keys formatted as 'transformer.layers.{layer_num}.mha_router.{param_name}'.
Args:
router_files_dir (str): Path to the directory containing attn_router_*.pt files.
Returns:
OrderedDict: A state dictionary suitable for loading into a transformer model.
"""
# Regular expression to extract layer number from filename
# Pattern: attn_router_{layer_num}-{value1}-{value2}.pt
router_file_pattern = re.compile(r'attn_router_(\d+)-[\d.]+-[\d.]+\.pt$')
router_state_dict = OrderedDict()
# List all files in the directory
try:
all_files = os.listdir(router_files_dir)
except FileNotFoundError:
print(f"Error: Directory '{router_files_dir}' does not exist.")
return None
# Filter files matching the pattern
router_files = [f for f in all_files if router_file_pattern.match(f)]
if not router_files:
print(f"No attn_router files found in directory '{router_files_dir}'.")
return None
# To handle potential duplicates, keep track of loaded layer numbers
loaded_layers = set()
for file_name in sorted(router_files, key=lambda x: int(router_file_pattern.match(x).group(1))):
match = router_file_pattern.match(file_name)
if not match:
print(f"Skipping file '{file_name}' as it does not match the pattern.")
continue
layer_num = int(match.group(1))
if layer_num in loaded_layers:
print(f"Warning: Multiple router files found for layer {layer_num}. Skipping '{file_name}'.")
continue # Skip duplicate layers
file_path = os.path.join(router_files_dir, file_name)
try:
# Load the router's state dict
router_weights = torch.load(file_path, map_location='cpu')
if not isinstance(router_weights, dict):
print(f"Warning: The file '{file_path}' does not contain a state dictionary. Skipping.")
continue
except Exception as e:
print(f"Error loading '{file_path}': {e}")
continue
# Iterate through each parameter in the router's state dict
for param_name, param_tensor in router_weights.items():
# Construct the new key
new_key = f"transformer.layers.{layer_num}.mha_router.{param_name}"
router_state_dict[new_key] = param_tensor
loaded_layers.add(layer_num)
# print(f"Loaded MHA router for layer {layer_num} from '{file_name}'.")
print(f"Total MHA routers loaded: {len(loaded_layers)}")
return router_state_dict
|