File size: 2,494 Bytes
c69e940
 
 
eceb983
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
03098fe
 
eceb983
03098fe
 
 
 
 
 
 
 
eceb983
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
---
datasets:
- edinburghcstr/ami
base_model:
- MIT/ast-finetuned-audioset-10-10-0.4593
---
# AST-based Speaker Identification on AMI

## Model description

This model is a **fine-tuned** version of [MIT/ast-finetuned-audioset-10-10-0.4593](https://huggingface.co/MIT/ast-finetuned-audioset-10-10-0.4593)  
for speaker classification on the AMI Meeting Corpus. It was trained on **50** speakers (adjust `num_labels` if different), using 128-bin mel-spectrograms of 1024 frames.

- **Base architecture**: Audio Spectrogram Transformer (AST)  
- **Training**: ~10 epochs, batch size=4, learning rate=1e-5, AdamW optimizer, mixed precision  
- **Data**: Stratified samples from AMI train/validation/test splits  
- **Performance**: Not good, this was just a small experiment for diarization



## How to use

```python
from transformers import AutoProcessor, ASTForAudioClassification
import torch
import numpy as np

# 1) Load the model and processor
MODEL_ID = "agutig/AST_diarizer"
processor = AutoProcessor.from_pretrained(MODEL_ID)
model     = ASTForAudioClassification.from_pretrained(MODEL_ID)
model.eval()

# 2) Prepare a 1-second audio sample (or load your own)
sr = 16000
audio = np.random.randn(sr).astype(np.float32)
# Alternatively:
# import librosa
# audio, _ = librosa.load("your_audio.wav", sr=sr)

# 3) Preprocess and run inference
inputs = processor(audio, sampling_rate=sr, return_tensors="pt")
with torch.no_grad():
    logits = model(**inputs).logits  # shape [1, num_labels]
    probs  = torch.softmax(logits, dim=-1)[0]
    pred_i = int(probs.argmax())

print(f"Predicted speaker index: {pred_i}")
```

## Usage with `pipeline`

```python
from transformers import pipeline

speaker_id = pipeline(
    task="audio-classification",
    model="agutig/AST_diarizer",
    return_all_scores=True
)

results = speaker_id("path/to/audio.wav")
print(results)
```

## Evaluation & Benchmarks

Clasification:
![image/png](https://cdn-uploads.huggingface.co/production/uploads/6759a53b608daa4d287fd97c/ZfvDY9M32wTtsePzwJV3v.png)

![image/png](https://cdn-uploads.huggingface.co/production/uploads/6759a53b608daa4d287fd97c/BUm30OuKmUWehOIqFjUdO.png)

Embeddings

![image/png](https://cdn-uploads.huggingface.co/production/uploads/6759a53b608daa4d287fd97c/-WBY39T4M4f9pRrZGjqMk.png)


![image/png](https://cdn-uploads.huggingface.co/production/uploads/6759a53b608daa4d287fd97c/ZEWyyjfZtxUZjzFgywqJI.png)



## License

- **Model**: Apache 2.0  
- **Base code (AST AudioSet)**: MIT License