remove example pt file
Browse files- 725159424.pt +0 -3
- README.md +31 -0
725159424.pt
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:eeb05bde43937f346fae4d7cf6152021187dc7dcaef6471506b255e0fd5ef647
|
| 3 |
-
size 1610891985
|
|
|
|
|
|
|
|
|
|
|
|
README.md
CHANGED
|
@@ -37,6 +37,37 @@ Training logs are available [via wandb](https://wandb.ai/lewington/ViT-L-14-laio
|
|
| 37 |
|
| 38 |
## Usage
|
| 39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
## Error Formulae
|
| 41 |
|
| 42 |
We calculate MSE as `(batch - reconstruction).pow(2).sum(dim=-1).mean()` i.e. The MSE between the batch and the un-normalized reconstruction, summed across features. We use batch norm to bring all activations into a similar range.
|
|
|
|
| 37 |
|
| 38 |
## Usage
|
| 39 |
|
| 40 |
+
```python
|
| 41 |
+
import PIL
|
| 42 |
+
from clipscope import ConfiguredViT, TopKSAE
|
| 43 |
+
|
| 44 |
+
device='cpu'
|
| 45 |
+
filename_in_hf_repo = "725159424.pt"
|
| 46 |
+
sae = TopKSAE.from_pretrained(repo_id="lewington/CLIP-ViT-L-scope", filename=filename_in_hf_repo, device=device)
|
| 47 |
+
|
| 48 |
+
transformer_name='laion/CLIP-ViT-L-14-laion2B-s32B-b82K'
|
| 49 |
+
locations = [(22, 'resid')]
|
| 50 |
+
transformer = ConfiguredViT(locations, transformer_name, device=device)
|
| 51 |
+
|
| 52 |
+
input = PIL.Image.new("RGB", (224, 224), (0, 0, 0)) # black image for testing
|
| 53 |
+
|
| 54 |
+
activations = transformer.all_activations(input)[locations[0]] # (1, 257, 1024)
|
| 55 |
+
assert activations.shape == (1, 257, 1024)
|
| 56 |
+
|
| 57 |
+
activations = activations[:, 0] # just the cls token
|
| 58 |
+
# alternatively flatten the activations
|
| 59 |
+
# activations = activations.flatten(1)
|
| 60 |
+
|
| 61 |
+
print('activations shape', activations.shape)
|
| 62 |
+
|
| 63 |
+
output = sae.forward_verbose(activations)
|
| 64 |
+
|
| 65 |
+
print('output keys', output.keys())
|
| 66 |
+
|
| 67 |
+
print('latent shape', output['latent'].shape) # (1, 65536)
|
| 68 |
+
print('reconstruction shape', output['reconstruction'].shape) # (1, 1024)
|
| 69 |
+
```
|
| 70 |
+
|
| 71 |
## Error Formulae
|
| 72 |
|
| 73 |
We calculate MSE as `(batch - reconstruction).pow(2).sum(dim=-1).mean()` i.e. The MSE between the batch and the un-normalized reconstruction, summed across features. We use batch norm to bring all activations into a similar range.
|