hankzero101 commited on
Commit
b87bbbc
·
1 Parent(s): cb4ab1f

Upload 52 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +14 -0
  2. Bark_Voice_Cloning_UI.ipynb +1088 -0
  3. Dockerfile +38 -0
  4. README.md +21 -10
  5. app.py +468 -0
  6. bark/__init__.py +2 -0
  7. bark/api.py +158 -0
  8. bark/assets/prompts/announcer.npz +3 -0
  9. bark/assets/prompts/file.npz +3 -0
  10. bark/assets/prompts/v2/en_speaker_0.npz +3 -0
  11. bark/assets/prompts/v2/en_speaker_1.npz +3 -0
  12. bark/assets/prompts/v2/en_speaker_2.npz +3 -0
  13. bark/assets/prompts/v2/en_speaker_3.npz +3 -0
  14. bark/assets/prompts/v2/en_speaker_4.npz +3 -0
  15. bark/assets/prompts/v2/en_speaker_5.npz +3 -0
  16. bark/assets/prompts/v2/en_speaker_6.npz +3 -0
  17. bark/assets/prompts/v2/en_speaker_7.npz +3 -0
  18. bark/assets/prompts/v2/en_speaker_8.npz +3 -0
  19. bark/assets/prompts/v2/en_speaker_9.npz +3 -0
  20. bark/assets/prompts/v2/zh_speaker_0.npz +3 -0
  21. bark/assets/prompts/v2/zh_speaker_1.npz +3 -0
  22. bark/assets/prompts/v2/zh_speaker_2.npz +3 -0
  23. bark/assets/prompts/v2/zh_speaker_3.npz +3 -0
  24. bark/assets/prompts/v2/zh_speaker_4.npz +3 -0
  25. bark/assets/prompts/v2/zh_speaker_5.npz +3 -0
  26. bark/assets/prompts/v2/zh_speaker_6.npz +3 -0
  27. bark/assets/prompts/v2/zh_speaker_7.npz +3 -0
  28. bark/assets/prompts/v2/zh_speaker_8.npz +3 -0
  29. bark/assets/prompts/v2/zh_speaker_9.npz +3 -0
  30. bark/generation.py +864 -0
  31. bark/hubert/__init__.py +0 -0
  32. bark/hubert/customtokenizer.py +195 -0
  33. bark/hubert/hubert_manager.py +48 -0
  34. bark/hubert/pre_kmeans_hubert.py +107 -0
  35. bark/model.py +218 -0
  36. bark/model_fine.py +149 -0
  37. bark/settings.py +7 -0
  38. cloning/__init__.py +0 -0
  39. cloning/clonevoice.py +68 -0
  40. config.yaml +8 -0
  41. pyproject.toml +60 -0
  42. requirements.txt +13 -0
  43. setup.py +3 -0
  44. swap_voice.py +62 -0
  45. training/__init__.py +0 -0
  46. training/data.py +52 -0
  47. training/train.py +47 -0
  48. training/training_prepare.py +73 -0
  49. util/__init__.py +0 -0
  50. util/helper.py +35 -0
.gitignore ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ /outputs
3
+ /speakers
4
+ .vs
5
+ *.npz
6
+ *.wav
7
+ *.npy
8
+ .vs/
9
+ /models
10
+ /bark_ui_enhanced.egg-info
11
+ /build/lib/bark
12
+ *.pth
13
+ *.pt
14
+ *.zip
Bark_Voice_Cloning_UI.ipynb ADDED
@@ -0,0 +1,1088 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "provenance": [],
7
+ "machine_shape": "hm",
8
+ "gpuType": "A100",
9
+ "authorship_tag": "ABX9TyPdN5espelPFPe/F1OA4L5f",
10
+ "include_colab_link": true
11
+ },
12
+ "kernelspec": {
13
+ "name": "python3",
14
+ "display_name": "Python 3"
15
+ },
16
+ "language_info": {
17
+ "name": "python"
18
+ },
19
+ "accelerator": "GPU",
20
+ "widgets": {
21
+ "application/vnd.jupyter.widget-state+json": {
22
+ "425505387f374468870cc4bcb52ea6c5": {
23
+ "model_module": "@jupyter-widgets/controls",
24
+ "model_name": "HBoxModel",
25
+ "model_module_version": "1.5.0",
26
+ "state": {
27
+ "_dom_classes": [],
28
+ "_model_module": "@jupyter-widgets/controls",
29
+ "_model_module_version": "1.5.0",
30
+ "_model_name": "HBoxModel",
31
+ "_view_count": null,
32
+ "_view_module": "@jupyter-widgets/controls",
33
+ "_view_module_version": "1.5.0",
34
+ "_view_name": "HBoxView",
35
+ "box_style": "",
36
+ "children": [
37
+ "IPY_MODEL_9b039beb3d7c4bc59ab95bd5d8a7dfcc",
38
+ "IPY_MODEL_55bf104e557340e5a88962134a765f1b",
39
+ "IPY_MODEL_f0768ce2c3484c4583810f461a0b742e"
40
+ ],
41
+ "layout": "IPY_MODEL_ff7dff340d9f41c29313ca68034be359"
42
+ }
43
+ },
44
+ "9b039beb3d7c4bc59ab95bd5d8a7dfcc": {
45
+ "model_module": "@jupyter-widgets/controls",
46
+ "model_name": "HTMLModel",
47
+ "model_module_version": "1.5.0",
48
+ "state": {
49
+ "_dom_classes": [],
50
+ "_model_module": "@jupyter-widgets/controls",
51
+ "_model_module_version": "1.5.0",
52
+ "_model_name": "HTMLModel",
53
+ "_view_count": null,
54
+ "_view_module": "@jupyter-widgets/controls",
55
+ "_view_module_version": "1.5.0",
56
+ "_view_name": "HTMLView",
57
+ "description": "",
58
+ "description_tooltip": null,
59
+ "layout": "IPY_MODEL_77a54e634f0d44c080eb769a4d2921b0",
60
+ "placeholder": "​",
61
+ "style": "IPY_MODEL_5584c9aaa4e04734bb6833cf7cf76534",
62
+ "value": "Downloading (…)rt_base_ls960_14.pth: 100%"
63
+ }
64
+ },
65
+ "55bf104e557340e5a88962134a765f1b": {
66
+ "model_module": "@jupyter-widgets/controls",
67
+ "model_name": "FloatProgressModel",
68
+ "model_module_version": "1.5.0",
69
+ "state": {
70
+ "_dom_classes": [],
71
+ "_model_module": "@jupyter-widgets/controls",
72
+ "_model_module_version": "1.5.0",
73
+ "_model_name": "FloatProgressModel",
74
+ "_view_count": null,
75
+ "_view_module": "@jupyter-widgets/controls",
76
+ "_view_module_version": "1.5.0",
77
+ "_view_name": "ProgressView",
78
+ "bar_style": "success",
79
+ "description": "",
80
+ "description_tooltip": null,
81
+ "layout": "IPY_MODEL_7a2e70b96a054cdd89f73edd2474e20c",
82
+ "max": 103981977,
83
+ "min": 0,
84
+ "orientation": "horizontal",
85
+ "style": "IPY_MODEL_1120230111694b4d8e63d476b0a35454",
86
+ "value": 103981977
87
+ }
88
+ },
89
+ "f0768ce2c3484c4583810f461a0b742e": {
90
+ "model_module": "@jupyter-widgets/controls",
91
+ "model_name": "HTMLModel",
92
+ "model_module_version": "1.5.0",
93
+ "state": {
94
+ "_dom_classes": [],
95
+ "_model_module": "@jupyter-widgets/controls",
96
+ "_model_module_version": "1.5.0",
97
+ "_model_name": "HTMLModel",
98
+ "_view_count": null,
99
+ "_view_module": "@jupyter-widgets/controls",
100
+ "_view_module_version": "1.5.0",
101
+ "_view_name": "HTMLView",
102
+ "description": "",
103
+ "description_tooltip": null,
104
+ "layout": "IPY_MODEL_643343218af349aaa63afbcd3cbc8009",
105
+ "placeholder": "​",
106
+ "style": "IPY_MODEL_dfb0df17546545a4b74ed7f5f10c7a9a",
107
+ "value": " 104M/104M [00:00<00:00, 406MB/s]"
108
+ }
109
+ },
110
+ "ff7dff340d9f41c29313ca68034be359": {
111
+ "model_module": "@jupyter-widgets/base",
112
+ "model_name": "LayoutModel",
113
+ "model_module_version": "1.2.0",
114
+ "state": {
115
+ "_model_module": "@jupyter-widgets/base",
116
+ "_model_module_version": "1.2.0",
117
+ "_model_name": "LayoutModel",
118
+ "_view_count": null,
119
+ "_view_module": "@jupyter-widgets/base",
120
+ "_view_module_version": "1.2.0",
121
+ "_view_name": "LayoutView",
122
+ "align_content": null,
123
+ "align_items": null,
124
+ "align_self": null,
125
+ "border": null,
126
+ "bottom": null,
127
+ "display": null,
128
+ "flex": null,
129
+ "flex_flow": null,
130
+ "grid_area": null,
131
+ "grid_auto_columns": null,
132
+ "grid_auto_flow": null,
133
+ "grid_auto_rows": null,
134
+ "grid_column": null,
135
+ "grid_gap": null,
136
+ "grid_row": null,
137
+ "grid_template_areas": null,
138
+ "grid_template_columns": null,
139
+ "grid_template_rows": null,
140
+ "height": null,
141
+ "justify_content": null,
142
+ "justify_items": null,
143
+ "left": null,
144
+ "margin": null,
145
+ "max_height": null,
146
+ "max_width": null,
147
+ "min_height": null,
148
+ "min_width": null,
149
+ "object_fit": null,
150
+ "object_position": null,
151
+ "order": null,
152
+ "overflow": null,
153
+ "overflow_x": null,
154
+ "overflow_y": null,
155
+ "padding": null,
156
+ "right": null,
157
+ "top": null,
158
+ "visibility": null,
159
+ "width": null
160
+ }
161
+ },
162
+ "77a54e634f0d44c080eb769a4d2921b0": {
163
+ "model_module": "@jupyter-widgets/base",
164
+ "model_name": "LayoutModel",
165
+ "model_module_version": "1.2.0",
166
+ "state": {
167
+ "_model_module": "@jupyter-widgets/base",
168
+ "_model_module_version": "1.2.0",
169
+ "_model_name": "LayoutModel",
170
+ "_view_count": null,
171
+ "_view_module": "@jupyter-widgets/base",
172
+ "_view_module_version": "1.2.0",
173
+ "_view_name": "LayoutView",
174
+ "align_content": null,
175
+ "align_items": null,
176
+ "align_self": null,
177
+ "border": null,
178
+ "bottom": null,
179
+ "display": null,
180
+ "flex": null,
181
+ "flex_flow": null,
182
+ "grid_area": null,
183
+ "grid_auto_columns": null,
184
+ "grid_auto_flow": null,
185
+ "grid_auto_rows": null,
186
+ "grid_column": null,
187
+ "grid_gap": null,
188
+ "grid_row": null,
189
+ "grid_template_areas": null,
190
+ "grid_template_columns": null,
191
+ "grid_template_rows": null,
192
+ "height": null,
193
+ "justify_content": null,
194
+ "justify_items": null,
195
+ "left": null,
196
+ "margin": null,
197
+ "max_height": null,
198
+ "max_width": null,
199
+ "min_height": null,
200
+ "min_width": null,
201
+ "object_fit": null,
202
+ "object_position": null,
203
+ "order": null,
204
+ "overflow": null,
205
+ "overflow_x": null,
206
+ "overflow_y": null,
207
+ "padding": null,
208
+ "right": null,
209
+ "top": null,
210
+ "visibility": null,
211
+ "width": null
212
+ }
213
+ },
214
+ "5584c9aaa4e04734bb6833cf7cf76534": {
215
+ "model_module": "@jupyter-widgets/controls",
216
+ "model_name": "DescriptionStyleModel",
217
+ "model_module_version": "1.5.0",
218
+ "state": {
219
+ "_model_module": "@jupyter-widgets/controls",
220
+ "_model_module_version": "1.5.0",
221
+ "_model_name": "DescriptionStyleModel",
222
+ "_view_count": null,
223
+ "_view_module": "@jupyter-widgets/base",
224
+ "_view_module_version": "1.2.0",
225
+ "_view_name": "StyleView",
226
+ "description_width": ""
227
+ }
228
+ },
229
+ "7a2e70b96a054cdd89f73edd2474e20c": {
230
+ "model_module": "@jupyter-widgets/base",
231
+ "model_name": "LayoutModel",
232
+ "model_module_version": "1.2.0",
233
+ "state": {
234
+ "_model_module": "@jupyter-widgets/base",
235
+ "_model_module_version": "1.2.0",
236
+ "_model_name": "LayoutModel",
237
+ "_view_count": null,
238
+ "_view_module": "@jupyter-widgets/base",
239
+ "_view_module_version": "1.2.0",
240
+ "_view_name": "LayoutView",
241
+ "align_content": null,
242
+ "align_items": null,
243
+ "align_self": null,
244
+ "border": null,
245
+ "bottom": null,
246
+ "display": null,
247
+ "flex": null,
248
+ "flex_flow": null,
249
+ "grid_area": null,
250
+ "grid_auto_columns": null,
251
+ "grid_auto_flow": null,
252
+ "grid_auto_rows": null,
253
+ "grid_column": null,
254
+ "grid_gap": null,
255
+ "grid_row": null,
256
+ "grid_template_areas": null,
257
+ "grid_template_columns": null,
258
+ "grid_template_rows": null,
259
+ "height": null,
260
+ "justify_content": null,
261
+ "justify_items": null,
262
+ "left": null,
263
+ "margin": null,
264
+ "max_height": null,
265
+ "max_width": null,
266
+ "min_height": null,
267
+ "min_width": null,
268
+ "object_fit": null,
269
+ "object_position": null,
270
+ "order": null,
271
+ "overflow": null,
272
+ "overflow_x": null,
273
+ "overflow_y": null,
274
+ "padding": null,
275
+ "right": null,
276
+ "top": null,
277
+ "visibility": null,
278
+ "width": null
279
+ }
280
+ },
281
+ "1120230111694b4d8e63d476b0a35454": {
282
+ "model_module": "@jupyter-widgets/controls",
283
+ "model_name": "ProgressStyleModel",
284
+ "model_module_version": "1.5.0",
285
+ "state": {
286
+ "_model_module": "@jupyter-widgets/controls",
287
+ "_model_module_version": "1.5.0",
288
+ "_model_name": "ProgressStyleModel",
289
+ "_view_count": null,
290
+ "_view_module": "@jupyter-widgets/base",
291
+ "_view_module_version": "1.2.0",
292
+ "_view_name": "StyleView",
293
+ "bar_color": null,
294
+ "description_width": ""
295
+ }
296
+ },
297
+ "643343218af349aaa63afbcd3cbc8009": {
298
+ "model_module": "@jupyter-widgets/base",
299
+ "model_name": "LayoutModel",
300
+ "model_module_version": "1.2.0",
301
+ "state": {
302
+ "_model_module": "@jupyter-widgets/base",
303
+ "_model_module_version": "1.2.0",
304
+ "_model_name": "LayoutModel",
305
+ "_view_count": null,
306
+ "_view_module": "@jupyter-widgets/base",
307
+ "_view_module_version": "1.2.0",
308
+ "_view_name": "LayoutView",
309
+ "align_content": null,
310
+ "align_items": null,
311
+ "align_self": null,
312
+ "border": null,
313
+ "bottom": null,
314
+ "display": null,
315
+ "flex": null,
316
+ "flex_flow": null,
317
+ "grid_area": null,
318
+ "grid_auto_columns": null,
319
+ "grid_auto_flow": null,
320
+ "grid_auto_rows": null,
321
+ "grid_column": null,
322
+ "grid_gap": null,
323
+ "grid_row": null,
324
+ "grid_template_areas": null,
325
+ "grid_template_columns": null,
326
+ "grid_template_rows": null,
327
+ "height": null,
328
+ "justify_content": null,
329
+ "justify_items": null,
330
+ "left": null,
331
+ "margin": null,
332
+ "max_height": null,
333
+ "max_width": null,
334
+ "min_height": null,
335
+ "min_width": null,
336
+ "object_fit": null,
337
+ "object_position": null,
338
+ "order": null,
339
+ "overflow": null,
340
+ "overflow_x": null,
341
+ "overflow_y": null,
342
+ "padding": null,
343
+ "right": null,
344
+ "top": null,
345
+ "visibility": null,
346
+ "width": null
347
+ }
348
+ },
349
+ "dfb0df17546545a4b74ed7f5f10c7a9a": {
350
+ "model_module": "@jupyter-widgets/controls",
351
+ "model_name": "DescriptionStyleModel",
352
+ "model_module_version": "1.5.0",
353
+ "state": {
354
+ "_model_module": "@jupyter-widgets/controls",
355
+ "_model_module_version": "1.5.0",
356
+ "_model_name": "DescriptionStyleModel",
357
+ "_view_count": null,
358
+ "_view_module": "@jupyter-widgets/base",
359
+ "_view_module_version": "1.2.0",
360
+ "_view_name": "StyleView",
361
+ "description_width": ""
362
+ }
363
+ }
364
+ }
365
+ }
366
+ },
367
+ "cells": [
368
+ {
369
+ "cell_type": "markdown",
370
+ "metadata": {
371
+ "id": "view-in-github",
372
+ "colab_type": "text"
373
+ },
374
+ "source": [
375
+ "<a href=\"https://colab.research.google.com/github/KevinWang676/Bark-Voice-Cloning/blob/main/Bark_Voice_Cloning_UI.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
376
+ ]
377
+ },
378
+ {
379
+ "cell_type": "code",
380
+ "execution_count": 4,
381
+ "metadata": {
382
+ "colab": {
383
+ "base_uri": "https://localhost:8080/"
384
+ },
385
+ "id": "n281rhWYnEbf",
386
+ "outputId": "cf8d7edf-63ff-4b7f-9cc0-97640e113c1b"
387
+ },
388
+ "outputs": [
389
+ {
390
+ "output_type": "stream",
391
+ "name": "stdout",
392
+ "text": [
393
+ "Cloning into 'Bark-Voice-Cloning'...\n",
394
+ "remote: Enumerating objects: 132, done.\u001b[K\n",
395
+ "remote: Counting objects: 100% (59/59), done.\u001b[K\n",
396
+ "remote: Compressing objects: 100% (59/59), done.\u001b[K\n",
397
+ "remote: Total 132 (delta 30), reused 0 (delta 0), pack-reused 73\u001b[K\n",
398
+ "Receiving objects: 100% (132/132), 225.44 KiB | 15.03 MiB/s, done.\n",
399
+ "Resolving deltas: 100% (38/38), done.\n"
400
+ ]
401
+ }
402
+ ],
403
+ "source": [
404
+ "!git clone https://github.com/KevinWang676/Bark-Voice-Cloning.git"
405
+ ]
406
+ },
407
+ {
408
+ "cell_type": "code",
409
+ "source": [
410
+ "cd Bark-Voice-Cloning/"
411
+ ],
412
+ "metadata": {
413
+ "colab": {
414
+ "base_uri": "https://localhost:8080/"
415
+ },
416
+ "id": "uyyMhQgBnJLG",
417
+ "outputId": "4a91aa03-787c-41e8-b59d-dbddc9eed3b8"
418
+ },
419
+ "execution_count": 10,
420
+ "outputs": [
421
+ {
422
+ "output_type": "stream",
423
+ "name": "stdout",
424
+ "text": [
425
+ "/content/Bark-Voice-Cloning\n"
426
+ ]
427
+ }
428
+ ]
429
+ },
430
+ {
431
+ "cell_type": "code",
432
+ "source": [
433
+ "pip install -r requirements.txt"
434
+ ],
435
+ "metadata": {
436
+ "colab": {
437
+ "base_uri": "https://localhost:8080/"
438
+ },
439
+ "id": "fm8b-BXPnPDb",
440
+ "outputId": "df4bfdec-d418-4edd-d41b-7cf4d40a6a2e"
441
+ },
442
+ "execution_count": 6,
443
+ "outputs": [
444
+ {
445
+ "output_type": "stream",
446
+ "name": "stdout",
447
+ "text": [
448
+ "Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/\n",
449
+ "Ignoring fairseq: markers 'platform_system == \"Windows\"' don't match your environment\n",
450
+ "Ignoring soundfile: markers 'platform_system == \"Windows\"' don't match your environment\n",
451
+ "Requirement already satisfied: fairseq in /usr/local/lib/python3.10/dist-packages (from -r requirements.txt (line 1)) (0.12.2)\n",
452
+ "Requirement already satisfied: audiolm-pytorch in /usr/local/lib/python3.10/dist-packages (from -r requirements.txt (line 3)) (1.1.4)\n",
453
+ "Requirement already satisfied: gradio in /usr/local/lib/python3.10/dist-packages (from -r requirements.txt (line 4)) (3.34.0)\n",
454
+ "Requirement already satisfied: funcy in /usr/local/lib/python3.10/dist-packages (from -r requirements.txt (line 5)) (2.0)\n",
455
+ "Requirement already satisfied: linkify in /usr/local/lib/python3.10/dist-packages (from -r requirements.txt (line 6)) (1.4)\n",
456
+ "Requirement already satisfied: mutagen in /usr/local/lib/python3.10/dist-packages (from -r requirements.txt (line 7)) (1.46.0)\n",
457
+ "Requirement already satisfied: pytorch_seed in /usr/local/lib/python3.10/dist-packages (from -r requirements.txt (line 8)) (0.2.0)\n",
458
+ "Requirement already satisfied: pyyaml in /usr/local/lib/python3.10/dist-packages (from -r requirements.txt (line 9)) (6.0)\n",
459
+ "Requirement already satisfied: sentencepiece in /usr/local/lib/python3.10/dist-packages (from -r requirements.txt (line 10)) (0.1.99)\n",
460
+ "Requirement already satisfied: sox in /usr/local/lib/python3.10/dist-packages (from -r requirements.txt (line 12)) (1.4.1)\n",
461
+ "Requirement already satisfied: transformers in /usr/local/lib/python3.10/dist-packages (from -r requirements.txt (line 13)) (4.30.1)\n",
462
+ "Requirement already satisfied: cffi in /usr/local/lib/python3.10/dist-packages (from fairseq->-r requirements.txt (line 1)) (1.15.1)\n",
463
+ "Requirement already satisfied: cython in /usr/local/lib/python3.10/dist-packages (from fairseq->-r requirements.txt (line 1)) (0.29.34)\n",
464
+ "Requirement already satisfied: hydra-core<1.1,>=1.0.7 in /usr/local/lib/python3.10/dist-packages (from fairseq->-r requirements.txt (line 1)) (1.0.7)\n",
465
+ "Requirement already satisfied: omegaconf<2.1 in /usr/local/lib/python3.10/dist-packages (from fairseq->-r requirements.txt (line 1)) (2.0.6)\n",
466
+ "Requirement already satisfied: regex in /usr/local/lib/python3.10/dist-packages (from fairseq->-r requirements.txt (line 1)) (2022.10.31)\n",
467
+ "Requirement already satisfied: sacrebleu>=1.4.12 in /usr/local/lib/python3.10/dist-packages (from fairseq->-r requirements.txt (line 1)) (2.3.1)\n",
468
+ "Requirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (from fairseq->-r requirements.txt (line 1)) (2.0.1+cu118)\n",
469
+ "Requirement already satisfied: tqdm in /usr/local/lib/python3.10/dist-packages (from fairseq->-r requirements.txt (line 1)) (4.65.0)\n",
470
+ "Requirement already satisfied: bitarray in /usr/local/lib/python3.10/dist-packages (from fairseq->-r requirements.txt (line 1)) (2.7.5)\n",
471
+ "Requirement already satisfied: torchaudio>=0.8.0 in /usr/local/lib/python3.10/dist-packages (from fairseq->-r requirements.txt (line 1)) (2.0.2+cu118)\n",
472
+ "Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from fairseq->-r requirements.txt (line 1)) (1.22.4)\n",
473
+ "Requirement already satisfied: accelerate in /usr/local/lib/python3.10/dist-packages (from audiolm-pytorch->-r requirements.txt (line 3)) (0.20.3)\n",
474
+ "Requirement already satisfied: beartype in /usr/local/lib/python3.10/dist-packages (from audiolm-pytorch->-r requirements.txt (line 3)) (0.14.1)\n",
475
+ "Requirement already satisfied: einops>=0.6.1 in /usr/local/lib/python3.10/dist-packages (from audiolm-pytorch->-r requirements.txt (line 3)) (0.6.1)\n",
476
+ "Requirement already satisfied: ema-pytorch>=0.2.2 in /usr/local/lib/python3.10/dist-packages (from audiolm-pytorch->-r requirements.txt (line 3)) (0.2.3)\n",
477
+ "Requirement already satisfied: encodec in /usr/local/lib/python3.10/dist-packages (from audiolm-pytorch->-r requirements.txt (line 3)) (0.1.1)\n",
478
+ "Requirement already satisfied: joblib in /usr/local/lib/python3.10/dist-packages (from audiolm-pytorch->-r requirements.txt (line 3)) (1.2.0)\n",
479
+ "Requirement already satisfied: lion-pytorch in /usr/local/lib/python3.10/dist-packages (from audiolm-pytorch->-r requirements.txt (line 3)) (0.1.2)\n",
480
+ "Requirement already satisfied: local-attention>=1.8.4 in /usr/local/lib/python3.10/dist-packages (from audiolm-pytorch->-r requirements.txt (line 3)) (1.8.6)\n",
481
+ "Requirement already satisfied: scikit-learn in /usr/local/lib/python3.10/dist-packages (from audiolm-pytorch->-r requirements.txt (line 3)) (1.2.2)\n",
482
+ "Requirement already satisfied: vector-quantize-pytorch>=1.5.14 in /usr/local/lib/python3.10/dist-packages (from audiolm-pytorch->-r requirements.txt (line 3)) (1.6.11)\n",
483
+ "Requirement already satisfied: aiofiles in /usr/local/lib/python3.10/dist-packages (from gradio->-r requirements.txt (line 4)) (23.1.0)\n",
484
+ "Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from gradio->-r requirements.txt (line 4)) (3.8.4)\n",
485
+ "Requirement already satisfied: altair>=4.2.0 in /usr/local/lib/python3.10/dist-packages (from gradio->-r requirements.txt (line 4)) (4.2.2)\n",
486
+ "Requirement already satisfied: fastapi in /usr/local/lib/python3.10/dist-packages (from gradio->-r requirements.txt (line 4)) (0.97.0)\n",
487
+ "Requirement already satisfied: ffmpy in /usr/local/lib/python3.10/dist-packages (from gradio->-r requirements.txt (line 4)) (0.3.0)\n",
488
+ "Requirement already satisfied: gradio-client>=0.2.6 in /usr/local/lib/python3.10/dist-packages (from gradio->-r requirements.txt (line 4)) (0.2.6)\n",
489
+ "Requirement already satisfied: httpx in /usr/local/lib/python3.10/dist-packages (from gradio->-r requirements.txt (line 4)) (0.24.1)\n",
490
+ "Requirement already satisfied: huggingface-hub>=0.14.0 in /usr/local/lib/python3.10/dist-packages (from gradio->-r requirements.txt (line 4)) (0.15.1)\n",
491
+ "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from gradio->-r requirements.txt (line 4)) (3.1.2)\n",
492
+ "Requirement already satisfied: markdown-it-py[linkify]>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from gradio->-r requirements.txt (line 4)) (2.2.0)\n",
493
+ "Requirement already satisfied: markupsafe in /usr/local/lib/python3.10/dist-packages (from gradio->-r requirements.txt (line 4)) (2.1.2)\n",
494
+ "Requirement already satisfied: matplotlib in /usr/local/lib/python3.10/dist-packages (from gradio->-r requirements.txt (line 4)) (3.7.1)\n",
495
+ "Requirement already satisfied: mdit-py-plugins<=0.3.3 in /usr/local/lib/python3.10/dist-packages (from gradio->-r requirements.txt (line 4)) (0.3.3)\n",
496
+ "Requirement already satisfied: orjson in /usr/local/lib/python3.10/dist-packages (from gradio->-r requirements.txt (line 4)) (3.9.1)\n",
497
+ "Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from gradio->-r requirements.txt (line 4)) (1.5.3)\n",
498
+ "Requirement already satisfied: pillow in /usr/local/lib/python3.10/dist-packages (from gradio->-r requirements.txt (line 4)) (8.4.0)\n",
499
+ "Requirement already satisfied: pydantic in /usr/local/lib/python3.10/dist-packages (from gradio->-r requirements.txt (line 4)) (1.10.7)\n",
500
+ "Requirement already satisfied: pydub in /usr/local/lib/python3.10/dist-packages (from gradio->-r requirements.txt (line 4)) (0.25.1)\n",
501
+ "Requirement already satisfied: pygments>=2.12.0 in /usr/local/lib/python3.10/dist-packages (from gradio->-r requirements.txt (line 4)) (2.14.0)\n",
502
+ "Requirement already satisfied: python-multipart in /usr/local/lib/python3.10/dist-packages (from gradio->-r requirements.txt (line 4)) (0.0.6)\n",
503
+ "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from gradio->-r requirements.txt (line 4)) (2.27.1)\n",
504
+ "Requirement already satisfied: semantic-version in /usr/local/lib/python3.10/dist-packages (from gradio->-r requirements.txt (line 4)) (2.10.0)\n",
505
+ "Requirement already satisfied: typing-extensions in /usr/local/lib/python3.10/dist-packages (from gradio->-r requirements.txt (line 4)) (4.5.0)\n",
506
+ "Requirement already satisfied: uvicorn>=0.14.0 in /usr/local/lib/python3.10/dist-packages (from gradio->-r requirements.txt (line 4)) (0.22.0)\n",
507
+ "Requirement already satisfied: websockets>=10.0 in /usr/local/lib/python3.10/dist-packages (from gradio->-r requirements.txt (line 4)) (11.0.3)\n",
508
+ "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers->-r requirements.txt (line 13)) (3.12.0)\n",
509
+ "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers->-r requirements.txt (line 13)) (23.1)\n",
510
+ "Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /usr/local/lib/python3.10/dist-packages (from transformers->-r requirements.txt (line 13)) (0.13.3)\n",
511
+ "Requirement already satisfied: safetensors>=0.3.1 in /usr/local/lib/python3.10/dist-packages (from transformers->-r requirements.txt (line 13)) (0.3.1)\n",
512
+ "Requirement already satisfied: entrypoints in /usr/local/lib/python3.10/dist-packages (from altair>=4.2.0->gradio->-r requirements.txt (line 4)) (0.4)\n",
513
+ "Requirement already satisfied: jsonschema>=3.0 in /usr/local/lib/python3.10/dist-packages (from altair>=4.2.0->gradio->-r requirements.txt (line 4)) (4.3.3)\n",
514
+ "Requirement already satisfied: toolz in /usr/local/lib/python3.10/dist-packages (from altair>=4.2.0->gradio->-r requirements.txt (line 4)) (0.12.0)\n",
515
+ "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from gradio-client>=0.2.6->gradio->-r requirements.txt (line 4)) (2023.4.0)\n",
516
+ "Requirement already satisfied: antlr4-python3-runtime==4.8 in /usr/local/lib/python3.10/dist-packages (from hydra-core<1.1,>=1.0.7->fairseq->-r requirements.txt (line 1)) (4.8)\n",
517
+ "Requirement already satisfied: mdurl~=0.1 in /usr/local/lib/python3.10/dist-packages (from markdown-it-py[linkify]>=2.0.0->gradio->-r requirements.txt (line 4)) (0.1.2)\n",
518
+ "Requirement already satisfied: linkify-it-py<3,>=1 in /usr/local/lib/python3.10/dist-packages (from markdown-it-py[linkify]>=2.0.0->gradio->-r requirements.txt (line 4)) (2.0.2)\n",
519
+ "Requirement already satisfied: python-dateutil>=2.8.1 in /usr/local/lib/python3.10/dist-packages (from pandas->gradio->-r requirements.txt (line 4)) (2.8.2)\n",
520
+ "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->gradio->-r requirements.txt (line 4)) (2022.7.1)\n",
521
+ "Requirement already satisfied: portalocker in /usr/local/lib/python3.10/dist-packages (from sacrebleu>=1.4.12->fairseq->-r requirements.txt (line 1)) (2.7.0)\n",
522
+ "Requirement already satisfied: tabulate>=0.8.9 in /usr/local/lib/python3.10/dist-packages (from sacrebleu>=1.4.12->fairseq->-r requirements.txt (line 1)) (0.8.10)\n",
523
+ "Requirement already satisfied: colorama in /usr/local/lib/python3.10/dist-packages (from sacrebleu>=1.4.12->fairseq->-r requirements.txt (line 1)) (0.4.6)\n",
524
+ "Requirement already satisfied: lxml in /usr/local/lib/python3.10/dist-packages (from sacrebleu>=1.4.12->fairseq->-r requirements.txt (line 1)) (4.9.2)\n",
525
+ "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch->fairseq->-r requirements.txt (line 1)) (1.11.1)\n",
526
+ "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch->fairseq->-r requirements.txt (line 1)) (3.1)\n",
527
+ "Requirement already satisfied: triton==2.0.0 in /usr/local/lib/python3.10/dist-packages (from torch->fairseq->-r requirements.txt (line 1)) (2.0.0)\n",
528
+ "Requirement already satisfied: cmake in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch->fairseq->-r requirements.txt (line 1)) (3.25.2)\n",
529
+ "Requirement already satisfied: lit in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch->fairseq->-r requirements.txt (line 1)) (16.0.5)\n",
530
+ "Requirement already satisfied: click>=7.0 in /usr/local/lib/python3.10/dist-packages (from uvicorn>=0.14.0->gradio->-r requirements.txt (line 4)) (8.1.3)\n",
531
+ "Requirement already satisfied: h11>=0.8 in /usr/local/lib/python3.10/dist-packages (from uvicorn>=0.14.0->gradio->-r requirements.txt (line 4)) (0.14.0)\n",
532
+ "Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from accelerate->audiolm-pytorch->-r requirements.txt (line 3)) (5.9.5)\n",
533
+ "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->gradio->-r requirements.txt (line 4)) (23.1.0)\n",
534
+ "Requirement already satisfied: charset-normalizer<4.0,>=2.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->gradio->-r requirements.txt (line 4)) (2.0.12)\n",
535
+ "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->gradio->-r requirements.txt (line 4)) (6.0.4)\n",
536
+ "Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /usr/local/lib/python3.10/dist-packages (from aiohttp->gradio->-r requirements.txt (line 4)) (4.0.2)\n",
537
+ "Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->gradio->-r requirements.txt (line 4)) (1.9.2)\n",
538
+ "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->gradio->-r requirements.txt (line 4)) (1.3.3)\n",
539
+ "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->gradio->-r requirements.txt (line 4)) (1.3.1)\n",
540
+ "Requirement already satisfied: pycparser in /usr/local/lib/python3.10/dist-packages (from cffi->fairseq->-r requirements.txt (line 1)) (2.21)\n",
541
+ "Requirement already satisfied: starlette<0.28.0,>=0.27.0 in /usr/local/lib/python3.10/dist-packages (from fastapi->gradio->-r requirements.txt (line 4)) (0.27.0)\n",
542
+ "Requirement already satisfied: certifi in /usr/local/lib/python3.10/dist-packages (from httpx->gradio->-r requirements.txt (line 4)) (2022.12.7)\n",
543
+ "Requirement already satisfied: httpcore<0.18.0,>=0.15.0 in /usr/local/lib/python3.10/dist-packages (from httpx->gradio->-r requirements.txt (line 4)) (0.17.2)\n",
544
+ "Requirement already satisfied: idna in /usr/local/lib/python3.10/dist-packages (from httpx->gradio->-r requirements.txt (line 4)) (3.4)\n",
545
+ "Requirement already satisfied: sniffio in /usr/local/lib/python3.10/dist-packages (from httpx->gradio->-r requirements.txt (line 4)) (1.3.0)\n",
546
+ "Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib->gradio->-r requirements.txt (line 4)) (1.0.7)\n",
547
+ "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.10/dist-packages (from matplotlib->gradio->-r requirements.txt (line 4)) (0.11.0)\n",
548
+ "Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.10/dist-packages (from matplotlib->gradio->-r requirements.txt (line 4)) (4.39.3)\n",
549
+ "Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib->gradio->-r requirements.txt (line 4)) (1.4.4)\n",
550
+ "Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.10/dist-packages (from matplotlib->gradio->-r requirements.txt (line 4)) (3.0.9)\n",
551
+ "Requirement already satisfied: urllib3<1.27,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->gradio->-r requirements.txt (line 4)) (1.26.15)\n",
552
+ "Requirement already satisfied: scipy>=1.3.2 in /usr/local/lib/python3.10/dist-packages (from scikit-learn->audiolm-pytorch->-r requirements.txt (line 3)) (1.10.1)\n",
553
+ "Requirement already satisfied: threadpoolctl>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from scikit-learn->audiolm-pytorch->-r requirements.txt (line 3)) (3.1.0)\n",
554
+ "Requirement already satisfied: anyio<5.0,>=3.0 in /usr/local/lib/python3.10/dist-packages (from httpcore<0.18.0,>=0.15.0->httpx->gradio->-r requirements.txt (line 4)) (3.6.2)\n",
555
+ "Requirement already satisfied: pyrsistent!=0.17.0,!=0.17.1,!=0.17.2,>=0.14.0 in /usr/local/lib/python3.10/dist-packages (from jsonschema>=3.0->altair>=4.2.0->gradio->-r requirements.txt (line 4)) (0.19.3)\n",
556
+ "Requirement already satisfied: uc-micro-py in /usr/local/lib/python3.10/dist-packages (from linkify-it-py<3,>=1->markdown-it-py[linkify]>=2.0.0->gradio->-r requirements.txt (line 4)) (1.0.2)\n",
557
+ "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.1->pandas->gradio->-r requirements.txt (line 4)) (1.16.0)\n",
558
+ "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch->fairseq->-r requirements.txt (line 1)) (1.3.0)\n"
559
+ ]
560
+ }
561
+ ]
562
+ },
563
+ {
564
+ "cell_type": "code",
565
+ "source": [
566
+ "from cProfile import label\n",
567
+ "import dataclasses\n",
568
+ "from distutils.command.check import check\n",
569
+ "from doctest import Example\n",
570
+ "import gradio as gr\n",
571
+ "import os\n",
572
+ "import sys\n",
573
+ "import numpy as np\n",
574
+ "import logging\n",
575
+ "import torch\n",
576
+ "import pytorch_seed\n",
577
+ "import time\n",
578
+ "\n",
579
+ "from xml.sax import saxutils\n",
580
+ "from bark.api import generate_with_settings\n",
581
+ "from bark.api import save_as_prompt\n",
582
+ "from util.settings import Settings\n",
583
+ "#import nltk\n",
584
+ "\n",
585
+ "from bark import SAMPLE_RATE\n",
586
+ "from cloning.clonevoice import clone_voice\n",
587
+ "from bark.generation import SAMPLE_RATE, preload_models, _load_history_prompt, codec_decode\n",
588
+ "from scipy.io.wavfile import write as write_wav\n",
589
+ "from util.parseinput import split_and_recombine_text, build_ssml, is_ssml, create_clips_from_ssml\n",
590
+ "from datetime import datetime\n",
591
+ "from tqdm.auto import tqdm\n",
592
+ "from util.helper import create_filename, add_id3_tag\n",
593
+ "from swap_voice import swap_voice_from_audio\n",
594
+ "from training.training_prepare import prepare_semantics_from_text, prepare_wavs_from_semantics\n",
595
+ "from training.train import training_prepare_files, train\n",
596
+ "\n",
597
+ "settings = Settings('config.yaml')\n",
598
+ "\n",
599
+ "\n",
600
+ "def generate_text_to_speech(text, selected_speaker, text_temp, waveform_temp, eos_prob, quick_generation, complete_settings, seed, batchcount, progress=gr.Progress(track_tqdm=True)):\n",
601
+ " # Chunk the text into smaller pieces then combine the generated audio\n",
602
+ "\n",
603
+ " # generation settings\n",
604
+ " if selected_speaker == 'None':\n",
605
+ " selected_speaker = None\n",
606
+ "\n",
607
+ " voice_name = selected_speaker\n",
608
+ "\n",
609
+ " if text == None or len(text) < 1:\n",
610
+ " if selected_speaker == None:\n",
611
+ " raise gr.Error('No text entered!')\n",
612
+ "\n",
613
+ " # Extract audio data from speaker if no text and speaker selected\n",
614
+ " voicedata = _load_history_prompt(voice_name)\n",
615
+ " audio_arr = codec_decode(voicedata[\"fine_prompt\"])\n",
616
+ " result = create_filename(settings.output_folder_path, \"None\", \"extract\",\".wav\")\n",
617
+ " save_wav(audio_arr, result)\n",
618
+ " return result\n",
619
+ "\n",
620
+ " if batchcount < 1:\n",
621
+ " batchcount = 1\n",
622
+ "\n",
623
+ "\n",
624
+ " silenceshort = np.zeros(int((float(settings.silence_sentence) / 1000.0) * SAMPLE_RATE), dtype=np.int16) # quarter second of silence\n",
625
+ " silencelong = np.zeros(int((float(settings.silence_speakers) / 1000.0) * SAMPLE_RATE), dtype=np.float32) # half a second of silence\n",
626
+ " use_last_generation_as_history = \"Use last generation as history\" in complete_settings\n",
627
+ " save_last_generation = \"Save generation as Voice\" in complete_settings\n",
628
+ " for l in range(batchcount):\n",
629
+ " currentseed = seed\n",
630
+ " if seed != None and seed > 2**32 - 1:\n",
631
+ " logger.warning(f\"Seed {seed} > 2**32 - 1 (max), setting to random\")\n",
632
+ " currentseed = None\n",
633
+ " if currentseed == None or currentseed <= 0:\n",
634
+ " currentseed = np.random.default_rng().integers(1, 2**32 - 1)\n",
635
+ " assert(0 < currentseed and currentseed < 2**32)\n",
636
+ "\n",
637
+ " progress(0, desc=\"Generating\")\n",
638
+ "\n",
639
+ " full_generation = None\n",
640
+ "\n",
641
+ " all_parts = []\n",
642
+ " complete_text = \"\"\n",
643
+ " text = text.lstrip()\n",
644
+ " if is_ssml(text):\n",
645
+ " list_speak = create_clips_from_ssml(text)\n",
646
+ " prev_speaker = None\n",
647
+ " for i, clip in tqdm(enumerate(list_speak), total=len(list_speak)):\n",
648
+ " selected_speaker = clip[0]\n",
649
+ " # Add pause break between speakers\n",
650
+ " if i > 0 and selected_speaker != prev_speaker:\n",
651
+ " all_parts += [silencelong.copy()]\n",
652
+ " prev_speaker = selected_speaker\n",
653
+ " text = clip[1]\n",
654
+ " text = saxutils.unescape(text)\n",
655
+ " if selected_speaker == \"None\":\n",
656
+ " selected_speaker = None\n",
657
+ "\n",
658
+ " print(f\"\\nGenerating Text ({i+1}/{len(list_speak)}) -> {selected_speaker} (Seed {currentseed}):`{text}`\")\n",
659
+ " complete_text += text\n",
660
+ " with pytorch_seed.SavedRNG(currentseed):\n",
661
+ " audio_array = generate_with_settings(text_prompt=text, voice_name=selected_speaker, semantic_temp=text_temp, coarse_temp=waveform_temp, eos_p=eos_prob)\n",
662
+ " currentseed = torch.random.initial_seed()\n",
663
+ " if len(list_speak) > 1:\n",
664
+ " filename = create_filename(settings.output_folder_path, currentseed, \"audioclip\",\".wav\")\n",
665
+ " save_wav(audio_array, filename)\n",
666
+ " add_id3_tag(filename, text, selected_speaker, currentseed)\n",
667
+ "\n",
668
+ " all_parts += [audio_array]\n",
669
+ " else:\n",
670
+ " texts = split_and_recombine_text(text, settings.input_text_desired_length, settings.input_text_max_length)\n",
671
+ " for i, text in tqdm(enumerate(texts), total=len(texts)):\n",
672
+ " print(f\"\\nGenerating Text ({i+1}/{len(texts)}) -> {selected_speaker} (Seed {currentseed}):`{text}`\")\n",
673
+ " complete_text += text\n",
674
+ " if quick_generation == True:\n",
675
+ " with pytorch_seed.SavedRNG(currentseed):\n",
676
+ " audio_array = generate_with_settings(text_prompt=text, voice_name=selected_speaker, semantic_temp=text_temp, coarse_temp=waveform_temp, eos_p=eos_prob)\n",
677
+ " currentseed = torch.random.initial_seed()\n",
678
+ " else:\n",
679
+ " full_output = use_last_generation_as_history or save_last_generation\n",
680
+ " if full_output:\n",
681
+ " full_generation, audio_array = generate_with_settings(text_prompt=text, voice_name=voice_name, semantic_temp=text_temp, coarse_temp=waveform_temp, eos_p=eos_prob, output_full=True)\n",
682
+ " else:\n",
683
+ " audio_array = generate_with_settings(text_prompt=text, voice_name=voice_name, semantic_temp=text_temp, coarse_temp=waveform_temp, eos_p=eos_prob)\n",
684
+ "\n",
685
+ " # Noticed this in the HF Demo - convert to 16bit int -32767/32767 - most used audio format \n",
686
+ " # audio_array = (audio_array * 32767).astype(np.int16)\n",
687
+ "\n",
688
+ " if len(texts) > 1:\n",
689
+ " filename = create_filename(settings.output_folder_path, currentseed, \"audioclip\",\".wav\")\n",
690
+ " save_wav(audio_array, filename)\n",
691
+ " add_id3_tag(filename, text, selected_speaker, currentseed)\n",
692
+ "\n",
693
+ " if quick_generation == False and (save_last_generation == True or use_last_generation_as_history == True):\n",
694
+ " # save to npz\n",
695
+ " voice_name = create_filename(settings.output_folder_path, seed, \"audioclip\", \".npz\")\n",
696
+ " save_as_prompt(voice_name, full_generation)\n",
697
+ " if use_last_generation_as_history:\n",
698
+ " selected_speaker = voice_name\n",
699
+ "\n",
700
+ " all_parts += [audio_array]\n",
701
+ " # Add short pause between sentences\n",
702
+ " if text[-1] in \"!?.\\n\" and i > 1:\n",
703
+ " all_parts += [silenceshort.copy()]\n",
704
+ "\n",
705
+ " # save & play audio\n",
706
+ " result = create_filename(settings.output_folder_path, currentseed, \"final\",\".wav\")\n",
707
+ " save_wav(np.concatenate(all_parts), result)\n",
708
+ " # write id3 tag with text truncated to 60 chars, as a precaution...\n",
709
+ " add_id3_tag(result, complete_text, selected_speaker, currentseed)\n",
710
+ "\n",
711
+ " return result\n",
712
+ "\n",
713
+ "\n",
714
+ "\n",
715
+ "def save_wav(audio_array, filename):\n",
716
+ " write_wav(filename, SAMPLE_RATE, audio_array)\n",
717
+ "\n",
718
+ "def save_voice(filename, semantic_prompt, coarse_prompt, fine_prompt):\n",
719
+ " np.savez_compressed(\n",
720
+ " filename,\n",
721
+ " semantic_prompt=semantic_prompt,\n",
722
+ " coarse_prompt=coarse_prompt,\n",
723
+ " fine_prompt=fine_prompt\n",
724
+ " )\n",
725
+ " \n",
726
+ "\n",
727
+ "def on_quick_gen_changed(checkbox):\n",
728
+ " if checkbox == False:\n",
729
+ " return gr.CheckboxGroup.update(visible=True)\n",
730
+ " return gr.CheckboxGroup.update(visible=False)\n",
731
+ "\n",
732
+ "def delete_output_files(checkbox_state):\n",
733
+ " if checkbox_state:\n",
734
+ " outputs_folder = os.path.join(os.getcwd(), settings.output_folder_path)\n",
735
+ " if os.path.exists(outputs_folder):\n",
736
+ " purgedir(outputs_folder)\n",
737
+ " return False\n",
738
+ "\n",
739
+ "\n",
740
+ "# https://stackoverflow.com/a/54494779\n",
741
+ "def purgedir(parent):\n",
742
+ " for root, dirs, files in os.walk(parent): \n",
743
+ " for item in files:\n",
744
+ " # Delete subordinate files \n",
745
+ " filespec = os.path.join(root, item)\n",
746
+ " os.unlink(filespec)\n",
747
+ " for item in dirs:\n",
748
+ " # Recursively perform this operation for subordinate directories \n",
749
+ " purgedir(os.path.join(root, item))\n",
750
+ "\n",
751
+ "def convert_text_to_ssml(text, selected_speaker):\n",
752
+ " return build_ssml(text, selected_speaker)\n",
753
+ "\n",
754
+ "\n",
755
+ "def training_prepare(selected_step, num_text_generations, progress=gr.Progress(track_tqdm=True)):\n",
756
+ " if selected_step == prepare_training_list[0]:\n",
757
+ " prepare_semantics_from_text()\n",
758
+ " else:\n",
759
+ " prepare_wavs_from_semantics()\n",
760
+ " return None\n",
761
+ "\n",
762
+ "\n",
763
+ "def start_training(save_model_epoch, max_epochs, progress=gr.Progress(track_tqdm=True)):\n",
764
+ " training_prepare_files(\"./training/data/\", \"./training/data/checkpoint/hubert_base_ls960.pt\")\n",
765
+ " train(\"./training/data/\", save_model_epoch, max_epochs)\n",
766
+ " return None\n",
767
+ "\n",
768
+ "\n",
769
+ "\n",
770
+ "def apply_settings(themes, input_server_name, input_server_port, input_server_public, input_desired_len, input_max_len, input_silence_break, input_silence_speaker):\n",
771
+ " settings.selected_theme = themes\n",
772
+ " settings.server_name = input_server_name\n",
773
+ " settings.server_port = input_server_port\n",
774
+ " settings.server_share = input_server_public\n",
775
+ " settings.input_text_desired_length = input_desired_len\n",
776
+ " settings.input_text_max_length = input_max_len\n",
777
+ " settings.silence_sentence = input_silence_break\n",
778
+ " settings.silence_speaker = input_silence_speaker\n",
779
+ " settings.save()\n",
780
+ "\n",
781
+ "def restart():\n",
782
+ " global restart_server\n",
783
+ " restart_server = True\n",
784
+ "\n",
785
+ "\n",
786
+ "def create_version_html():\n",
787
+ " python_version = \".\".join([str(x) for x in sys.version_info[0:3]])\n",
788
+ " versions_html = f\"\"\"\n",
789
+ "python: <span title=\"{sys.version}\">{python_version}</span>\n",
790
+ " • \n",
791
+ "torch: {getattr(torch, '__long_version__',torch.__version__)}\n",
792
+ " • \n",
793
+ "gradio: {gr.__version__}\n",
794
+ "\"\"\"\n",
795
+ " return versions_html\n",
796
+ "\n",
797
+ " \n",
798
+ "\n",
799
+ "logger = logging.getLogger(__name__)\n",
800
+ "APPTITLE = \"Bark Voice Cloning UI\"\n",
801
+ "\n",
802
+ "\n",
803
+ "autolaunch = False\n",
804
+ "\n",
805
+ "if len(sys.argv) > 1:\n",
806
+ " autolaunch = \"-autolaunch\" in sys.argv\n",
807
+ "\n",
808
+ "\n",
809
+ "if torch.cuda.is_available() == False:\n",
810
+ " os.environ['BARK_FORCE_CPU'] = 'True'\n",
811
+ " logger.warning(\"No CUDA detected, fallback to CPU!\")\n",
812
+ "\n",
813
+ "print(f'smallmodels={os.environ.get(\"SUNO_USE_SMALL_MODELS\", False)}')\n",
814
+ "print(f'enablemps={os.environ.get(\"SUNO_ENABLE_MPS\", False)}')\n",
815
+ "print(f'offloadcpu={os.environ.get(\"SUNO_OFFLOAD_CPU\", False)}')\n",
816
+ "print(f'forcecpu={os.environ.get(\"BARK_FORCE_CPU\", False)}')\n",
817
+ "print(f'autolaunch={autolaunch}\\n\\n')\n",
818
+ "\n",
819
+ "#print(\"Updating nltk\\n\")\n",
820
+ "#nltk.download('punkt')\n",
821
+ "\n",
822
+ "print(\"Preloading Models\\n\")\n",
823
+ "preload_models()\n",
824
+ "\n",
825
+ "available_themes = [\"Default\", \"gradio/glass\", \"gradio/monochrome\", \"gradio/seafoam\", \"gradio/soft\", \"gstaff/xkcd\", \"freddyaboulton/dracula_revamped\", \"ysharma/steampunk\"]\n",
826
+ "tokenizer_language_list = [\"de\",\"en\", \"pl\"]\n",
827
+ "prepare_training_list = [\"Step 1: Semantics from Text\",\"Step 2: WAV from Semantics\"]\n",
828
+ "\n",
829
+ "seed = -1\n",
830
+ "server_name = settings.server_name\n",
831
+ "if len(server_name) < 1:\n",
832
+ " server_name = None\n",
833
+ "server_port = settings.server_port\n",
834
+ "if server_port <= 0:\n",
835
+ " server_port = None\n",
836
+ "global run_server\n",
837
+ "global restart_server\n",
838
+ "\n",
839
+ "run_server = True\n",
840
+ "\n",
841
+ "while run_server:\n",
842
+ " # Collect all existing speakers/voices in dir\n",
843
+ " speakers_list = []\n",
844
+ "\n",
845
+ " for root, dirs, files in os.walk(\"./bark/assets/prompts\"):\n",
846
+ " for file in files:\n",
847
+ " if file.endswith(\".npz\"):\n",
848
+ " pathpart = root.replace(\"./bark/assets/prompts\", \"\")\n",
849
+ " name = os.path.join(pathpart, file[:-4])\n",
850
+ " if name.startswith(\"/\") or name.startswith(\"\\\\\"):\n",
851
+ " name = name[1:]\n",
852
+ " speakers_list.append(name)\n",
853
+ "\n",
854
+ " speakers_list = sorted(speakers_list, key=lambda x: x.lower())\n",
855
+ " speakers_list.insert(0, 'None')\n",
856
+ "\n",
857
+ " print(f'Launching {APPTITLE} Server')\n",
858
+ "\n",
859
+ " # Create Gradio Blocks\n",
860
+ "\n",
861
+ " with gr.Blocks(title=f\"{APPTITLE}\", mode=f\"{APPTITLE}\", theme=settings.selected_theme) as barkgui:\n",
862
+ " with gr.Row():\n",
863
+ " with gr.Column():\n",
864
+ " gr.Markdown(f\"### [{APPTITLE}](https://github.com/KevinWang676/Bark-Voice-Cloning)\")\n",
865
+ " with gr.Column():\n",
866
+ " gr.HTML(create_version_html(), elem_id=\"versions\")\n",
867
+ "\n",
868
+ " with gr.Tab(\"Clone Voice\"):\n",
869
+ " with gr.Row():\n",
870
+ " input_audio_filename = gr.Audio(label=\"Input audio.wav\", source=\"upload\", type=\"filepath\")\n",
871
+ " #transcription_text = gr.Textbox(label=\"Transcription Text\", lines=1, placeholder=\"Enter Text of your Audio Sample here...\")\n",
872
+ " with gr.Row():\n",
873
+ " with gr.Column():\n",
874
+ " initialname = \"/content/Bark-Voice-Cloning/bark/assets/prompts/file\"\n",
875
+ " output_voice = gr.Textbox(label=\"Filename of trained Voice (do not change the initial name)\", lines=1, placeholder=initialname, value=initialname)\n",
876
+ " with gr.Column():\n",
877
+ " tokenizerlang = gr.Dropdown(tokenizer_language_list, label=\"Base Language Tokenizer\", value=tokenizer_language_list[1])\n",
878
+ " with gr.Row():\n",
879
+ " clone_voice_button = gr.Button(\"Create Voice\")\n",
880
+ " with gr.Row():\n",
881
+ " dummy = gr.Text(label=\"Progress\")\n",
882
+ " npz_file = gr.File(label=\".npz file\")\n",
883
+ " speakers_list.insert(0, npz_file) # add prompt\n",
884
+ "\n",
885
+ " with gr.Tab(\"TTS\"):\n",
886
+ " with gr.Row():\n",
887
+ " with gr.Column():\n",
888
+ " placeholder = \"Enter text here.\"\n",
889
+ " input_text = gr.Textbox(label=\"Input Text\", lines=4, placeholder=placeholder)\n",
890
+ " with gr.Column():\n",
891
+ " seedcomponent = gr.Number(label=\"Seed (default -1 = Random)\", precision=0, value=-1)\n",
892
+ " batchcount = gr.Number(label=\"Batch count\", precision=0, value=1)\n",
893
+ " with gr.Row():\n",
894
+ " with gr.Column():\n",
895
+ " examples = [\n",
896
+ " \"Special meanings: [laughter] [laughs] [sighs] [music] [gasps] [clears throat] MAN: WOMAN:\",\n",
897
+ " \"♪ Never gonna make you cry, never gonna say goodbye, never gonna tell a lie and hurt you ♪\",\n",
898
+ " \"And now — a picture of a larch [laughter]\",\n",
899
+ " \"\"\"\n",
900
+ " WOMAN: I would like an oatmilk latte please.\n",
901
+ " MAN: Wow, that's expensive!\n",
902
+ " \"\"\",\n",
903
+ " \"\"\"<?xml version=\"1.0\"?>\n",
904
+ " <speak version=\"1.0\" xmlns=\"http://www.w3.org/2001/10/synthesis\"\n",
905
+ " xmlns:xsi=\"http://www.w3.org/2001/XMLSchema-instance\"\n",
906
+ " xsi:schemaLocation=\"http://www.w3.org/2001/10/synthesis\n",
907
+ " http://www.w3.org/TR/speech-synthesis/synthesis.xsd\"\n",
908
+ " xml:lang=\"en-US\">\n",
909
+ " <voice name=\"/v2/en_speaker_9\">Look at that drunk guy!</voice>\n",
910
+ " <voice name=\"/v2/en_speaker_3\">Who is he?</voice>\n",
911
+ " <voice name=\"/v2/en_speaker_9\">WOMAN: [clears throat] 10 years ago, he proposed me and I rejected him.</voice>\n",
912
+ " <voice name=\"/v2/en_speaker_3\">Oh my God [laughs] he is still celebrating</voice>\n",
913
+ " </speak>\"\"\"\n",
914
+ " ]\n",
915
+ " examples = gr.Examples(examples=examples, inputs=input_text)\n",
916
+ " with gr.Column():\n",
917
+ " convert_to_ssml_button = gr.Button(\"Convert Input Text to SSML\")\n",
918
+ "\n",
919
+ " with gr.Row():\n",
920
+ " with gr.Column():\n",
921
+ " gr.Markdown(\"[Voice Prompt Library](https://suno-ai.notion.site/8b8e8749ed514b0cbf3f699013548683?v=bc67cff786b04b50b3ceb756fd05f68c)\")\n",
922
+ " speaker = gr.Dropdown(speakers_list, value=speakers_list[0], label=\"Voice\")\n",
923
+ " \n",
924
+ " with gr.Column():\n",
925
+ " text_temp = gr.Slider(0.1, 1.0, value=0.6, label=\"Generation Temperature\", info=\"1.0 more diverse, 0.1 more conservative\")\n",
926
+ " waveform_temp = gr.Slider(0.1, 1.0, value=0.7, label=\"Waveform temperature\", info=\"1.0 more diverse, 0.1 more conservative\")\n",
927
+ "\n",
928
+ " with gr.Row():\n",
929
+ " with gr.Column():\n",
930
+ " quick_gen_checkbox = gr.Checkbox(label=\"Quick Generation\", value=True)\n",
931
+ " settings_checkboxes = [\"Use last generation as history\", \"Save generation as Voice\"]\n",
932
+ " complete_settings = gr.CheckboxGroup(choices=settings_checkboxes, value=settings_checkboxes, label=\"Detailed Generation Settings\", type=\"value\", interactive=True, visible=False)\n",
933
+ " with gr.Column():\n",
934
+ " eos_prob = gr.Slider(0.0, 0.5, value=0.05, label=\"End of sentence probability\")\n",
935
+ "\n",
936
+ " with gr.Row():\n",
937
+ " with gr.Column():\n",
938
+ " tts_create_button = gr.Button(\"Generate\")\n",
939
+ " with gr.Column():\n",
940
+ " hidden_checkbox = gr.Checkbox(visible=False)\n",
941
+ " button_stop_generation = gr.Button(\"Stop generation\")\n",
942
+ " with gr.Row():\n",
943
+ " output_audio = gr.Audio(label=\"Generated Audio\", type=\"filepath\")\n",
944
+ "\n",
945
+ " with gr.Tab(\"Swap Voice\"):\n",
946
+ " with gr.Row():\n",
947
+ " swap_audio_filename = gr.Audio(label=\"Input audio.wav to swap voice\", source=\"upload\", type=\"filepath\")\n",
948
+ " with gr.Row():\n",
949
+ " with gr.Column():\n",
950
+ " swap_tokenizer_lang = gr.Dropdown(tokenizer_language_list, label=\"Base Language Tokenizer\", value=tokenizer_language_list[1])\n",
951
+ " swap_seed = gr.Number(label=\"Seed (default -1 = Random)\", precision=0, value=-1)\n",
952
+ " with gr.Column():\n",
953
+ " speaker_swap = gr.Dropdown(speakers_list, value=speakers_list[0], label=\"Voice\")\n",
954
+ " swap_batchcount = gr.Number(label=\"Batch count\", precision=0, value=1)\n",
955
+ " with gr.Row():\n",
956
+ " swap_voice_button = gr.Button(\"Swap Voice\")\n",
957
+ " with gr.Row():\n",
958
+ " output_swap = gr.Audio(label=\"Generated Audio\", type=\"filepath\")\n",
959
+ "\n",
960
+ " \n",
961
+ " quick_gen_checkbox.change(fn=on_quick_gen_changed, inputs=quick_gen_checkbox, outputs=complete_settings)\n",
962
+ " convert_to_ssml_button.click(convert_text_to_ssml, inputs=[input_text, speaker],outputs=input_text)\n",
963
+ " gen_click = tts_create_button.click(generate_text_to_speech, inputs=[input_text, speaker, text_temp, waveform_temp, eos_prob, quick_gen_checkbox, complete_settings, seedcomponent, batchcount],outputs=output_audio)\n",
964
+ " button_stop_generation.click(fn=None, inputs=None, outputs=None, cancels=[gen_click])\n",
965
+ " \n",
966
+ "\n",
967
+ "\n",
968
+ " swap_voice_button.click(swap_voice_from_audio, inputs=[swap_audio_filename, speaker_swap, swap_tokenizer_lang, swap_seed, swap_batchcount], outputs=output_swap)\n",
969
+ " clone_voice_button.click(clone_voice, inputs=[input_audio_filename, output_voice], outputs=[dummy, npz_file])\n",
970
+ "\n",
971
+ "\n",
972
+ " restart_server = False\n",
973
+ " try:\n",
974
+ " barkgui.queue().launch(show_error=True)\n",
975
+ " except:\n",
976
+ " restart_server = True\n",
977
+ " run_server = False\n",
978
+ " try:\n",
979
+ " while restart_server == False:\n",
980
+ " time.sleep(1.0)\n",
981
+ " except (KeyboardInterrupt, OSError):\n",
982
+ " print(\"Keyboard interruption in main thread... closing server.\")\n",
983
+ " run_server = False\n",
984
+ " barkgui.close()"
985
+ ],
986
+ "metadata": {
987
+ "colab": {
988
+ "base_uri": "https://localhost:8080/",
989
+ "height": 981,
990
+ "referenced_widgets": [
991
+ "425505387f374468870cc4bcb52ea6c5",
992
+ "9b039beb3d7c4bc59ab95bd5d8a7dfcc",
993
+ "55bf104e557340e5a88962134a765f1b",
994
+ "f0768ce2c3484c4583810f461a0b742e",
995
+ "ff7dff340d9f41c29313ca68034be359",
996
+ "77a54e634f0d44c080eb769a4d2921b0",
997
+ "5584c9aaa4e04734bb6833cf7cf76534",
998
+ "7a2e70b96a054cdd89f73edd2474e20c",
999
+ "1120230111694b4d8e63d476b0a35454",
1000
+ "643343218af349aaa63afbcd3cbc8009",
1001
+ "dfb0df17546545a4b74ed7f5f10c7a9a"
1002
+ ]
1003
+ },
1004
+ "id": "jDsXfOlEnTO-",
1005
+ "outputId": "debf279f-7788-411f-ee2b-4a0522e20122"
1006
+ },
1007
+ "execution_count": null,
1008
+ "outputs": [
1009
+ {
1010
+ "output_type": "stream",
1011
+ "name": "stdout",
1012
+ "text": [
1013
+ "smallmodels=False\n",
1014
+ "enablemps=False\n",
1015
+ "offloadcpu=False\n",
1016
+ "forcecpu=False\n",
1017
+ "autolaunch=False\n",
1018
+ "\n",
1019
+ "\n",
1020
+ "Preloading Models\n",
1021
+ "\n",
1022
+ "Launching Bark Voice Cloning UI Server\n",
1023
+ "Setting queue=True in a Colab notebook requires sharing enabled. Setting `share=True` (you can turn this off by setting `share=False` in `launch()` explicitly).\n",
1024
+ "\n",
1025
+ "Colab notebook detected. To show errors in colab notebook, set debug=True in launch()\n",
1026
+ "Running on public URL: https://5fbed86c1148a1f8e5.gradio.live\n",
1027
+ "\n",
1028
+ "This share link expires in 72 hours. For free permanent hosting and GPU upgrades, run `gradio deploy` from Terminal to deploy to Spaces (https://huggingface.co/spaces)\n"
1029
+ ]
1030
+ },
1031
+ {
1032
+ "output_type": "display_data",
1033
+ "data": {
1034
+ "text/plain": [
1035
+ "<IPython.core.display.HTML object>"
1036
+ ],
1037
+ "text/html": [
1038
+ "<div><iframe src=\"https://5fbed86c1148a1f8e5.gradio.live\" width=\"100%\" height=\"500\" allow=\"autoplay; camera; microphone; clipboard-read; clipboard-write;\" frameborder=\"0\" allowfullscreen></iframe></div>"
1039
+ ]
1040
+ },
1041
+ "metadata": {}
1042
+ },
1043
+ {
1044
+ "output_type": "stream",
1045
+ "name": "stdout",
1046
+ "text": [
1047
+ "Downloading HuBERT base model from https://dl.fbaipublicfiles.com/hubert/hubert_base_ls960.pt\n",
1048
+ "Downloaded HuBERT\n",
1049
+ "en_tokenizer.pth not found. Downloading HuBERT custom tokenizer\n"
1050
+ ]
1051
+ },
1052
+ {
1053
+ "output_type": "display_data",
1054
+ "data": {
1055
+ "text/plain": [
1056
+ "Downloading (…)rt_base_ls960_14.pth: 0%| | 0.00/104M [00:00<?, ?B/s]"
1057
+ ],
1058
+ "application/vnd.jupyter.widget-view+json": {
1059
+ "version_major": 2,
1060
+ "version_minor": 0,
1061
+ "model_id": "425505387f374468870cc4bcb52ea6c5"
1062
+ }
1063
+ },
1064
+ "metadata": {}
1065
+ },
1066
+ {
1067
+ "output_type": "stream",
1068
+ "name": "stdout",
1069
+ "text": [
1070
+ "Downloaded tokenizer\n",
1071
+ "Loading Hubert ./models/hubert/hubert.pt\n",
1072
+ "\n",
1073
+ "Generating Text (1/1) -> file (Seed 529525761):`Authors are required to disclose financial or non-financial interests that are directly or indirectly related to`\n"
1074
+ ]
1075
+ }
1076
+ ]
1077
+ },
1078
+ {
1079
+ "cell_type": "code",
1080
+ "source": [],
1081
+ "metadata": {
1082
+ "id": "Mt5lkF1gnX54"
1083
+ },
1084
+ "execution_count": null,
1085
+ "outputs": []
1086
+ }
1087
+ ]
1088
+ }
Dockerfile ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM debian:stable
2
+
3
+ # Install system packages
4
+ RUN apt update && apt install -y git pip
5
+
6
+ # Create non-root user
7
+ RUN useradd -m -d /bark bark
8
+
9
+ # Run as new user
10
+ USER bark
11
+ WORKDIR /bark
12
+
13
+ # Clone git repo
14
+ RUN git clone https://github.com/C0untFloyd/bark-gui
15
+
16
+ # Switch to git directory
17
+ WORKDIR /bark/bark-gui
18
+
19
+ # Append pip bin path to PATH
20
+ ENV PATH=$PATH:/bark/.local/bin
21
+
22
+ # Install dependancies
23
+ RUN pip install .
24
+ RUN pip install -r requirements.txt
25
+
26
+ # List on all addresses, since we are in a container.
27
+ RUN sed -i "s/server_name: ''/server_name: 0.0.0.0/g" ./config.yaml
28
+
29
+ # Suggested volumes
30
+ VOLUME /bark/bark-gui/assets/prompts/custom
31
+ VOLUME /bark/bark-gui/models
32
+ VOLUME /bark/.cache/huggingface/hub
33
+
34
+ # Default port for web-ui
35
+ EXPOSE 7860/tcp
36
+
37
+ # Start script
38
+ CMD python3 webui.py
README.md CHANGED
@@ -1,10 +1,21 @@
1
- ---
2
- title: Barkclone
3
- emoji: 📈
4
- colorFrom: green
5
- colorTo: green
6
- sdk: docker
7
- pinned: false
8
- ---
9
-
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Bark-Voice-Cloning 🐶
2
+
3
+ Based on [bark-gui](https://github.com/C0untFloyd/bark-gui). Thanks to [C0untFloyd](https://github.com/C0untFloyd).
4
+
5
+ Quick start: [Colab Notebook](https://colab.research.google.com/github/KevinWang676/Bark-Voice-Cloning/blob/main/Bark_Voice_Cloning_UI.ipynb) ⚡
6
+
7
+ HuggingFace Demo: [Bark Voice Cloning](https://huggingface.co/spaces/kevinwang676/Bark-Voice-Cloning) 🤗
8
+
9
+ ### If you like the quick start, please star this repository. ⭐⭐⭐
10
+
11
+ ## Easy to use:
12
+
13
+ (1) First upload audio for voice cloning and click `Create Voice`.
14
+
15
+ ![image](https://github.com/KevinWang676/Bark-Voice-Cloning/assets/126712357/65e2b695-f529-4fb5-9549-4e86e6a4d8b2)
16
+
17
+ (2) Choose the option called "file" in `Voice` if you'd like to use voice cloning.
18
+
19
+ (3) Click `Generate`. Done!
20
+
21
+ ![image](https://github.com/KevinWang676/Bark-Voice-Cloning/assets/126712357/20911e37-768d-47d5-bb86-d12a3ab04c5d)
app.py ADDED
@@ -0,0 +1,468 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from cProfile import label
2
+ import dataclasses
3
+ from distutils.command.check import check
4
+ from doctest import Example
5
+ import gradio as gr
6
+ import os
7
+ import sys
8
+ import numpy as np
9
+ import logging
10
+ import torch
11
+ import pytorch_seed
12
+ import time
13
+
14
+ from xml.sax import saxutils
15
+ from bark.api import generate_with_settings
16
+ from bark.api import save_as_prompt
17
+ from util.settings import Settings
18
+ #import nltk
19
+
20
+ from bark import SAMPLE_RATE
21
+ from cloning.clonevoice import clone_voice
22
+ from bark.generation import SAMPLE_RATE, preload_models, _load_history_prompt, codec_decode
23
+ from scipy.io.wavfile import write as write_wav
24
+ from util.parseinput import split_and_recombine_text, build_ssml, is_ssml, create_clips_from_ssml
25
+ from datetime import datetime
26
+ from tqdm.auto import tqdm
27
+ from util.helper import create_filename, add_id3_tag
28
+ from swap_voice import swap_voice_from_audio
29
+ from training.training_prepare import prepare_semantics_from_text, prepare_wavs_from_semantics
30
+ from training.train import training_prepare_files, train
31
+
32
+ settings = Settings('config.yaml')
33
+
34
+
35
+ def generate_text_to_speech(text, selected_speaker, text_temp, waveform_temp, eos_prob, quick_generation, complete_settings, seed, batchcount, progress=gr.Progress(track_tqdm=True)):
36
+ # Chunk the text into smaller pieces then combine the generated audio
37
+
38
+ # generation settings
39
+ if selected_speaker == 'None':
40
+ selected_speaker = None
41
+
42
+ voice_name = selected_speaker
43
+
44
+ if text == None or len(text) < 1:
45
+ if selected_speaker == None:
46
+ raise gr.Error('No text entered!')
47
+
48
+ # Extract audio data from speaker if no text and speaker selected
49
+ voicedata = _load_history_prompt(voice_name)
50
+ audio_arr = codec_decode(voicedata["fine_prompt"])
51
+ result = create_filename(settings.output_folder_path, "None", "extract",".wav")
52
+ save_wav(audio_arr, result)
53
+ return result
54
+
55
+ if batchcount < 1:
56
+ batchcount = 1
57
+
58
+
59
+ silenceshort = np.zeros(int((float(settings.silence_sentence) / 1000.0) * SAMPLE_RATE), dtype=np.int16) # quarter second of silence
60
+ silencelong = np.zeros(int((float(settings.silence_speakers) / 1000.0) * SAMPLE_RATE), dtype=np.float32) # half a second of silence
61
+ use_last_generation_as_history = "Use last generation as history" in complete_settings
62
+ save_last_generation = "Save generation as Voice" in complete_settings
63
+ for l in range(batchcount):
64
+ currentseed = seed
65
+ if seed != None and seed > 2**32 - 1:
66
+ logger.warning(f"Seed {seed} > 2**32 - 1 (max), setting to random")
67
+ currentseed = None
68
+ if currentseed == None or currentseed <= 0:
69
+ currentseed = np.random.default_rng().integers(1, 2**32 - 1)
70
+ assert(0 < currentseed and currentseed < 2**32)
71
+
72
+ progress(0, desc="Generating")
73
+
74
+ full_generation = None
75
+
76
+ all_parts = []
77
+ complete_text = ""
78
+ text = text.lstrip()
79
+ if is_ssml(text):
80
+ list_speak = create_clips_from_ssml(text)
81
+ prev_speaker = None
82
+ for i, clip in tqdm(enumerate(list_speak), total=len(list_speak)):
83
+ selected_speaker = clip[0]
84
+ # Add pause break between speakers
85
+ if i > 0 and selected_speaker != prev_speaker:
86
+ all_parts += [silencelong.copy()]
87
+ prev_speaker = selected_speaker
88
+ text = clip[1]
89
+ text = saxutils.unescape(text)
90
+ if selected_speaker == "None":
91
+ selected_speaker = None
92
+
93
+ print(f"\nGenerating Text ({i+1}/{len(list_speak)}) -> {selected_speaker} (Seed {currentseed}):`{text}`")
94
+ complete_text += text
95
+ with pytorch_seed.SavedRNG(currentseed):
96
+ audio_array = generate_with_settings(text_prompt=text, voice_name=selected_speaker, semantic_temp=text_temp, coarse_temp=waveform_temp, eos_p=eos_prob)
97
+ currentseed = torch.random.initial_seed()
98
+ if len(list_speak) > 1:
99
+ filename = create_filename(settings.output_folder_path, currentseed, "audioclip",".wav")
100
+ save_wav(audio_array, filename)
101
+ add_id3_tag(filename, text, selected_speaker, currentseed)
102
+
103
+ all_parts += [audio_array]
104
+ else:
105
+ texts = split_and_recombine_text(text, settings.input_text_desired_length, settings.input_text_max_length)
106
+ for i, text in tqdm(enumerate(texts), total=len(texts)):
107
+ print(f"\nGenerating Text ({i+1}/{len(texts)}) -> {selected_speaker} (Seed {currentseed}):`{text}`")
108
+ complete_text += text
109
+ if quick_generation == True:
110
+ with pytorch_seed.SavedRNG(currentseed):
111
+ audio_array = generate_with_settings(text_prompt=text, voice_name=selected_speaker, semantic_temp=text_temp, coarse_temp=waveform_temp, eos_p=eos_prob)
112
+ currentseed = torch.random.initial_seed()
113
+ else:
114
+ full_output = use_last_generation_as_history or save_last_generation
115
+ if full_output:
116
+ full_generation, audio_array = generate_with_settings(text_prompt=text, voice_name=voice_name, semantic_temp=text_temp, coarse_temp=waveform_temp, eos_p=eos_prob, output_full=True)
117
+ else:
118
+ audio_array = generate_with_settings(text_prompt=text, voice_name=voice_name, semantic_temp=text_temp, coarse_temp=waveform_temp, eos_p=eos_prob)
119
+
120
+ # Noticed this in the HF Demo - convert to 16bit int -32767/32767 - most used audio format
121
+ # audio_array = (audio_array * 32767).astype(np.int16)
122
+
123
+ if len(texts) > 1:
124
+ filename = create_filename(settings.output_folder_path, currentseed, "audioclip",".wav")
125
+ save_wav(audio_array, filename)
126
+ add_id3_tag(filename, text, selected_speaker, currentseed)
127
+
128
+ if quick_generation == False and (save_last_generation == True or use_last_generation_as_history == True):
129
+ # save to npz
130
+ voice_name = create_filename(settings.output_folder_path, seed, "audioclip", ".npz")
131
+ save_as_prompt(voice_name, full_generation)
132
+ if use_last_generation_as_history:
133
+ selected_speaker = voice_name
134
+
135
+ all_parts += [audio_array]
136
+ # Add short pause between sentences
137
+ if text[-1] in "!?.\n" and i > 1:
138
+ all_parts += [silenceshort.copy()]
139
+
140
+ # save & play audio
141
+ result = create_filename(settings.output_folder_path, currentseed, "final",".wav")
142
+ save_wav(np.concatenate(all_parts), result)
143
+ # write id3 tag with text truncated to 60 chars, as a precaution...
144
+ add_id3_tag(result, complete_text, selected_speaker, currentseed)
145
+
146
+ return result
147
+
148
+
149
+
150
+ def save_wav(audio_array, filename):
151
+ write_wav(filename, SAMPLE_RATE, audio_array)
152
+
153
+ def save_voice(filename, semantic_prompt, coarse_prompt, fine_prompt):
154
+ np.savez_compressed(
155
+ filename,
156
+ semantic_prompt=semantic_prompt,
157
+ coarse_prompt=coarse_prompt,
158
+ fine_prompt=fine_prompt
159
+ )
160
+
161
+
162
+ def on_quick_gen_changed(checkbox):
163
+ if checkbox == False:
164
+ return gr.CheckboxGroup.update(visible=True)
165
+ return gr.CheckboxGroup.update(visible=False)
166
+
167
+ def delete_output_files(checkbox_state):
168
+ if checkbox_state:
169
+ outputs_folder = os.path.join(os.getcwd(), settings.output_folder_path)
170
+ if os.path.exists(outputs_folder):
171
+ purgedir(outputs_folder)
172
+ return False
173
+
174
+
175
+ # https://stackoverflow.com/a/54494779
176
+ def purgedir(parent):
177
+ for root, dirs, files in os.walk(parent):
178
+ for item in files:
179
+ # Delete subordinate files
180
+ filespec = os.path.join(root, item)
181
+ os.unlink(filespec)
182
+ for item in dirs:
183
+ # Recursively perform this operation for subordinate directories
184
+ purgedir(os.path.join(root, item))
185
+
186
+ def convert_text_to_ssml(text, selected_speaker):
187
+ return build_ssml(text, selected_speaker)
188
+
189
+
190
+ def training_prepare(selected_step, num_text_generations, progress=gr.Progress(track_tqdm=True)):
191
+ if selected_step == prepare_training_list[0]:
192
+ prepare_semantics_from_text()
193
+ else:
194
+ prepare_wavs_from_semantics()
195
+ return None
196
+
197
+
198
+ def start_training(save_model_epoch, max_epochs, progress=gr.Progress(track_tqdm=True)):
199
+ training_prepare_files("./training/data/", "./training/data/checkpoint/hubert_base_ls960.pt")
200
+ train("./training/data/", save_model_epoch, max_epochs)
201
+ return None
202
+
203
+
204
+
205
+ def apply_settings(themes, input_server_name, input_server_port, input_server_public, input_desired_len, input_max_len, input_silence_break, input_silence_speaker):
206
+ settings.selected_theme = themes
207
+ settings.server_name = input_server_name
208
+ settings.server_port = input_server_port
209
+ settings.server_share = input_server_public
210
+ settings.input_text_desired_length = input_desired_len
211
+ settings.input_text_max_length = input_max_len
212
+ settings.silence_sentence = input_silence_break
213
+ settings.silence_speaker = input_silence_speaker
214
+ settings.save()
215
+
216
+ def restart():
217
+ global restart_server
218
+ restart_server = True
219
+
220
+
221
+ def create_version_html():
222
+ python_version = ".".join([str(x) for x in sys.version_info[0:3]])
223
+ versions_html = f"""
224
+ python: <span title="{sys.version}">{python_version}</span>
225
+  • 
226
+ torch: {getattr(torch, '__long_version__',torch.__version__)}
227
+  • 
228
+ gradio: {gr.__version__}
229
+ """
230
+ return versions_html
231
+
232
+
233
+
234
+ logger = logging.getLogger(__name__)
235
+ APPTITLE = "Bark UI Enhanced v0.7"
236
+
237
+
238
+ autolaunch = False
239
+
240
+ if len(sys.argv) > 1:
241
+ autolaunch = "-autolaunch" in sys.argv
242
+
243
+
244
+ if torch.cuda.is_available() == False:
245
+ os.environ['BARK_FORCE_CPU'] = 'True'
246
+ logger.warning("No CUDA detected, fallback to CPU!")
247
+
248
+ print(f'smallmodels={os.environ.get("SUNO_USE_SMALL_MODELS", False)}')
249
+ print(f'enablemps={os.environ.get("SUNO_ENABLE_MPS", False)}')
250
+ print(f'offloadcpu={os.environ.get("SUNO_OFFLOAD_CPU", False)}')
251
+ print(f'forcecpu={os.environ.get("BARK_FORCE_CPU", False)}')
252
+ print(f'autolaunch={autolaunch}\n\n')
253
+
254
+ #print("Updating nltk\n")
255
+ #nltk.download('punkt')
256
+
257
+ print("Preloading Models\n")
258
+ preload_models()
259
+
260
+ available_themes = ["Default", "gradio/glass", "gradio/monochrome", "gradio/seafoam", "gradio/soft", "gstaff/xkcd", "freddyaboulton/dracula_revamped", "ysharma/steampunk"]
261
+ tokenizer_language_list = ["de","en", "pl"]
262
+ prepare_training_list = ["Step 1: Semantics from Text","Step 2: WAV from Semantics"]
263
+
264
+ seed = -1
265
+ server_name = settings.server_name
266
+ if len(server_name) < 1:
267
+ server_name = None
268
+ server_port = settings.server_port
269
+ if server_port <= 0:
270
+ server_port = None
271
+ global run_server
272
+ global restart_server
273
+
274
+ run_server = True
275
+
276
+ while run_server:
277
+ # Collect all existing speakers/voices in dir
278
+ speakers_list = []
279
+
280
+ for root, dirs, files in os.walk("./bark/assets/prompts"):
281
+ for file in files:
282
+ if file.endswith(".npz"):
283
+ pathpart = root.replace("./bark/assets/prompts", "")
284
+ name = os.path.join(pathpart, file[:-4])
285
+ if name.startswith("/") or name.startswith("\\"):
286
+ name = name[1:]
287
+ speakers_list.append(name)
288
+
289
+ speakers_list = sorted(speakers_list, key=lambda x: x.lower())
290
+ speakers_list.insert(0, 'None')
291
+
292
+ print(f'Launching {APPTITLE} Server')
293
+
294
+ # Create Gradio Blocks
295
+
296
+ with gr.Blocks(title=f"{APPTITLE}", mode=f"{APPTITLE}", theme=settings.selected_theme) as barkgui:
297
+ with gr.Row():
298
+ with gr.Column():
299
+ gr.Markdown(f"### [{APPTITLE}](https://github.com/C0untFloyd/bark-gui)")
300
+ with gr.Column():
301
+ gr.HTML(create_version_html(), elem_id="versions")
302
+
303
+ with gr.Tab("Clone Voice"):
304
+ with gr.Row():
305
+ input_audio_filename = gr.Audio(label="Input audio.wav", source="upload", type="filepath")
306
+ #transcription_text = gr.Textbox(label="Transcription Text", lines=1, placeholder="Enter Text of your Audio Sample here...")
307
+ with gr.Row():
308
+ with gr.Column():
309
+ initialname = "/content/Bark-Voice-Cloning/bark/assets/prompts/file"
310
+ output_voice = gr.Textbox(label="Filename of trained Voice (do not change the initial name)", lines=1, placeholder=initialname, value=initialname)
311
+ with gr.Column():
312
+ tokenizerlang = gr.Dropdown(tokenizer_language_list, label="Base Language Tokenizer", value=tokenizer_language_list[1])
313
+ with gr.Row():
314
+ clone_voice_button = gr.Button("Create Voice")
315
+ with gr.Row():
316
+ dummy = gr.Text(label="Progress")
317
+ npz_file = gr.File(label=".npz file")
318
+ speakers_list.insert(0, npz_file) # add prompt
319
+
320
+ with gr.Tab("TTS"):
321
+ with gr.Row():
322
+ with gr.Column():
323
+ placeholder = "Enter text here."
324
+ input_text = gr.Textbox(label="Input Text", lines=4, placeholder=placeholder)
325
+ with gr.Column():
326
+ seedcomponent = gr.Number(label="Seed (default -1 = Random)", precision=0, value=-1)
327
+ batchcount = gr.Number(label="Batch count", precision=0, value=1)
328
+ with gr.Row():
329
+ with gr.Column():
330
+ examples = [
331
+ "Special meanings: [laughter] [laughs] [sighs] [music] [gasps] [clears throat] MAN: WOMAN:",
332
+ "♪ Never gonna make you cry, never gonna say goodbye, never gonna tell a lie and hurt you ♪",
333
+ "And now — a picture of a larch [laughter]",
334
+ """
335
+ WOMAN: I would like an oatmilk latte please.
336
+ MAN: Wow, that's expensive!
337
+ """,
338
+ """<?xml version="1.0"?>
339
+ <speak version="1.0" xmlns="http://www.w3.org/2001/10/synthesis"
340
+ xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
341
+ xsi:schemaLocation="http://www.w3.org/2001/10/synthesis
342
+ http://www.w3.org/TR/speech-synthesis/synthesis.xsd"
343
+ xml:lang="en-US">
344
+ <voice name="/v2/en_speaker_9">Look at that drunk guy!</voice>
345
+ <voice name="/v2/en_speaker_3">Who is he?</voice>
346
+ <voice name="/v2/en_speaker_9">WOMAN: [clears throat] 10 years ago, he proposed me and I rejected him.</voice>
347
+ <voice name="/v2/en_speaker_3">Oh my God [laughs] he is still celebrating</voice>
348
+ </speak>"""
349
+ ]
350
+ examples = gr.Examples(examples=examples, inputs=input_text)
351
+ with gr.Column():
352
+ convert_to_ssml_button = gr.Button("Convert Input Text to SSML")
353
+
354
+ with gr.Row():
355
+ with gr.Column():
356
+ gr.Markdown("[Voice Prompt Library](https://suno-ai.notion.site/8b8e8749ed514b0cbf3f699013548683?v=bc67cff786b04b50b3ceb756fd05f68c)")
357
+ speaker = gr.Dropdown(speakers_list, value=speakers_list[0], label="Voice")
358
+
359
+ with gr.Column():
360
+ text_temp = gr.Slider(0.1, 1.0, value=0.6, label="Generation Temperature", info="1.0 more diverse, 0.1 more conservative")
361
+ waveform_temp = gr.Slider(0.1, 1.0, value=0.7, label="Waveform temperature", info="1.0 more diverse, 0.1 more conservative")
362
+
363
+ with gr.Row():
364
+ with gr.Column():
365
+ quick_gen_checkbox = gr.Checkbox(label="Quick Generation", value=True)
366
+ settings_checkboxes = ["Use last generation as history", "Save generation as Voice"]
367
+ complete_settings = gr.CheckboxGroup(choices=settings_checkboxes, value=settings_checkboxes, label="Detailed Generation Settings", type="value", interactive=True, visible=False)
368
+ with gr.Column():
369
+ eos_prob = gr.Slider(0.0, 0.5, value=0.05, label="End of sentence probability")
370
+
371
+ with gr.Row():
372
+ with gr.Column():
373
+ tts_create_button = gr.Button("Generate")
374
+ with gr.Column():
375
+ hidden_checkbox = gr.Checkbox(visible=False)
376
+ button_stop_generation = gr.Button("Stop generation")
377
+ with gr.Row():
378
+ output_audio = gr.Audio(label="Generated Audio", type="filepath")
379
+
380
+ with gr.Tab("Swap Voice"):
381
+ with gr.Row():
382
+ swap_audio_filename = gr.Audio(label="Input audio.wav to swap voice", source="upload", type="filepath")
383
+ with gr.Row():
384
+ with gr.Column():
385
+ swap_tokenizer_lang = gr.Dropdown(tokenizer_language_list, label="Base Language Tokenizer", value=tokenizer_language_list[1])
386
+ swap_seed = gr.Number(label="Seed (default -1 = Random)", precision=0, value=-1)
387
+ with gr.Column():
388
+ speaker_swap = gr.Dropdown(speakers_list, value=speakers_list[0], label="Voice")
389
+ swap_batchcount = gr.Number(label="Batch count", precision=0, value=1)
390
+ with gr.Row():
391
+ swap_voice_button = gr.Button("Swap Voice")
392
+ with gr.Row():
393
+ output_swap = gr.Audio(label="Generated Audio", type="filepath")
394
+
395
+ with gr.Tab("Training Data Prepare"):
396
+ gr.Markdown("This tab should be used to generate the training dataset. For Step 1 put some books into the inputtext folder in UTF-8 Text Format.")
397
+ prepare_semantics_number = gr.Number(label="Number of semantics to create", precision=0, value=3079)
398
+ prepare_dropdown = gr.Dropdown(prepare_training_list, value=prepare_training_list[0], label="Prepare")
399
+ training_prepare_button = gr.Button("Generate")
400
+ dummytrd = gr.Text(label="Progress")
401
+
402
+ with gr.Tab("Training"):
403
+ with gr.Row():
404
+ gr.Markdown("This tab is used to train the actual model (language).")
405
+ with gr.Row():
406
+ with gr.Column():
407
+ save_model_epoch = gr.Number(label="Auto-save model after number of epochs", precision=0, value=1)
408
+ with gr.Column():
409
+ max_epochs = gr.Number(label="Train for number of epochs", precision=0, value=6)
410
+ with gr.Row():
411
+ with gr.Column():
412
+ allowed_chars = ' abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!@#$%^&*()-_+=\"\':;[]{}/<>,.`~'
413
+ allowedcharsfilter = gr.Textbox(label="Allowed chars for text input", lines=1, value=allowed_chars)
414
+ with gr.Column():
415
+ train_button = gr.Button("Start Training")
416
+ with gr.Row():
417
+ dummytrain = gr.Text(label="Progress")
418
+
419
+
420
+ with gr.Tab("Settings"):
421
+ with gr.Row():
422
+ themes = gr.Dropdown(available_themes, label="Theme", info="Change needs complete restart", value=settings.selected_theme)
423
+ with gr.Row():
424
+ input_server_name = gr.Textbox(label="Server Name", lines=1, info="Leave blank to run locally", value=settings.server_name)
425
+ input_server_port = gr.Number(label="Server Port", precision=0, info="Leave at 0 to use default", value=settings.server_port)
426
+ share_checkbox = gr.Checkbox(label="Public Server", value=settings.server_share)
427
+ with gr.Row():
428
+ input_desired_len = gr.Slider(100, 150, value=settings.input_text_desired_length, label="Desired Input Text Length", info="Ideal length to split input sentences")
429
+ input_max_len = gr.Slider(150, 256, value=settings.input_text_max_length, label="Max Input Text Length", info="Maximum Input Text Length")
430
+ with gr.Row():
431
+ input_silence_break = gr.Slider(1, 1000, value=settings.silence_sentence, label="Sentence Pause Time (ms)", info="Silence between sentences in milliseconds")
432
+ input_silence_speakers = gr.Slider(1, 5000, value=settings.silence_speakers, label="Speaker Pause Time (ms)", info="Silence between different speakers in milliseconds")
433
+
434
+ with gr.Row():
435
+ button_apply_settings = gr.Button("Apply Settings")
436
+ button_apply_restart = gr.Button("Restart Server")
437
+ button_delete_files = gr.Button("Clear output folder")
438
+
439
+ quick_gen_checkbox.change(fn=on_quick_gen_changed, inputs=quick_gen_checkbox, outputs=complete_settings)
440
+ convert_to_ssml_button.click(convert_text_to_ssml, inputs=[input_text, speaker],outputs=input_text)
441
+ gen_click = tts_create_button.click(generate_text_to_speech, inputs=[input_text, speaker, text_temp, waveform_temp, eos_prob, quick_gen_checkbox, complete_settings, seedcomponent, batchcount],outputs=output_audio)
442
+ button_stop_generation.click(fn=None, inputs=None, outputs=None, cancels=[gen_click])
443
+
444
+ # Javascript hack to display modal confirmation dialog
445
+ js = "(x) => confirm('Are you sure? This will remove all files from output folder')"
446
+ button_delete_files.click(None, None, hidden_checkbox, _js=js)
447
+ hidden_checkbox.change(delete_output_files, [hidden_checkbox], [hidden_checkbox])
448
+
449
+ swap_voice_button.click(swap_voice_from_audio, inputs=[swap_audio_filename, speaker_swap, swap_tokenizer_lang, swap_seed, swap_batchcount], outputs=output_swap)
450
+ clone_voice_button.click(clone_voice, inputs=[input_audio_filename, tokenizerlang, output_voice], outputs=[dummy, npz_file])
451
+ training_prepare_button.click(training_prepare, inputs=[prepare_dropdown, prepare_semantics_number], outputs=dummytrd)
452
+ train_button.click(start_training, inputs=[save_model_epoch, max_epochs], outputs=dummytrain)
453
+ button_apply_settings.click(apply_settings, inputs=[themes, input_server_name, input_server_port, share_checkbox, input_desired_len, input_max_len, input_silence_break, input_silence_speakers])
454
+ button_apply_restart.click(restart)
455
+
456
+ restart_server = False
457
+ try:
458
+ barkgui.queue().launch(show_error=True)
459
+ except:
460
+ restart_server = True
461
+ run_server = False
462
+ try:
463
+ while restart_server == False:
464
+ time.sleep(1.0)
465
+ except (KeyboardInterrupt, OSError):
466
+ print("Keyboard interruption in main thread... closing server.")
467
+ run_server = False
468
+ barkgui.close()
bark/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .api import generate_audio, text_to_semantic, semantic_to_waveform, save_as_prompt
2
+ from .generation import SAMPLE_RATE, preload_models
bark/api.py ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Optional, Union
2
+
3
+ import numpy as np
4
+
5
+ from .generation import codec_decode, generate_coarse, generate_fine, generate_text_semantic
6
+
7
+
8
+ def generate_with_settings(text_prompt, semantic_temp=0.6, eos_p=0.2, coarse_temp=0.7, fine_temp=0.5, voice_name=None, output_full=False):
9
+
10
+ # generation with more control
11
+ x_semantic = generate_text_semantic(
12
+ text_prompt,
13
+ history_prompt=voice_name,
14
+ temp=semantic_temp,
15
+ min_eos_p = eos_p,
16
+ use_kv_caching=True
17
+ )
18
+
19
+ x_coarse_gen = generate_coarse(
20
+ x_semantic,
21
+ history_prompt=voice_name,
22
+ temp=coarse_temp,
23
+ use_kv_caching=True
24
+ )
25
+ x_fine_gen = generate_fine(
26
+ x_coarse_gen,
27
+ history_prompt=voice_name,
28
+ temp=fine_temp,
29
+ )
30
+
31
+ if output_full:
32
+ full_generation = {
33
+ 'semantic_prompt': x_semantic,
34
+ 'coarse_prompt': x_coarse_gen,
35
+ 'fine_prompt': x_fine_gen
36
+ }
37
+ return full_generation, codec_decode(x_fine_gen)
38
+ return codec_decode(x_fine_gen)
39
+
40
+
41
+ def text_to_semantic(
42
+ text: str,
43
+ history_prompt: Optional[Union[Dict, str]] = None,
44
+ temp: float = 0.7,
45
+ silent: bool = False,
46
+ ):
47
+ """Generate semantic array from text.
48
+
49
+ Args:
50
+ text: text to be turned into audio
51
+ history_prompt: history choice for audio cloning
52
+ temp: generation temperature (1.0 more diverse, 0.0 more conservative)
53
+ silent: disable progress bar
54
+
55
+ Returns:
56
+ numpy semantic array to be fed into `semantic_to_waveform`
57
+ """
58
+ x_semantic = generate_text_semantic(
59
+ text,
60
+ history_prompt=history_prompt,
61
+ temp=temp,
62
+ silent=silent,
63
+ use_kv_caching=True
64
+ )
65
+ return x_semantic
66
+
67
+
68
+ def semantic_to_waveform(
69
+ semantic_tokens: np.ndarray,
70
+ history_prompt: Optional[Union[Dict, str]] = None,
71
+ temp: float = 0.7,
72
+ silent: bool = False,
73
+ output_full: bool = False,
74
+ ):
75
+ """Generate audio array from semantic input.
76
+
77
+ Args:
78
+ semantic_tokens: semantic token output from `text_to_semantic`
79
+ history_prompt: history choice for audio cloning
80
+ temp: generation temperature (1.0 more diverse, 0.0 more conservative)
81
+ silent: disable progress bar
82
+ output_full: return full generation to be used as a history prompt
83
+
84
+ Returns:
85
+ numpy audio array at sample frequency 24khz
86
+ """
87
+ coarse_tokens = generate_coarse(
88
+ semantic_tokens,
89
+ history_prompt=history_prompt,
90
+ temp=temp,
91
+ silent=silent,
92
+ use_kv_caching=True
93
+ )
94
+ fine_tokens = generate_fine(
95
+ coarse_tokens,
96
+ history_prompt=history_prompt,
97
+ temp=0.5,
98
+ )
99
+ audio_arr = codec_decode(fine_tokens)
100
+ if output_full:
101
+ full_generation = {
102
+ "semantic_prompt": semantic_tokens,
103
+ "coarse_prompt": coarse_tokens,
104
+ "fine_prompt": fine_tokens,
105
+ }
106
+ return full_generation, audio_arr
107
+ return audio_arr
108
+
109
+
110
+ def save_as_prompt(filepath, full_generation):
111
+ assert(filepath.endswith(".npz"))
112
+ assert(isinstance(full_generation, dict))
113
+ assert("semantic_prompt" in full_generation)
114
+ assert("coarse_prompt" in full_generation)
115
+ assert("fine_prompt" in full_generation)
116
+ np.savez(filepath, **full_generation)
117
+
118
+
119
+ def generate_audio(
120
+ text: str,
121
+ history_prompt: Optional[Union[Dict, str]] = None,
122
+ text_temp: float = 0.7,
123
+ waveform_temp: float = 0.7,
124
+ silent: bool = False,
125
+ output_full: bool = False,
126
+ ):
127
+ """Generate audio array from input text.
128
+
129
+ Args:
130
+ text: text to be turned into audio
131
+ history_prompt: history choice for audio cloning
132
+ text_temp: generation temperature (1.0 more diverse, 0.0 more conservative)
133
+ waveform_temp: generation temperature (1.0 more diverse, 0.0 more conservative)
134
+ silent: disable progress bar
135
+ output_full: return full generation to be used as a history prompt
136
+
137
+ Returns:
138
+ numpy audio array at sample frequency 24khz
139
+ """
140
+ semantic_tokens = text_to_semantic(
141
+ text,
142
+ history_prompt=history_prompt,
143
+ temp=text_temp,
144
+ silent=silent,
145
+ )
146
+ out = semantic_to_waveform(
147
+ semantic_tokens,
148
+ history_prompt=history_prompt,
149
+ temp=waveform_temp,
150
+ silent=silent,
151
+ output_full=output_full,
152
+ )
153
+ if output_full:
154
+ full_generation, audio_arr = out
155
+ return full_generation, audio_arr
156
+ else:
157
+ audio_arr = out
158
+ return audio_arr
bark/assets/prompts/announcer.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:26f2d1a9e3b6fe453cf5fc8191de26cbfae6276c5b0f7c376c6a0f3c35867f83
3
+ size 16794
bark/assets/prompts/file.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d86d7af8d8adc44f3a9fb4c0bb47249e066c379d523933e05c46936d8d6113cd
3
+ size 35092
bark/assets/prompts/v2/en_speaker_0.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:932f40d879ba8659f1ca26319ba64ea3b0647b2050fe24313bf42b0dff1fe241
3
+ size 28100
bark/assets/prompts/v2/en_speaker_1.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5e7f18015e1ab9b6302ded1e28a971af5306a72f193bb6c411f1948a083c8578
3
+ size 25220
bark/assets/prompts/v2/en_speaker_2.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d218990680ece5f2d4fc18ea4783b016b3ae353ec413eaee2058f2d57263c9b3
3
+ size 26236
bark/assets/prompts/v2/en_speaker_3.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:92c2e2a29145c83738e9b63f082fd1c873d9422468a155463cb27f814aeaea66
3
+ size 34980
bark/assets/prompts/v2/en_speaker_4.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:992f91991a9a5359d72f00b09a11a550e71bb8ebfc0cfd877e39d7d41f98b714
3
+ size 23780
bark/assets/prompts/v2/en_speaker_5.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:18831c3f6014e4a2ff60ad5169b1fae06e28ed07f43f8a3616aafb84515091bf
3
+ size 24740
bark/assets/prompts/v2/en_speaker_6.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fab38dc6b6bc9226bcc414f4c5a9524bc1b2441865a586153fb620127a8faa4e
3
+ size 25540
bark/assets/prompts/v2/en_speaker_7.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8f4c4eb33f5994be8de5cfd1744ebce13da1618a6da3a7d244514178c61ef7db
3
+ size 22716
bark/assets/prompts/v2/en_speaker_8.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8fc9f11b539588f51bbf78150a73e0365c49b2306bd72e5a22b28ef09c4fb15d
3
+ size 23300
bark/assets/prompts/v2/en_speaker_9.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:78b3ba32eb9aeb9ed34556856c40633ecc8332d1c3ae3c81e6f5015ac3eefbd5
3
+ size 30180
bark/assets/prompts/v2/zh_speaker_0.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bd7ac118a3e944b3f20c89f2446056a00850a630ee16318922acc6572ce80929
3
+ size 20636
bark/assets/prompts/v2/zh_speaker_1.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0eacf5c862dfd3c5ac825f2ebb26f323e64309cb712e7e264cbd31c5bca3f038
3
+ size 19836
bark/assets/prompts/v2/zh_speaker_2.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e324b47f8250e5798c314f395d4e049575e7ca369d0b6074e91c7bba70e9f26d
3
+ size 21060
bark/assets/prompts/v2/zh_speaker_3.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:98c476abc7bf634ffb2d71d363284e7bd8c8abd5e33ec5ca21d4aa5b15730d18
3
+ size 31300
bark/assets/prompts/v2/zh_speaker_4.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1fa8673a9895ad3302d13ac94193b5ad5da481f1cc276e6181fa895acaae133b
3
+ size 29964
bark/assets/prompts/v2/zh_speaker_5.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:226edfe5fabc72eeb83a13e350599bc8babe5adc2264b3cdb661fd1258dc4044
3
+ size 17436
bark/assets/prompts/v2/zh_speaker_6.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:285d51fbe81cc263636b5b487fbb6633e6f3cf92c53ca9ab8e6b7f55d4b4a31d
3
+ size 16900
bark/assets/prompts/v2/zh_speaker_7.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0967cdb14ffa79895747b0d52df9f15bdad80d6c55b7630894345c9a7ec87c91
3
+ size 21060
bark/assets/prompts/v2/zh_speaker_8.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c028f78530013f29ab8c0c1cf4fe2138106fbe5252951f5f36e0168056779549
3
+ size 19300
bark/assets/prompts/v2/zh_speaker_9.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6265bb827008d7af8a45a8e057fe3e91efb347d56208180a9ed990ad54e4d75e
3
+ size 16156
bark/generation.py ADDED
@@ -0,0 +1,864 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextlib
2
+ import gc
3
+ import os
4
+ import re
5
+ import requests
6
+ import gc
7
+ import sys
8
+
9
+ from encodec import EncodecModel
10
+ import funcy
11
+ import logging
12
+ import numpy as np
13
+ from scipy.special import softmax
14
+ import torch
15
+ import torch.nn.functional as F
16
+ import tqdm
17
+ from transformers import BertTokenizer
18
+ from huggingface_hub import hf_hub_download, hf_hub_url
19
+
20
+ from .model import GPTConfig, GPT
21
+ from .model_fine import FineGPT, FineGPTConfig
22
+ from .settings import initenv
23
+
24
+ initenv(sys.argv)
25
+ global_force_cpu = os.environ.get("BARK_FORCE_CPU", False)
26
+ if (
27
+ global_force_cpu != True and
28
+ torch.cuda.is_available() and
29
+ hasattr(torch.cuda, "amp") and
30
+ hasattr(torch.cuda.amp, "autocast") and
31
+ hasattr(torch.cuda, "is_bf16_supported") and
32
+ torch.cuda.is_bf16_supported()
33
+ ):
34
+ autocast = funcy.partial(torch.cuda.amp.autocast, dtype=torch.bfloat16)
35
+ else:
36
+ @contextlib.contextmanager
37
+ def autocast():
38
+ yield
39
+
40
+
41
+ # hold models in global scope to lazy load
42
+ global models
43
+ models = {}
44
+
45
+ global models_devices
46
+ models_devices = {}
47
+
48
+
49
+ CONTEXT_WINDOW_SIZE = 1024
50
+
51
+ SEMANTIC_RATE_HZ = 49.9
52
+ SEMANTIC_VOCAB_SIZE = 10_000
53
+
54
+ CODEBOOK_SIZE = 1024
55
+ N_COARSE_CODEBOOKS = 2
56
+ N_FINE_CODEBOOKS = 8
57
+ COARSE_RATE_HZ = 75
58
+
59
+ SAMPLE_RATE = 24_000
60
+
61
+
62
+ SUPPORTED_LANGS = [
63
+ ("English", "en"),
64
+ ("German", "de"),
65
+ ("Spanish", "es"),
66
+ ("French", "fr"),
67
+ ("Hindi", "hi"),
68
+ ("Italian", "it"),
69
+ ("Japanese", "ja"),
70
+ ("Korean", "ko"),
71
+ ("Polish", "pl"),
72
+ ("Portuguese", "pt"),
73
+ ("Russian", "ru"),
74
+ ("Turkish", "tr"),
75
+ ("Chinese", "zh"),
76
+ ]
77
+
78
+ ALLOWED_PROMPTS = {"announcer"}
79
+ for _, lang in SUPPORTED_LANGS:
80
+ for prefix in ("", f"v2{os.path.sep}"):
81
+ for n in range(10):
82
+ ALLOWED_PROMPTS.add(f"{prefix}{lang}_speaker_{n}")
83
+
84
+
85
+ logger = logging.getLogger(__name__)
86
+
87
+
88
+ CUR_PATH = os.path.dirname(os.path.abspath(__file__))
89
+
90
+
91
+ #default_cache_dir = os.path.join(os.path.expanduser("~"), ".cache")
92
+ #CACHE_DIR = os.path.join(os.getenv("XDG_CACHE_HOME", default_cache_dir), "suno", "bark_v0")
93
+ #CACHE_DIR = os.path.join(os.getcwd(), "models"
94
+ CACHE_DIR = "./models"
95
+
96
+
97
+ def _cast_bool_env_var(s):
98
+ return s.lower() in ('true', '1', 't')
99
+
100
+ USE_SMALL_MODELS = _cast_bool_env_var(os.environ.get("SUNO_USE_SMALL_MODELS", "False"))
101
+ GLOBAL_ENABLE_MPS = _cast_bool_env_var(os.environ.get("SUNO_ENABLE_MPS", "False"))
102
+ OFFLOAD_CPU = _cast_bool_env_var(os.environ.get("SUNO_OFFLOAD_CPU", "False"))
103
+
104
+ REMOTE_MODEL_PATHS = {
105
+ "text_small": {
106
+ "repo_id": "suno/bark",
107
+ "file_name": "text.pt",
108
+ },
109
+ "coarse_small": {
110
+ "repo_id": "suno/bark",
111
+ "file_name": "coarse.pt",
112
+ },
113
+ "fine_small": {
114
+ "repo_id": "suno/bark",
115
+ "file_name": "fine.pt",
116
+ },
117
+ "text": {
118
+ "repo_id": "suno/bark",
119
+ "file_name": "text_2.pt",
120
+ },
121
+ "coarse": {
122
+ "repo_id": "suno/bark",
123
+ "file_name": "coarse_2.pt",
124
+ },
125
+ "fine": {
126
+ "repo_id": "suno/bark",
127
+ "file_name": "fine_2.pt",
128
+ },
129
+ }
130
+
131
+
132
+ if not hasattr(torch.nn.functional, 'scaled_dot_product_attention') and torch.cuda.is_available():
133
+ logger.warning(
134
+ "torch version does not support flash attention. You will get faster" +
135
+ " inference speed by upgrade torch to newest nightly version."
136
+ )
137
+
138
+
139
+ def grab_best_device(use_gpu=True):
140
+ if torch.cuda.device_count() > 0 and use_gpu:
141
+ device = "cuda"
142
+ elif torch.backends.mps.is_available() and use_gpu and GLOBAL_ENABLE_MPS:
143
+ device = "mps"
144
+ else:
145
+ device = "cpu"
146
+ return device
147
+
148
+
149
+ def _get_ckpt_path(model_type, use_small=False):
150
+ key = model_type
151
+ if use_small or USE_SMALL_MODELS:
152
+ key += "_small"
153
+ return os.path.join(CACHE_DIR, REMOTE_MODEL_PATHS[key]["file_name"])
154
+
155
+ """
156
+ def _download(from_hf_path, file_name, destfilename):
157
+ os.makedirs(CACHE_DIR, exist_ok=True)
158
+ hf_hub_download(repo_id=from_hf_path, filename=file_name, local_dir=CACHE_DIR, local_dir_use_symlinks=False)
159
+ # Bug in original repo? Downloaded name differs from expected...
160
+ if not os.path.exists(destfilename):
161
+ localname = os.path.join(CACHE_DIR, file_name)
162
+ os.rename(localname, destfilename)
163
+ """
164
+ def _download(from_hf_path, file_name):
165
+ os.makedirs(CACHE_DIR, exist_ok=True)
166
+ hf_hub_download(repo_id=from_hf_path, filename=file_name, local_dir=CACHE_DIR)
167
+
168
+
169
+ class InferenceContext:
170
+ def __init__(self, benchmark=False):
171
+ # we can't expect inputs to be the same length, so disable benchmarking by default
172
+ self._chosen_cudnn_benchmark = benchmark
173
+ self._cudnn_benchmark = None
174
+
175
+ def __enter__(self):
176
+ self._cudnn_benchmark = torch.backends.cudnn.benchmark
177
+ torch.backends.cudnn.benchmark = self._chosen_cudnn_benchmark
178
+
179
+ def __exit__(self, exc_type, exc_value, exc_traceback):
180
+ torch.backends.cudnn.benchmark = self._cudnn_benchmark
181
+
182
+
183
+ if torch.cuda.is_available():
184
+ torch.backends.cuda.matmul.allow_tf32 = True
185
+ torch.backends.cudnn.allow_tf32 = True
186
+
187
+
188
+ @contextlib.contextmanager
189
+ def _inference_mode():
190
+ with InferenceContext(), torch.inference_mode(), torch.no_grad(), autocast():
191
+ yield
192
+
193
+
194
+ def _clear_cuda_cache():
195
+ if torch.cuda.is_available():
196
+ torch.cuda.empty_cache()
197
+ torch.cuda.synchronize()
198
+
199
+
200
+ def clean_models(model_key=None):
201
+ global models
202
+ model_keys = [model_key] if model_key is not None else models.keys()
203
+ for k in model_keys:
204
+ if k in models:
205
+ del models[k]
206
+ _clear_cuda_cache()
207
+ gc.collect()
208
+
209
+
210
+ def _load_model(ckpt_path, device, use_small=False, model_type="text"):
211
+ if model_type == "text":
212
+ ConfigClass = GPTConfig
213
+ ModelClass = GPT
214
+ elif model_type == "coarse":
215
+ ConfigClass = GPTConfig
216
+ ModelClass = GPT
217
+ elif model_type == "fine":
218
+ ConfigClass = FineGPTConfig
219
+ ModelClass = FineGPT
220
+ else:
221
+ raise NotImplementedError()
222
+
223
+ # Force-remove Models to allow running on >12Gb GPU
224
+ # CF: Probably not needed anymore
225
+ #global models
226
+ #models.clear()
227
+ #gc.collect()
228
+ #torch.cuda.empty_cache()
229
+ # to here...
230
+
231
+ model_key = f"{model_type}_small" if use_small or USE_SMALL_MODELS else model_type
232
+ model_info = REMOTE_MODEL_PATHS[model_key]
233
+ if not os.path.exists(ckpt_path):
234
+ logger.info(f"{model_type} model not found, downloading into `{CACHE_DIR}`.")
235
+ ## added next two lines to make it super clear which model is being downloaded
236
+ remote_filename = hf_hub_url(model_info["repo_id"], model_info["file_name"])
237
+ print(f"Downloading {model_key} {model_info['repo_id']} remote model file {remote_filename} {model_info['file_name']} to {CACHE_DIR}")
238
+ _download(model_info["repo_id"], model_info["file_name"])
239
+ # add next line to make it super clear which model is being loaded
240
+ print(f"Loading {model_key} model from {ckpt_path} to {device}") # added
241
+ checkpoint = torch.load(ckpt_path, map_location=device)
242
+ # this is a hack
243
+ model_args = checkpoint["model_args"]
244
+ if "input_vocab_size" not in model_args:
245
+ model_args["input_vocab_size"] = model_args["vocab_size"]
246
+ model_args["output_vocab_size"] = model_args["vocab_size"]
247
+ del model_args["vocab_size"]
248
+ gptconf = ConfigClass(**checkpoint["model_args"])
249
+ model = ModelClass(gptconf)
250
+ state_dict = checkpoint["model"]
251
+ # fixup checkpoint
252
+ unwanted_prefix = "_orig_mod."
253
+ for k, v in list(state_dict.items()):
254
+ if k.startswith(unwanted_prefix):
255
+ state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k)
256
+ extra_keys = set(state_dict.keys()) - set(model.state_dict().keys())
257
+ extra_keys = set([k for k in extra_keys if not k.endswith(".attn.bias")])
258
+ missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
259
+ missing_keys = set([k for k in missing_keys if not k.endswith(".attn.bias")])
260
+ if len(extra_keys) != 0:
261
+ raise ValueError(f"extra keys found: {extra_keys}")
262
+ if len(missing_keys) != 0:
263
+ raise ValueError(f"missing keys: {missing_keys}")
264
+ model.load_state_dict(state_dict, strict=False)
265
+ n_params = model.get_num_params()
266
+ val_loss = checkpoint["best_val_loss"].item()
267
+ logger.info(f"model loaded: {round(n_params/1e6,1)}M params, {round(val_loss,3)} loss")
268
+ model.eval()
269
+ model.to(device)
270
+ del checkpoint, state_dict
271
+ _clear_cuda_cache()
272
+ if model_type == "text":
273
+ tokenizer = BertTokenizer.from_pretrained("bert-base-multilingual-cased")
274
+ return {
275
+ "model": model,
276
+ "tokenizer": tokenizer,
277
+ }
278
+ return model
279
+
280
+
281
+ def _load_codec_model(device):
282
+ model = EncodecModel.encodec_model_24khz()
283
+ model.set_target_bandwidth(6.0)
284
+ model.eval()
285
+ model.to(device)
286
+ _clear_cuda_cache()
287
+ return model
288
+
289
+
290
+ def load_model(use_gpu=True, use_small=False, force_reload=False, model_type="text"):
291
+ _load_model_f = funcy.partial(_load_model, model_type=model_type, use_small=use_small)
292
+ if model_type not in ("text", "coarse", "fine"):
293
+ raise NotImplementedError()
294
+ global models
295
+ global models_devices
296
+ device = grab_best_device(use_gpu=use_gpu)
297
+ model_key = f"{model_type}"
298
+ if OFFLOAD_CPU:
299
+ models_devices[model_key] = device
300
+ device = "cpu"
301
+ if model_key not in models or force_reload:
302
+ ckpt_path = _get_ckpt_path(model_type, use_small=use_small)
303
+ clean_models(model_key=model_key)
304
+ model = _load_model_f(ckpt_path, device)
305
+ models[model_key] = model
306
+ if model_type == "text":
307
+ models[model_key]["model"].to(device)
308
+ else:
309
+ models[model_key].to(device)
310
+ return models[model_key]
311
+
312
+
313
+ def load_codec_model(use_gpu=True, force_reload=False):
314
+ global models
315
+ global models_devices
316
+ device = grab_best_device(use_gpu=use_gpu)
317
+ if device == "mps":
318
+ # encodec doesn't support mps
319
+ device = "cpu"
320
+ model_key = "codec"
321
+ if OFFLOAD_CPU:
322
+ models_devices[model_key] = device
323
+ device = "cpu"
324
+ if model_key not in models or force_reload:
325
+ clean_models(model_key=model_key)
326
+ model = _load_codec_model(device)
327
+ models[model_key] = model
328
+ models[model_key].to(device)
329
+ return models[model_key]
330
+
331
+
332
+ def preload_models(
333
+ text_use_gpu=True,
334
+ text_use_small=False,
335
+ coarse_use_gpu=True,
336
+ coarse_use_small=False,
337
+ fine_use_gpu=True,
338
+ fine_use_small=False,
339
+ codec_use_gpu=True,
340
+ force_reload=False
341
+ ):
342
+ """Load all the necessary models for the pipeline."""
343
+ if grab_best_device() == "cpu" and (
344
+ text_use_gpu or coarse_use_gpu or fine_use_gpu or codec_use_gpu
345
+ ):
346
+ logger.warning("No GPU being used. Careful, inference might be very slow!")
347
+ _ = load_model(
348
+ model_type="text", use_gpu=text_use_gpu, use_small=text_use_small, force_reload=force_reload
349
+ )
350
+ _ = load_model(
351
+ model_type="coarse",
352
+ use_gpu=coarse_use_gpu,
353
+ use_small=coarse_use_small,
354
+ force_reload=force_reload,
355
+ )
356
+ _ = load_model(
357
+ model_type="fine", use_gpu=fine_use_gpu, use_small=fine_use_small, force_reload=force_reload
358
+ )
359
+ _ = load_codec_model(use_gpu=codec_use_gpu, force_reload=force_reload)
360
+
361
+
362
+ ####
363
+ # Generation Functionality
364
+ ####
365
+
366
+
367
+ def _tokenize(tokenizer, text):
368
+ return tokenizer.encode(text, add_special_tokens=False)
369
+
370
+
371
+ def _detokenize(tokenizer, enc_text):
372
+ return tokenizer.decode(enc_text)
373
+
374
+
375
+ def _normalize_whitespace(text):
376
+ return re.sub(r"\s+", " ", text).strip()
377
+
378
+
379
+ TEXT_ENCODING_OFFSET = 10_048
380
+ SEMANTIC_PAD_TOKEN = 10_000
381
+ TEXT_PAD_TOKEN = 129_595
382
+ SEMANTIC_INFER_TOKEN = 129_599
383
+
384
+
385
+ def _load_history_prompt(history_prompt_input):
386
+ if isinstance(history_prompt_input, str) and history_prompt_input.endswith(".npz"):
387
+ history_prompt = np.load(history_prompt_input)
388
+ elif isinstance(history_prompt_input, str):
389
+ # make sure this works on non-ubuntu
390
+ history_prompt_input = os.path.join(*history_prompt_input.split("/"))
391
+ # if history_prompt_input not in ALLOWED_PROMPTS:
392
+ # raise ValueError("history prompt not found")
393
+ history_prompt = np.load(
394
+ os.path.join(CUR_PATH, "assets", "prompts", f"{history_prompt_input}.npz")
395
+ )
396
+ elif isinstance(history_prompt_input, dict):
397
+ assert("semantic_prompt" in history_prompt_input)
398
+ assert("coarse_prompt" in history_prompt_input)
399
+ assert("fine_prompt" in history_prompt_input)
400
+ history_prompt = history_prompt_input
401
+ else:
402
+ raise ValueError("history prompt format unrecognized")
403
+ return history_prompt
404
+
405
+
406
+ def generate_text_semantic(
407
+ text,
408
+ history_prompt=None,
409
+ temp=0.7,
410
+ top_k=None,
411
+ top_p=None,
412
+ silent=False,
413
+ min_eos_p=0.2,
414
+ max_gen_duration_s=None,
415
+ allow_early_stop=True,
416
+ use_kv_caching=False,
417
+ ):
418
+ """Generate semantic tokens from text."""
419
+ assert isinstance(text, str)
420
+ text = _normalize_whitespace(text)
421
+ assert len(text.strip()) > 0
422
+ if history_prompt is not None:
423
+ history_prompt = _load_history_prompt(history_prompt)
424
+ semantic_history = history_prompt["semantic_prompt"]
425
+ assert (
426
+ isinstance(semantic_history, np.ndarray)
427
+ and len(semantic_history.shape) == 1
428
+ and len(semantic_history) > 0
429
+ and semantic_history.min() >= 0
430
+ and semantic_history.max() <= SEMANTIC_VOCAB_SIZE - 1
431
+ )
432
+ else:
433
+ semantic_history = None
434
+ # load models if not yet exist
435
+ global models
436
+ global models_devices
437
+ if "text" not in models:
438
+ preload_models()
439
+ model_container = models["text"]
440
+ model = model_container["model"]
441
+ tokenizer = model_container["tokenizer"]
442
+ encoded_text = np.array(_tokenize(tokenizer, text)) + TEXT_ENCODING_OFFSET
443
+ if OFFLOAD_CPU:
444
+ model.to(models_devices["text"])
445
+ device = next(model.parameters()).device
446
+ if len(encoded_text) > 256:
447
+ p = round((len(encoded_text) - 256) / len(encoded_text) * 100, 1)
448
+ logger.warning(f"warning, text too long, lopping of last {p}%")
449
+ encoded_text = encoded_text[:256]
450
+ encoded_text = np.pad(
451
+ encoded_text,
452
+ (0, 256 - len(encoded_text)),
453
+ constant_values=TEXT_PAD_TOKEN,
454
+ mode="constant",
455
+ )
456
+ if semantic_history is not None:
457
+ semantic_history = semantic_history.astype(np.int64)
458
+ # lop off if history is too long, pad if needed
459
+ semantic_history = semantic_history[-256:]
460
+ semantic_history = np.pad(
461
+ semantic_history,
462
+ (0, 256 - len(semantic_history)),
463
+ constant_values=SEMANTIC_PAD_TOKEN,
464
+ mode="constant",
465
+ )
466
+ else:
467
+ semantic_history = np.array([SEMANTIC_PAD_TOKEN] * 256)
468
+ x = torch.from_numpy(
469
+ np.hstack([
470
+ encoded_text, semantic_history, np.array([SEMANTIC_INFER_TOKEN])
471
+ ]).astype(np.int64)
472
+ )[None]
473
+ assert x.shape[1] == 256 + 256 + 1
474
+ with _inference_mode():
475
+ x = x.to(device)
476
+ n_tot_steps = 768
477
+ # custom tqdm updates since we don't know when eos will occur
478
+ pbar = tqdm.tqdm(disable=silent, total=100)
479
+ pbar_state = 0
480
+ tot_generated_duration_s = 0
481
+ kv_cache = None
482
+ for n in range(n_tot_steps):
483
+ if use_kv_caching and kv_cache is not None:
484
+ x_input = x[:, [-1]]
485
+ else:
486
+ x_input = x
487
+ logits, kv_cache = model(
488
+ x_input, merge_context=True, use_cache=use_kv_caching, past_kv=kv_cache
489
+ )
490
+ relevant_logits = logits[0, 0, :SEMANTIC_VOCAB_SIZE]
491
+ if allow_early_stop:
492
+ relevant_logits = torch.hstack(
493
+ (relevant_logits, logits[0, 0, [SEMANTIC_PAD_TOKEN]]) # eos
494
+ )
495
+ if top_p is not None:
496
+ # faster to convert to numpy
497
+ original_device = relevant_logits.device
498
+ relevant_logits = relevant_logits.detach().cpu().type(torch.float32).numpy()
499
+ sorted_indices = np.argsort(relevant_logits)[::-1]
500
+ sorted_logits = relevant_logits[sorted_indices]
501
+ cumulative_probs = np.cumsum(softmax(sorted_logits))
502
+ sorted_indices_to_remove = cumulative_probs > top_p
503
+ sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].copy()
504
+ sorted_indices_to_remove[0] = False
505
+ relevant_logits[sorted_indices[sorted_indices_to_remove]] = -np.inf
506
+ relevant_logits = torch.from_numpy(relevant_logits)
507
+ relevant_logits = relevant_logits.to(original_device)
508
+ if top_k is not None:
509
+ v, _ = torch.topk(relevant_logits, min(top_k, relevant_logits.size(-1)))
510
+ relevant_logits[relevant_logits < v[-1]] = -float("Inf")
511
+ probs = F.softmax(relevant_logits / temp, dim=-1)
512
+ # multinomial bugged on mps: shuttle to cpu if necessary
513
+ inf_device = probs.device
514
+ if probs.device.type == "mps":
515
+ probs = probs.to("cpu")
516
+ item_next = torch.multinomial(probs, num_samples=1)
517
+ probs = probs.to(inf_device)
518
+ item_next = item_next.to(inf_device)
519
+ if allow_early_stop and (
520
+ item_next == SEMANTIC_VOCAB_SIZE
521
+ or (min_eos_p is not None and probs[-1] >= min_eos_p)
522
+ ):
523
+ # eos found, so break
524
+ pbar.update(100 - pbar_state)
525
+ break
526
+ x = torch.cat((x, item_next[None]), dim=1)
527
+ tot_generated_duration_s += 1 / SEMANTIC_RATE_HZ
528
+ if max_gen_duration_s is not None and tot_generated_duration_s > max_gen_duration_s:
529
+ pbar.update(100 - pbar_state)
530
+ break
531
+ if n == n_tot_steps - 1:
532
+ pbar.update(100 - pbar_state)
533
+ break
534
+ del logits, relevant_logits, probs, item_next
535
+ req_pbar_state = np.min([100, int(round(100 * n / n_tot_steps))])
536
+ if req_pbar_state > pbar_state:
537
+ pbar.update(req_pbar_state - pbar_state)
538
+ pbar_state = req_pbar_state
539
+ pbar.close()
540
+ out = x.detach().cpu().numpy().squeeze()[256 + 256 + 1 :]
541
+ if OFFLOAD_CPU:
542
+ model.to("cpu")
543
+ assert all(0 <= out) and all(out < SEMANTIC_VOCAB_SIZE)
544
+ _clear_cuda_cache()
545
+ return out
546
+
547
+
548
+ def _flatten_codebooks(arr, offset_size=CODEBOOK_SIZE):
549
+ assert len(arr.shape) == 2
550
+ arr = arr.copy()
551
+ if offset_size is not None:
552
+ for n in range(1, arr.shape[0]):
553
+ arr[n, :] += offset_size * n
554
+ flat_arr = arr.ravel("F")
555
+ return flat_arr
556
+
557
+
558
+ COARSE_SEMANTIC_PAD_TOKEN = 12_048
559
+ COARSE_INFER_TOKEN = 12_050
560
+
561
+
562
+ def generate_coarse(
563
+ x_semantic,
564
+ history_prompt=None,
565
+ temp=0.7,
566
+ top_k=None,
567
+ top_p=None,
568
+ silent=False,
569
+ max_coarse_history=630, # min 60 (faster), max 630 (more context)
570
+ sliding_window_len=60,
571
+ use_kv_caching=False,
572
+ ):
573
+ """Generate coarse audio codes from semantic tokens."""
574
+ # CF: Uncommented because it breaks swap voice more than once
575
+ # assert (
576
+ # isinstance(x_semantic, np.ndarray)
577
+ # and len(x_semantic.shape) == 1
578
+ # and len(x_semantic) > 0
579
+ # and x_semantic.min() >= 0
580
+ # and x_semantic.max() <= SEMANTIC_VOCAB_SIZE - 1
581
+ # )
582
+ assert 60 <= max_coarse_history <= 630
583
+ assert max_coarse_history + sliding_window_len <= 1024 - 256
584
+ semantic_to_coarse_ratio = COARSE_RATE_HZ / SEMANTIC_RATE_HZ * N_COARSE_CODEBOOKS
585
+ max_semantic_history = int(np.floor(max_coarse_history / semantic_to_coarse_ratio))
586
+ if history_prompt is not None:
587
+ history_prompt = _load_history_prompt(history_prompt)
588
+ x_semantic_history = history_prompt["semantic_prompt"]
589
+ x_coarse_history = history_prompt["coarse_prompt"]
590
+ assert (
591
+ isinstance(x_semantic_history, np.ndarray)
592
+ and len(x_semantic_history.shape) == 1
593
+ and len(x_semantic_history) > 0
594
+ and x_semantic_history.min() >= 0
595
+ and x_semantic_history.max() <= SEMANTIC_VOCAB_SIZE - 1
596
+ and isinstance(x_coarse_history, np.ndarray)
597
+ and len(x_coarse_history.shape) == 2
598
+ and x_coarse_history.shape[0] == N_COARSE_CODEBOOKS
599
+ and x_coarse_history.shape[-1] >= 0
600
+ and x_coarse_history.min() >= 0
601
+ and x_coarse_history.max() <= CODEBOOK_SIZE - 1
602
+ #and (
603
+ # round(x_coarse_history.shape[-1] / len(x_semantic_history), 1)
604
+ # == round(semantic_to_coarse_ratio / N_COARSE_CODEBOOKS, 1)
605
+ #)
606
+ )
607
+ x_coarse_history = _flatten_codebooks(x_coarse_history) + SEMANTIC_VOCAB_SIZE
608
+ # trim histories correctly
609
+ n_semantic_hist_provided = np.min(
610
+ [
611
+ max_semantic_history,
612
+ len(x_semantic_history) - len(x_semantic_history) % 2,
613
+ int(np.floor(len(x_coarse_history) / semantic_to_coarse_ratio)),
614
+ ]
615
+ )
616
+ n_coarse_hist_provided = int(round(n_semantic_hist_provided * semantic_to_coarse_ratio))
617
+ x_semantic_history = x_semantic_history[-n_semantic_hist_provided:].astype(np.int32)
618
+ x_coarse_history = x_coarse_history[-n_coarse_hist_provided:].astype(np.int32)
619
+ # TODO: bit of a hack for time alignment (sounds better)
620
+ x_coarse_history = x_coarse_history[:-2]
621
+ else:
622
+ x_semantic_history = np.array([], dtype=np.int32)
623
+ x_coarse_history = np.array([], dtype=np.int32)
624
+ # load models if not yet exist
625
+ global models
626
+ global models_devices
627
+ if "coarse" not in models:
628
+ preload_models()
629
+ model = models["coarse"]
630
+ if OFFLOAD_CPU:
631
+ model.to(models_devices["coarse"])
632
+ device = next(model.parameters()).device
633
+ # start loop
634
+ n_steps = int(
635
+ round(
636
+ np.floor(len(x_semantic) * semantic_to_coarse_ratio / N_COARSE_CODEBOOKS)
637
+ * N_COARSE_CODEBOOKS
638
+ )
639
+ )
640
+ assert n_steps > 0 and n_steps % N_COARSE_CODEBOOKS == 0
641
+ x_semantic = np.hstack([x_semantic_history, x_semantic]).astype(np.int32)
642
+ x_coarse = x_coarse_history.astype(np.int32)
643
+ base_semantic_idx = len(x_semantic_history)
644
+ with _inference_mode():
645
+ x_semantic_in = torch.from_numpy(x_semantic)[None].to(device)
646
+ x_coarse_in = torch.from_numpy(x_coarse)[None].to(device)
647
+ n_window_steps = int(np.ceil(n_steps / sliding_window_len))
648
+ n_step = 0
649
+ for _ in tqdm.tqdm(range(n_window_steps), total=n_window_steps, disable=silent):
650
+ semantic_idx = base_semantic_idx + int(round(n_step / semantic_to_coarse_ratio))
651
+ # pad from right side
652
+ x_in = x_semantic_in[:, np.max([0, semantic_idx - max_semantic_history]) :]
653
+ x_in = x_in[:, :256]
654
+ x_in = F.pad(
655
+ x_in,
656
+ (0, 256 - x_in.shape[-1]),
657
+ "constant",
658
+ COARSE_SEMANTIC_PAD_TOKEN,
659
+ )
660
+ x_in = torch.hstack(
661
+ [
662
+ x_in,
663
+ torch.tensor([COARSE_INFER_TOKEN])[None].to(device),
664
+ x_coarse_in[:, -max_coarse_history:],
665
+ ]
666
+ )
667
+ kv_cache = None
668
+ for _ in range(sliding_window_len):
669
+ if n_step >= n_steps:
670
+ continue
671
+ is_major_step = n_step % N_COARSE_CODEBOOKS == 0
672
+
673
+ if use_kv_caching and kv_cache is not None:
674
+ x_input = x_in[:, [-1]]
675
+ else:
676
+ x_input = x_in
677
+
678
+ logits, kv_cache = model(x_input, use_cache=use_kv_caching, past_kv=kv_cache)
679
+ logit_start_idx = (
680
+ SEMANTIC_VOCAB_SIZE + (1 - int(is_major_step)) * CODEBOOK_SIZE
681
+ )
682
+ logit_end_idx = (
683
+ SEMANTIC_VOCAB_SIZE + (2 - int(is_major_step)) * CODEBOOK_SIZE
684
+ )
685
+ relevant_logits = logits[0, 0, logit_start_idx:logit_end_idx]
686
+ if top_p is not None:
687
+ # faster to convert to numpy
688
+ original_device = relevant_logits.device
689
+ relevant_logits = relevant_logits.detach().cpu().type(torch.float32).numpy()
690
+ sorted_indices = np.argsort(relevant_logits)[::-1]
691
+ sorted_logits = relevant_logits[sorted_indices]
692
+ cumulative_probs = np.cumsum(softmax(sorted_logits))
693
+ sorted_indices_to_remove = cumulative_probs > top_p
694
+ sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].copy()
695
+ sorted_indices_to_remove[0] = False
696
+ relevant_logits[sorted_indices[sorted_indices_to_remove]] = -np.inf
697
+ relevant_logits = torch.from_numpy(relevant_logits)
698
+ relevant_logits = relevant_logits.to(original_device)
699
+ if top_k is not None:
700
+ v, _ = torch.topk(relevant_logits, min(top_k, relevant_logits.size(-1)))
701
+ relevant_logits[relevant_logits < v[-1]] = -float("Inf")
702
+ probs = F.softmax(relevant_logits / temp, dim=-1)
703
+ # multinomial bugged on mps: shuttle to cpu if necessary
704
+ inf_device = probs.device
705
+ if probs.device.type == "mps":
706
+ probs = probs.to("cpu")
707
+ item_next = torch.multinomial(probs, num_samples=1)
708
+ probs = probs.to(inf_device)
709
+ item_next = item_next.to(inf_device)
710
+ item_next += logit_start_idx
711
+ x_coarse_in = torch.cat((x_coarse_in, item_next[None]), dim=1)
712
+ x_in = torch.cat((x_in, item_next[None]), dim=1)
713
+ del logits, relevant_logits, probs, item_next
714
+ n_step += 1
715
+ del x_in
716
+ del x_semantic_in
717
+ if OFFLOAD_CPU:
718
+ model.to("cpu")
719
+ gen_coarse_arr = x_coarse_in.detach().cpu().numpy().squeeze()[len(x_coarse_history) :]
720
+ del x_coarse_in
721
+ assert len(gen_coarse_arr) == n_steps
722
+ gen_coarse_audio_arr = gen_coarse_arr.reshape(-1, N_COARSE_CODEBOOKS).T - SEMANTIC_VOCAB_SIZE
723
+ for n in range(1, N_COARSE_CODEBOOKS):
724
+ gen_coarse_audio_arr[n, :] -= n * CODEBOOK_SIZE
725
+ _clear_cuda_cache()
726
+ return gen_coarse_audio_arr
727
+
728
+
729
+ def generate_fine(
730
+ x_coarse_gen,
731
+ history_prompt=None,
732
+ temp=0.5,
733
+ silent=True,
734
+ ):
735
+ """Generate full audio codes from coarse audio codes."""
736
+ assert (
737
+ isinstance(x_coarse_gen, np.ndarray)
738
+ and len(x_coarse_gen.shape) == 2
739
+ and 1 <= x_coarse_gen.shape[0] <= N_FINE_CODEBOOKS - 1
740
+ and x_coarse_gen.shape[1] > 0
741
+ and x_coarse_gen.min() >= 0
742
+ and x_coarse_gen.max() <= CODEBOOK_SIZE - 1
743
+ )
744
+ if history_prompt is not None:
745
+ history_prompt = _load_history_prompt(history_prompt)
746
+ x_fine_history = history_prompt["fine_prompt"]
747
+ assert (
748
+ isinstance(x_fine_history, np.ndarray)
749
+ and len(x_fine_history.shape) == 2
750
+ and x_fine_history.shape[0] == N_FINE_CODEBOOKS
751
+ and x_fine_history.shape[1] >= 0
752
+ and x_fine_history.min() >= 0
753
+ and x_fine_history.max() <= CODEBOOK_SIZE - 1
754
+ )
755
+ else:
756
+ x_fine_history = None
757
+ n_coarse = x_coarse_gen.shape[0]
758
+ # load models if not yet exist
759
+ global models
760
+ global models_devices
761
+ if "fine" not in models:
762
+ preload_models()
763
+ model = models["fine"]
764
+ if OFFLOAD_CPU:
765
+ model.to(models_devices["fine"])
766
+ device = next(model.parameters()).device
767
+ # make input arr
768
+ in_arr = np.vstack(
769
+ [
770
+ x_coarse_gen,
771
+ np.zeros((N_FINE_CODEBOOKS - n_coarse, x_coarse_gen.shape[1]))
772
+ + CODEBOOK_SIZE, # padding
773
+ ]
774
+ ).astype(np.int32)
775
+ # prepend history if available (max 512)
776
+ if x_fine_history is not None:
777
+ x_fine_history = x_fine_history.astype(np.int32)
778
+ in_arr = np.hstack(
779
+ [
780
+ x_fine_history[:, -512:].astype(np.int32),
781
+ in_arr,
782
+ ]
783
+ )
784
+ n_history = x_fine_history[:, -512:].shape[1]
785
+ else:
786
+ n_history = 0
787
+ n_remove_from_end = 0
788
+ # need to pad if too short (since non-causal model)
789
+ if in_arr.shape[1] < 1024:
790
+ n_remove_from_end = 1024 - in_arr.shape[1]
791
+ in_arr = np.hstack(
792
+ [
793
+ in_arr,
794
+ np.zeros((N_FINE_CODEBOOKS, n_remove_from_end), dtype=np.int32) + CODEBOOK_SIZE,
795
+ ]
796
+ )
797
+ # we can be lazy about fractional loop and just keep overwriting codebooks
798
+ n_loops = np.max([0, int(np.ceil((x_coarse_gen.shape[1] - (1024 - n_history)) / 512))]) + 1
799
+ with _inference_mode():
800
+ in_arr = torch.tensor(in_arr.T).to(device)
801
+ for n in tqdm.tqdm(range(n_loops), disable=silent):
802
+ start_idx = np.min([n * 512, in_arr.shape[0] - 1024])
803
+ start_fill_idx = np.min([n_history + n * 512, in_arr.shape[0] - 512])
804
+ rel_start_fill_idx = start_fill_idx - start_idx
805
+ in_buffer = in_arr[start_idx : start_idx + 1024, :][None]
806
+ for nn in range(n_coarse, N_FINE_CODEBOOKS):
807
+ logits = model(nn, in_buffer)
808
+ if temp is None:
809
+ relevant_logits = logits[0, rel_start_fill_idx:, :CODEBOOK_SIZE]
810
+ codebook_preds = torch.argmax(relevant_logits, -1)
811
+ else:
812
+ relevant_logits = logits[0, :, :CODEBOOK_SIZE] / temp
813
+ probs = F.softmax(relevant_logits, dim=-1)
814
+ # multinomial bugged on mps: shuttle to cpu if necessary
815
+ inf_device = probs.device
816
+ if probs.device.type == "mps":
817
+ probs = probs.to("cpu")
818
+ codebook_preds = torch.hstack(
819
+ [
820
+ torch.multinomial(probs[nnn], num_samples=1).to(inf_device)
821
+ for nnn in range(rel_start_fill_idx, 1024)
822
+ ]
823
+ )
824
+ in_buffer[0, rel_start_fill_idx:, nn] = codebook_preds
825
+ del logits, codebook_preds
826
+ # transfer over info into model_in and convert to numpy
827
+ for nn in range(n_coarse, N_FINE_CODEBOOKS):
828
+ in_arr[
829
+ start_fill_idx : start_fill_idx + (1024 - rel_start_fill_idx), nn
830
+ ] = in_buffer[0, rel_start_fill_idx:, nn]
831
+ del in_buffer
832
+ gen_fine_arr = in_arr.detach().cpu().numpy().squeeze().T
833
+ del in_arr
834
+ if OFFLOAD_CPU:
835
+ model.to("cpu")
836
+ gen_fine_arr = gen_fine_arr[:, n_history:]
837
+ if n_remove_from_end > 0:
838
+ gen_fine_arr = gen_fine_arr[:, :-n_remove_from_end]
839
+ assert gen_fine_arr.shape[-1] == x_coarse_gen.shape[-1]
840
+ _clear_cuda_cache()
841
+ return gen_fine_arr
842
+
843
+
844
+ def codec_decode(fine_tokens):
845
+ """Turn quantized audio codes into audio array using encodec."""
846
+ # load models if not yet exist
847
+ global models
848
+ global models_devices
849
+ if "codec" not in models:
850
+ preload_models()
851
+ model = models["codec"]
852
+ if OFFLOAD_CPU:
853
+ model.to(models_devices["codec"])
854
+ device = next(model.parameters()).device
855
+ arr = torch.from_numpy(fine_tokens)[None]
856
+ arr = arr.to(device)
857
+ arr = arr.transpose(0, 1)
858
+ emb = model.quantizer.decode(arr)
859
+ out = model.decoder(emb)
860
+ audio_arr = out.detach().cpu().numpy().squeeze()
861
+ del arr, emb, out
862
+ if OFFLOAD_CPU:
863
+ model.to("cpu")
864
+ return audio_arr
bark/hubert/__init__.py ADDED
File without changes
bark/hubert/customtokenizer.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Custom tokenizer model.
3
+ Author: https://www.github.com/gitmylo/
4
+ License: MIT
5
+ """
6
+
7
+ import json
8
+ import os.path
9
+ from zipfile import ZipFile
10
+
11
+ import numpy
12
+ import torch
13
+ from torch import nn, optim
14
+ from torch.serialization import MAP_LOCATION
15
+ from tqdm.auto import tqdm
16
+
17
+
18
+ class CustomTokenizer(nn.Module):
19
+ def __init__(self, hidden_size=1024, input_size=768, output_size=10000, version=0):
20
+ super(CustomTokenizer, self).__init__()
21
+ next_size = input_size
22
+ if version == 0:
23
+ self.lstm = nn.LSTM(input_size, hidden_size, 2, batch_first=True)
24
+ next_size = hidden_size
25
+ if version == 1:
26
+ self.lstm = nn.LSTM(input_size, hidden_size, 2, batch_first=True)
27
+ self.intermediate = nn.Linear(hidden_size, 4096)
28
+ next_size = 4096
29
+
30
+ self.fc = nn.Linear(next_size, output_size)
31
+ self.softmax = nn.LogSoftmax(dim=1)
32
+ self.optimizer: optim.Optimizer = None
33
+ self.lossfunc = nn.CrossEntropyLoss()
34
+ self.input_size = input_size
35
+ self.hidden_size = hidden_size
36
+ self.output_size = output_size
37
+ self.version = version
38
+
39
+ def forward(self, x):
40
+ x, _ = self.lstm(x)
41
+ if self.version == 1:
42
+ x = self.intermediate(x)
43
+ x = self.fc(x)
44
+ x = self.softmax(x)
45
+ return x
46
+
47
+ @torch.no_grad()
48
+ def get_token(self, x):
49
+ """
50
+ Used to get the token for the first
51
+ :param x: An array with shape (N, input_size) where N is a whole number greater or equal to 1, and input_size is the input size used when creating the model.
52
+ :return: An array with shape (N,) where N is the same as N from the input. Every number in the array is a whole number in range 0...output_size - 1 where output_size is the output size used when creating the model.
53
+ """
54
+ return torch.argmax(self(x), dim=1)
55
+
56
+ def prepare_training(self):
57
+ self.optimizer = optim.Adam(self.parameters(), 0.001)
58
+
59
+ def train_step(self, x_train, y_train, log_loss=False):
60
+ # y_train = y_train[:-1]
61
+ # y_train = y_train[1:]
62
+
63
+ optimizer = self.optimizer
64
+ lossfunc = self.lossfunc
65
+ # Zero the gradients
66
+ self.zero_grad()
67
+
68
+ # Forward pass
69
+ y_pred = self(x_train)
70
+
71
+ y_train_len = len(y_train)
72
+ y_pred_len = y_pred.shape[0]
73
+
74
+ if y_train_len > y_pred_len:
75
+ diff = y_train_len - y_pred_len
76
+ y_train = y_train[diff:]
77
+ elif y_train_len < y_pred_len:
78
+ diff = y_pred_len - y_train_len
79
+ y_pred = y_pred[:-diff, :]
80
+
81
+ y_train_hot = torch.zeros(len(y_train), self.output_size)
82
+ y_train_hot[range(len(y_train)), y_train] = 1
83
+ y_train_hot = y_train_hot.to('cuda')
84
+
85
+ # Calculate the loss
86
+ loss = lossfunc(y_pred, y_train_hot)
87
+
88
+ # Print loss
89
+ if log_loss:
90
+ print('Loss', loss.item())
91
+
92
+ # Backward pass
93
+ loss.backward()
94
+
95
+ # Update the weights
96
+ optimizer.step()
97
+
98
+ def save(self, path):
99
+ info_path = '.'.join(os.path.basename(path).split('.')[:-1]) + '/.info'
100
+ torch.save(self.state_dict(), path)
101
+ data_from_model = Data(self.input_size, self.hidden_size, self.output_size, self.version)
102
+ with ZipFile(path, 'a') as model_zip:
103
+ model_zip.writestr(info_path, data_from_model.save())
104
+ model_zip.close()
105
+
106
+ @staticmethod
107
+ def load_from_checkpoint(path, map_location: MAP_LOCATION = None):
108
+ old = True
109
+ with ZipFile(path) as model_zip:
110
+ filesMatch = [file for file in model_zip.namelist() if file.endswith('/.info')]
111
+ file = filesMatch[0] if filesMatch else None
112
+ if file:
113
+ old = False
114
+ print(f"Loading Custom Hubert Tokenizer {path}")
115
+ data_from_model = Data.load(model_zip.read(file).decode('utf-8'))
116
+ model_zip.close()
117
+ if old:
118
+ model = CustomTokenizer()
119
+ else:
120
+ model = CustomTokenizer(data_from_model.hidden_size, data_from_model.input_size, data_from_model.output_size, data_from_model.version)
121
+ model.load_state_dict(torch.load(path))
122
+ if map_location:
123
+ model = model.to(map_location)
124
+ return model
125
+
126
+
127
+
128
+ class Data:
129
+ input_size: int
130
+ hidden_size: int
131
+ output_size: int
132
+ version: int
133
+
134
+ def __init__(self, input_size=768, hidden_size=1024, output_size=10000, version=0):
135
+ self.input_size = input_size
136
+ self.hidden_size = hidden_size
137
+ self.output_size = output_size
138
+ self.version = version
139
+
140
+ @staticmethod
141
+ def load(string):
142
+ data = json.loads(string)
143
+ return Data(data['input_size'], data['hidden_size'], data['output_size'], data['version'])
144
+
145
+ def save(self):
146
+ data = {
147
+ 'input_size': self.input_size,
148
+ 'hidden_size': self.hidden_size,
149
+ 'output_size': self.output_size,
150
+ 'version': self.version,
151
+ }
152
+ return json.dumps(data)
153
+
154
+
155
+ def auto_train(data_path, save_path='model.pth', load_model: str | None = None, save_epochs=1, max_epochs=14):
156
+ data_x, data_y = [], []
157
+
158
+ if load_model and os.path.isfile(load_model):
159
+ print('Loading model from', load_model)
160
+ model_training = CustomTokenizer.load_from_checkpoint(load_model, 'cuda')
161
+ else:
162
+ print('Creating new model.')
163
+ model_training = CustomTokenizer(version=1).to('cuda') # Settings for the model to run without lstm
164
+ save_path = os.path.join(data_path, save_path)
165
+ base_save_path = '.'.join(save_path.split('.')[:-1])
166
+
167
+ sem_string = '_semantic.npy'
168
+ feat_string = '_semantic_features.npy'
169
+
170
+ ready = os.path.join(data_path, 'ready')
171
+ for input_file in os.listdir(ready):
172
+ full_path = os.path.join(ready, input_file)
173
+ if input_file.endswith(sem_string):
174
+ data_y.append(numpy.load(full_path))
175
+ elif input_file.endswith(feat_string):
176
+ data_x.append(numpy.load(full_path))
177
+ model_training.prepare_training()
178
+
179
+ epoch = 1
180
+ with tqdm(total=((len(data_x) * len(data_y)) / 50) * save_epochs) as pbar1:
181
+ while epoch <= max_epochs:
182
+ for i in range(save_epochs):
183
+ j = 0
184
+ for x, y in zip(data_x, data_y):
185
+ model_training.train_step(torch.tensor(x).to('cuda'), torch.tensor(y).to('cuda'), j % 50 == 0) # Print loss every 50 steps
186
+ j += 1
187
+ pbar1.update()
188
+
189
+ save_p = save_path
190
+ save_p_2 = f'{base_save_path}_epoch_{epoch}.pth'
191
+ model_training.save(save_p)
192
+ model_training.save(save_p_2)
193
+ print(f'Epoch {epoch} completed')
194
+ epoch += 1
195
+ print(f'Done training for {max_epochs} epochs!')
bark/hubert/hubert_manager.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path
2
+ import shutil
3
+ import urllib.request
4
+
5
+ import huggingface_hub
6
+
7
+
8
+ class HuBERTManager:
9
+
10
+
11
+ @staticmethod
12
+ def make_sure_hubert_installed(download_url: str = 'https://dl.fbaipublicfiles.com/hubert/hubert_base_ls960.pt', file_name: str = 'hubert.pt'):
13
+ install_dir = os.path.join('models', 'hubert')
14
+ if not os.path.isdir(install_dir):
15
+ os.makedirs(install_dir, exist_ok=True)
16
+ install_file = os.path.join(install_dir, file_name)
17
+ if not os.path.isfile(install_file):
18
+ print(f'Downloading HuBERT base model from {download_url}')
19
+ urllib.request.urlretrieve(download_url, install_file)
20
+ print('Downloaded HuBERT')
21
+ return install_file
22
+
23
+
24
+ @staticmethod
25
+ def make_sure_tokenizer_installed(model: str = 'quantifier_hubert_base_ls960_14.pth', repo: str = 'GitMylo/bark-voice-cloning', tokenizer_lang: str = 'en'):
26
+ local_file = tokenizer_lang + '_tokenizer.pth'
27
+ install_dir = os.path.join('models', 'hubert')
28
+ if not os.path.isdir(install_dir):
29
+ os.makedirs(install_dir, exist_ok=True)
30
+ install_file = os.path.join(install_dir, local_file)
31
+ if not os.path.isfile(install_file):
32
+ # refactor to use lists
33
+ if tokenizer_lang == 'en':
34
+ repo = 'GitMylo/bark-voice-cloning'
35
+ model = 'quantifier_hubert_base_ls960_14.pth'
36
+ elif tokenizer_lang == 'de':
37
+ repo = 'CountFloyd/bark-voice-cloning-german-HuBERT-quantizer'
38
+ model = 'german-HuBERT-quantizer_14_epoch.pth'
39
+ elif tokenizer_lang == 'pl':
40
+ repo = 'Hobis/bark-voice-cloning-polish-HuBERT-quantizer'
41
+ model = 'polish-HuBERT-quantizer_8_epoch.pth'
42
+ else:
43
+ raise 'Unknown Tokenizer Language!'
44
+ print(f'{local_file} not found. Downloading HuBERT custom tokenizer')
45
+ huggingface_hub.hf_hub_download(repo, model, local_dir=install_dir, local_dir_use_symlinks=False)
46
+ shutil.move(os.path.join(install_dir, model), install_file)
47
+ print('Downloaded tokenizer')
48
+ return install_file
bark/hubert/pre_kmeans_hubert.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Modified HuBERT model without kmeans.
3
+ Original author: https://github.com/lucidrains/
4
+ Modified by: https://www.github.com/gitmylo/
5
+ License: MIT
6
+ """
7
+
8
+ # Modified code from https://github.com/lucidrains/audiolm-pytorch/blob/main/audiolm_pytorch/hubert_kmeans.py
9
+
10
+ from pathlib import Path
11
+
12
+ import torch
13
+ from torch import nn
14
+ from einops import pack, unpack
15
+
16
+ import fairseq
17
+
18
+ from torchaudio.functional import resample
19
+
20
+ from audiolm_pytorch.utils import curtail_to_multiple
21
+
22
+ import logging
23
+ logging.root.setLevel(logging.ERROR)
24
+
25
+
26
+ def exists(val):
27
+ return val is not None
28
+
29
+
30
+ def default(val, d):
31
+ return val if exists(val) else d
32
+
33
+
34
+ class CustomHubert(nn.Module):
35
+ """
36
+ checkpoint and kmeans can be downloaded at https://github.com/facebookresearch/fairseq/tree/main/examples/hubert
37
+ or you can train your own
38
+ """
39
+
40
+ def __init__(
41
+ self,
42
+ checkpoint_path,
43
+ target_sample_hz=16000,
44
+ seq_len_multiple_of=None,
45
+ output_layer=9,
46
+ device=None
47
+ ):
48
+ super().__init__()
49
+ self.target_sample_hz = target_sample_hz
50
+ self.seq_len_multiple_of = seq_len_multiple_of
51
+ self.output_layer = output_layer
52
+
53
+ if device is not None:
54
+ self.to(device)
55
+
56
+ model_path = Path(checkpoint_path)
57
+
58
+ assert model_path.exists(), f'path {checkpoint_path} does not exist'
59
+
60
+ print(f"Loading Hubert {checkpoint_path}")
61
+ checkpoint = torch.load(checkpoint_path)
62
+ load_model_input = {checkpoint_path: checkpoint}
63
+ model, *_ = fairseq.checkpoint_utils.load_model_ensemble_and_task(load_model_input)
64
+
65
+ if device is not None:
66
+ model[0].to(device)
67
+
68
+ self.model = model[0]
69
+ self.model.eval()
70
+
71
+ @property
72
+ def groups(self):
73
+ return 1
74
+
75
+ @torch.no_grad()
76
+ def forward(
77
+ self,
78
+ wav_input,
79
+ flatten=True,
80
+ input_sample_hz=None
81
+ ):
82
+ device = wav_input.device
83
+
84
+ if exists(input_sample_hz):
85
+ wav_input = resample(wav_input, input_sample_hz, self.target_sample_hz)
86
+
87
+ if exists(self.seq_len_multiple_of):
88
+ wav_input = curtail_to_multiple(wav_input, self.seq_len_multiple_of)
89
+
90
+ embed = self.model(
91
+ wav_input,
92
+ features_only=True,
93
+ mask=False, # thanks to @maitycyrus for noticing that mask is defaulted to True in the fairseq code
94
+ output_layer=self.output_layer
95
+ )
96
+
97
+ embed, packed_shape = pack([embed['x']], '* d')
98
+
99
+ # codebook_indices = self.kmeans.predict(embed.cpu().detach().numpy())
100
+
101
+ codebook_indices = torch.from_numpy(embed.cpu().detach().numpy()).to(device) # .long()
102
+
103
+ if flatten:
104
+ return codebook_indices
105
+
106
+ codebook_indices, = unpack(codebook_indices, packed_shape, '*')
107
+ return codebook_indices
bark/model.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Much of this code is adapted from Andrej Karpathy's NanoGPT
3
+ (https://github.com/karpathy/nanoGPT)
4
+ """
5
+ import math
6
+ from dataclasses import dataclass
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from torch.nn import functional as F
11
+
12
+ class LayerNorm(nn.Module):
13
+ """ LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """
14
+
15
+ def __init__(self, ndim, bias):
16
+ super().__init__()
17
+ self.weight = nn.Parameter(torch.ones(ndim))
18
+ self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
19
+
20
+ def forward(self, input):
21
+ return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
22
+
23
+ class CausalSelfAttention(nn.Module):
24
+
25
+ def __init__(self, config):
26
+ super().__init__()
27
+ assert config.n_embd % config.n_head == 0
28
+ # key, query, value projections for all heads, but in a batch
29
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
30
+ # output projection
31
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
32
+ # regularization
33
+ self.attn_dropout = nn.Dropout(config.dropout)
34
+ self.resid_dropout = nn.Dropout(config.dropout)
35
+ self.n_head = config.n_head
36
+ self.n_embd = config.n_embd
37
+ self.dropout = config.dropout
38
+ # flash attention make GPU go brrrrr but support is only in PyTorch nightly and still a bit scary
39
+ self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
40
+ if not self.flash:
41
+ # print("WARNING: using slow attention. Flash Attention atm needs PyTorch nightly and dropout=0.0")
42
+ # causal mask to ensure that attention is only applied to the left in the input sequence
43
+ self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
44
+ .view(1, 1, config.block_size, config.block_size))
45
+
46
+ def forward(self, x, past_kv=None, use_cache=False):
47
+ B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
48
+
49
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
50
+ q, k ,v = self.c_attn(x).split(self.n_embd, dim=2)
51
+ k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
52
+ q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
53
+ v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
54
+
55
+ if past_kv is not None:
56
+ past_key = past_kv[0]
57
+ past_value = past_kv[1]
58
+ k = torch.cat((past_key, k), dim=-2)
59
+ v = torch.cat((past_value, v), dim=-2)
60
+
61
+ FULL_T = k.shape[-2]
62
+
63
+ if use_cache is True:
64
+ present = (k, v)
65
+ else:
66
+ present = None
67
+
68
+ # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
69
+ if self.flash:
70
+ # efficient attention using Flash Attention CUDA kernels
71
+ if past_kv is not None:
72
+ # When `past_kv` is provided, we're doing incremental decoding and `q.shape[2] == 1`: q only contains
73
+ # the query for the last token. scaled_dot_product_attention interprets this as the first token in the
74
+ # sequence, so if is_causal=True it will mask out all attention from it. This is not what we want, so
75
+ # to work around this we set is_causal=False.
76
+ is_causal = False
77
+ else:
78
+ is_causal = True
79
+
80
+ y = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=self.dropout, is_causal=is_causal)
81
+ else:
82
+ # manual implementation of attention
83
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
84
+ att = att.masked_fill(self.bias[:,:,FULL_T-T:FULL_T,:FULL_T] == 0, float('-inf'))
85
+ att = F.softmax(att, dim=-1)
86
+ att = self.attn_dropout(att)
87
+ y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
88
+ y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
89
+
90
+ # output projection
91
+ y = self.resid_dropout(self.c_proj(y))
92
+ return (y, present)
93
+
94
+ class MLP(nn.Module):
95
+
96
+ def __init__(self, config):
97
+ super().__init__()
98
+ self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
99
+ self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
100
+ self.dropout = nn.Dropout(config.dropout)
101
+ self.gelu = nn.GELU()
102
+
103
+ def forward(self, x):
104
+ x = self.c_fc(x)
105
+ x = self.gelu(x)
106
+ x = self.c_proj(x)
107
+ x = self.dropout(x)
108
+ return x
109
+
110
+ class Block(nn.Module):
111
+
112
+ def __init__(self, config, layer_idx):
113
+ super().__init__()
114
+ self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
115
+ self.attn = CausalSelfAttention(config)
116
+ self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
117
+ self.mlp = MLP(config)
118
+ self.layer_idx = layer_idx
119
+
120
+ def forward(self, x, past_kv=None, use_cache=False):
121
+ attn_output, prev_kvs = self.attn(self.ln_1(x), past_kv=past_kv, use_cache=use_cache)
122
+ x = x + attn_output
123
+ x = x + self.mlp(self.ln_2(x))
124
+ return (x, prev_kvs)
125
+
126
+ @dataclass
127
+ class GPTConfig:
128
+ block_size: int = 1024
129
+ input_vocab_size: int = 10_048
130
+ output_vocab_size: int = 10_048
131
+ n_layer: int = 12
132
+ n_head: int = 12
133
+ n_embd: int = 768
134
+ dropout: float = 0.0
135
+ bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
136
+
137
+ class GPT(nn.Module):
138
+
139
+ def __init__(self, config):
140
+ super().__init__()
141
+ assert config.input_vocab_size is not None
142
+ assert config.output_vocab_size is not None
143
+ assert config.block_size is not None
144
+ self.config = config
145
+
146
+ self.transformer = nn.ModuleDict(dict(
147
+ wte = nn.Embedding(config.input_vocab_size, config.n_embd),
148
+ wpe = nn.Embedding(config.block_size, config.n_embd),
149
+ drop = nn.Dropout(config.dropout),
150
+ h = nn.ModuleList([Block(config, idx) for idx in range(config.n_layer)]),
151
+ ln_f = LayerNorm(config.n_embd, bias=config.bias),
152
+ ))
153
+ self.lm_head = nn.Linear(config.n_embd, config.output_vocab_size, bias=False)
154
+
155
+ def get_num_params(self, non_embedding=True):
156
+ """
157
+ Return the number of parameters in the model.
158
+ For non-embedding count (default), the position embeddings get subtracted.
159
+ The token embeddings would too, except due to the parameter sharing these
160
+ params are actually used as weights in the final layer, so we include them.
161
+ """
162
+ n_params = sum(p.numel() for p in self.parameters())
163
+ if non_embedding:
164
+ n_params -= self.transformer.wte.weight.numel()
165
+ n_params -= self.transformer.wpe.weight.numel()
166
+ return n_params
167
+
168
+ def forward(self, idx, merge_context=False, past_kv=None, position_ids=None, use_cache=False):
169
+ device = idx.device
170
+ b, t = idx.size()
171
+ if past_kv is not None:
172
+ assert t == 1
173
+ tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
174
+ else:
175
+ if merge_context:
176
+ assert(idx.shape[1] >= 256+256+1)
177
+ t = idx.shape[1] - 256
178
+ else:
179
+ assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
180
+
181
+ # forward the GPT model itself
182
+ if merge_context:
183
+ tok_emb = torch.cat([
184
+ self.transformer.wte(idx[:,:256]) + self.transformer.wte(idx[:,256:256+256]),
185
+ self.transformer.wte(idx[:,256+256:])
186
+ ], dim=1)
187
+ else:
188
+ tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
189
+
190
+ if past_kv is None:
191
+ past_length = 0
192
+ past_kv = tuple([None] * len(self.transformer.h))
193
+ else:
194
+ past_length = past_kv[0][0].size(-2)
195
+
196
+ if position_ids is None:
197
+ position_ids = torch.arange(past_length, t + past_length, dtype=torch.long, device=device)
198
+ position_ids = position_ids.unsqueeze(0) # shape (1, t)
199
+ assert position_ids.shape == (1, t)
200
+
201
+ pos_emb = self.transformer.wpe(position_ids) # position embeddings of shape (1, t, n_embd)
202
+
203
+ x = self.transformer.drop(tok_emb + pos_emb)
204
+
205
+ new_kv = () if use_cache else None
206
+
207
+ for i, (block, past_layer_kv) in enumerate(zip(self.transformer.h, past_kv)):
208
+ x, kv = block(x, past_kv=past_layer_kv, use_cache=use_cache)
209
+
210
+ if use_cache:
211
+ new_kv = new_kv + (kv,)
212
+
213
+ x = self.transformer.ln_f(x)
214
+
215
+ # inference-time mini-optimization: only forward the lm_head on the very last position
216
+ logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
217
+
218
+ return (logits, new_kv)
bark/model_fine.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Much of this code is adapted from Andrej Karpathy's NanoGPT
3
+ (https://github.com/karpathy/nanoGPT)
4
+ """
5
+ from dataclasses import dataclass
6
+ import math
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from torch.nn import functional as F
11
+
12
+ from .model import GPT, GPTConfig, MLP
13
+
14
+
15
+ class NonCausalSelfAttention(nn.Module):
16
+ def __init__(self, config):
17
+ super().__init__()
18
+ assert config.n_embd % config.n_head == 0
19
+ # key, query, value projections for all heads, but in a batch
20
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
21
+ # output projection
22
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
23
+ # regularization
24
+ self.attn_dropout = nn.Dropout(config.dropout)
25
+ self.resid_dropout = nn.Dropout(config.dropout)
26
+ self.n_head = config.n_head
27
+ self.n_embd = config.n_embd
28
+ self.dropout = config.dropout
29
+ # flash attention make GPU go brrrrr but support is only in PyTorch nightly and still a bit scary
30
+ self.flash = (
31
+ hasattr(torch.nn.functional, "scaled_dot_product_attention") and self.dropout == 0.0
32
+ )
33
+
34
+ def forward(self, x):
35
+ B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
36
+
37
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
38
+ q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
39
+ k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
40
+ q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
41
+ v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
42
+
43
+ # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
44
+ if self.flash:
45
+ # efficient attention using Flash Attention CUDA kernels
46
+ y = torch.nn.functional.scaled_dot_product_attention(
47
+ q, k, v, attn_mask=None, dropout_p=self.dropout, is_causal=False
48
+ )
49
+ else:
50
+ # manual implementation of attention
51
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
52
+ att = F.softmax(att, dim=-1)
53
+ att = self.attn_dropout(att)
54
+ y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
55
+ y = (
56
+ y.transpose(1, 2).contiguous().view(B, T, C)
57
+ ) # re-assemble all head outputs side by side
58
+
59
+ # output projection
60
+ y = self.resid_dropout(self.c_proj(y))
61
+ return y
62
+
63
+
64
+ class FineBlock(nn.Module):
65
+ def __init__(self, config):
66
+ super().__init__()
67
+ self.ln_1 = nn.LayerNorm(config.n_embd)
68
+ self.attn = NonCausalSelfAttention(config)
69
+ self.ln_2 = nn.LayerNorm(config.n_embd)
70
+ self.mlp = MLP(config)
71
+
72
+ def forward(self, x):
73
+ x = x + self.attn(self.ln_1(x))
74
+ x = x + self.mlp(self.ln_2(x))
75
+ return x
76
+
77
+
78
+ class FineGPT(GPT):
79
+ def __init__(self, config):
80
+ super().__init__(config)
81
+ del self.lm_head
82
+ self.config = config
83
+ self.n_codes_total = config.n_codes_total
84
+ self.transformer = nn.ModuleDict(
85
+ dict(
86
+ wtes=nn.ModuleList(
87
+ [
88
+ nn.Embedding(config.input_vocab_size, config.n_embd)
89
+ for _ in range(config.n_codes_total)
90
+ ]
91
+ ),
92
+ wpe=nn.Embedding(config.block_size, config.n_embd),
93
+ drop=nn.Dropout(config.dropout),
94
+ h=nn.ModuleList([FineBlock(config) for _ in range(config.n_layer)]),
95
+ ln_f=nn.LayerNorm(config.n_embd),
96
+ )
97
+ )
98
+ self.lm_heads = nn.ModuleList(
99
+ [
100
+ nn.Linear(config.n_embd, config.output_vocab_size, bias=False)
101
+ for _ in range(config.n_codes_given, self.n_codes_total)
102
+ ]
103
+ )
104
+ for i in range(self.n_codes_total - config.n_codes_given):
105
+ self.transformer.wtes[i + 1].weight = self.lm_heads[i].weight
106
+
107
+ def forward(self, pred_idx, idx):
108
+ device = idx.device
109
+ b, t, codes = idx.size()
110
+ assert (
111
+ t <= self.config.block_size
112
+ ), f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
113
+ assert pred_idx > 0, "cannot predict 0th codebook"
114
+ assert codes == self.n_codes_total, (b, t, codes)
115
+ pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t)
116
+
117
+ # forward the GPT model itself
118
+ tok_embs = [
119
+ wte(idx[:, :, i]).unsqueeze(-1) for i, wte in enumerate(self.transformer.wtes)
120
+ ] # token embeddings of shape (b, t, n_embd)
121
+ tok_emb = torch.cat(tok_embs, dim=-1)
122
+ pos_emb = self.transformer.wpe(pos) # position embeddings of shape (1, t, n_embd)
123
+ x = tok_emb[:, :, :, : pred_idx + 1].sum(dim=-1)
124
+ x = self.transformer.drop(x + pos_emb)
125
+ for block in self.transformer.h:
126
+ x = block(x)
127
+ x = self.transformer.ln_f(x)
128
+ logits = self.lm_heads[pred_idx - self.config.n_codes_given](x)
129
+ return logits
130
+
131
+ def get_num_params(self, non_embedding=True):
132
+ """
133
+ Return the number of parameters in the model.
134
+ For non-embedding count (default), the position embeddings get subtracted.
135
+ The token embeddings would too, except due to the parameter sharing these
136
+ params are actually used as weights in the final layer, so we include them.
137
+ """
138
+ n_params = sum(p.numel() for p in self.parameters())
139
+ if non_embedding:
140
+ for wte in self.transformer.wtes:
141
+ n_params -= wte.weight.numel()
142
+ n_params -= self.transformer.wpe.weight.numel()
143
+ return n_params
144
+
145
+
146
+ @dataclass
147
+ class FineGPTConfig(GPTConfig):
148
+ n_codes_total: int = 8
149
+ n_codes_given: int = 1
bark/settings.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ def initenv(args):
4
+ os.environ['SUNO_USE_SMALL_MODELS'] = str("-smallmodels" in args)
5
+ os.environ['BARK_FORCE_CPU'] = str("-forcecpu" in args)
6
+ os.environ['SUNO_ENABLE_MPS'] = str("-enablemps" in args)
7
+ os.environ['SUNO_OFFLOAD_CPU'] = str("-offloadcpu" in args)
cloning/__init__.py ADDED
File without changes
cloning/clonevoice.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from bark.generation import load_codec_model, generate_text_semantic, grab_best_device
2
+ from encodec.utils import convert_audio
3
+ from bark.hubert.hubert_manager import HuBERTManager
4
+ from bark.hubert.pre_kmeans_hubert import CustomHubert
5
+ from bark.hubert.customtokenizer import CustomTokenizer
6
+
7
+ import torchaudio
8
+ import torch
9
+ import os
10
+ import gradio
11
+
12
+
13
+ def clone_voice(audio_filepath, tokenizer_lang, dest_filename, progress=gradio.Progress(track_tqdm=True)):
14
+ # if len(text) < 1:
15
+ # raise gradio.Error('No transcription text entered!')
16
+
17
+ use_gpu = not os.environ.get("BARK_FORCE_CPU", False)
18
+ progress(0, desc="Loading Codec")
19
+ model = load_codec_model(use_gpu=use_gpu)
20
+
21
+ # From https://github.com/gitmylo/bark-voice-cloning-HuBERT-quantizer
22
+ hubert_manager = HuBERTManager()
23
+ hubert_manager.make_sure_hubert_installed()
24
+ hubert_manager.make_sure_tokenizer_installed(tokenizer_lang=tokenizer_lang)
25
+
26
+ # From https://github.com/gitmylo/bark-voice-cloning-HuBERT-quantizer
27
+ # Load HuBERT for semantic tokens
28
+
29
+ # Load the HuBERT model
30
+ device = grab_best_device(use_gpu)
31
+ hubert_model = CustomHubert(checkpoint_path='./models/hubert/hubert.pt').to(device)
32
+
33
+ # Load the CustomTokenizer model
34
+ tokenizer = CustomTokenizer.load_from_checkpoint(f'./models/hubert/{tokenizer_lang}_tokenizer.pth').to(device) # Automatically uses the right layers
35
+
36
+ progress(0.25, desc="Converting WAV")
37
+
38
+ # Load and pre-process the audio waveform
39
+ wav, sr = torchaudio.load(audio_filepath)
40
+ if wav.shape[0] == 2: # Stereo to mono if needed
41
+ wav = wav.mean(0, keepdim=True)
42
+
43
+ wav = convert_audio(wav, sr, model.sample_rate, model.channels)
44
+ wav = wav.to(device)
45
+ progress(0.5, desc="Extracting codes")
46
+
47
+ semantic_vectors = hubert_model.forward(wav, input_sample_hz=model.sample_rate)
48
+ semantic_tokens = tokenizer.get_token(semantic_vectors)
49
+
50
+ # Extract discrete codes from EnCodec
51
+ with torch.no_grad():
52
+ encoded_frames = model.encode(wav.unsqueeze(0))
53
+ codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1).squeeze() # [n_q, T]
54
+
55
+ # get seconds of audio
56
+ # seconds = wav.shape[-1] / model.sample_rate
57
+ # generate semantic tokens
58
+ # semantic_tokens = generate_text_semantic(text, max_gen_duration_s=seconds, top_k=50, top_p=.95, temp=0.7)
59
+
60
+ # move codes to cpu
61
+ codes = codes.cpu().numpy()
62
+ # move semantic tokens to cpu
63
+ semantic_tokens = semantic_tokens.cpu().numpy()
64
+
65
+ import numpy as np
66
+ output_path = dest_filename + '.npz'
67
+ np.savez(output_path, fine_prompt=codes, coarse_prompt=codes[:2, :], semantic_prompt=semantic_tokens)
68
+ return ["Finished", output_path]
config.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ input_text_desired_length: 110
2
+ input_text_max_length: 170
3
+ selected_theme: JohnSmith9982/small_and_pretty
4
+ server_name: ''
5
+ server_port: 0
6
+ server_share: false
7
+ silence_between_sentences: 250
8
+ silence_between_speakers: 500
pyproject.toml ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "bark-ui-enhanced"
7
+ version = "0.7.0"
8
+ description = "Bark text to audio model with addition features and a Web UI"
9
+ readme = "README.md"
10
+ requires-python = ">=3.8"
11
+ authors = [
12
+ {name = "Suno Inc (original Bark)", email = "hello@suno.ai"},
13
+ {name = "Count Floyd"},
14
+ ]
15
+ # MIT License
16
+ license = {file = "LICENSE"}
17
+
18
+ dependencies = [
19
+ "boto3",
20
+ "encodec",
21
+ "funcy",
22
+ "huggingface-hub>=0.14.1",
23
+ "numpy",
24
+ "scipy",
25
+ "tokenizers",
26
+ "torch",
27
+ "tqdm",
28
+ "transformers",
29
+ ]
30
+
31
+ [project.urls]
32
+ source = "https://github.com/C0untFloyd/bark-gui"
33
+
34
+ [project.optional-dependencies]
35
+ dev = [
36
+ "bandit",
37
+ "black",
38
+ "codecov",
39
+ "flake8",
40
+ "hypothesis>=6.14,<7",
41
+ "isort>=5.0.0,<6",
42
+ "jupyter",
43
+ "mypy",
44
+ "nbconvert",
45
+ "nbformat",
46
+ "pydocstyle",
47
+ "pylint",
48
+ "pytest",
49
+ "pytest-cov",
50
+ ]
51
+
52
+ [tool.setuptools]
53
+ packages = ["bark"]
54
+
55
+ [tool.setuptools.package-data]
56
+ bark = ["assets/prompts/*.npz", "assets/prompts/v2/*.npz"]
57
+
58
+
59
+ [tool.black]
60
+ line-length = 100
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ fairseq; platform_system != "Windows"
2
+ fairseq@https://github.com/Sharrnah/fairseq/releases/download/v0.12.4/fairseq-0.12.4-cp310-cp310-win_amd64.whl; platform_system == "Windows"
3
+ audiolm-pytorch
4
+ gradio
5
+ funcy
6
+ linkify
7
+ mutagen
8
+ pytorch_seed
9
+ pyyaml
10
+ sentencepiece
11
+ soundfile; platform_system == "Windows"
12
+ sox; platform_system != "Windows"
13
+ transformers
setup.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from setuptools import setup
2
+
3
+ setup()
swap_voice.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from bark.generation import load_codec_model, generate_text_semantic, grab_best_device
2
+ from bark import SAMPLE_RATE
3
+ from encodec.utils import convert_audio
4
+ from bark.hubert.hubert_manager import HuBERTManager
5
+ from bark.hubert.pre_kmeans_hubert import CustomHubert
6
+ from bark.hubert.customtokenizer import CustomTokenizer
7
+ from bark.api import semantic_to_waveform
8
+ from scipy.io.wavfile import write as write_wav
9
+ from util.helper import create_filename
10
+ from util.settings import Settings
11
+
12
+
13
+ import torchaudio
14
+ import torch
15
+ import os
16
+ import gradio
17
+
18
+ def swap_voice_from_audio(swap_audio_filename, selected_speaker, tokenizer_lang, seed, batchcount, progress=gradio.Progress(track_tqdm=True)):
19
+ use_gpu = not os.environ.get("BARK_FORCE_CPU", False)
20
+ progress(0, desc="Loading Codec")
21
+
22
+ # From https://github.com/gitmylo/bark-voice-cloning-HuBERT-quantizer
23
+ hubert_manager = HuBERTManager()
24
+ hubert_manager.make_sure_hubert_installed()
25
+ hubert_manager.make_sure_tokenizer_installed(tokenizer_lang=tokenizer_lang)
26
+
27
+ # From https://github.com/gitmylo/bark-voice-cloning-HuBERT-quantizer
28
+ # Load HuBERT for semantic tokens
29
+
30
+ # Load the HuBERT model
31
+ device = grab_best_device(use_gpu)
32
+ hubert_model = CustomHubert(checkpoint_path='./models/hubert/hubert.pt').to(device)
33
+ model = load_codec_model(use_gpu=use_gpu)
34
+
35
+ # Load the CustomTokenizer model
36
+ tokenizer = CustomTokenizer.load_from_checkpoint(f'./models/hubert/{tokenizer_lang}_tokenizer.pth').to(device) # Automatically uses the right layers
37
+
38
+ progress(0.25, desc="Converting WAV")
39
+
40
+ # Load and pre-process the audio waveform
41
+ wav, sr = torchaudio.load(swap_audio_filename)
42
+ if wav.shape[0] == 2: # Stereo to mono if needed
43
+ wav = wav.mean(0, keepdim=True)
44
+
45
+ wav = convert_audio(wav, sr, model.sample_rate, model.channels)
46
+ wav = wav.to(device)
47
+ semantic_vectors = hubert_model.forward(wav, input_sample_hz=model.sample_rate)
48
+ semantic_tokens = tokenizer.get_token(semantic_vectors)
49
+
50
+ audio = semantic_to_waveform(
51
+ semantic_tokens,
52
+ history_prompt=selected_speaker,
53
+ temp=0.7,
54
+ silent=False,
55
+ output_full=False)
56
+
57
+ settings = Settings('config.yaml')
58
+
59
+ result = create_filename(settings.output_folder_path, None, "swapvoice",".wav")
60
+ write_wav(result, SAMPLE_RATE, audio)
61
+ return result
62
+
training/__init__.py ADDED
File without changes
training/data.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import requests
3
+ import os, glob
4
+
5
+ # english literature
6
+ books = [
7
+ 'https://www.gutenberg.org/cache/epub/1513/pg1513.txt',
8
+ 'https://www.gutenberg.org/files/2701/2701-0.txt',
9
+ 'https://www.gutenberg.org/cache/epub/84/pg84.txt',
10
+ 'https://www.gutenberg.org/cache/epub/2641/pg2641.txt',
11
+ 'https://www.gutenberg.org/cache/epub/1342/pg1342.txt',
12
+ 'https://www.gutenberg.org/cache/epub/100/pg100.txt'
13
+ ]
14
+
15
+ #default english
16
+ # allowed_chars = ' abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!@#$%^&*()-_+=\"\':;[]{}/<>,.`~\n\\'
17
+
18
+ #german
19
+ allowed_chars = ' aäbcdefghijklmnoöpqrsßtuüvwxyzABCDEFGHIJKLMNOÖPQRSTUÜVWXYZ0123456789!@#$%^&*()-_+=\"\':;[]{}/<>,.`~\n\\'
20
+
21
+
22
+ def download_book(book):
23
+ return requests.get(book).content.decode('utf-8')
24
+
25
+
26
+ def filter_data(data):
27
+ print('Filtering data')
28
+ return ''.join([char for char in data if char in allowed_chars])
29
+
30
+
31
+ def load_books(fromfolder=False):
32
+ text_data = []
33
+ if fromfolder:
34
+ current_working_directory = os.getcwd()
35
+ print(current_working_directory)
36
+ path = 'text'
37
+ for filename in glob.glob(os.path.join(path, '*.txt')):
38
+ with open(os.path.join(os.getcwd(), filename), 'r') as f: # open in readonly mode
39
+ print(f'Loading {filename}')
40
+ text_data.append(filter_data(str(f.read())))
41
+ else:
42
+ print(f'Loading {len(books)} books into ram')
43
+ for book in books:
44
+ text_data.append(filter_data(str(download_book(book))))
45
+ print('Loaded books')
46
+ return ' '.join(text_data)
47
+
48
+
49
+ def random_split_chunk(data, size=14):
50
+ data = data.split(' ')
51
+ index = random.randrange(0, len(data))
52
+ return ' '.join(data[index:index+size])
training/train.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import fnmatch
3
+ import shutil
4
+
5
+ import numpy
6
+ import torchaudio
7
+ import gradio
8
+
9
+ from bark.hubert.pre_kmeans_hubert import CustomHubert
10
+ from bark.hubert.customtokenizer import auto_train
11
+ from tqdm.auto import tqdm
12
+
13
+
14
+ def training_prepare_files(path, model,progress=gradio.Progress(track_tqdm=True)):
15
+
16
+ semanticsfolder = "./training/data/output"
17
+ wavfolder = "./training/data/output_wav"
18
+ ready = os.path.join(path, 'ready')
19
+
20
+ testfiles = fnmatch.filter(os.listdir(ready), '*.npy')
21
+ if(len(testfiles) < 1):
22
+ # prepare and copy for training
23
+ hubert_model = CustomHubert(checkpoint_path=model)
24
+
25
+ wavfiles = fnmatch.filter(os.listdir(wavfolder), '*.wav')
26
+ for i, f in tqdm(enumerate(wavfiles), total=len(wavfiles)):
27
+ semaname = '.'.join(f.split('.')[:-1]) # Cut off the extension
28
+ semaname = f'{semaname}.npy'
29
+ semafilename = os.path.join(semanticsfolder, semaname)
30
+ if not os.path.isfile(semafilename):
31
+ print(f'Skipping {f} no semantics pair found!')
32
+ continue
33
+
34
+ print('Processing', f)
35
+ wav, sr = torchaudio.load(os.path.join(wavfolder, f))
36
+ if wav.shape[0] == 2: # Stereo to mono if needed
37
+ wav = wav.mean(0, keepdim=True)
38
+ output = hubert_model.forward(wav, input_sample_hz=sr)
39
+ out_array = output.cpu().numpy()
40
+ fname = f'{i}_semantic_features.npy'
41
+ numpy.save(os.path.join(ready, fname), out_array)
42
+ fname = f'{i}_semantic.npy'
43
+ shutil.copy(semafilename, os.path.join(ready, fname))
44
+
45
+ def train(path, save_every, max_epochs):
46
+ auto_train(path, save_epochs=save_every)
47
+
training/training_prepare.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import uuid
3
+ import numpy
4
+ import os
5
+ import random
6
+ import fnmatch
7
+
8
+ from tqdm.auto import tqdm
9
+ from scipy.io import wavfile
10
+
11
+ from bark.generation import load_model, SAMPLE_RATE
12
+ from bark.api import semantic_to_waveform
13
+
14
+ from bark import text_to_semantic
15
+ from bark.generation import load_model
16
+
17
+ from training.data import load_books, random_split_chunk
18
+
19
+ output = 'training/data/output'
20
+ output_wav = 'training/data/output_wav'
21
+
22
+
23
+ def prepare_semantics_from_text(num_generations):
24
+ loaded_data = load_books(True)
25
+
26
+ print('Loading semantics model')
27
+ load_model(use_gpu=True, use_small=False, force_reload=False, model_type='text')
28
+
29
+ if not os.path.isdir(output):
30
+ os.mkdir(output)
31
+
32
+ loop = 1
33
+ while 1:
34
+ filename = uuid.uuid4().hex + '.npy'
35
+ file_name = os.path.join(output, filename)
36
+ text = ''
37
+ while not len(text) > 0:
38
+ text = random_split_chunk(loaded_data) # Obtain a short chunk of text
39
+ text = text.strip()
40
+ print(f'{loop} Generating semantics for text:', text)
41
+ loop+=1
42
+ semantics = text_to_semantic(text, temp=round(random.uniform(0.6, 0.8), ndigits=2))
43
+ numpy.save(file_name, semantics)
44
+
45
+
46
+ def prepare_wavs_from_semantics():
47
+ if not os.path.isdir(output):
48
+ raise Exception('No \'output\' folder, make sure you run create_data.py first!')
49
+ if not os.path.isdir(output_wav):
50
+ os.mkdir(output_wav)
51
+
52
+ print('Loading coarse model')
53
+ load_model(use_gpu=True, use_small=False, force_reload=False, model_type='coarse')
54
+ print('Loading fine model')
55
+ load_model(use_gpu=True, use_small=False, force_reload=False, model_type='fine')
56
+
57
+ files = fnmatch.filter(os.listdir(output), '*.npy')
58
+ current = 1
59
+ total = len(files)
60
+
61
+ for i, f in tqdm(enumerate(files), total=len(files)):
62
+ real_name = '.'.join(f.split('.')[:-1]) # Cut off the extension
63
+ file_name = os.path.join(output, f)
64
+ out_file = os.path.join(output_wav, f'{real_name}.wav')
65
+ if not os.path.isfile(out_file) and os.path.isfile(file_name): # Don't process files that have already been processed, to be able to continue previous generations
66
+ print(f'Processing ({i+1}/{total}) -> {f}')
67
+ wav = semantic_to_waveform(numpy.load(file_name), temp=round(random.uniform(0.6, 0.8), ndigits=2))
68
+ # Change to PCM16
69
+ # wav = (wav * 32767).astype(np.int16)
70
+ wavfile.write(out_file, SAMPLE_RATE, wav)
71
+
72
+ print('Done!')
73
+
util/__init__.py ADDED
File without changes
util/helper.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from datetime import datetime
3
+ from mutagen.wave import WAVE
4
+ from mutagen.id3._frames import *
5
+
6
+ def create_filename(path, seed, name, extension):
7
+ now = datetime.now()
8
+ date_str =now.strftime("%m-%d-%Y")
9
+ outputs_folder = os.path.join(os.getcwd(), path)
10
+ if not os.path.exists(outputs_folder):
11
+ os.makedirs(outputs_folder)
12
+
13
+ sub_folder = os.path.join(outputs_folder, date_str)
14
+ if not os.path.exists(sub_folder):
15
+ os.makedirs(sub_folder)
16
+
17
+ time_str = now.strftime("%H-%M-%S")
18
+ if seed == None:
19
+ file_name = f"{name}_{time_str}{extension}"
20
+ else:
21
+ file_name = f"{name}_{time_str}_s{seed}{extension}"
22
+ return os.path.join(sub_folder, file_name)
23
+
24
+
25
+ def add_id3_tag(filename, text, speakername, seed):
26
+ audio = WAVE(filename)
27
+ if speakername == None:
28
+ speakername = "Unconditional"
29
+
30
+ # write id3 tag with text truncated to 60 chars, as a precaution...
31
+ audio["TIT2"] = TIT2(encoding=3, text=text[:60])
32
+ audio["TPE1"] = TPE1(encoding=3, text=f"Voice {speakername} using Seed={seed}")
33
+ audio["TPUB"] = TPUB(encoding=3, text="Bark by Suno AI")
34
+ audio["COMMENT"] = COMM(encoding=3, text="Generated with Bark GUI - Text-Prompted Generative Audio Model. Visit https://github.com/C0untFloyd/bark-gui")
35
+ audio.save()