Create README.md
Browse files
README.md
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
How to load model
|
| 2 |
+
|
| 3 |
+
```
|
| 4 |
+
from safetensors.torch import load_file
|
| 5 |
+
import torch
|
| 6 |
+
import json
|
| 7 |
+
from models.sae.multilayer import MultiLayerSAEBase
|
| 8 |
+
from config.sae.training import SAEConfig, LossCoefficients # Adjust based on your actual config classes
|
| 9 |
+
import os
|
| 10 |
+
|
| 11 |
+
def load_sae_from_huggingface(save_dir: str, model_name: str = "multi_sae", device: str = "cuda"):
|
| 12 |
+
"""
|
| 13 |
+
Load a MultiLayerSAEBase model from Hugging Face format using safetensors.
|
| 14 |
+
|
| 15 |
+
Args:
|
| 16 |
+
save_dir: Directory where the model is saved
|
| 17 |
+
model_name: Name of the model file (default: "multi_sae")
|
| 18 |
+
device: Device to load the model onto (default: "cuda")
|
| 19 |
+
|
| 20 |
+
Returns:
|
| 21 |
+
MultiLayerSAEBase: Loaded model instance
|
| 22 |
+
"""
|
| 23 |
+
# Load configuration
|
| 24 |
+
config_path = os.path.join(save_dir, "config.json")
|
| 25 |
+
with open(config_path, "r") as f:
|
| 26 |
+
config_dict = json.load(f)
|
| 27 |
+
|
| 28 |
+
# Reconstruct gpt_config, converting device string back to torch.device if needed
|
| 29 |
+
gpt_config_dict = config_dict["gpt_config"]
|
| 30 |
+
if "device" in gpt_config_dict:
|
| 31 |
+
gpt_config_dict["device"] = torch.device(gpt_config_dict["device"]) # Convert string back to torch.device
|
| 32 |
+
|
| 33 |
+
# Reconstruct SAEConfig (adjust based on your actual SAEConfig class)
|
| 34 |
+
gpt_config = type(sae_train_config.sae_config.gpt_config)(**gpt_config_dict) # Assuming a dataclass
|
| 35 |
+
sae_config = SAEConfig(
|
| 36 |
+
gpt_config=gpt_config,
|
| 37 |
+
n_features=config_dict["feature_size"],
|
| 38 |
+
# Add other required fields if necessary
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
# Reconstruct LossCoefficients if provided
|
| 42 |
+
loss_coefficients = LossCoefficients(sparsity=config_dict["l1_coefficient"]) if config_dict["l1_coefficient"] else None
|
| 43 |
+
|
| 44 |
+
# Initialize the model
|
| 45 |
+
sae = MultiLayerSAEBase(config=sae_config, loss_coefficients=loss_coefficients)
|
| 46 |
+
|
| 47 |
+
# Load the state dictionary
|
| 48 |
+
model_path = os.path.join(save_dir, f"{model_name}.safetensors")
|
| 49 |
+
state_dict = load_file(model_path)
|
| 50 |
+
|
| 51 |
+
# Load tensors into the model
|
| 52 |
+
sae.load_state_dict(state_dict)
|
| 53 |
+
sae.to(device)
|
| 54 |
+
sae.eval() # Set to evaluation mode
|
| 55 |
+
|
| 56 |
+
print(f"Model loaded from {model_path}")
|
| 57 |
+
return sae
|
| 58 |
+
|
| 59 |
+
# Example usage
|
| 60 |
+
save_dir = "../checkpoints/multi-layer.shakespeare_64x4"
|
| 61 |
+
loaded_sae = load_sae_from_huggingface(save_dir, model_name="sae", device="cuda")
|
| 62 |
+
```
|