YAML Metadata Warning: empty or missing yaml metadata in repo card (https://huggingface.co/docs/hub/model-cards#model-card-metadata)

How to load model

from safetensors.torch import load_file
import torch
import json
from models.sae.multilayer import MultiLayerSAEBase
from config.sae.training import SAEConfig, LossCoefficients  # Adjust based on your actual config classes
import os

def load_sae_from_huggingface(save_dir: str, model_name: str = "multi_sae", device: str = "cuda"):
    """
    Load a MultiLayerSAEBase model from Hugging Face format using safetensors.
    
    Args:
        save_dir: Directory where the model is saved
        model_name: Name of the model file (default: "multi_sae")
        device: Device to load the model onto (default: "cuda")
    
    Returns:
        MultiLayerSAEBase: Loaded model instance
    """
    # Load configuration
    config_path = os.path.join(save_dir, "config.json")
    with open(config_path, "r") as f:
        config_dict = json.load(f)
    
    # Reconstruct gpt_config, converting device string back to torch.device if needed
    gpt_config_dict = config_dict["gpt_config"]
    if "device" in gpt_config_dict:
        gpt_config_dict["device"] = torch.device(gpt_config_dict["device"])  # Convert string back to torch.device
    
    # Reconstruct SAEConfig (adjust based on your actual SAEConfig class)
    gpt_config = type(sae_train_config.sae_config.gpt_config)(**gpt_config_dict)  # Assuming a dataclass
    sae_config = SAEConfig(
        gpt_config=gpt_config,
        n_features=config_dict["feature_size"],
        # Add other required fields if necessary
    )
    
    # Reconstruct LossCoefficients if provided
    loss_coefficients = LossCoefficients(sparsity=config_dict["l1_coefficient"]) if config_dict["l1_coefficient"] else None
    
    # Initialize the model
    sae = MultiLayerSAEBase(config=sae_config, loss_coefficients=loss_coefficients)
    
    # Load the state dictionary
    model_path = os.path.join(save_dir, f"{model_name}.safetensors")
    state_dict = load_file(model_path)
    
    # Load tensors into the model
    sae.load_state_dict(state_dict)
    sae.to(device)
    sae.eval()  # Set to evaluation mode
    
    print(f"Model loaded from {model_path}")
    return sae

# Example usage
save_dir = "../checkpoints/multi-layer.shakespeare_64x4"
loaded_sae = load_sae_from_huggingface(save_dir, model_name="sae", device="cuda")
Downloads last month
6
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support