|
|
import os |
|
|
import re |
|
|
import torch |
|
|
from collections import OrderedDict |
|
|
|
|
|
def create_mlp_router_state_dict(router_files_dir): |
|
|
""" |
|
|
Loads all mlp_router weight files from the specified directory and creates a router_state_dict |
|
|
with keys formatted as 'transformer.layers.{layer_num}.mlp_router.{param_name}'. |
|
|
|
|
|
Args: |
|
|
router_files_dir (str): Path to the directory containing mlp_router_*.pt files. |
|
|
|
|
|
Returns: |
|
|
OrderedDict: A state dictionary suitable for loading into a transformer model. |
|
|
""" |
|
|
|
|
|
router_file_pattern = re.compile(r'mlp_router_(\d+)-[\d.]+-[\d.]+-[\d.]+\.pt$') |
|
|
|
|
|
router_state_dict = OrderedDict() |
|
|
|
|
|
|
|
|
try: |
|
|
all_files = os.listdir(router_files_dir) |
|
|
except FileNotFoundError: |
|
|
print(f"Error: Directory '{router_files_dir}' does not exist.") |
|
|
return None |
|
|
|
|
|
|
|
|
router_files = [f for f in all_files if router_file_pattern.match(f)] |
|
|
|
|
|
if not router_files: |
|
|
print(f"No router files found in directory '{router_files_dir}'.") |
|
|
return None |
|
|
|
|
|
for file_name in sorted(router_files, key=lambda x: int(router_file_pattern.match(x).group(1))): |
|
|
match = router_file_pattern.match(file_name) |
|
|
if not match: |
|
|
print(f"Skipping file '{file_name}' as it does not match the pattern.") |
|
|
continue |
|
|
|
|
|
layer_num = int(match.group(1)) |
|
|
file_path = os.path.join(router_files_dir, file_name) |
|
|
|
|
|
try: |
|
|
|
|
|
router_weights = torch.load(file_path, map_location='cpu') |
|
|
if not isinstance(router_weights, dict): |
|
|
print(f"Warning: The file '{file_path}' does not contain a state dictionary. Skipping.") |
|
|
continue |
|
|
except Exception as e: |
|
|
print(f"Error loading '{file_path}': {e}") |
|
|
continue |
|
|
|
|
|
|
|
|
for param_name, param_tensor in router_weights.items(): |
|
|
|
|
|
new_key = f"transformer.layers.{layer_num}.mlp_router.{param_name}" |
|
|
router_state_dict[new_key] = param_tensor |
|
|
|
|
|
|
|
|
|
|
|
print(f"Total routers loaded: {len(router_state_dict) // 2}") |
|
|
return router_state_dict |
|
|
|
|
|
|
|
|
def create_attn_router_state_dict(router_files_dir): |
|
|
""" |
|
|
Loads all attn_router weight files from the specified directory and creates a router_state_dict |
|
|
with keys formatted as 'transformer.layers.{layer_num}.mha_router.{param_name}'. |
|
|
|
|
|
Args: |
|
|
router_files_dir (str): Path to the directory containing attn_router_*.pt files. |
|
|
|
|
|
Returns: |
|
|
OrderedDict: A state dictionary suitable for loading into a transformer model. |
|
|
""" |
|
|
|
|
|
|
|
|
router_file_pattern = re.compile(r'attn_router_(\d+)-[\d.]+-[\d.]+\.pt$') |
|
|
|
|
|
router_state_dict = OrderedDict() |
|
|
|
|
|
|
|
|
try: |
|
|
all_files = os.listdir(router_files_dir) |
|
|
except FileNotFoundError: |
|
|
print(f"Error: Directory '{router_files_dir}' does not exist.") |
|
|
return None |
|
|
|
|
|
|
|
|
router_files = [f for f in all_files if router_file_pattern.match(f)] |
|
|
|
|
|
if not router_files: |
|
|
print(f"No attn_router files found in directory '{router_files_dir}'.") |
|
|
return None |
|
|
|
|
|
|
|
|
loaded_layers = set() |
|
|
|
|
|
for file_name in sorted(router_files, key=lambda x: int(router_file_pattern.match(x).group(1))): |
|
|
match = router_file_pattern.match(file_name) |
|
|
if not match: |
|
|
print(f"Skipping file '{file_name}' as it does not match the pattern.") |
|
|
continue |
|
|
|
|
|
layer_num = int(match.group(1)) |
|
|
if layer_num in loaded_layers: |
|
|
print(f"Warning: Multiple router files found for layer {layer_num}. Skipping '{file_name}'.") |
|
|
continue |
|
|
|
|
|
file_path = os.path.join(router_files_dir, file_name) |
|
|
|
|
|
try: |
|
|
|
|
|
router_weights = torch.load(file_path, map_location='cpu') |
|
|
if not isinstance(router_weights, dict): |
|
|
print(f"Warning: The file '{file_path}' does not contain a state dictionary. Skipping.") |
|
|
continue |
|
|
except Exception as e: |
|
|
print(f"Error loading '{file_path}': {e}") |
|
|
continue |
|
|
|
|
|
|
|
|
for param_name, param_tensor in router_weights.items(): |
|
|
|
|
|
new_key = f"transformer.layers.{layer_num}.mha_router.{param_name}" |
|
|
router_state_dict[new_key] = param_tensor |
|
|
|
|
|
loaded_layers.add(layer_num) |
|
|
|
|
|
|
|
|
print(f"Total MHA routers loaded: {len(loaded_layers)}") |
|
|
return router_state_dict |
|
|
|
|
|
|
|
|
|