| import torch |
| from torch import nn |
|
|
|
|
| |
| class AudioClassifier(nn.Module): |
| def __init__( |
| self, |
| label2id: dict, |
| feature_dim=256, |
| hidden_dim=256, |
| device="cpu", |
| dropout_rate=0.5, |
| num_hidden_layers=2, |
| ): |
| super(AudioClassifier, self).__init__() |
| self.num_classes = len(label2id) |
| self.device = device |
| self.label2id = label2id |
| self.id2label = {v: k for k, v in self.label2id.items()} |
| |
| self.fc1 = nn.Sequential( |
| nn.Linear(feature_dim, hidden_dim), |
| nn.BatchNorm1d(hidden_dim), |
| nn.Mish(), |
| nn.Dropout(dropout_rate), |
| ) |
| |
| self.hidden_layers = nn.ModuleList() |
| for _ in range(num_hidden_layers): |
| layer = nn.Sequential( |
| nn.Linear(hidden_dim, hidden_dim), |
| nn.BatchNorm1d(hidden_dim), |
| nn.Mish(), |
| nn.Dropout(dropout_rate), |
| ) |
| self.hidden_layers.append(layer) |
| |
| self.fc_last = nn.Linear(hidden_dim, self.num_classes) |
|
|
| def forward(self, x): |
| |
| x = self.fc1(x) |
|
|
| |
| for layer in self.hidden_layers: |
| x = layer(x) |
|
|
| |
| x = self.fc_last(x) |
| return x |
|
|
| def infer_from_features(self, features): |
| |
| features = ( |
| torch.tensor(features, dtype=torch.float32).unsqueeze(0).to(self.device) |
| ) |
|
|
| |
| self.eval() |
|
|
| |
| with torch.no_grad(): |
| output = self.forward(features) |
|
|
| |
| probs = torch.softmax(output, dim=1) |
|
|
| |
| probs, indices = torch.sort(probs, descending=True) |
| probs = probs.cpu().numpy().squeeze() |
| indices = indices.cpu().numpy().squeeze() |
| return [(self.id2label[i], p) for i, p in zip(indices, probs)] |
|
|
| def infer_from_file(self, file_path): |
| feature = extract_features(file_path, device=self.device) |
| return self.infer_from_features(feature) |
|
|
|
|
| from pyannote.audio import Inference, Model |
|
|
| emb_model = Model.from_pretrained("pyannote/wespeaker-voxceleb-resnet34-LM") |
| inference = Inference(emb_model, window="whole") |
|
|
|
|
| def extract_features(file_path, device="cpu"): |
| inference.to(torch.device(device)) |
| return inference(file_path) |
|
|