gaunernst commited on
Commit
a1e9f8f
·
verified ·
1 Parent(s): b0f2ea3

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +7 -1
README.md CHANGED
@@ -30,7 +30,7 @@ import torch.nn.functional as F
30
  from torchaudio.compliance import kaldi
31
 
32
  # for fine-tuning, you can pass `num_classes={your number of classes}`
33
- model = timm.create_model("hf_hub:gaunernst/vit_base_patch16_1024_128.audiomae_as2m_ft", pretrained=True)
34
  model = model.eval()
35
 
36
  MEAN = -4.2677393
@@ -48,6 +48,12 @@ melspec = (melspec - MEAN) / (STD * 2)
48
 
49
  melspec = melspec.view(1, 1, 1024, 128) # add batch dim and channel dim
50
  output = model(melspec) # embeddings with shape (1, 768)
 
 
 
 
 
 
51
  ```
52
 
53
  ## Citation
 
30
  from torchaudio.compliance import kaldi
31
 
32
  # for fine-tuning, you can pass `num_classes={your number of classes}`
33
+ model = timm.create_model("hf_hub:gaunernst/vit_base_patch16_1024_128.audiomae_as2m", pretrained=True)
34
  model = model.eval()
35
 
36
  MEAN = -4.2677393
 
48
 
49
  melspec = melspec.view(1, 1, 1024, 128) # add batch dim and channel dim
50
  output = model(melspec) # embeddings with shape (1, 768)
51
+
52
+ # to get frame level embeddings
53
+ output = model.forward_features(melspec) # shape (1, 513, 768)
54
+ output = output[:, 1:] # remove [CLS] token
55
+ output = output.unflatten(1, (1024 // 16, 128 // 16)) # (1, 64, 8, 768) -> 2D patches
56
+ output = output.mean(2) # (1, 64, 768) -> mean pooling across mel dimension
57
  ```
58
 
59
  ## Citation