--- license: mit tags: - Audio - SSL - EAT - flexEAT - audio embeddings - audio features library_name: transformers language: - en base_model: - worstchan/EAT-base_epoch30_pretrain pipeline_tag: audio-classification --- **⚠️ Codebase Update: Input Flexibility & Fine-Tuning Preparation** This repository has been updated to support **dynamic input lengths**. The model now utilizes on-the-fly positional embeddings, removing the restriction on fixed audio durations found in previous versions. While the core pre-trained weights and architectural logic remain unchanged, we have introduced new infrastructure to facilitate **downstream fine-tuning** (including dynamic classification heads and normalization layers). Additional fine-tuning configurations and hyperparameters will be documented in future updates. # EAT-base (Epoch 30, Pre-trained Checkpoint) This is the **pre-trained EAT-base model** at epoch 30, trained on the AS-2M dataset using the EAT framework for audio self-supervised learning. It offers efficient feature extraction and can also serve as a strong initialization for fine-tuning on a wide range of downstream audio understanding tasks such as classification and captioning. For more details on the EAT framework, please refer to the [GitHub repository](https://github.com/cwx-worst-one/EAT) and our paper [EAT: Self-Supervised Pre-Training with Efficient Audio Transformer](https://arxiv.org/abs/2401.03497). ## 🔧 Usage You can load and use the model for feature extraction directly via Hugging Face Transformers: ```python import torchaudio import torch import soundfile as sf import numpy as np from transformers import AutoModel model_id = "HTill/flexEAT-base_epoch30_pretrain" model = AutoModel.from_pretrained(model_id, trust_remote_code=True).eval().cuda() source_file = "/path/to/input.wav" target_file = "/path/to/output.npy" norm_mean = -4.268 norm_std = 4.569 # Load and resample audio wav, sr = sf.read(source_file) waveform = torch.tensor(wav).float().cuda() if sr != 16000: waveform = torchaudio.functional.resample(waveform, sr, 16000) # Normalize and convert to mel-spectrogram waveform = waveform - waveform.mean() mel = torchaudio.compliance.kaldi.fbank( waveform.unsqueeze(0), htk_compat=True, sample_frequency=16000, use_energy=False, window_type='hanning', num_mel_bins=128, dither=0.0, frame_shift=10 ).unsqueeze(0) # Normalize mel = (mel - norm_mean) / (norm_std * 2) mel = mel.unsqueeze(0).cuda() # shape: [1, 1, T, F] # Extract features with torch.no_grad(): feat = model.extract_features(mel) feat = feat.squeeze(0).cpu().numpy() np.save(target_file, feat) print(f"Feature shape: {feat.shape}") print(f"Saved to: {target_file}") ``` ## 📌 Notes The model supports both **frame-level** (\~50Hz) and **utterance-level** (CLS token) representations. See the [feature extraction guide](https://github.com/cwx-worst-one/EAT/tree/main/feature_extract) for more instructions. ## 📚 Citation If you find this model useful, please consider citing our [paper](https://arxiv.org/abs/2401.03497): ```bibtex @article{chen2024eat, title={EAT: Self-supervised pre-training with efficient audio transformer}, author={Chen, Wenxi and Liang, Yuzhe and Ma, Ziyang and Zheng, Zhisheng and Chen, Xie}, journal={arXiv preprint arXiv:2401.03497}, year={2024} }