jhansss dslee2601 commited on
Commit
a96dc72
·
verified ·
0 Parent(s):

Duplicate from hance-ai/audiomae

Browse files

Co-authored-by: Daesoo Lee <dslee2601@users.noreply.huggingface.co>

.fig/sanity_check_result_audiomae.png ADDED

Git LFS Details

  • SHA256: 88b5eeae8610ba0899004d5b1e8a48041e9d08b12277e25a43221528f296270a
  • Pointer size: 132 Bytes
  • Size of remote file: 3.05 MB
.gitattributes ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ sanity_check_result_audiomae.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ __pycache__/
2
+ token.txt
.sample_sound/baby_coughing.wav ADDED
Binary file (882 kB). View file
 
README.md ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ tags:
4
+ - AudioMAE
5
+ - PyTorch
6
+ ---
7
+
8
+ # AudioMAE
9
+ This model card provides an easy-to-use API for a *pretrained encoder of AudioMAE* [1] whose weights are from [its original reposotiry](https://github.com/facebookresearch/AudioMAE).
10
+ The provided model is specifically designed to easily obtain learned representations given an input audio file.
11
+ The resulting representation $z$ has a dimension of $(d, h, w)$ for a single audio file, where $d$, $h$, and $w$ denote a latent dimension size, latent frequency dim, and latent temporal dim, respectively.
12
+ [2] indicates that both frequency and temporal dimensional semantics are preserved in $z$.
13
+
14
+ # Dependency
15
+ See `requirements.txt`
16
+
17
+
18
+ # Usage
19
+ ```python
20
+ from transformers import AutoModel
21
+
22
+ device = 'cpu' # 'cpu' or 'cuda'
23
+ model = AutoModel.from_pretrained("hance-ai/audiomae", trust_remote_code=True).to(device) # load the pretrained model
24
+ z = model('path/audio_fname.wav') # (768, 8, 64) = (latent_dim_size, latent_freq_dim, latent_temporal_dim)
25
+ ```
26
+
27
+ Depending on a task, a different pooling strategy should be facilitated.
28
+ For instance, a global average pooling can be used for a classification task. [2] uses an adaptive pooling.
29
+
30
+ ⚠️ AudioMAE accepts audio with maximum length of 10s (as described in [1]). Any audio longer than 10s will be clipped to 10s, meaning the excess beyond 10s will be discarded.
31
+
32
+
33
+ # Sanity Check Result
34
+ In the following, a spectrogram of an input audio and corresponding $z$ are visualized.
35
+ The input audio is 10s, containing baby coughing, hiccuping, and adult sneezing.
36
+ The latent dimension size of $z$ is reduced to 8 using PCA for visualization.
37
+
38
+ <p align="center">
39
+ <img src=".fig/sanity_check_result_audiomae.png" alt="" width=100%>
40
+ </p>
41
+
42
+ The result shows that the presence of labeled sound is clearly captured in the 3rd principal component (PC).
43
+ While the baby coughing and hiccuping sounds are not so distinugisable up to the 5th PC, they are in the 6th PC.
44
+ This result briefly shows the effectiveness of the pretrained AudioMAE.
45
+
46
+
47
+ # References
48
+
49
+ [1] Huang, Po-Yao, et al. "Masked autoencoders that listen." Advances in Neural Information Processing Systems 35 (2022): 28708-28720.
50
+
51
+ [2] Liu, Haohe, et al. "Audioldm 2: Learning holistic audio generation with self-supervised pretraining." IEEE/ACM Transactions on Audio, Speech, and Language Processing (2024).
config.json ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "PretrainedAudioMAEEncoder"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "model.AudioMAEConfig",
7
+ "AutoModel": "model.PretrainedAudioMAEEncoder"
8
+ },
9
+ "img_size": [
10
+ 1024,
11
+ 128
12
+ ],
13
+ "in_chans": 1,
14
+ "model_type": "audiomae",
15
+ "num_classes": 0,
16
+ "torch_dtype": "float32",
17
+ "transformers_version": "4.44.0"
18
+ }
model.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+
3
+ import torch
4
+ import torchaudio
5
+ import torchaudio.transforms as transforms
6
+ from torchaudio.compliance import kaldi
7
+ from transformers import PretrainedConfig
8
+
9
+ from einops import rearrange
10
+
11
+ from timm.models.vision_transformer import VisionTransformer
12
+ from transformers import PreTrainedModel
13
+
14
+
15
+ # it seems like Config class and Model class should be located in the same file; otherwise, seemingly casuing an issue in model loading after pushing to HF.
16
+ class AudioMAEConfig(PretrainedConfig):
17
+ model_type = "audiomae"
18
+
19
+ def __init__(self,
20
+ img_size:Tuple[int,int]=(1024,128),
21
+ in_chans:int=1,
22
+ num_classes:int=0,
23
+ **kwargs,):
24
+ super().__init__(**kwargs)
25
+ self.img_size = img_size
26
+ self.in_chans = in_chans
27
+ self.num_classes = num_classes
28
+
29
+
30
+ class AudioMAEEncoder(VisionTransformer):
31
+ def __init__(self, *args, **kwargs):
32
+ super().__init__(*args, **kwargs)
33
+ """
34
+ - img_size of (1024, 128) = (temporal_length, n_freq_bins) is fixed, as described in the paper
35
+ - AudoMAE accepts a mono-channel (i.e., in_chans=1)
36
+ """
37
+ self.MEAN = -4.2677393 # written on the paper
38
+ self.STD = 4.5689974 # written on the paper
39
+
40
+ def load_wav_file(self, file_path:str):
41
+ """
42
+ to use this, `torchaudio` and `ffmpeg` must be installed
43
+ - `ffmpeg` version must be >=4.4 and <7.
44
+ - `ffmpeg` installation by `conda install -c conda-forge ffmpeg==6.1.1`
45
+ """
46
+ audio, sample_rate = torchaudio.load(file_path) # audio: (n_channels, length);
47
+
48
+ # length clip
49
+ audio_len = audio.shape[-1] / sample_rate
50
+ if audio_len > 10.0:
51
+ print('current audio length is:', audio_len)
52
+ print('[WARNING] AudioMAE only accepts audio length up to 10s. The audio frames exceeding 10s will be clipped.')
53
+
54
+ # Check if the audio has multiple channels
55
+ if audio.shape[0] > 1:
56
+ # Convert stereo audio to mono by taking the mean across channels
57
+ # AudioMAE accepts a mono channel.
58
+ audio = torch.mean(audio, dim=0, keepdim=True)
59
+
60
+ # resample the audio into 16khz
61
+ # AudioMAE accepts 16khz
62
+ if sample_rate != 16000:
63
+ converter = transforms.Resample(orig_freq=sample_rate, new_freq=16000)
64
+ audio = converter(audio)
65
+ return audio
66
+
67
+ def waveform_to_melspec(self, waveform:torch.FloatTensor):
68
+ # Compute the Mel spectrogram using Kaldi-compatible features
69
+ # the parameters are chosen as described in the audioMAE paper (4.2 implementation details)
70
+ mel_spectrogram = kaldi.fbank(
71
+ waveform,
72
+ num_mel_bins=128,
73
+ frame_length=25.0,
74
+ frame_shift=10.0,
75
+ htk_compat=True,
76
+ use_energy=False,
77
+ sample_frequency=16000,
78
+ window_type='hanning',
79
+ dither=0.0
80
+ )
81
+
82
+ # Ensure the output shape matches 1x1024x128 by padding or trimming the time dimension
83
+ expected_frames = 1024 # as described in the paper
84
+ current_frames = mel_spectrogram.shape[0]
85
+ if current_frames > expected_frames:
86
+ mel_spectrogram = mel_spectrogram[:expected_frames, :]
87
+ elif current_frames < expected_frames:
88
+ padding = expected_frames - current_frames
89
+ mel_spectrogram = torch.nn.functional.pad(mel_spectrogram, (0, 0, # (left, right) for the 1st dim
90
+ 0, padding), # (left, right) for the 2nd dim
91
+ )
92
+
93
+ # scale
94
+ # as in the AudioMAE implementation [REF: https://github.com/facebookresearch/AudioMAE/blob/bd60e29651285f80d32a6405082835ad26e6f19f/dataset.py#L300]
95
+ mel_spectrogram = (mel_spectrogram - self.MEAN) / (self.STD * 2) # (length, n_freq_bins) = (1024, 128)
96
+ return mel_spectrogram
97
+
98
+ @torch.no_grad()
99
+ def encode(self, file_path:str, device):
100
+ self.eval()
101
+
102
+ waveform = self.load_wav_file(file_path)
103
+ melspec = self.waveform_to_melspec(waveform) # (length, n_freq_bins) = (1024, 128)
104
+ melspec = melspec[None,None,:,:] # (1, 1, length, n_freq_bins) = (1, 1, 1024, 128)
105
+ z = self.forward_features(melspec.to(device)).cpu() # (b, 1+n, d); d=768
106
+ z = z[:,1:,:] # (b n d); remove [CLS], the class token
107
+
108
+ b, c, w, h = melspec.shape # w: temporal dim; h:freq dim
109
+ wprime = round(w / self.patch_embed.patch_size[0]) # width in the latent space
110
+ hprime = round(h / self.patch_embed.patch_size[1]) # height in the latent space
111
+
112
+ # reconstruct the temporal and freq dims
113
+ z = rearrange(z, 'b (w h) d -> b d h w', h=hprime) # (b d h' w')
114
+
115
+ # remove the batch dim
116
+ z = z[0] # (d h' w')
117
+ return z # (d h' w')
118
+
119
+
120
+
121
+ class PretrainedAudioMAEEncoder(PreTrainedModel):
122
+ config_class = AudioMAEConfig
123
+
124
+ def __init__(self, config):
125
+ super().__init__(config)
126
+ self.encoder = AudioMAEEncoder(img_size=config.img_size, in_chans=config.in_chans, num_classes=config.num_classes)
127
+
128
+ def forward(self, file_path:str):
129
+ device = self.device
130
+ return self.encoder.encode(file_path, device) # (d h' w')
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c868088d7f6f9ee8c29292bfa029a552bad781e1c59abb3692762330811bf535
3
+ size 342607672
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ torch==2.4.0
2
+ torchaudio==2.4.0
3
+ transformers==4.44.0
4
+ timm==1.0.8
save_audioMAE_init.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
save_audioMAE_self_sustainable.ipynb ADDED
The diff for this file is too large to render. See raw diff