davidquarel commited on
Commit
a23de93
·
verified ·
1 Parent(s): 5a4d928

Create README.md

Browse files
Files changed (1) hide show
  1. README.md +62 -0
README.md ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ How to load model
2
+
3
+ ```
4
+ from safetensors.torch import load_file
5
+ import torch
6
+ import json
7
+ from models.sae.multilayer import MultiLayerSAEBase
8
+ from config.sae.training import SAEConfig, LossCoefficients # Adjust based on your actual config classes
9
+ import os
10
+
11
+ def load_sae_from_huggingface(save_dir: str, model_name: str = "multi_sae", device: str = "cuda"):
12
+ """
13
+ Load a MultiLayerSAEBase model from Hugging Face format using safetensors.
14
+
15
+ Args:
16
+ save_dir: Directory where the model is saved
17
+ model_name: Name of the model file (default: "multi_sae")
18
+ device: Device to load the model onto (default: "cuda")
19
+
20
+ Returns:
21
+ MultiLayerSAEBase: Loaded model instance
22
+ """
23
+ # Load configuration
24
+ config_path = os.path.join(save_dir, "config.json")
25
+ with open(config_path, "r") as f:
26
+ config_dict = json.load(f)
27
+
28
+ # Reconstruct gpt_config, converting device string back to torch.device if needed
29
+ gpt_config_dict = config_dict["gpt_config"]
30
+ if "device" in gpt_config_dict:
31
+ gpt_config_dict["device"] = torch.device(gpt_config_dict["device"]) # Convert string back to torch.device
32
+
33
+ # Reconstruct SAEConfig (adjust based on your actual SAEConfig class)
34
+ gpt_config = type(sae_train_config.sae_config.gpt_config)(**gpt_config_dict) # Assuming a dataclass
35
+ sae_config = SAEConfig(
36
+ gpt_config=gpt_config,
37
+ n_features=config_dict["feature_size"],
38
+ # Add other required fields if necessary
39
+ )
40
+
41
+ # Reconstruct LossCoefficients if provided
42
+ loss_coefficients = LossCoefficients(sparsity=config_dict["l1_coefficient"]) if config_dict["l1_coefficient"] else None
43
+
44
+ # Initialize the model
45
+ sae = MultiLayerSAEBase(config=sae_config, loss_coefficients=loss_coefficients)
46
+
47
+ # Load the state dictionary
48
+ model_path = os.path.join(save_dir, f"{model_name}.safetensors")
49
+ state_dict = load_file(model_path)
50
+
51
+ # Load tensors into the model
52
+ sae.load_state_dict(state_dict)
53
+ sae.to(device)
54
+ sae.eval() # Set to evaluation mode
55
+
56
+ print(f"Model loaded from {model_path}")
57
+ return sae
58
+
59
+ # Example usage
60
+ save_dir = "../checkpoints/multi-layer.shakespeare_64x4"
61
+ loaded_sae = load_sae_from_huggingface(save_dir, model_name="sae", device="cuda")
62
+ ```