PolarSparsity / HybridTensor /modules /SelectiveRouters.py
Susav's picture
Upload folder using huggingface_hub
b3a3b15 verified
import os
import re
import torch
from collections import OrderedDict
def create_mlp_router_state_dict(router_files_dir):
"""
Loads all mlp_router weight files from the specified directory and creates a router_state_dict
with keys formatted as 'transformer.layers.{layer_num}.mlp_router.{param_name}'.
Args:
router_files_dir (str): Path to the directory containing mlp_router_*.pt files.
Returns:
OrderedDict: A state dictionary suitable for loading into a transformer model.
"""
# Regular expression to extract layer number from filename
router_file_pattern = re.compile(r'mlp_router_(\d+)-[\d.]+-[\d.]+-[\d.]+\.pt$')
router_state_dict = OrderedDict()
# List all files in the directory
try:
all_files = os.listdir(router_files_dir)
except FileNotFoundError:
print(f"Error: Directory '{router_files_dir}' does not exist.")
return None
# Filter files matching the pattern
router_files = [f for f in all_files if router_file_pattern.match(f)]
if not router_files:
print(f"No router files found in directory '{router_files_dir}'.")
return None
for file_name in sorted(router_files, key=lambda x: int(router_file_pattern.match(x).group(1))):
match = router_file_pattern.match(file_name)
if not match:
print(f"Skipping file '{file_name}' as it does not match the pattern.")
continue
layer_num = int(match.group(1))
file_path = os.path.join(router_files_dir, file_name)
try:
# Load the router's state dict
router_weights = torch.load(file_path, map_location='cpu')
if not isinstance(router_weights, dict):
print(f"Warning: The file '{file_path}' does not contain a state dictionary. Skipping.")
continue
except Exception as e:
print(f"Error loading '{file_path}': {e}")
continue
# Iterate through each parameter in the router's state dict
for param_name, param_tensor in router_weights.items():
# Construct the new key
new_key = f"transformer.layers.{layer_num}.mlp_router.{param_name}"
router_state_dict[new_key] = param_tensor
# print(f"Loaded router for layer {layer_num} from '{file_name}'.")
print(f"Total routers loaded: {len(router_state_dict) // 2}") # Assuming 4 params per router (weight & bias for 2 layers)
return router_state_dict
def create_attn_router_state_dict(router_files_dir):
"""
Loads all attn_router weight files from the specified directory and creates a router_state_dict
with keys formatted as 'transformer.layers.{layer_num}.mha_router.{param_name}'.
Args:
router_files_dir (str): Path to the directory containing attn_router_*.pt files.
Returns:
OrderedDict: A state dictionary suitable for loading into a transformer model.
"""
# Regular expression to extract layer number from filename
# Pattern: attn_router_{layer_num}-{value1}-{value2}.pt
router_file_pattern = re.compile(r'attn_router_(\d+)-[\d.]+-[\d.]+\.pt$')
router_state_dict = OrderedDict()
# List all files in the directory
try:
all_files = os.listdir(router_files_dir)
except FileNotFoundError:
print(f"Error: Directory '{router_files_dir}' does not exist.")
return None
# Filter files matching the pattern
router_files = [f for f in all_files if router_file_pattern.match(f)]
if not router_files:
print(f"No attn_router files found in directory '{router_files_dir}'.")
return None
# To handle potential duplicates, keep track of loaded layer numbers
loaded_layers = set()
for file_name in sorted(router_files, key=lambda x: int(router_file_pattern.match(x).group(1))):
match = router_file_pattern.match(file_name)
if not match:
print(f"Skipping file '{file_name}' as it does not match the pattern.")
continue
layer_num = int(match.group(1))
if layer_num in loaded_layers:
print(f"Warning: Multiple router files found for layer {layer_num}. Skipping '{file_name}'.")
continue # Skip duplicate layers
file_path = os.path.join(router_files_dir, file_name)
try:
# Load the router's state dict
router_weights = torch.load(file_path, map_location='cpu')
if not isinstance(router_weights, dict):
print(f"Warning: The file '{file_path}' does not contain a state dictionary. Skipping.")
continue
except Exception as e:
print(f"Error loading '{file_path}': {e}")
continue
# Iterate through each parameter in the router's state dict
for param_name, param_tensor in router_weights.items():
# Construct the new key
new_key = f"transformer.layers.{layer_num}.mha_router.{param_name}"
router_state_dict[new_key] = param_tensor
loaded_layers.add(layer_num)
# print(f"Loaded MHA router for layer {layer_num} from '{file_name}'.")
print(f"Total MHA routers loaded: {len(loaded_layers)}")
return router_state_dict