HTill's picture
Update README.md
ef039b3 verified
---
license: mit
tags:
- Audio
- SSL
- EAT
- flexEAT
- audio embeddings
- audio features
library_name: transformers
language:
- en
base_model:
- worstchan/EAT-base_epoch30_pretrain
pipeline_tag: audio-classification
---
**⚠️ Codebase Update: Input Flexibility & Fine-Tuning Preparation**
This repository has been updated to support **dynamic input lengths**. The model now utilizes on-the-fly positional embeddings, removing the restriction on fixed audio durations found in previous versions.
While the core pre-trained weights and architectural logic remain unchanged, we have introduced new infrastructure to facilitate **downstream fine-tuning** (including dynamic classification heads and normalization layers). Additional fine-tuning configurations and hyperparameters will be documented in future updates.
# EAT-base (Epoch 30, Pre-trained Checkpoint)
This is the **pre-trained EAT-base model** at epoch 30, trained on the AS-2M dataset using the EAT framework for audio self-supervised learning.
It offers efficient feature extraction and can also serve as a strong initialization for fine-tuning on a wide range of downstream audio understanding tasks such as classification and captioning.
For more details on the EAT framework, please refer to the [GitHub repository](https://github.com/cwx-worst-one/EAT) and our paper [EAT: Self-Supervised Pre-Training with Efficient Audio Transformer](https://arxiv.org/abs/2401.03497).
## πŸ”§ Usage
You can load and use the model for feature extraction directly via Hugging Face Transformers:
```python
import torchaudio
import torch
import soundfile as sf
import numpy as np
from transformers import AutoModel
model_id = "HTill/flexEAT-base_epoch30_pretrain"
model = AutoModel.from_pretrained(model_id, trust_remote_code=True).eval().cuda()
source_file = "/path/to/input.wav"
target_file = "/path/to/output.npy"
norm_mean = -4.268
norm_std = 4.569
# Load and resample audio
wav, sr = sf.read(source_file)
waveform = torch.tensor(wav).float().cuda()
if sr != 16000:
waveform = torchaudio.functional.resample(waveform, sr, 16000)
# Normalize and convert to mel-spectrogram
waveform = waveform - waveform.mean()
mel = torchaudio.compliance.kaldi.fbank(
waveform.unsqueeze(0),
htk_compat=True,
sample_frequency=16000,
use_energy=False,
window_type='hanning',
num_mel_bins=128,
dither=0.0,
frame_shift=10
).unsqueeze(0)
# Normalize
mel = (mel - norm_mean) / (norm_std * 2)
mel = mel.unsqueeze(0).cuda() # shape: [1, 1, T, F]
# Extract features
with torch.no_grad():
feat = model.extract_features(mel)
feat = feat.squeeze(0).cpu().numpy()
np.save(target_file, feat)
print(f"Feature shape: {feat.shape}")
print(f"Saved to: {target_file}")
```
## πŸ“Œ Notes
The model supports both **frame-level** (\~50Hz) and **utterance-level** (CLS token) representations.
See the [feature extraction guide](https://github.com/cwx-worst-one/EAT/tree/main/feature_extract) for more instructions.
## πŸ“š Citation
If you find this model useful, please consider citing our [paper](https://arxiv.org/abs/2401.03497):
```bibtex
@article{chen2024eat,
title={EAT: Self-supervised pre-training with efficient audio transformer},
author={Chen, Wenxi and Liang, Yuzhe and Ma, Ziyang and Zheng, Zhisheng and Chen, Xie},
journal={arXiv preprint arXiv:2401.03497},
year={2024}
}