--- library_name: transformers license: mit base_model: TencentGameMate/chinese-hubert-base tags: - generated_from_trainer metrics: - accuracy model-index: - name: hubert-base-ser results: [] --- # hubert-base-ser This model is a fine-tuned version of [TencentGameMate/chinese-hubert-base](https://huggingface.co/TencentGameMate/chinese-hubert-base) on an unknown dataset. It achieves the following results on the evaluation set: - Loss: 0.1466 - Accuracy: 0.9526 ## How to use ### Requirements ```bash # requirement packages !pip install git+https://github.com/huggingface/datasets.git !pip install git+https://github.com/huggingface/transformers.git !pip install torchaudio ``` ### Prediction ```python import os import torch import torchaudio import torch.nn as nn import torch.nn.functional as F from typing import Optional, Tuple from dataclasses import dataclass from transformers import AutoConfig, Wav2Vec2FeatureExtractor, HubertPreTrainedModel, HubertModel from transformers.file_utils import ModelOutput def speech_file_to_array_fn(path, sampling_rate): speech_array, _sampling_rate = torchaudio.load(path) resampler = torchaudio.transforms.Resample(_sampling_rate,sampling_rate) speech = resampler(speech_array).squeeze().numpy() return speech @dataclass class SpeechClassifierOutput(ModelOutput): loss: Optional[torch.FloatTensor] = None logits: torch.FloatTensor = None hidden_states: Optional[Tuple[torch.FloatTensor]] = None attentions: Optional[Tuple[torch.FloatTensor]] = None class HubertClassificationHead(nn.Module): """Head for hubert classification task.""" def __init__(self, config): super().__init__() self.dense = nn.Linear(config.hidden_size, config.hidden_size) self.dropout = nn.Dropout(config.final_dropout) self.out_proj = nn.Linear(config.hidden_size, config.num_labels) def forward(self, features, **kwargs): x = features x = self.dropout(x) x = self.dense(x) x = torch.tanh(x) x = self.dropout(x) x = self.out_proj(x) return x class HubertForSpeechClassification(HubertPreTrainedModel): def __init__(self, config): super().__init__(config) self.config = config self.pooling_mode = config.pooling_mode self.hubert = HubertModel(config) self.classifier = HubertClassificationHead(config) self.init_weights() def merged_strategy( self, hidden_states, mode="mean" ): if mode == "mean": outputs = torch.mean(hidden_states, dim=1) elif mode == "sum": outputs = torch.sum(hidden_states, dim=1) elif mode == "max": outputs = torch.max(hidden_states, dim=1)[0] else: raise Exception( "The pooling method hasn't been defined! Your pooling mode must be one of these ['mean', 'sum', 'max']") return outputs def forward(self, x): outputs = self.hubert(x) hidden_states = outputs[0] hidden_states = self.merged_strategy(hidden_states, mode=self.pooling_mode) logits = self.classifier(hidden_states) # 返回SpeechClassifierOutput对象 return SpeechClassifierOutput(logits=logits) def main(): print("正在加载模型...") device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model_name_or_path = "ZipperDeng/hubert-base-ser" config = AutoConfig.from_pretrained(model_name_or_path) feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_name_or_path) sampling_rate = feature_extractor.sampling_rate model = HubertForSpeechClassification.from_pretrained(model_name_or_path).to(device) def predict_single_file(file_path, sampling_rate): """预测单个音频文件的情感""" try: speech = speech_file_to_array_fn(file_path, sampling_rate) features = feature_extractor(speech, sampling_rate=sampling_rate, return_tensors="pt", padding=True) input_values = features.input_values.to(device) with torch.no_grad(): logits = model(input_values).logits scores = F.softmax(logits, dim=1).detach().cpu().numpy()[0] outputs = [{"Label": config.id2label[i], "Score": f"{round(score * 100, 3):.1f}%"} for i, score in enumerate(scores)] return outputs except Exception as e: print(f"处理文件 {file_path} 时出错: {e}") return None # 检查测试数据目录是否存在 test_data = r"F:\test_ser" if not os.path.exists(test_data): print(f"测试数据目录不存在: {test_data}") print("请确保目录存在并包含音频文件") return file_path_list = [f"{test_data}/{path}" for path in os.listdir(f"{test_data}") if path.endswith(('.wav', '.mp3', '.flac'))] print(f"找到 {len(file_path_list)} 个音频文件") # 逐个处理每个文件 for file_path in file_path_list: print(f"\n处理文件: {file_path}") outputs = predict_single_file(file_path, sampling_rate) print("预测结果:") for result in outputs: print(f" {result['Label']}: {result['Score']}") if __name__ == "__main__": # multiprocessing.freeze_support() main() ``` ## Model description More information needed ## Intended uses & limitations More information needed ## Training and evaluation data More information needed ## Training procedure ### Training hyperparameters The following hyperparameters were used during training: - learning_rate: 0.0001 - train_batch_size: 32 - eval_batch_size: 4 - seed: 42 - gradient_accumulation_steps: 2 - total_train_batch_size: 64 - optimizer: Use adamw_torch with betas=(0.9,0.999) and epsilon=1e-08 and optimizer_args=No additional optimizer arguments - lr_scheduler_type: linear - num_epochs: 1.0 - mixed_precision_training: Native AMP ### Training results | Training Loss | Epoch | Step | Validation Loss | Accuracy | |:-------------:|:------:|:----:|:---------------:|:--------:| | 0.9709 | 0.0229 | 10 | 0.8923 | 0.6399 | | 0.9219 | 0.0457 | 20 | 0.6903 | 0.7664 | | 0.7112 | 0.0686 | 30 | 0.5838 | 0.7909 | | 0.567 | 0.0914 | 40 | 0.5405 | 0.8159 | | 0.6184 | 0.1143 | 50 | 0.4148 | 0.8581 | | 0.5291 | 0.1371 | 60 | 0.4444 | 0.8511 | | 0.533 | 0.16 | 70 | 0.4643 | 0.8271 | | 0.4753 | 0.1829 | 80 | 0.3560 | 0.8767 | | 0.4252 | 0.2057 | 90 | 0.5889 | 0.8103 | | 0.5007 | 0.2286 | 100 | 0.3882 | 0.8663 | | 0.5605 | 0.2514 | 110 | 0.3221 | 0.8921 | | 0.4875 | 0.2743 | 120 | 0.3639 | 0.8559 | | 0.4277 | 0.2971 | 130 | 0.3571 | 0.8746 | | 0.3415 | 0.32 | 140 | 0.3382 | 0.8861 | | 0.413 | 0.3429 | 150 | 0.2596 | 0.9104 | | 0.377 | 0.3657 | 160 | 0.3519 | 0.8711 | | 0.4219 | 0.3886 | 170 | 0.2979 | 0.8947 | | 0.3317 | 0.4114 | 180 | 0.2227 | 0.9226 | | 0.3131 | 0.4343 | 190 | 0.3680 | 0.8693 | | 0.3266 | 0.4571 | 200 | 0.2098 | 0.9309 | | 0.3306 | 0.48 | 210 | 0.3849 | 0.8824 | | 0.3037 | 0.5029 | 220 | 0.2852 | 0.9024 | | 0.3086 | 0.5257 | 230 | 0.2725 | 0.9121 | | 0.2576 | 0.5486 | 240 | 0.1869 | 0.9356 | | 0.2469 | 0.5714 | 250 | 0.2262 | 0.9243 | | 0.2405 | 0.5943 | 260 | 0.1963 | 0.9347 | | 0.2802 | 0.6171 | 270 | 0.3680 | 0.8804 | | 0.2442 | 0.64 | 280 | 0.2053 | 0.9293 | | 0.2302 | 0.6629 | 290 | 0.3356 | 0.8967 | | 0.2492 | 0.6857 | 300 | 0.1880 | 0.9371 | | 0.2089 | 0.7086 | 310 | 0.2076 | 0.9289 | | 0.2824 | 0.7314 | 320 | 0.1999 | 0.9301 | | 0.2009 | 0.7543 | 330 | 0.1492 | 0.9521 | | 0.2001 | 0.7771 | 340 | 0.1496 | 0.9517 | | 0.2298 | 0.8 | 350 | 0.1579 | 0.9490 | | 0.1802 | 0.8229 | 360 | 0.1506 | 0.9501 | | 0.1914 | 0.8457 | 370 | 0.2036 | 0.9311 | | 0.1897 | 0.8686 | 380 | 0.1838 | 0.9383 | | 0.1203 | 0.8914 | 390 | 0.1459 | 0.9504 | | 0.1372 | 0.9143 | 400 | 0.1748 | 0.9419 | | 0.1942 | 0.9371 | 410 | 0.1813 | 0.9406 | | 0.1886 | 0.96 | 420 | 0.1536 | 0.9510 | | 0.1872 | 0.9829 | 430 | 0.1466 | 0.9526 | ### Framework versions - Transformers 4.47.0 - Pytorch 2.4.1+cu118 - Datasets 3.6.0 - Tokenizers 0.21.0