Update README.md
Browse files
README.md
CHANGED
|
@@ -27,9 +27,13 @@ ESM_DIM = 1280
|
|
| 27 |
SAE_DIM = 4096
|
| 28 |
LAYER = 24
|
| 29 |
|
|
|
|
|
|
|
| 30 |
# Load ESM model
|
| 31 |
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
|
| 32 |
esm_model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D")
|
|
|
|
|
|
|
| 33 |
|
| 34 |
# Load SAE model
|
| 35 |
checkpoint_path = hf_hub_download(
|
|
@@ -38,11 +42,13 @@ checkpoint_path = hf_hub_download(
|
|
| 38 |
)
|
| 39 |
sae_model = SparseAutoencoder(ESM_DIM, SAE_DIM)
|
| 40 |
sae_model.load_state_dict(load_file(checkpoint_path))
|
|
|
|
|
|
|
| 41 |
```
|
| 42 |
|
| 43 |
ESM -> SAE inference on an amino acid sequence of length `L`
|
| 44 |
```
|
| 45 |
-
seq = "
|
| 46 |
|
| 47 |
# Tokenize sequence and run ESM inference
|
| 48 |
inputs = tokenizer(seq, padding=True, return_tensors="pt")
|
|
|
|
| 27 |
SAE_DIM = 4096
|
| 28 |
LAYER = 24
|
| 29 |
|
| 30 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 31 |
+
|
| 32 |
# Load ESM model
|
| 33 |
tokenizer = AutoTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
|
| 34 |
esm_model = EsmModel.from_pretrained("facebook/esm2_t33_650M_UR50D")
|
| 35 |
+
esm_model.to(device)
|
| 36 |
+
esm_model.eval()
|
| 37 |
|
| 38 |
# Load SAE model
|
| 39 |
checkpoint_path = hf_hub_download(
|
|
|
|
| 42 |
)
|
| 43 |
sae_model = SparseAutoencoder(ESM_DIM, SAE_DIM)
|
| 44 |
sae_model.load_state_dict(load_file(checkpoint_path))
|
| 45 |
+
sae_model.to(device)
|
| 46 |
+
sae_model.eval()
|
| 47 |
```
|
| 48 |
|
| 49 |
ESM -> SAE inference on an amino acid sequence of length `L`
|
| 50 |
```
|
| 51 |
+
seq = "TTCCPSIVARSNFNVCRLPGTPEALCATYTGCIIIPGATCPGDYAN"
|
| 52 |
|
| 53 |
# Tokenize sequence and run ESM inference
|
| 54 |
inputs = tokenizer(seq, padding=True, return_tensors="pt")
|