hubert-base-ser / README.md
ZipperDeng's picture
add predict code
52cff6c verified
---
library_name: transformers
license: mit
base_model: TencentGameMate/chinese-hubert-base
tags:
- generated_from_trainer
metrics:
- accuracy
model-index:
- name: hubert-base-ser
results: []
---
<!-- This model card has been generated automatically according to the information the Trainer had access to. You
should probably proofread and complete it, then remove this comment. -->
# hubert-base-ser
This model is a fine-tuned version of [TencentGameMate/chinese-hubert-base](https://huggingface.co/TencentGameMate/chinese-hubert-base) on an unknown dataset.
It achieves the following results on the evaluation set:
- Loss: 0.1466
- Accuracy: 0.9526
## How to use
### Requirements
```bash
# requirement packages
!pip install git+https://github.com/huggingface/datasets.git
!pip install git+https://github.com/huggingface/transformers.git
!pip install torchaudio
```
### Prediction
```python
import os
import torch
import torchaudio
import torch.nn as nn
import torch.nn.functional as F
from typing import Optional, Tuple
from dataclasses import dataclass
from transformers import AutoConfig, Wav2Vec2FeatureExtractor, HubertPreTrainedModel, HubertModel
from transformers.file_utils import ModelOutput
def speech_file_to_array_fn(path, sampling_rate):
speech_array, _sampling_rate = torchaudio.load(path)
resampler = torchaudio.transforms.Resample(_sampling_rate,sampling_rate)
speech = resampler(speech_array).squeeze().numpy()
return speech
@dataclass
class SpeechClassifierOutput(ModelOutput):
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
class HubertClassificationHead(nn.Module):
"""Head for hubert classification task."""
def __init__(self, config):
super().__init__()
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
self.dropout = nn.Dropout(config.final_dropout)
self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
def forward(self, features, **kwargs):
x = features
x = self.dropout(x)
x = self.dense(x)
x = torch.tanh(x)
x = self.dropout(x)
x = self.out_proj(x)
return x
class HubertForSpeechClassification(HubertPreTrainedModel):
def __init__(self, config):
super().__init__(config)
self.config = config
self.pooling_mode = config.pooling_mode
self.hubert = HubertModel(config)
self.classifier = HubertClassificationHead(config)
self.init_weights()
def merged_strategy(
self,
hidden_states,
mode="mean"
):
if mode == "mean":
outputs = torch.mean(hidden_states, dim=1)
elif mode == "sum":
outputs = torch.sum(hidden_states, dim=1)
elif mode == "max":
outputs = torch.max(hidden_states, dim=1)[0]
else:
raise Exception(
"The pooling method hasn't been defined! Your pooling mode must be one of these ['mean', 'sum', 'max']")
return outputs
def forward(self, x):
outputs = self.hubert(x)
hidden_states = outputs[0]
hidden_states = self.merged_strategy(hidden_states, mode=self.pooling_mode)
logits = self.classifier(hidden_states)
# 返回SpeechClassifierOutput对象
return SpeechClassifierOutput(logits=logits)
def main():
print("正在加载模型...")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_name_or_path = "ZipperDeng/hubert-base-ser"
config = AutoConfig.from_pretrained(model_name_or_path)
feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_name_or_path)
sampling_rate = feature_extractor.sampling_rate
model = HubertForSpeechClassification.from_pretrained(model_name_or_path).to(device)
def predict_single_file(file_path, sampling_rate):
"""预测单个音频文件的情感"""
try:
speech = speech_file_to_array_fn(file_path, sampling_rate)
features = feature_extractor(speech, sampling_rate=sampling_rate, return_tensors="pt", padding=True)
input_values = features.input_values.to(device)
with torch.no_grad():
logits = model(input_values).logits
scores = F.softmax(logits, dim=1).detach().cpu().numpy()[0]
outputs = [{"Label": config.id2label[i], "Score": f"{round(score * 100, 3):.1f}%"} for i, score in enumerate(scores)]
return outputs
except Exception as e:
print(f"处理文件 {file_path} 时出错: {e}")
return None
# 检查测试数据目录是否存在
test_data = r"F:\test_ser"
if not os.path.exists(test_data):
print(f"测试数据目录不存在: {test_data}")
print("请确保目录存在并包含音频文件")
return
file_path_list = [f"{test_data}/{path}" for path in os.listdir(f"{test_data}") if path.endswith(('.wav', '.mp3', '.flac'))]
print(f"找到 {len(file_path_list)} 个音频文件")
# 逐个处理每个文件
for file_path in file_path_list:
print(f"\n处理文件: {file_path}")
outputs = predict_single_file(file_path, sampling_rate)
print("预测结果:")
for result in outputs:
print(f" {result['Label']}: {result['Score']}")
if __name__ == "__main__":
# multiprocessing.freeze_support()
main()
```
## Model description
More information needed
## Intended uses & limitations
More information needed
## Training and evaluation data
More information needed
## Training procedure
### Training hyperparameters
The following hyperparameters were used during training:
- learning_rate: 0.0001
- train_batch_size: 32
- eval_batch_size: 4
- seed: 42
- gradient_accumulation_steps: 2
- total_train_batch_size: 64
- optimizer: Use adamw_torch with betas=(0.9,0.999) and epsilon=1e-08 and optimizer_args=No additional optimizer arguments
- lr_scheduler_type: linear
- num_epochs: 1.0
- mixed_precision_training: Native AMP
### Training results
| Training Loss | Epoch | Step | Validation Loss | Accuracy |
|:-------------:|:------:|:----:|:---------------:|:--------:|
| 0.9709 | 0.0229 | 10 | 0.8923 | 0.6399 |
| 0.9219 | 0.0457 | 20 | 0.6903 | 0.7664 |
| 0.7112 | 0.0686 | 30 | 0.5838 | 0.7909 |
| 0.567 | 0.0914 | 40 | 0.5405 | 0.8159 |
| 0.6184 | 0.1143 | 50 | 0.4148 | 0.8581 |
| 0.5291 | 0.1371 | 60 | 0.4444 | 0.8511 |
| 0.533 | 0.16 | 70 | 0.4643 | 0.8271 |
| 0.4753 | 0.1829 | 80 | 0.3560 | 0.8767 |
| 0.4252 | 0.2057 | 90 | 0.5889 | 0.8103 |
| 0.5007 | 0.2286 | 100 | 0.3882 | 0.8663 |
| 0.5605 | 0.2514 | 110 | 0.3221 | 0.8921 |
| 0.4875 | 0.2743 | 120 | 0.3639 | 0.8559 |
| 0.4277 | 0.2971 | 130 | 0.3571 | 0.8746 |
| 0.3415 | 0.32 | 140 | 0.3382 | 0.8861 |
| 0.413 | 0.3429 | 150 | 0.2596 | 0.9104 |
| 0.377 | 0.3657 | 160 | 0.3519 | 0.8711 |
| 0.4219 | 0.3886 | 170 | 0.2979 | 0.8947 |
| 0.3317 | 0.4114 | 180 | 0.2227 | 0.9226 |
| 0.3131 | 0.4343 | 190 | 0.3680 | 0.8693 |
| 0.3266 | 0.4571 | 200 | 0.2098 | 0.9309 |
| 0.3306 | 0.48 | 210 | 0.3849 | 0.8824 |
| 0.3037 | 0.5029 | 220 | 0.2852 | 0.9024 |
| 0.3086 | 0.5257 | 230 | 0.2725 | 0.9121 |
| 0.2576 | 0.5486 | 240 | 0.1869 | 0.9356 |
| 0.2469 | 0.5714 | 250 | 0.2262 | 0.9243 |
| 0.2405 | 0.5943 | 260 | 0.1963 | 0.9347 |
| 0.2802 | 0.6171 | 270 | 0.3680 | 0.8804 |
| 0.2442 | 0.64 | 280 | 0.2053 | 0.9293 |
| 0.2302 | 0.6629 | 290 | 0.3356 | 0.8967 |
| 0.2492 | 0.6857 | 300 | 0.1880 | 0.9371 |
| 0.2089 | 0.7086 | 310 | 0.2076 | 0.9289 |
| 0.2824 | 0.7314 | 320 | 0.1999 | 0.9301 |
| 0.2009 | 0.7543 | 330 | 0.1492 | 0.9521 |
| 0.2001 | 0.7771 | 340 | 0.1496 | 0.9517 |
| 0.2298 | 0.8 | 350 | 0.1579 | 0.9490 |
| 0.1802 | 0.8229 | 360 | 0.1506 | 0.9501 |
| 0.1914 | 0.8457 | 370 | 0.2036 | 0.9311 |
| 0.1897 | 0.8686 | 380 | 0.1838 | 0.9383 |
| 0.1203 | 0.8914 | 390 | 0.1459 | 0.9504 |
| 0.1372 | 0.9143 | 400 | 0.1748 | 0.9419 |
| 0.1942 | 0.9371 | 410 | 0.1813 | 0.9406 |
| 0.1886 | 0.96 | 420 | 0.1536 | 0.9510 |
| 0.1872 | 0.9829 | 430 | 0.1466 | 0.9526 |
### Framework versions
- Transformers 4.47.0
- Pytorch 2.4.1+cu118
- Datasets 3.6.0
- Tokenizers 0.21.0