ZipperDeng commited on
Commit
52cff6c
·
verified ·
1 Parent(s): cd63b02

add predict code

Browse files
Files changed (1) hide show
  1. README.md +253 -106
README.md CHANGED
@@ -1,106 +1,253 @@
1
- ---
2
- library_name: transformers
3
- license: mit
4
- base_model: TencentGameMate/chinese-hubert-base
5
- tags:
6
- - generated_from_trainer
7
- metrics:
8
- - accuracy
9
- model-index:
10
- - name: hubert-base-ser
11
- results: []
12
- ---
13
-
14
- <!-- This model card has been generated automatically according to the information the Trainer had access to. You
15
- should probably proofread and complete it, then remove this comment. -->
16
-
17
- # hubert-base-ser
18
-
19
- This model is a fine-tuned version of [TencentGameMate/chinese-hubert-base](https://huggingface.co/TencentGameMate/chinese-hubert-base) on an unknown dataset.
20
- It achieves the following results on the evaluation set:
21
- - Loss: 0.1466
22
- - Accuracy: 0.9526
23
-
24
- ## Model description
25
-
26
- More information needed
27
-
28
- ## Intended uses & limitations
29
-
30
- More information needed
31
-
32
- ## Training and evaluation data
33
-
34
- More information needed
35
-
36
- ## Training procedure
37
-
38
- ### Training hyperparameters
39
-
40
- The following hyperparameters were used during training:
41
- - learning_rate: 0.0001
42
- - train_batch_size: 32
43
- - eval_batch_size: 4
44
- - seed: 42
45
- - gradient_accumulation_steps: 2
46
- - total_train_batch_size: 64
47
- - optimizer: Use adamw_torch with betas=(0.9,0.999) and epsilon=1e-08 and optimizer_args=No additional optimizer arguments
48
- - lr_scheduler_type: linear
49
- - num_epochs: 1.0
50
- - mixed_precision_training: Native AMP
51
-
52
- ### Training results
53
-
54
- | Training Loss | Epoch | Step | Validation Loss | Accuracy |
55
- |:-------------:|:------:|:----:|:---------------:|:--------:|
56
- | 0.9709 | 0.0229 | 10 | 0.8923 | 0.6399 |
57
- | 0.9219 | 0.0457 | 20 | 0.6903 | 0.7664 |
58
- | 0.7112 | 0.0686 | 30 | 0.5838 | 0.7909 |
59
- | 0.567 | 0.0914 | 40 | 0.5405 | 0.8159 |
60
- | 0.6184 | 0.1143 | 50 | 0.4148 | 0.8581 |
61
- | 0.5291 | 0.1371 | 60 | 0.4444 | 0.8511 |
62
- | 0.533 | 0.16 | 70 | 0.4643 | 0.8271 |
63
- | 0.4753 | 0.1829 | 80 | 0.3560 | 0.8767 |
64
- | 0.4252 | 0.2057 | 90 | 0.5889 | 0.8103 |
65
- | 0.5007 | 0.2286 | 100 | 0.3882 | 0.8663 |
66
- | 0.5605 | 0.2514 | 110 | 0.3221 | 0.8921 |
67
- | 0.4875 | 0.2743 | 120 | 0.3639 | 0.8559 |
68
- | 0.4277 | 0.2971 | 130 | 0.3571 | 0.8746 |
69
- | 0.3415 | 0.32 | 140 | 0.3382 | 0.8861 |
70
- | 0.413 | 0.3429 | 150 | 0.2596 | 0.9104 |
71
- | 0.377 | 0.3657 | 160 | 0.3519 | 0.8711 |
72
- | 0.4219 | 0.3886 | 170 | 0.2979 | 0.8947 |
73
- | 0.3317 | 0.4114 | 180 | 0.2227 | 0.9226 |
74
- | 0.3131 | 0.4343 | 190 | 0.3680 | 0.8693 |
75
- | 0.3266 | 0.4571 | 200 | 0.2098 | 0.9309 |
76
- | 0.3306 | 0.48 | 210 | 0.3849 | 0.8824 |
77
- | 0.3037 | 0.5029 | 220 | 0.2852 | 0.9024 |
78
- | 0.3086 | 0.5257 | 230 | 0.2725 | 0.9121 |
79
- | 0.2576 | 0.5486 | 240 | 0.1869 | 0.9356 |
80
- | 0.2469 | 0.5714 | 250 | 0.2262 | 0.9243 |
81
- | 0.2405 | 0.5943 | 260 | 0.1963 | 0.9347 |
82
- | 0.2802 | 0.6171 | 270 | 0.3680 | 0.8804 |
83
- | 0.2442 | 0.64 | 280 | 0.2053 | 0.9293 |
84
- | 0.2302 | 0.6629 | 290 | 0.3356 | 0.8967 |
85
- | 0.2492 | 0.6857 | 300 | 0.1880 | 0.9371 |
86
- | 0.2089 | 0.7086 | 310 | 0.2076 | 0.9289 |
87
- | 0.2824 | 0.7314 | 320 | 0.1999 | 0.9301 |
88
- | 0.2009 | 0.7543 | 330 | 0.1492 | 0.9521 |
89
- | 0.2001 | 0.7771 | 340 | 0.1496 | 0.9517 |
90
- | 0.2298 | 0.8 | 350 | 0.1579 | 0.9490 |
91
- | 0.1802 | 0.8229 | 360 | 0.1506 | 0.9501 |
92
- | 0.1914 | 0.8457 | 370 | 0.2036 | 0.9311 |
93
- | 0.1897 | 0.8686 | 380 | 0.1838 | 0.9383 |
94
- | 0.1203 | 0.8914 | 390 | 0.1459 | 0.9504 |
95
- | 0.1372 | 0.9143 | 400 | 0.1748 | 0.9419 |
96
- | 0.1942 | 0.9371 | 410 | 0.1813 | 0.9406 |
97
- | 0.1886 | 0.96 | 420 | 0.1536 | 0.9510 |
98
- | 0.1872 | 0.9829 | 430 | 0.1466 | 0.9526 |
99
-
100
-
101
- ### Framework versions
102
-
103
- - Transformers 4.47.0
104
- - Pytorch 2.4.1+cu118
105
- - Datasets 3.6.0
106
- - Tokenizers 0.21.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ license: mit
4
+ base_model: TencentGameMate/chinese-hubert-base
5
+ tags:
6
+ - generated_from_trainer
7
+ metrics:
8
+ - accuracy
9
+ model-index:
10
+ - name: hubert-base-ser
11
+ results: []
12
+ ---
13
+
14
+ <!-- This model card has been generated automatically according to the information the Trainer had access to. You
15
+ should probably proofread and complete it, then remove this comment. -->
16
+
17
+ # hubert-base-ser
18
+
19
+ This model is a fine-tuned version of [TencentGameMate/chinese-hubert-base](https://huggingface.co/TencentGameMate/chinese-hubert-base) on an unknown dataset.
20
+ It achieves the following results on the evaluation set:
21
+ - Loss: 0.1466
22
+ - Accuracy: 0.9526
23
+
24
+ ## How to use
25
+
26
+ ### Requirements
27
+
28
+ ```bash
29
+ # requirement packages
30
+ !pip install git+https://github.com/huggingface/datasets.git
31
+ !pip install git+https://github.com/huggingface/transformers.git
32
+ !pip install torchaudio
33
+ ```
34
+
35
+
36
+ ### Prediction
37
+
38
+ ```python
39
+ import os
40
+ import torch
41
+ import torchaudio
42
+ import torch.nn as nn
43
+ import torch.nn.functional as F
44
+ from typing import Optional, Tuple
45
+ from dataclasses import dataclass
46
+ from transformers import AutoConfig, Wav2Vec2FeatureExtractor, HubertPreTrainedModel, HubertModel
47
+ from transformers.file_utils import ModelOutput
48
+
49
+ def speech_file_to_array_fn(path, sampling_rate):
50
+ speech_array, _sampling_rate = torchaudio.load(path)
51
+ resampler = torchaudio.transforms.Resample(_sampling_rate,sampling_rate)
52
+ speech = resampler(speech_array).squeeze().numpy()
53
+ return speech
54
+
55
+
56
+ @dataclass
57
+ class SpeechClassifierOutput(ModelOutput):
58
+ loss: Optional[torch.FloatTensor] = None
59
+ logits: torch.FloatTensor = None
60
+ hidden_states: Optional[Tuple[torch.FloatTensor]] = None
61
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
62
+
63
+ class HubertClassificationHead(nn.Module):
64
+ """Head for hubert classification task."""
65
+
66
+ def __init__(self, config):
67
+ super().__init__()
68
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
69
+ self.dropout = nn.Dropout(config.final_dropout)
70
+ self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
71
+
72
+ def forward(self, features, **kwargs):
73
+ x = features
74
+ x = self.dropout(x)
75
+ x = self.dense(x)
76
+ x = torch.tanh(x)
77
+ x = self.dropout(x)
78
+ x = self.out_proj(x)
79
+ return x
80
+
81
+
82
+ class HubertForSpeechClassification(HubertPreTrainedModel):
83
+ def __init__(self, config):
84
+ super().__init__(config)
85
+ self.config = config
86
+ self.pooling_mode = config.pooling_mode
87
+
88
+ self.hubert = HubertModel(config)
89
+ self.classifier = HubertClassificationHead(config)
90
+ self.init_weights()
91
+
92
+ def merged_strategy(
93
+ self,
94
+ hidden_states,
95
+ mode="mean"
96
+ ):
97
+ if mode == "mean":
98
+ outputs = torch.mean(hidden_states, dim=1)
99
+ elif mode == "sum":
100
+ outputs = torch.sum(hidden_states, dim=1)
101
+ elif mode == "max":
102
+ outputs = torch.max(hidden_states, dim=1)[0]
103
+ else:
104
+ raise Exception(
105
+ "The pooling method hasn't been defined! Your pooling mode must be one of these ['mean', 'sum', 'max']")
106
+
107
+ return outputs
108
+
109
+ def forward(self, x):
110
+ outputs = self.hubert(x)
111
+ hidden_states = outputs[0]
112
+ hidden_states = self.merged_strategy(hidden_states, mode=self.pooling_mode)
113
+ logits = self.classifier(hidden_states)
114
+ # 返回SpeechClassifierOutput对象
115
+ return SpeechClassifierOutput(logits=logits)
116
+
117
+
118
+ def main():
119
+ print("正在加载模型...")
120
+
121
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
122
+ model_name_or_path = "ZipperDeng/hubert-base-ser"
123
+ config = AutoConfig.from_pretrained(model_name_or_path)
124
+ feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(model_name_or_path)
125
+ sampling_rate = feature_extractor.sampling_rate
126
+ model = HubertForSpeechClassification.from_pretrained(model_name_or_path).to(device)
127
+
128
+
129
+ def predict_single_file(file_path, sampling_rate):
130
+ """预测单个音频文件的情感"""
131
+ try:
132
+ speech = speech_file_to_array_fn(file_path, sampling_rate)
133
+ features = feature_extractor(speech, sampling_rate=sampling_rate, return_tensors="pt", padding=True)
134
+
135
+ input_values = features.input_values.to(device)
136
+
137
+ with torch.no_grad():
138
+ logits = model(input_values).logits
139
+
140
+ scores = F.softmax(logits, dim=1).detach().cpu().numpy()[0]
141
+ outputs = [{"Label": config.id2label[i], "Score": f"{round(score * 100, 3):.1f}%"} for i, score in enumerate(scores)]
142
+ return outputs
143
+ except Exception as e:
144
+ print(f"处理文件 {file_path} 时出错: {e}")
145
+ return None
146
+
147
+ # 检查测试数据目录是否存在
148
+ test_data = r"F:\test_ser"
149
+ if not os.path.exists(test_data):
150
+ print(f"测试数据目录不存在: {test_data}")
151
+ print("请确保目录存在并包含音频文件")
152
+ return
153
+
154
+ file_path_list = [f"{test_data}/{path}" for path in os.listdir(f"{test_data}") if path.endswith(('.wav', '.mp3', '.flac'))]
155
+ print(f"找到 {len(file_path_list)} 个音频文件")
156
+
157
+ # 逐个处理每个文件
158
+ for file_path in file_path_list:
159
+ print(f"\n处理文件: {file_path}")
160
+ outputs = predict_single_file(file_path, sampling_rate)
161
+ print("预测结果:")
162
+ for result in outputs:
163
+ print(f" {result['Label']}: {result['Score']}")
164
+
165
+
166
+ if __name__ == "__main__":
167
+ # multiprocessing.freeze_support()
168
+ main()
169
+ ```
170
+
171
+ ## Model description
172
+
173
+ More information needed
174
+
175
+ ## Intended uses & limitations
176
+
177
+ More information needed
178
+
179
+ ## Training and evaluation data
180
+
181
+ More information needed
182
+
183
+ ## Training procedure
184
+
185
+ ### Training hyperparameters
186
+
187
+ The following hyperparameters were used during training:
188
+ - learning_rate: 0.0001
189
+ - train_batch_size: 32
190
+ - eval_batch_size: 4
191
+ - seed: 42
192
+ - gradient_accumulation_steps: 2
193
+ - total_train_batch_size: 64
194
+ - optimizer: Use adamw_torch with betas=(0.9,0.999) and epsilon=1e-08 and optimizer_args=No additional optimizer arguments
195
+ - lr_scheduler_type: linear
196
+ - num_epochs: 1.0
197
+ - mixed_precision_training: Native AMP
198
+
199
+ ### Training results
200
+
201
+ | Training Loss | Epoch | Step | Validation Loss | Accuracy |
202
+ |:-------------:|:------:|:----:|:---------------:|:--------:|
203
+ | 0.9709 | 0.0229 | 10 | 0.8923 | 0.6399 |
204
+ | 0.9219 | 0.0457 | 20 | 0.6903 | 0.7664 |
205
+ | 0.7112 | 0.0686 | 30 | 0.5838 | 0.7909 |
206
+ | 0.567 | 0.0914 | 40 | 0.5405 | 0.8159 |
207
+ | 0.6184 | 0.1143 | 50 | 0.4148 | 0.8581 |
208
+ | 0.5291 | 0.1371 | 60 | 0.4444 | 0.8511 |
209
+ | 0.533 | 0.16 | 70 | 0.4643 | 0.8271 |
210
+ | 0.4753 | 0.1829 | 80 | 0.3560 | 0.8767 |
211
+ | 0.4252 | 0.2057 | 90 | 0.5889 | 0.8103 |
212
+ | 0.5007 | 0.2286 | 100 | 0.3882 | 0.8663 |
213
+ | 0.5605 | 0.2514 | 110 | 0.3221 | 0.8921 |
214
+ | 0.4875 | 0.2743 | 120 | 0.3639 | 0.8559 |
215
+ | 0.4277 | 0.2971 | 130 | 0.3571 | 0.8746 |
216
+ | 0.3415 | 0.32 | 140 | 0.3382 | 0.8861 |
217
+ | 0.413 | 0.3429 | 150 | 0.2596 | 0.9104 |
218
+ | 0.377 | 0.3657 | 160 | 0.3519 | 0.8711 |
219
+ | 0.4219 | 0.3886 | 170 | 0.2979 | 0.8947 |
220
+ | 0.3317 | 0.4114 | 180 | 0.2227 | 0.9226 |
221
+ | 0.3131 | 0.4343 | 190 | 0.3680 | 0.8693 |
222
+ | 0.3266 | 0.4571 | 200 | 0.2098 | 0.9309 |
223
+ | 0.3306 | 0.48 | 210 | 0.3849 | 0.8824 |
224
+ | 0.3037 | 0.5029 | 220 | 0.2852 | 0.9024 |
225
+ | 0.3086 | 0.5257 | 230 | 0.2725 | 0.9121 |
226
+ | 0.2576 | 0.5486 | 240 | 0.1869 | 0.9356 |
227
+ | 0.2469 | 0.5714 | 250 | 0.2262 | 0.9243 |
228
+ | 0.2405 | 0.5943 | 260 | 0.1963 | 0.9347 |
229
+ | 0.2802 | 0.6171 | 270 | 0.3680 | 0.8804 |
230
+ | 0.2442 | 0.64 | 280 | 0.2053 | 0.9293 |
231
+ | 0.2302 | 0.6629 | 290 | 0.3356 | 0.8967 |
232
+ | 0.2492 | 0.6857 | 300 | 0.1880 | 0.9371 |
233
+ | 0.2089 | 0.7086 | 310 | 0.2076 | 0.9289 |
234
+ | 0.2824 | 0.7314 | 320 | 0.1999 | 0.9301 |
235
+ | 0.2009 | 0.7543 | 330 | 0.1492 | 0.9521 |
236
+ | 0.2001 | 0.7771 | 340 | 0.1496 | 0.9517 |
237
+ | 0.2298 | 0.8 | 350 | 0.1579 | 0.9490 |
238
+ | 0.1802 | 0.8229 | 360 | 0.1506 | 0.9501 |
239
+ | 0.1914 | 0.8457 | 370 | 0.2036 | 0.9311 |
240
+ | 0.1897 | 0.8686 | 380 | 0.1838 | 0.9383 |
241
+ | 0.1203 | 0.8914 | 390 | 0.1459 | 0.9504 |
242
+ | 0.1372 | 0.9143 | 400 | 0.1748 | 0.9419 |
243
+ | 0.1942 | 0.9371 | 410 | 0.1813 | 0.9406 |
244
+ | 0.1886 | 0.96 | 420 | 0.1536 | 0.9510 |
245
+ | 0.1872 | 0.9829 | 430 | 0.1466 | 0.9526 |
246
+
247
+
248
+ ### Framework versions
249
+
250
+ - Transformers 4.47.0
251
+ - Pytorch 2.4.1+cu118
252
+ - Datasets 3.6.0
253
+ - Tokenizers 0.21.0