Update README.md
Browse files
README.md
CHANGED
|
@@ -30,7 +30,7 @@ import torch.nn.functional as F
|
|
| 30 |
from torchaudio.compliance import kaldi
|
| 31 |
|
| 32 |
# for fine-tuning, you can pass `num_classes={your number of classes}`
|
| 33 |
-
model = timm.create_model("hf_hub:gaunernst/vit_base_patch16_1024_128.
|
| 34 |
model = model.eval()
|
| 35 |
|
| 36 |
MEAN = -4.2677393
|
|
@@ -48,6 +48,12 @@ melspec = (melspec - MEAN) / (STD * 2)
|
|
| 48 |
|
| 49 |
melspec = melspec.view(1, 1, 1024, 128) # add batch dim and channel dim
|
| 50 |
output = model(melspec) # embeddings with shape (1, 768)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
```
|
| 52 |
|
| 53 |
## Citation
|
|
|
|
| 30 |
from torchaudio.compliance import kaldi
|
| 31 |
|
| 32 |
# for fine-tuning, you can pass `num_classes={your number of classes}`
|
| 33 |
+
model = timm.create_model("hf_hub:gaunernst/vit_base_patch16_1024_128.audiomae_as2m", pretrained=True)
|
| 34 |
model = model.eval()
|
| 35 |
|
| 36 |
MEAN = -4.2677393
|
|
|
|
| 48 |
|
| 49 |
melspec = melspec.view(1, 1, 1024, 128) # add batch dim and channel dim
|
| 50 |
output = model(melspec) # embeddings with shape (1, 768)
|
| 51 |
+
|
| 52 |
+
# to get frame level embeddings
|
| 53 |
+
output = model.forward_features(melspec) # shape (1, 513, 768)
|
| 54 |
+
output = output[:, 1:] # remove [CLS] token
|
| 55 |
+
output = output.unflatten(1, (1024 // 16, 128 // 16)) # (1, 64, 8, 768) -> 2D patches
|
| 56 |
+
output = output.mean(2) # (1, 64, 768) -> mean pooling across mel dimension
|
| 57 |
```
|
| 58 |
|
| 59 |
## Citation
|