| How to load model | |
| ```python | |
| from safetensors.torch import load_file | |
| import torch | |
| import json | |
| from models.sae.multilayer import MultiLayerSAEBase | |
| from config.sae.training import SAEConfig, LossCoefficients # Adjust based on your actual config classes | |
| import os | |
| def load_sae_from_huggingface(save_dir: str, model_name: str = "multi_sae", device: str = "cuda"): | |
| """ | |
| Load a MultiLayerSAEBase model from Hugging Face format using safetensors. | |
| Args: | |
| save_dir: Directory where the model is saved | |
| model_name: Name of the model file (default: "multi_sae") | |
| device: Device to load the model onto (default: "cuda") | |
| Returns: | |
| MultiLayerSAEBase: Loaded model instance | |
| """ | |
| # Load configuration | |
| config_path = os.path.join(save_dir, "config.json") | |
| with open(config_path, "r") as f: | |
| config_dict = json.load(f) | |
| # Reconstruct gpt_config, converting device string back to torch.device if needed | |
| gpt_config_dict = config_dict["gpt_config"] | |
| if "device" in gpt_config_dict: | |
| gpt_config_dict["device"] = torch.device(gpt_config_dict["device"]) # Convert string back to torch.device | |
| # Reconstruct SAEConfig (adjust based on your actual SAEConfig class) | |
| gpt_config = type(sae_train_config.sae_config.gpt_config)(**gpt_config_dict) # Assuming a dataclass | |
| sae_config = SAEConfig( | |
| gpt_config=gpt_config, | |
| n_features=config_dict["feature_size"], | |
| # Add other required fields if necessary | |
| ) | |
| # Reconstruct LossCoefficients if provided | |
| loss_coefficients = LossCoefficients(sparsity=config_dict["l1_coefficient"]) if config_dict["l1_coefficient"] else None | |
| # Initialize the model | |
| sae = MultiLayerSAEBase(config=sae_config, loss_coefficients=loss_coefficients) | |
| # Load the state dictionary | |
| model_path = os.path.join(save_dir, f"{model_name}.safetensors") | |
| state_dict = load_file(model_path) | |
| # Load tensors into the model | |
| sae.load_state_dict(state_dict) | |
| sae.to(device) | |
| sae.eval() # Set to evaluation mode | |
| print(f"Model loaded from {model_path}") | |
| return sae | |
| # Example usage | |
| save_dir = "../checkpoints/multi-layer.shakespeare_64x4" | |
| loaded_sae = load_sae_from_huggingface(save_dir, model_name="sae", device="cuda") | |
| ``` |