File size: 2,494 Bytes
c69e940 eceb983 03098fe eceb983 03098fe eceb983 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 | ---
datasets:
- edinburghcstr/ami
base_model:
- MIT/ast-finetuned-audioset-10-10-0.4593
---
# AST-based Speaker Identification on AMI
## Model description
This model is a **fine-tuned** version of [MIT/ast-finetuned-audioset-10-10-0.4593](https://huggingface.co/MIT/ast-finetuned-audioset-10-10-0.4593)
for speaker classification on the AMI Meeting Corpus. It was trained on **50** speakers (adjust `num_labels` if different), using 128-bin mel-spectrograms of 1024 frames.
- **Base architecture**: Audio Spectrogram Transformer (AST)
- **Training**: ~10 epochs, batch size=4, learning rate=1e-5, AdamW optimizer, mixed precision
- **Data**: Stratified samples from AMI train/validation/test splits
- **Performance**: Not good, this was just a small experiment for diarization
## How to use
```python
from transformers import AutoProcessor, ASTForAudioClassification
import torch
import numpy as np
# 1) Load the model and processor
MODEL_ID = "agutig/AST_diarizer"
processor = AutoProcessor.from_pretrained(MODEL_ID)
model = ASTForAudioClassification.from_pretrained(MODEL_ID)
model.eval()
# 2) Prepare a 1-second audio sample (or load your own)
sr = 16000
audio = np.random.randn(sr).astype(np.float32)
# Alternatively:
# import librosa
# audio, _ = librosa.load("your_audio.wav", sr=sr)
# 3) Preprocess and run inference
inputs = processor(audio, sampling_rate=sr, return_tensors="pt")
with torch.no_grad():
logits = model(**inputs).logits # shape [1, num_labels]
probs = torch.softmax(logits, dim=-1)[0]
pred_i = int(probs.argmax())
print(f"Predicted speaker index: {pred_i}")
```
## Usage with `pipeline`
```python
from transformers import pipeline
speaker_id = pipeline(
task="audio-classification",
model="agutig/AST_diarizer",
return_all_scores=True
)
results = speaker_id("path/to/audio.wav")
print(results)
```
## Evaluation & Benchmarks
Clasification:


Embeddings


## License
- **Model**: Apache 2.0
- **Base code (AST AudioSet)**: MIT License |