inoryQwQ
commited on
Commit
·
ce28028
1
Parent(s):
4e849a4
Update models, remove decoder_main
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- models-ax630c/{base-decoder-loop.axmodel → base/base-decoder.axmodel} +2 -2
- models-ax630c/{base-encoder.axmodel → base/base-encoder.axmodel} +2 -2
- models-ax630c/{base-tokens.txt → base/base-tokens.txt} +0 -0
- models-ax630c/base/base_config.json +30 -0
- models-ax630c/{base-decoder-main.axmodel → small/small-decoder.axmodel} +2 -2
- models-ax630c/small/small-tokens.txt +0 -0
- models-ax630c/small/small_config.json +30 -0
- models-ax630c/{base-positional_embedding.bin → tiny/tiny-decoder.axmodel} +2 -2
- models-ax650/small/small-positional_embedding.bin → models-ax630c/tiny/tiny-encoder.axmodel +2 -2
- models-ax630c/tiny/tiny-tokens.txt +0 -0
- models-ax630c/tiny/tiny_config.json +30 -0
- models-ax650/base/base-decoder-loop.axmodel +0 -3
- models-ax650/base/base-decoder-main.axmodel +0 -3
- models-ax650/base/base-decoder.axmodel +3 -0
- models-ax650/base/base-encoder.axmodel +2 -2
- models-ax650/small/small-decoder-loop.axmodel +0 -3
- models-ax650/small/small-decoder-main.axmodel +0 -3
- models-ax650/small/small-decoder.axmodel +3 -0
- models-ax650/small/small-encoder.axmodel +2 -2
- models-ax650/tiny/tiny-decoder-loop.axmodel +0 -3
- models-ax650/tiny/tiny-decoder-main.axmodel +0 -3
- models-ax650/tiny/tiny-decoder.axmodel +3 -0
- models-ax650/tiny/tiny-encoder.axmodel +2 -2
- models-ax650/tiny/tiny-positional_embedding.bin +0 -3
- models-ax650/turbo/turbo-decoder-loop.axmodel +0 -3
- models-ax650/turbo/turbo-decoder-main.axmodel +0 -3
- models-ax650/turbo/turbo-decoder.axmodel +3 -0
- models-ax650/turbo/turbo-encoder.axmodel +2 -2
- models-ax650/turbo/turbo-positional_embedding.bin +0 -3
- models-onnx/base/base-decoder-loop.onnx +0 -3
- models-onnx/base/base-decoder-main.onnx +0 -3
- models-onnx/base/base-decoder.onnx +3 -0
- models-onnx/base/base-encoder.onnx +2 -2
- models-onnx/base/base-positional_embedding.bin +0 -3
- models-onnx/small/small-decoder.onnx +3 -0
- models-onnx/small/small-encoder.onnx +3 -0
- models-onnx/small/small-positional_embedding.bin +0 -3
- models-onnx/tiny/tiny-decoder-loop.onnx +0 -3
- models-onnx/tiny/tiny-decoder-main.onnx +0 -3
- models-onnx/tiny/tiny-decoder.onnx +3 -0
- models-onnx/tiny/tiny-encoder.onnx +2 -2
- models-onnx/tiny/tiny-positional_embedding.bin +0 -3
- models-onnx/turbo/turbo-decoder.onnx +3 -0
- models-ax650/base/base-positional_embedding.bin → models-onnx/turbo/turbo-encoder.onnx +2 -2
- models-onnx/turbo/turbo-tokens.txt +0 -0
- python/assets/multilingual.tiktoken +0 -0
- python/languages.py +1 -1
- python/main.py +39 -19
- python/test_wer.py +96 -61
- python/whisper.py +179 -128
models-ax630c/{base-decoder-loop.axmodel → base/base-decoder.axmodel}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:51e595565f0121eb4dc9ee14172cb8a111a56bf280a927660fdee5fbffa9d52e
|
| 3 |
+
size 184323085
|
models-ax630c/{base-encoder.axmodel → base/base-encoder.axmodel}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:00ac3d4d0aa81f3910d4aa9c777e81fbf3b4bc22f26a9d9ac38f236392261603
|
| 3 |
+
size 56706622
|
models-ax630c/{base-tokens.txt → base/base-tokens.txt}
RENAMED
|
File without changes
|
models-ax630c/base/base_config.json
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_type": "whisper-base",
|
| 3 |
+
"version": "1",
|
| 4 |
+
"maintainer": "k2-fsa",
|
| 5 |
+
"n_mels": 80,
|
| 6 |
+
"n_audio_ctx": 1500,
|
| 7 |
+
"n_audio_state": 512,
|
| 8 |
+
"n_audio_head": 8,
|
| 9 |
+
"n_audio_layer": 6,
|
| 10 |
+
"n_vocab": 51865,
|
| 11 |
+
"n_text_ctx": 448,
|
| 12 |
+
"n_text_state": 512,
|
| 13 |
+
"n_text_head": 8,
|
| 14 |
+
"n_text_layer": 6,
|
| 15 |
+
"sot_sequence": "50258,50259,50359",
|
| 16 |
+
"all_language_tokens": "50346,50356,50292,50319,50325,50330,50327,50328,50341,50331,50307,50353,50279,50320,50322,50326,50340,50267,50269,50297,50301,50316,50323,50287,50276,50342,50304,50277,50311,50350,50286,50278,50290,50309,50284,50268,50300,50321,50272,50291,50281,50266,50357,50333,50293,50299,50294,50337,50271,50263,50296,50264,50343,50315,50314,50270,50352,50317,50349,50348,50283,50265,50308,50305,50336,50261,50335,50262,50345,50344,50351,50310,50329,50332,50289,50274,50302,50259,50324,50282,50285,50313,50280,50334,50260,50303,50312,50318,50295,50273,50338,50298,50347,50288,50354,50355,50275,50306,50339",
|
| 17 |
+
"all_language_codes": "my,jw,bg,gl,yo,be,af,oc,tk,tg,et,ln,he,mr,si,so,ps,pt,pl,cy,lv,kk,km,ta,hi,nn,az,fi,is,as,hu,vi,ur,br,ro,tr,fa,pa,ar,hr,el,ja,su,gu,lt,te,la,uz,nl,ru,ml,ko,mt,bs,mn,ca,haw,sq,mg,tl,cs,fr,mk,sl,lo,de,yi,es,lb,sa,tt,eu,ka,sd,th,it,bn,en,sn,ms,da,ne,uk,am,zh,sr,hy,sw,mi,sv,fo,sk,bo,no,ha,ba,id,kn,ht",
|
| 18 |
+
"sot": 50258,
|
| 19 |
+
"sot_index": 0,
|
| 20 |
+
"eot": 50257,
|
| 21 |
+
"blank_id": 220,
|
| 22 |
+
"is_multilingual": 1,
|
| 23 |
+
"no_speech": 50362,
|
| 24 |
+
"non_speech_tokens": "1,2,7,8,9,10,14,25,26,27,28,29,31,58,59,60,61,62,63,90,91,92,93,359,503,522,542,873,893,902,918,922,931,1350,1853,1982,2460,2627,3246,3253,3268,3536,3846,3961,4183,4667,6585,6647,7273,9061,9383,10428,10929,11938,12033,12331,12562,13793,14157,14635,15265,15618,16553,16604,18362,18956,20075,21675,22520,26130,26161,26435,28279,29464,31650,32302,32470,36865,42863,47425,49870,50254",
|
| 25 |
+
"transcribe": 50359,
|
| 26 |
+
"translate": 50358,
|
| 27 |
+
"sot_prev": 50361,
|
| 28 |
+
"sot_lm": 50360,
|
| 29 |
+
"no_timestamps": 50363
|
| 30 |
+
}
|
models-ax630c/{base-decoder-main.axmodel → small/small-decoder.axmodel}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:9d13b4c3693f72e280162a1fd78c52ffc8ecc9318d7d2bd56db9810945c88f1b
|
| 3 |
+
size 345275786
|
models-ax630c/small/small-tokens.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
models-ax630c/small/small_config.json
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_type": "whisper-small",
|
| 3 |
+
"version": "1",
|
| 4 |
+
"maintainer": "k2-fsa",
|
| 5 |
+
"n_mels": 80,
|
| 6 |
+
"n_audio_ctx": 1500,
|
| 7 |
+
"n_audio_state": 768,
|
| 8 |
+
"n_audio_head": 12,
|
| 9 |
+
"n_audio_layer": 12,
|
| 10 |
+
"n_vocab": 51865,
|
| 11 |
+
"n_text_ctx": 448,
|
| 12 |
+
"n_text_state": 768,
|
| 13 |
+
"n_text_head": 12,
|
| 14 |
+
"n_text_layer": 12,
|
| 15 |
+
"sot_sequence": "50258,50259,50359",
|
| 16 |
+
"all_language_tokens": "50349,50276,50334,50346,50335,50313,50353,50284,50280,50268,50266,50267,50330,50350,50259,50311,50336,50340,50288,50302,50291,50279,50299,50351,50270,50301,50314,50295,50331,50285,50303,50326,50342,50355,50348,50341,50354,50321,50272,50269,50357,50333,50283,50309,50271,50324,50323,50290,50327,50298,50319,50356,50282,50332,50275,50263,50294,50305,50293,50317,50338,50287,50292,50316,50343,50289,50260,50328,50312,50344,50325,50304,50339,50320,50308,50274,50262,50345,50278,50296,50337,50310,50329,50318,50347,50265,50307,50264,50352,50315,50273,50277,50300,50261,50286,50306,50322,50281,50297",
|
| 17 |
+
"all_language_codes": "mg,hi,am,my,yi,ne,ln,ro,uk,tr,ja,pt,be,as,en,is,lo,ps,no,bn,hr,he,te,tt,ca,lv,mn,mi,tg,da,sr,so,nn,ba,tl,tk,ha,pa,ar,pl,su,gu,cs,br,nl,sn,km,ur,af,sk,gl,jw,ms,sd,id,ru,la,sl,lt,sq,fo,ta,bg,kk,mt,th,zh,oc,hy,sa,yo,az,ht,mr,mk,it,es,lb,vi,ml,uz,eu,ka,sw,bo,fr,et,ko,haw,bs,sv,fi,fa,de,hu,kn,si,el,cy",
|
| 18 |
+
"sot": 50258,
|
| 19 |
+
"sot_index": 0,
|
| 20 |
+
"eot": 50257,
|
| 21 |
+
"blank_id": 220,
|
| 22 |
+
"is_multilingual": 1,
|
| 23 |
+
"no_speech": 50362,
|
| 24 |
+
"non_speech_tokens": "1,2,7,8,9,10,14,25,26,27,28,29,31,58,59,60,61,62,63,90,91,92,93,359,503,522,542,873,893,902,918,922,931,1350,1853,1982,2460,2627,3246,3253,3268,3536,3846,3961,4183,4667,6585,6647,7273,9061,9383,10428,10929,11938,12033,12331,12562,13793,14157,14635,15265,15618,16553,16604,18362,18956,20075,21675,22520,26130,26161,26435,28279,29464,31650,32302,32470,36865,42863,47425,49870,50254",
|
| 25 |
+
"transcribe": 50359,
|
| 26 |
+
"translate": 50358,
|
| 27 |
+
"sot_prev": 50361,
|
| 28 |
+
"sot_lm": 50360,
|
| 29 |
+
"no_timestamps": 50363
|
| 30 |
+
}
|
models-ax630c/{base-positional_embedding.bin → tiny/tiny-decoder.axmodel}
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:64a611a2575597fed3e705d9faf941df0b33d58bc10a733fa31f5e937fd58ec4
|
| 3 |
+
size 129647157
|
models-ax650/small/small-positional_embedding.bin → models-ax630c/tiny/tiny-encoder.axmodel
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0528e2d7e317668e43a5641695f93d0c80d9902e28f0e9f2fdef76470855efe7
|
| 3 |
+
size 26853722
|
models-ax630c/tiny/tiny-tokens.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
models-ax630c/tiny/tiny_config.json
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"model_type": "whisper-tiny",
|
| 3 |
+
"version": "1",
|
| 4 |
+
"maintainer": "k2-fsa",
|
| 5 |
+
"n_mels": 80,
|
| 6 |
+
"n_audio_ctx": 1500,
|
| 7 |
+
"n_audio_state": 384,
|
| 8 |
+
"n_audio_head": 6,
|
| 9 |
+
"n_audio_layer": 4,
|
| 10 |
+
"n_vocab": 51865,
|
| 11 |
+
"n_text_ctx": 448,
|
| 12 |
+
"n_text_state": 384,
|
| 13 |
+
"n_text_head": 6,
|
| 14 |
+
"n_text_layer": 4,
|
| 15 |
+
"sot_sequence": "50258,50259,50359",
|
| 16 |
+
"all_language_tokens": "50286,50307,50345,50299,50290,50265,50275,50294,50309,50262,50318,50331,50282,50349,50270,50326,50279,50272,50355,50287,50269,50315,50278,50283,50266,50260,50297,50348,50346,50296,50277,50334,50342,50313,50288,50322,50325,50259,50302,50332,50338,50344,50261,50330,50304,50357,50314,50340,50291,50352,50320,50271,50316,50336,50323,50293,50263,50308,50284,50273,50267,50312,50321,50328,50285,50298,50301,50327,50354,50303,50356,50351,50295,50339,50292,50319,50264,50310,50276,50335,50311,50341,50350,50268,50289,50281,50324,50333,50317,50343,50305,50274,50306,50353,50300,50347,50329,50337,50280",
|
| 17 |
+
"all_language_codes": "hu,et,lb,te,ur,fr,id,la,br,es,sw,tg,ms,mg,ca,so,he,ar,ba,ta,pl,bs,vi,cs,ja,zh,cy,tl,my,ml,fi,am,nn,ne,no,si,yo,en,bn,sd,fo,sa,de,be,az,su,mn,ps,hr,haw,mr,nl,kk,lo,km,lt,ru,mk,ro,sv,pt,hy,pa,oc,da,sk,lv,af,ha,sr,jw,tt,mi,ht,bg,gl,ko,eu,hi,yi,is,tk,as,tr,th,el,sn,gu,sq,mt,sl,it,kn,ln,fa,bo,ka,uz,uk",
|
| 18 |
+
"sot": 50258,
|
| 19 |
+
"sot_index": 0,
|
| 20 |
+
"eot": 50257,
|
| 21 |
+
"blank_id": 220,
|
| 22 |
+
"is_multilingual": 1,
|
| 23 |
+
"no_speech": 50362,
|
| 24 |
+
"non_speech_tokens": "1,2,7,8,9,10,14,25,26,27,28,29,31,58,59,60,61,62,63,90,91,92,93,359,503,522,542,873,893,902,918,922,931,1350,1853,1982,2460,2627,3246,3253,3268,3536,3846,3961,4183,4667,6585,6647,7273,9061,9383,10428,10929,11938,12033,12331,12562,13793,14157,14635,15265,15618,16553,16604,18362,18956,20075,21675,22520,26130,26161,26435,28279,29464,31650,32302,32470,36865,42863,47425,49870,50254",
|
| 25 |
+
"transcribe": 50359,
|
| 26 |
+
"translate": 50358,
|
| 27 |
+
"sot_prev": 50361,
|
| 28 |
+
"sot_lm": 50360,
|
| 29 |
+
"no_timestamps": 50363
|
| 30 |
+
}
|
models-ax650/base/base-decoder-loop.axmodel
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:43fbc9a1672eabd705bb68fdfd6b0837c4d3bceec5e07c80cc829cf47417e11d
|
| 3 |
-
size 183531172
|
|
|
|
|
|
|
|
|
|
|
|
models-ax650/base/base-decoder-main.axmodel
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:7d9167ae1c52ed1cb318fb27e45734c2ceb0560444078693bf831ce02f2c0331
|
| 3 |
-
size 183985586
|
|
|
|
|
|
|
|
|
|
|
|
models-ax650/base/base-decoder.axmodel
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a5212e05f8c0d2d4b2a319eaf87eb93253a9bb476dee8fa3d8a85f3137b61045
|
| 3 |
+
size 184444593
|
models-ax650/base/base-encoder.axmodel
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c07d03194566f292d26cf1f0e104d5740e1a249a0dc92b23c5b15b9e96496c24
|
| 3 |
+
size 33132600
|
models-ax650/small/small-decoder-loop.axmodel
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:b472a0f3539d17fece09e92bf6cd69ebf391928a6050896bbf86b558a25def22
|
| 3 |
-
size 269002567
|
|
|
|
|
|
|
|
|
|
|
|
models-ax650/small/small-decoder-main.axmodel
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:f3bfc577f60c35192d8ce8cc24f9ca4aa84af72756ba11af9d178d337cb7eb1c
|
| 3 |
-
size 285531695
|
|
|
|
|
|
|
|
|
|
|
|
models-ax650/small/small-decoder.axmodel
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d4ed1e62122ed495624efa0275a0d4cb2450996b6f1a5e1d9ea9d48026d1bb66
|
| 3 |
+
size 350609498
|
models-ax650/small/small-encoder.axmodel
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6e4e8462e3b6ac3ea9465d61560bc42e1398a06ea640d9ff5cdb82636ab73d47
|
| 3 |
+
size 136275980
|
models-ax650/tiny/tiny-decoder-loop.axmodel
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:139aa1429a6f439b7a5d1e5f481cb761c673afd4718a25968ad979fccfdfecaf
|
| 3 |
-
size 128541899
|
|
|
|
|
|
|
|
|
|
|
|
models-ax650/tiny/tiny-decoder-main.axmodel
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:8c1d08ec9309a26103a955cb432e0ddc25476da706a6a0a94108d225a48385aa
|
| 3 |
-
size 128909975
|
|
|
|
|
|
|
|
|
|
|
|
models-ax650/tiny/tiny-decoder.axmodel
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ecd77406ae1c04883beb7517f3e9f0ca0c4b35e640467997f126efb539b96306
|
| 3 |
+
size 129267343
|
models-ax650/tiny/tiny-encoder.axmodel
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:698670eb7410eeb2a84e26cee8918a780c97d1232a9dc7f946f61949105299a9
|
| 3 |
+
size 14085412
|
models-ax650/tiny/tiny-positional_embedding.bin
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:c13450ae630323a0bdd39b1226f92a7ac251131a909c7efdb7d2f5516736eb83
|
| 3 |
-
size 688128
|
|
|
|
|
|
|
|
|
|
|
|
models-ax650/turbo/turbo-decoder-loop.axmodel
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:1b50daea5f4776006bf0e635ca5224f118649a9b3c37b2b821ae4c321db096ec
|
| 3 |
-
size 499257709
|
|
|
|
|
|
|
|
|
|
|
|
models-ax650/turbo/turbo-decoder-main.axmodel
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:56d34930691311002e401d8243abbae632593b57406e8ae780cd02c9076a783d
|
| 3 |
-
size 500341239
|
|
|
|
|
|
|
|
|
|
|
|
models-ax650/turbo/turbo-decoder.axmodel
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3d5a52c2c1f2c08fdb3b664201a79b42e3439184ffcc5f19e151eeb4a88cadd8
|
| 3 |
+
size 501696186
|
models-ax650/turbo/turbo-encoder.axmodel
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:b2cbdf3941b8d739318148505c380167a41e7d710d1a3cfda1a257c2c7fe428f
|
| 3 |
+
size 893420089
|
models-ax650/turbo/turbo-positional_embedding.bin
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:a94790ac6719da6e134835255274ee7fc6066ad5e0f08a0f747c1e1cf6407dc3
|
| 3 |
-
size 2293760
|
|
|
|
|
|
|
|
|
|
|
|
models-onnx/base/base-decoder-loop.onnx
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:1616a829b7d3d643616633551204b8d0f008fb7a7dc38919eda2e8c6c6ed9714
|
| 3 |
-
size 194571088
|
|
|
|
|
|
|
|
|
|
|
|
models-onnx/base/base-decoder-main.onnx
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:1096b83590016bdbe74c66c7ccad1c0120abd6d37214560b1dfe4cd886a0e683
|
| 3 |
-
size 205485892
|
|
|
|
|
|
|
|
|
|
|
|
models-onnx/base/base-decoder.onnx
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a2ba37d199ec7c1794facba8efc84a484b0224f6650e362fda7d2db75827023a
|
| 3 |
+
size 195497242
|
models-onnx/base/base-encoder.onnx
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1c75e7cca22000432ec4f8f50726299aa20db34a9c154646e8ac10c0ddb4699b
|
| 3 |
+
size 95025778
|
models-onnx/base/base-positional_embedding.bin
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:88fa1cdbf2b06f86b0ecb7be0fccfc39e906502986572b8cf5319c250e857169
|
| 3 |
-
size 917504
|
|
|
|
|
|
|
|
|
|
|
|
models-onnx/small/small-decoder.onnx
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f696ce159cffc13416ef624b91dae341de9d3f7e720cfc832000fe987f6d50b4
|
| 3 |
+
size 557821592
|
models-onnx/small/small-encoder.onnx
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f7dce3226784147e42021007dede9c2c144d33487711a1950b5935b9c4829f1b
|
| 3 |
+
size 409408370
|
models-onnx/small/small-positional_embedding.bin
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:c10bc44f2bd94bdf1b7aa03581309fa536132b3fe79bfe22c9a6934a42cd8b58
|
| 3 |
-
size 1376256
|
|
|
|
|
|
|
|
|
|
|
|
models-onnx/tiny/tiny-decoder-loop.onnx
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:5cbb3533939e2dfdf567b27762b12cf0956b7d7982bfb915228d24789f483058
|
| 3 |
-
size 112843354
|
|
|
|
|
|
|
|
|
|
|
|
models-onnx/tiny/tiny-decoder-main.onnx
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:59ced1cf4e9a6f2aef0a2457f64f846e5682033abb4b894ba7680a60c792ad73
|
| 3 |
-
size 118301861
|
|
|
|
|
|
|
|
|
|
|
|
models-onnx/tiny/tiny-decoder.onnx
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0d1a16a5a70e9a5940b68559eb5142c7c53b8f595f2f5e4a01d7e168acca6eb1
|
| 3 |
+
size 113537271
|
models-onnx/tiny/tiny-encoder.onnx
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:324e93c5ddc2e922273ebeb16bba8c453aa009c698e2262103ac5ce21df1c3ed
|
| 3 |
+
size 37605342
|
models-onnx/tiny/tiny-positional_embedding.bin
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:c13450ae630323a0bdd39b1226f92a7ac251131a909c7efdb7d2f5516736eb83
|
| 3 |
-
size 688128
|
|
|
|
|
|
|
|
|
|
|
|
models-onnx/turbo/turbo-decoder.onnx
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:75b7d28ee244aa1d39d02d98514cd951e7697f6fb742c7f84f808b8aab9b1d2a
|
| 3 |
+
size 635240242
|
models-ax650/base/base-positional_embedding.bin → models-onnx/turbo/turbo-encoder.onnx
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:883b1a459251b53f7431bdb81d2a573d0965c42be73179cbdeb0041154fb0a7d
|
| 3 |
+
size 389433
|
models-onnx/turbo/turbo-tokens.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
python/assets/multilingual.tiktoken
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
python/languages.py
CHANGED
|
@@ -99,4 +99,4 @@ WHISPER_LANGUAGES = {
|
|
| 99 |
"jw": "javanese",
|
| 100 |
"su": "sundanese",
|
| 101 |
"yue": "cantonese",
|
| 102 |
-
}
|
|
|
|
| 99 |
"jw": "javanese",
|
| 100 |
"su": "sundanese",
|
| 101 |
"yue": "cantonese",
|
| 102 |
+
}
|
python/main.py
CHANGED
|
@@ -6,29 +6,49 @@ import time
|
|
| 6 |
|
| 7 |
def get_args():
|
| 8 |
parser = argparse.ArgumentParser(
|
| 9 |
-
prog="whisper",
|
| 10 |
-
description="Run Whisper on input audio file"
|
| 11 |
)
|
| 12 |
parser.add_argument("--wav", "-w", type=str, required=True, help="Input audio file")
|
| 13 |
-
parser.add_argument(
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 18 |
return parser.parse_args()
|
| 19 |
|
| 20 |
|
| 21 |
-
def print_args(args):
|
| 22 |
-
print(f"wav: {args.wav}")
|
| 23 |
-
print(f"model_type: {args.model_type}")
|
| 24 |
-
print(f"model_path: {args.model_path}")
|
| 25 |
-
print(f"language: {args.language}")
|
| 26 |
-
print(f"task: {args.task}")
|
| 27 |
-
|
| 28 |
-
|
| 29 |
def main():
|
| 30 |
args = get_args()
|
| 31 |
-
|
| 32 |
|
| 33 |
# Check wav existence
|
| 34 |
wav_path = args.wav
|
|
@@ -36,19 +56,19 @@ def main():
|
|
| 36 |
|
| 37 |
model = Whisper(args.model_type, args.model_path, args.language, args.task)
|
| 38 |
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
print("\n预测结果:")
|
| 42 |
start = time.time()
|
| 43 |
print(model.run(wav_path))
|
| 44 |
end = time.time()
|
| 45 |
|
| 46 |
if args.print_rtf:
|
| 47 |
import librosa
|
|
|
|
| 48 |
samples, sr = librosa.load(wav_path, sr=16000)
|
| 49 |
duration = len(samples) / sr
|
| 50 |
process_time = end - start
|
| 51 |
print(f"RTF: {process_time / duration}")
|
| 52 |
|
|
|
|
| 53 |
if __name__ == "__main__":
|
| 54 |
main()
|
|
|
|
| 6 |
|
| 7 |
def get_args():
|
| 8 |
parser = argparse.ArgumentParser(
|
| 9 |
+
prog="whisper", description="Run Whisper on input audio file"
|
|
|
|
| 10 |
)
|
| 11 |
parser.add_argument("--wav", "-w", type=str, required=True, help="Input audio file")
|
| 12 |
+
parser.add_argument(
|
| 13 |
+
"--model_type",
|
| 14 |
+
"-t",
|
| 15 |
+
type=str,
|
| 16 |
+
choices=["tiny", "base", "small", "large", "large-v3", "turbo"],
|
| 17 |
+
required=True,
|
| 18 |
+
help="model type, only support tiny, base and small currently",
|
| 19 |
+
)
|
| 20 |
+
parser.add_argument(
|
| 21 |
+
"--model_path",
|
| 22 |
+
"-p",
|
| 23 |
+
type=str,
|
| 24 |
+
required=False,
|
| 25 |
+
default="../models-ax650",
|
| 26 |
+
help="model path for *.axmodel, tokens.txt, positional_embedding.bin",
|
| 27 |
+
)
|
| 28 |
+
parser.add_argument(
|
| 29 |
+
"--language",
|
| 30 |
+
"-l",
|
| 31 |
+
type=str,
|
| 32 |
+
required=False,
|
| 33 |
+
default="zh",
|
| 34 |
+
help="Target language, support en, zh, ja, and others. See languages.py for more options.",
|
| 35 |
+
)
|
| 36 |
+
parser.add_argument(
|
| 37 |
+
"--task",
|
| 38 |
+
type=str,
|
| 39 |
+
required=False,
|
| 40 |
+
choices=["translate", "transcribe"],
|
| 41 |
+
default="transcribe",
|
| 42 |
+
)
|
| 43 |
+
parser.add_argument(
|
| 44 |
+
"--print_rtf", action="store_true", help="Print Real-Time Factor"
|
| 45 |
+
)
|
| 46 |
return parser.parse_args()
|
| 47 |
|
| 48 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
def main():
|
| 50 |
args = get_args()
|
| 51 |
+
print(vars(args))
|
| 52 |
|
| 53 |
# Check wav existence
|
| 54 |
wav_path = args.wav
|
|
|
|
| 56 |
|
| 57 |
model = Whisper(args.model_type, args.model_path, args.language, args.task)
|
| 58 |
|
| 59 |
+
print("ASR result:")
|
|
|
|
|
|
|
| 60 |
start = time.time()
|
| 61 |
print(model.run(wav_path))
|
| 62 |
end = time.time()
|
| 63 |
|
| 64 |
if args.print_rtf:
|
| 65 |
import librosa
|
| 66 |
+
|
| 67 |
samples, sr = librosa.load(wav_path, sr=16000)
|
| 68 |
duration = len(samples) / sr
|
| 69 |
process_time = end - start
|
| 70 |
print(f"RTF: {process_time / duration}")
|
| 71 |
|
| 72 |
+
|
| 73 |
if __name__ == "__main__":
|
| 74 |
main()
|
python/test_wer.py
CHANGED
|
@@ -10,35 +10,35 @@ def setup_logging():
|
|
| 10 |
# 获取脚本所在目录
|
| 11 |
script_dir = os.path.dirname(os.path.abspath(__file__))
|
| 12 |
log_file = os.path.join(script_dir, "test_wer.log")
|
| 13 |
-
|
| 14 |
# 配置日志格式
|
| 15 |
-
log_format =
|
| 16 |
-
date_format =
|
| 17 |
-
|
| 18 |
# 创建logger
|
| 19 |
logger = logging.getLogger()
|
| 20 |
logger.setLevel(logging.INFO)
|
| 21 |
-
|
| 22 |
# 清除现有的handler
|
| 23 |
for handler in logger.handlers[:]:
|
| 24 |
logger.removeHandler(handler)
|
| 25 |
-
|
| 26 |
# 创建文件handler
|
| 27 |
-
file_handler = logging.FileHandler(log_file, mode=
|
| 28 |
file_handler.setLevel(logging.INFO)
|
| 29 |
file_formatter = logging.Formatter(log_format, date_format)
|
| 30 |
file_handler.setFormatter(file_formatter)
|
| 31 |
-
|
| 32 |
# 创建控制台handler
|
| 33 |
console_handler = logging.StreamHandler()
|
| 34 |
console_handler.setLevel(logging.INFO)
|
| 35 |
console_formatter = logging.Formatter(log_format, date_format)
|
| 36 |
console_handler.setFormatter(console_formatter)
|
| 37 |
-
|
| 38 |
# 添加handler到logger
|
| 39 |
logger.addHandler(file_handler)
|
| 40 |
logger.addHandler(console_handler)
|
| 41 |
-
|
| 42 |
return logger
|
| 43 |
|
| 44 |
|
|
@@ -46,21 +46,21 @@ class AIShellDataset:
|
|
| 46 |
def __init__(self, gt_path: str):
|
| 47 |
"""
|
| 48 |
初始化数据集
|
| 49 |
-
|
| 50 |
Args:
|
| 51 |
json_path: voice.json文件的路径
|
| 52 |
"""
|
| 53 |
self.gt_path = gt_path
|
| 54 |
self.dataset_dir = os.path.dirname(gt_path)
|
| 55 |
self.voice_dir = os.path.join(self.dataset_dir, "aishell_S0764")
|
| 56 |
-
|
| 57 |
# 检查必要文件和文件夹是否存在
|
| 58 |
assert os.path.exists(gt_path), f"gt文件不存在: {gt_path}"
|
| 59 |
assert os.path.exists(self.voice_dir), f"aishell_S0764文件夹不存在: {self.voice_dir}"
|
| 60 |
-
|
| 61 |
# 加载数据
|
| 62 |
self.data = []
|
| 63 |
-
with open(gt_path,
|
| 64 |
for line in f:
|
| 65 |
line = line.strip()
|
| 66 |
audio_path, gt = line.split(" ")
|
|
@@ -70,50 +70,50 @@ class AIShellDataset:
|
|
| 70 |
# 使用logging而不是print
|
| 71 |
logger = logging.getLogger()
|
| 72 |
logger.info(f"加载了 {len(self.data)} 条数据")
|
| 73 |
-
|
| 74 |
def __iter__(self):
|
| 75 |
"""返回迭代器"""
|
| 76 |
self.index = 0
|
| 77 |
return self
|
| 78 |
-
|
| 79 |
def __next__(self):
|
| 80 |
"""返回下一个数据项"""
|
| 81 |
if self.index >= len(self.data):
|
| 82 |
raise StopIteration
|
| 83 |
-
|
| 84 |
item = self.data[self.index]
|
| 85 |
audio_path = item["audio_path"]
|
| 86 |
ground_truth = item["gt"]
|
| 87 |
-
|
| 88 |
self.index += 1
|
| 89 |
return audio_path, ground_truth
|
| 90 |
-
|
| 91 |
def __len__(self):
|
| 92 |
"""返回数据集大小"""
|
| 93 |
return len(self.data)
|
| 94 |
-
|
| 95 |
|
| 96 |
class CommonVoiceDataset:
|
| 97 |
"""Common Voice数据集解析器"""
|
| 98 |
-
|
| 99 |
def __init__(self, tsv_path: str):
|
| 100 |
"""
|
| 101 |
初始化数据集
|
| 102 |
-
|
| 103 |
Args:
|
| 104 |
json_path: voice.json文件的路径
|
| 105 |
"""
|
| 106 |
self.tsv_path = tsv_path
|
| 107 |
self.dataset_dir = os.path.dirname(tsv_path)
|
| 108 |
self.voice_dir = os.path.join(self.dataset_dir, "clips")
|
| 109 |
-
|
| 110 |
# 检查必要文件和文件夹是否存在
|
| 111 |
assert os.path.exists(tsv_path), f"{tsv_path}文件不存在: {tsv_path}"
|
| 112 |
assert os.path.exists(self.voice_dir), f"voice文件夹不存在: {self.voice_dir}"
|
| 113 |
-
|
| 114 |
# 加载JSON数据
|
| 115 |
self.data = []
|
| 116 |
-
with open(tsv_path,
|
| 117 |
f.readline()
|
| 118 |
for line in f:
|
| 119 |
line = line.strip()
|
|
@@ -122,43 +122,77 @@ class CommonVoiceDataset:
|
|
| 122 |
gt = splits[2]
|
| 123 |
audio_path = os.path.join(self.voice_dir, audio_path)
|
| 124 |
self.data.append({"audio_path": audio_path, "gt": gt})
|
| 125 |
-
|
| 126 |
# 使用logging而不是print
|
| 127 |
logger = logging.getLogger()
|
| 128 |
logger.info(f"加载了 {len(self.data)} 条数据")
|
| 129 |
-
|
| 130 |
def __iter__(self):
|
| 131 |
"""返回迭代器"""
|
| 132 |
self.index = 0
|
| 133 |
return self
|
| 134 |
-
|
| 135 |
def __next__(self):
|
| 136 |
"""返回下一个数据项"""
|
| 137 |
if self.index >= len(self.data):
|
| 138 |
raise StopIteration
|
| 139 |
-
|
| 140 |
item = self.data[self.index]
|
| 141 |
audio_path = item["audio_path"]
|
| 142 |
ground_truth = item["gt"]
|
| 143 |
-
|
| 144 |
self.index += 1
|
| 145 |
return audio_path, ground_truth
|
| 146 |
-
|
| 147 |
def __len__(self):
|
| 148 |
"""返回数据集大小"""
|
| 149 |
return len(self.data)
|
| 150 |
|
|
|
|
| 151 |
def get_args():
|
| 152 |
-
parser = argparse.ArgumentParser(
|
| 153 |
-
|
| 154 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 155 |
)
|
| 156 |
-
parser.add_argument("--dataset", "-d", type=str, required=True, choices=["aishell", "common_voice"], help="Test dataset")
|
| 157 |
-
parser.add_argument("--gt_path", "-g", type=str, required=True, help="Test dataset ground truth file")
|
| 158 |
-
parser.add_argument("--max_num", type=int, default=-1, required=False, help="Maximum test data num")
|
| 159 |
-
parser.add_argument("--model_type", "-t", type=str, choices=["tiny", "base", "small", "large", "large-v3", "turbo"], required=True, help="model type, only support tiny, base and small currently")
|
| 160 |
-
parser.add_argument("--model_path", "-p", type=str, required=False, default="../models/models-ax650", help="model path for *.axmodel, tokens.txt, positional_embedding.bin")
|
| 161 |
-
parser.add_argument("--language", "-l", type=str, required=False, default="zh", help="Target language, support en, zh, ja, and others. See languages.py for more options.")
|
| 162 |
return parser.parse_args()
|
| 163 |
|
| 164 |
|
|
@@ -173,42 +207,42 @@ def print_args(args):
|
|
| 173 |
|
| 174 |
|
| 175 |
def min_distance(word1: str, word2: str) -> int:
|
| 176 |
-
|
| 177 |
row = len(word1) + 1
|
| 178 |
column = len(word2) + 1
|
| 179 |
-
|
| 180 |
-
cache = [
|
| 181 |
-
|
| 182 |
for i in range(row):
|
| 183 |
for j in range(column):
|
| 184 |
-
|
| 185 |
-
if i ==0 and j ==0:
|
| 186 |
cache[i][j] = 0
|
| 187 |
-
elif i == 0 and j!=0:
|
| 188 |
cache[i][j] = j
|
| 189 |
-
elif j == 0 and i!=0:
|
| 190 |
cache[i][j] = i
|
| 191 |
else:
|
| 192 |
-
if word1[i-1] == word2[j-1]:
|
| 193 |
-
cache[i][j] = cache[i-1][j-1]
|
| 194 |
else:
|
| 195 |
-
replace = cache[i-1][j-1] + 1
|
| 196 |
-
insert = cache[i][j-1] + 1
|
| 197 |
-
remove = cache[i-1][j] + 1
|
| 198 |
-
|
| 199 |
cache[i][j] = min(replace, insert, remove)
|
| 200 |
-
|
| 201 |
-
return cache[row-1][column-1]
|
| 202 |
|
| 203 |
|
| 204 |
def remove_punctuation(text):
|
| 205 |
# 定义正则表达式模式,匹配所有标点符号
|
| 206 |
# 这个模式包括常见的标点符号和中文标点
|
| 207 |
-
pattern = r
|
| 208 |
-
|
| 209 |
# 使用sub方法将所有匹配的标点符号替换为空字符串
|
| 210 |
-
cleaned_text = re.sub(pattern,
|
| 211 |
-
|
| 212 |
return cleaned_text
|
| 213 |
|
| 214 |
|
|
@@ -254,7 +288,7 @@ def main():
|
|
| 254 |
|
| 255 |
hyp.append(hypothesis)
|
| 256 |
references.append(reference)
|
| 257 |
-
|
| 258 |
line_content = f"({n+1}/{max_data_num}) {os.path.basename(audio_path)} gt: {reference} predict: {hypothesis} WER: {character_error_rate}%"
|
| 259 |
wer_file.write(line_content + "\n")
|
| 260 |
logger.info(line_content)
|
|
@@ -268,5 +302,6 @@ def main():
|
|
| 268 |
wer_file.write(f"Total WER: {total_character_error_rate}%")
|
| 269 |
wer_file.close()
|
| 270 |
|
|
|
|
| 271 |
if __name__ == "__main__":
|
| 272 |
main()
|
|
|
|
| 10 |
# 获取脚本所在目录
|
| 11 |
script_dir = os.path.dirname(os.path.abspath(__file__))
|
| 12 |
log_file = os.path.join(script_dir, "test_wer.log")
|
| 13 |
+
|
| 14 |
# 配置日志格式
|
| 15 |
+
log_format = "%(asctime)s - %(levelname)s - %(message)s"
|
| 16 |
+
date_format = "%Y-%m-%d %H:%M:%S"
|
| 17 |
+
|
| 18 |
# 创建logger
|
| 19 |
logger = logging.getLogger()
|
| 20 |
logger.setLevel(logging.INFO)
|
| 21 |
+
|
| 22 |
# 清除现有的handler
|
| 23 |
for handler in logger.handlers[:]:
|
| 24 |
logger.removeHandler(handler)
|
| 25 |
+
|
| 26 |
# 创建文件handler
|
| 27 |
+
file_handler = logging.FileHandler(log_file, mode="a", encoding="utf-8")
|
| 28 |
file_handler.setLevel(logging.INFO)
|
| 29 |
file_formatter = logging.Formatter(log_format, date_format)
|
| 30 |
file_handler.setFormatter(file_formatter)
|
| 31 |
+
|
| 32 |
# 创建控制台handler
|
| 33 |
console_handler = logging.StreamHandler()
|
| 34 |
console_handler.setLevel(logging.INFO)
|
| 35 |
console_formatter = logging.Formatter(log_format, date_format)
|
| 36 |
console_handler.setFormatter(console_formatter)
|
| 37 |
+
|
| 38 |
# 添加handler到logger
|
| 39 |
logger.addHandler(file_handler)
|
| 40 |
logger.addHandler(console_handler)
|
| 41 |
+
|
| 42 |
return logger
|
| 43 |
|
| 44 |
|
|
|
|
| 46 |
def __init__(self, gt_path: str):
|
| 47 |
"""
|
| 48 |
初始化数据集
|
| 49 |
+
|
| 50 |
Args:
|
| 51 |
json_path: voice.json文件的路径
|
| 52 |
"""
|
| 53 |
self.gt_path = gt_path
|
| 54 |
self.dataset_dir = os.path.dirname(gt_path)
|
| 55 |
self.voice_dir = os.path.join(self.dataset_dir, "aishell_S0764")
|
| 56 |
+
|
| 57 |
# 检查必要文件和文件夹是否存在
|
| 58 |
assert os.path.exists(gt_path), f"gt文件不存在: {gt_path}"
|
| 59 |
assert os.path.exists(self.voice_dir), f"aishell_S0764文件夹不存在: {self.voice_dir}"
|
| 60 |
+
|
| 61 |
# 加载数据
|
| 62 |
self.data = []
|
| 63 |
+
with open(gt_path, "r", encoding="utf-8") as f:
|
| 64 |
for line in f:
|
| 65 |
line = line.strip()
|
| 66 |
audio_path, gt = line.split(" ")
|
|
|
|
| 70 |
# 使用logging而不是print
|
| 71 |
logger = logging.getLogger()
|
| 72 |
logger.info(f"加载了 {len(self.data)} 条数据")
|
| 73 |
+
|
| 74 |
def __iter__(self):
|
| 75 |
"""返回迭代器"""
|
| 76 |
self.index = 0
|
| 77 |
return self
|
| 78 |
+
|
| 79 |
def __next__(self):
|
| 80 |
"""返回下一个数据项"""
|
| 81 |
if self.index >= len(self.data):
|
| 82 |
raise StopIteration
|
| 83 |
+
|
| 84 |
item = self.data[self.index]
|
| 85 |
audio_path = item["audio_path"]
|
| 86 |
ground_truth = item["gt"]
|
| 87 |
+
|
| 88 |
self.index += 1
|
| 89 |
return audio_path, ground_truth
|
| 90 |
+
|
| 91 |
def __len__(self):
|
| 92 |
"""返回数据集大小"""
|
| 93 |
return len(self.data)
|
| 94 |
+
|
| 95 |
|
| 96 |
class CommonVoiceDataset:
|
| 97 |
"""Common Voice数据集解析器"""
|
| 98 |
+
|
| 99 |
def __init__(self, tsv_path: str):
|
| 100 |
"""
|
| 101 |
初始化数据集
|
| 102 |
+
|
| 103 |
Args:
|
| 104 |
json_path: voice.json文件的路径
|
| 105 |
"""
|
| 106 |
self.tsv_path = tsv_path
|
| 107 |
self.dataset_dir = os.path.dirname(tsv_path)
|
| 108 |
self.voice_dir = os.path.join(self.dataset_dir, "clips")
|
| 109 |
+
|
| 110 |
# 检查必要文件和文件夹是否存在
|
| 111 |
assert os.path.exists(tsv_path), f"{tsv_path}文件不存在: {tsv_path}"
|
| 112 |
assert os.path.exists(self.voice_dir), f"voice文件夹不存在: {self.voice_dir}"
|
| 113 |
+
|
| 114 |
# 加载JSON数据
|
| 115 |
self.data = []
|
| 116 |
+
with open(tsv_path, "r", encoding="utf-8") as f:
|
| 117 |
f.readline()
|
| 118 |
for line in f:
|
| 119 |
line = line.strip()
|
|
|
|
| 122 |
gt = splits[2]
|
| 123 |
audio_path = os.path.join(self.voice_dir, audio_path)
|
| 124 |
self.data.append({"audio_path": audio_path, "gt": gt})
|
| 125 |
+
|
| 126 |
# 使用logging而不是print
|
| 127 |
logger = logging.getLogger()
|
| 128 |
logger.info(f"加载了 {len(self.data)} 条数据")
|
| 129 |
+
|
| 130 |
def __iter__(self):
|
| 131 |
"""返回迭代器"""
|
| 132 |
self.index = 0
|
| 133 |
return self
|
| 134 |
+
|
| 135 |
def __next__(self):
|
| 136 |
"""返回下一个数据项"""
|
| 137 |
if self.index >= len(self.data):
|
| 138 |
raise StopIteration
|
| 139 |
+
|
| 140 |
item = self.data[self.index]
|
| 141 |
audio_path = item["audio_path"]
|
| 142 |
ground_truth = item["gt"]
|
| 143 |
+
|
| 144 |
self.index += 1
|
| 145 |
return audio_path, ground_truth
|
| 146 |
+
|
| 147 |
def __len__(self):
|
| 148 |
"""返回数据集大小"""
|
| 149 |
return len(self.data)
|
| 150 |
|
| 151 |
+
|
| 152 |
def get_args():
|
| 153 |
+
parser = argparse.ArgumentParser(prog="whisper", description="Test WER on dataset")
|
| 154 |
+
parser.add_argument(
|
| 155 |
+
"--dataset",
|
| 156 |
+
"-d",
|
| 157 |
+
type=str,
|
| 158 |
+
required=True,
|
| 159 |
+
choices=["aishell", "common_voice"],
|
| 160 |
+
help="Test dataset",
|
| 161 |
+
)
|
| 162 |
+
parser.add_argument(
|
| 163 |
+
"--gt_path",
|
| 164 |
+
"-g",
|
| 165 |
+
type=str,
|
| 166 |
+
required=True,
|
| 167 |
+
help="Test dataset ground truth file",
|
| 168 |
+
)
|
| 169 |
+
parser.add_argument(
|
| 170 |
+
"--max_num", type=int, default=-1, required=False, help="Maximum test data num"
|
| 171 |
+
)
|
| 172 |
+
parser.add_argument(
|
| 173 |
+
"--model_type",
|
| 174 |
+
"-t",
|
| 175 |
+
type=str,
|
| 176 |
+
choices=["tiny", "base", "small", "large", "large-v3", "turbo"],
|
| 177 |
+
required=True,
|
| 178 |
+
help="model type, only support tiny, base and small currently",
|
| 179 |
+
)
|
| 180 |
+
parser.add_argument(
|
| 181 |
+
"--model_path",
|
| 182 |
+
"-p",
|
| 183 |
+
type=str,
|
| 184 |
+
required=False,
|
| 185 |
+
default="../models/models-ax650",
|
| 186 |
+
help="model path for *.axmodel, tokens.txt, positional_embedding.bin",
|
| 187 |
+
)
|
| 188 |
+
parser.add_argument(
|
| 189 |
+
"--language",
|
| 190 |
+
"-l",
|
| 191 |
+
type=str,
|
| 192 |
+
required=False,
|
| 193 |
+
default="zh",
|
| 194 |
+
help="Target language, support en, zh, ja, and others. See languages.py for more options.",
|
| 195 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 196 |
return parser.parse_args()
|
| 197 |
|
| 198 |
|
|
|
|
| 207 |
|
| 208 |
|
| 209 |
def min_distance(word1: str, word2: str) -> int:
|
| 210 |
+
|
| 211 |
row = len(word1) + 1
|
| 212 |
column = len(word2) + 1
|
| 213 |
+
|
| 214 |
+
cache = [[0] * column for i in range(row)]
|
| 215 |
+
|
| 216 |
for i in range(row):
|
| 217 |
for j in range(column):
|
| 218 |
+
|
| 219 |
+
if i == 0 and j == 0:
|
| 220 |
cache[i][j] = 0
|
| 221 |
+
elif i == 0 and j != 0:
|
| 222 |
cache[i][j] = j
|
| 223 |
+
elif j == 0 and i != 0:
|
| 224 |
cache[i][j] = i
|
| 225 |
else:
|
| 226 |
+
if word1[i - 1] == word2[j - 1]:
|
| 227 |
+
cache[i][j] = cache[i - 1][j - 1]
|
| 228 |
else:
|
| 229 |
+
replace = cache[i - 1][j - 1] + 1
|
| 230 |
+
insert = cache[i][j - 1] + 1
|
| 231 |
+
remove = cache[i - 1][j] + 1
|
| 232 |
+
|
| 233 |
cache[i][j] = min(replace, insert, remove)
|
| 234 |
+
|
| 235 |
+
return cache[row - 1][column - 1]
|
| 236 |
|
| 237 |
|
| 238 |
def remove_punctuation(text):
|
| 239 |
# 定义正则表达式模式,匹配所有标点符号
|
| 240 |
# 这个模式包括常见的标点符号和中文标点
|
| 241 |
+
pattern = r"[^\w\s]|_"
|
| 242 |
+
|
| 243 |
# 使用sub方法将所有匹配的标点符号替换为空字符串
|
| 244 |
+
cleaned_text = re.sub(pattern, "", text)
|
| 245 |
+
|
| 246 |
return cleaned_text
|
| 247 |
|
| 248 |
|
|
|
|
| 288 |
|
| 289 |
hyp.append(hypothesis)
|
| 290 |
references.append(reference)
|
| 291 |
+
|
| 292 |
line_content = f"({n+1}/{max_data_num}) {os.path.basename(audio_path)} gt: {reference} predict: {hypothesis} WER: {character_error_rate}%"
|
| 293 |
wer_file.write(line_content + "\n")
|
| 294 |
logger.info(line_content)
|
|
|
|
| 302 |
wer_file.write(f"Total WER: {total_character_error_rate}%")
|
| 303 |
wer_file.close()
|
| 304 |
|
| 305 |
+
|
| 306 |
if __name__ == "__main__":
|
| 307 |
main()
|
python/whisper.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
import axengine as axe
|
| 2 |
import numpy as np
|
| 3 |
import librosa
|
| 4 |
import os
|
|
@@ -9,27 +9,27 @@ from dataclasses import dataclass
|
|
| 9 |
import zhconv
|
| 10 |
|
| 11 |
|
| 12 |
-
NEG_INF = float("-inf")
|
| 13 |
-
|
| 14 |
@dataclass
|
| 15 |
class WhisperConfig:
|
| 16 |
-
n_mels
|
| 17 |
-
sample_rate
|
| 18 |
-
n_fft
|
| 19 |
-
hop_length
|
| 20 |
-
|
| 21 |
-
sot
|
| 22 |
-
eot
|
| 23 |
-
blank_id
|
| 24 |
-
no_timestamps
|
| 25 |
-
no_speech
|
| 26 |
-
translate
|
| 27 |
-
transcribe
|
| 28 |
-
n_vocab
|
| 29 |
-
n_text_ctx
|
| 30 |
-
n_text_state
|
| 31 |
-
|
| 32 |
-
sot_sequence
|
|
|
|
|
|
|
| 33 |
|
| 34 |
|
| 35 |
class Whisper:
|
|
@@ -38,35 +38,41 @@ class Whisper:
|
|
| 38 |
|
| 39 |
self.language = language
|
| 40 |
self.task = task
|
| 41 |
-
self.encoder, self.
|
| 42 |
-
|
|
|
|
| 43 |
self.config = self.load_config(model_config)
|
| 44 |
|
| 45 |
-
|
| 46 |
def load_model(self, model_type, model_path, language, task):
|
| 47 |
encoder_path = f"{model_type}/{model_type}-encoder.axmodel"
|
| 48 |
-
|
| 49 |
-
decoder_loop_path = f"{model_type}/{model_type}-decoder-loop.axmodel"
|
| 50 |
-
pe_path = f"{model_type}/{model_type}-positional_embedding.bin"
|
| 51 |
model_config_file = f"{model_type}/{model_type}_config.json"
|
|
|
|
| 52 |
|
| 53 |
-
required_files = [
|
|
|
|
|
|
|
|
|
|
| 54 |
# Check file existence
|
| 55 |
for i, file_path in enumerate(required_files):
|
| 56 |
assert os.path.exists(file_path), f"{file_path} NOT exist"
|
| 57 |
|
| 58 |
# Load encoder
|
| 59 |
-
encoder = axe.InferenceSession(
|
|
|
|
|
|
|
| 60 |
# Load decoder main
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
# Load position embedding
|
| 65 |
-
pe = np.fromfile(required_files[3], dtype=np.float32)
|
| 66 |
# Load tokens
|
| 67 |
-
model_config = json.load(open(required_files[
|
| 68 |
-
model_config["all_language_tokens"] = [
|
| 69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
tokenizer = get_tokenizer(
|
| 71 |
model_config["is_multilingual"],
|
| 72 |
num_languages=len(model_config["all_language_codes"]),
|
|
@@ -74,8 +80,9 @@ class Whisper:
|
|
| 74 |
task=task,
|
| 75 |
)
|
| 76 |
|
| 77 |
-
|
| 78 |
-
|
|
|
|
| 79 |
|
| 80 |
def load_config(self, model_config):
|
| 81 |
config = WhisperConfig
|
|
@@ -94,34 +101,46 @@ class Whisper:
|
|
| 94 |
config.n_vocab = model_config["n_vocab"]
|
| 95 |
config.n_text_ctx = model_config["n_text_ctx"]
|
| 96 |
config.n_text_state = model_config["n_text_state"]
|
|
|
|
| 97 |
|
| 98 |
-
lang_token = model_config["all_language_tokens"][
|
| 99 |
-
|
| 100 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
|
| 102 |
return config
|
| 103 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
|
| 105 |
def load_audio(self, audio: str):
|
| 106 |
data, sample_rate = librosa.load(audio, sr=self.config.sample_rate)
|
| 107 |
samples = np.ascontiguousarray(data)
|
| 108 |
return samples, sample_rate
|
| 109 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 110 |
|
| 111 |
-
def compute_feature(self, audio: np.ndarray, padding = 480000):
|
| 112 |
-
if padding > 0:
|
| 113 |
-
audio = np.concatenate((audio, np.zeros((padding,), dtype=np.float32)), axis=-1)
|
| 114 |
-
|
| 115 |
-
mel = librosa.feature.melspectrogram(y=audio,
|
| 116 |
-
sr=self.config.sample_rate,
|
| 117 |
-
n_fft=self.config.n_fft,
|
| 118 |
-
hop_length=self.config.hop_length,
|
| 119 |
-
window="hann",
|
| 120 |
-
center=True,
|
| 121 |
-
pad_mode="reflect",
|
| 122 |
-
power=2.0,
|
| 123 |
-
n_mels=self.config.n_mels)
|
| 124 |
-
|
| 125 |
log_spec = np.log10(np.maximum(mel, 1e-10))
|
| 126 |
log_spec = np.maximum(log_spec, log_spec.max() - 8.0)
|
| 127 |
mel = (log_spec + 4.0) / 4.0
|
|
@@ -129,31 +148,71 @@ class Whisper:
|
|
| 129 |
target = 3000
|
| 130 |
if mel.shape[1] > target:
|
| 131 |
# -50 so that there are some zero tail paddings.
|
| 132 |
-
mel = mel[:, :
|
| 133 |
mel[:, -50:] = 0
|
| 134 |
|
| 135 |
# We don't need to pad it to 30 seconds now!
|
| 136 |
if mel.shape[1] < target:
|
| 137 |
-
mel = np.concatenate(
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
return logits
|
| 156 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 157 |
|
| 158 |
def run(self, audio: Union[str, np.ndarray]) -> str:
|
| 159 |
if isinstance(audio, str):
|
|
@@ -161,64 +220,56 @@ class Whisper:
|
|
| 161 |
|
| 162 |
mel = self.compute_feature(audio)
|
| 163 |
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
x = self.decoder_loop.run(None, input_feed={
|
| 198 |
-
"tokens": np.array([[output_tokens[-1]]], dtype=np.int32),
|
| 199 |
-
"in_n_layer_self_k_cache": n_layer_self_k_cache,
|
| 200 |
-
"in_n_layer_self_v_cache": n_layer_self_v_cache,
|
| 201 |
-
"n_layer_cross_k": n_layer_cross_k,
|
| 202 |
-
"n_layer_cross_v": n_layer_cross_v,
|
| 203 |
-
"positional_embedding": self.pe[offset * self.config.n_text_state : (offset + 1) * self.config.n_text_state][None, ...],
|
| 204 |
-
"mask": mask
|
| 205 |
-
})
|
| 206 |
-
logits, n_layer_self_k_cache, n_layer_self_v_cache = x
|
| 207 |
-
|
| 208 |
-
# Decode token
|
| 209 |
offset += 1
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
|
| 213 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 214 |
|
| 215 |
if self.language == "zh":
|
| 216 |
try:
|
| 217 |
-
sim_zh = zhconv.convert(text,
|
| 218 |
return sim_zh
|
| 219 |
except:
|
| 220 |
return text
|
| 221 |
-
|
| 222 |
-
return text
|
| 223 |
|
| 224 |
-
|
|
|
|
| 1 |
+
import axengine as axe
|
| 2 |
import numpy as np
|
| 3 |
import librosa
|
| 4 |
import os
|
|
|
|
| 9 |
import zhconv
|
| 10 |
|
| 11 |
|
|
|
|
|
|
|
| 12 |
@dataclass
|
| 13 |
class WhisperConfig:
|
| 14 |
+
n_mels: int = 0
|
| 15 |
+
sample_rate: int = 0
|
| 16 |
+
n_fft: int = 0
|
| 17 |
+
hop_length: int = 0
|
| 18 |
+
|
| 19 |
+
sot: int = 0
|
| 20 |
+
eot: int = 0
|
| 21 |
+
blank_id: int = 0
|
| 22 |
+
no_timestamps: int = 0
|
| 23 |
+
no_speech: int = 0
|
| 24 |
+
translate: int = 0
|
| 25 |
+
transcribe: int = 0
|
| 26 |
+
n_vocab: int = 0
|
| 27 |
+
n_text_ctx: int = 0
|
| 28 |
+
n_text_state: int = 0
|
| 29 |
+
|
| 30 |
+
sot_sequence: np.ndarray = field(
|
| 31 |
+
default_factory=lambda: np.array([0, 0, 0, 0], dtype=np.int32)
|
| 32 |
+
)
|
| 33 |
|
| 34 |
|
| 35 |
class Whisper:
|
|
|
|
| 38 |
|
| 39 |
self.language = language
|
| 40 |
self.task = task
|
| 41 |
+
self.encoder, self.decoder, self.tokenizer, model_config = self.load_model(
|
| 42 |
+
model_type, model_path, language, task
|
| 43 |
+
)
|
| 44 |
self.config = self.load_config(model_config)
|
| 45 |
|
|
|
|
| 46 |
def load_model(self, model_type, model_path, language, task):
|
| 47 |
encoder_path = f"{model_type}/{model_type}-encoder.axmodel"
|
| 48 |
+
decoder_path = f"{model_type}/{model_type}-decoder.axmodel"
|
|
|
|
|
|
|
| 49 |
model_config_file = f"{model_type}/{model_type}_config.json"
|
| 50 |
+
token_file = f"{model_type}/{model_type}-tokens.txt"
|
| 51 |
|
| 52 |
+
required_files = [
|
| 53 |
+
os.path.join(model_path, i)
|
| 54 |
+
for i in (encoder_path, decoder_path, model_config_file, token_file)
|
| 55 |
+
]
|
| 56 |
# Check file existence
|
| 57 |
for i, file_path in enumerate(required_files):
|
| 58 |
assert os.path.exists(file_path), f"{file_path} NOT exist"
|
| 59 |
|
| 60 |
# Load encoder
|
| 61 |
+
encoder = axe.InferenceSession(
|
| 62 |
+
required_files[0], providers=["AxEngineExecutionProvider"]
|
| 63 |
+
)
|
| 64 |
# Load decoder main
|
| 65 |
+
decoder = axe.InferenceSession(
|
| 66 |
+
required_files[1], providers=["AxEngineExecutionProvider"]
|
| 67 |
+
)
|
|
|
|
|
|
|
| 68 |
# Load tokens
|
| 69 |
+
model_config = json.load(open(required_files[2], "r"))
|
| 70 |
+
model_config["all_language_tokens"] = [
|
| 71 |
+
int(i) for i in model_config["all_language_tokens"].split(",")
|
| 72 |
+
]
|
| 73 |
+
model_config["all_language_codes"] = [
|
| 74 |
+
i for i in model_config["all_language_codes"].split(",")
|
| 75 |
+
]
|
| 76 |
tokenizer = get_tokenizer(
|
| 77 |
model_config["is_multilingual"],
|
| 78 |
num_languages=len(model_config["all_language_codes"]),
|
|
|
|
| 80 |
task=task,
|
| 81 |
)
|
| 82 |
|
| 83 |
+
self.id2token = self.load_tokens(required_files[3])
|
| 84 |
+
|
| 85 |
+
return encoder, decoder, tokenizer, model_config
|
| 86 |
|
| 87 |
def load_config(self, model_config):
|
| 88 |
config = WhisperConfig
|
|
|
|
| 101 |
config.n_vocab = model_config["n_vocab"]
|
| 102 |
config.n_text_ctx = model_config["n_text_ctx"]
|
| 103 |
config.n_text_state = model_config["n_text_state"]
|
| 104 |
+
config.n_text_layer = model_config["n_text_layer"]
|
| 105 |
|
| 106 |
+
lang_token = model_config["all_language_tokens"][
|
| 107 |
+
model_config["all_language_codes"].index(self.language)
|
| 108 |
+
]
|
| 109 |
+
task_token = (
|
| 110 |
+
config.transcribe if self.task == "transcribe" else config.translate
|
| 111 |
+
)
|
| 112 |
+
config.sot_sequence = np.array(
|
| 113 |
+
[config.sot, lang_token, task_token, config.no_timestamps], dtype=np.int32
|
| 114 |
+
)
|
| 115 |
|
| 116 |
return config
|
| 117 |
+
|
| 118 |
+
def load_tokens(self, filename):
|
| 119 |
+
tokens = dict()
|
| 120 |
+
with open(filename, "r") as f:
|
| 121 |
+
for line in f:
|
| 122 |
+
t, i = line.split()
|
| 123 |
+
tokens[int(i)] = t
|
| 124 |
+
return tokens
|
| 125 |
|
| 126 |
def load_audio(self, audio: str):
|
| 127 |
data, sample_rate = librosa.load(audio, sr=self.config.sample_rate)
|
| 128 |
samples = np.ascontiguousarray(data)
|
| 129 |
return samples, sample_rate
|
| 130 |
|
| 131 |
+
def compute_feature(self, audio: np.ndarray):
|
| 132 |
+
mel = librosa.feature.melspectrogram(
|
| 133 |
+
y=audio,
|
| 134 |
+
sr=self.config.sample_rate,
|
| 135 |
+
n_fft=self.config.n_fft,
|
| 136 |
+
hop_length=self.config.hop_length,
|
| 137 |
+
window="hann",
|
| 138 |
+
center=True,
|
| 139 |
+
pad_mode="reflect",
|
| 140 |
+
power=2.0,
|
| 141 |
+
n_mels=self.config.n_mels,
|
| 142 |
+
)
|
| 143 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 144 |
log_spec = np.log10(np.maximum(mel, 1e-10))
|
| 145 |
log_spec = np.maximum(log_spec, log_spec.max() - 8.0)
|
| 146 |
mel = (log_spec + 4.0) / 4.0
|
|
|
|
| 148 |
target = 3000
|
| 149 |
if mel.shape[1] > target:
|
| 150 |
# -50 so that there are some zero tail paddings.
|
| 151 |
+
mel = mel[:, :target]
|
| 152 |
mel[:, -50:] = 0
|
| 153 |
|
| 154 |
# We don't need to pad it to 30 seconds now!
|
| 155 |
if mel.shape[1] < target:
|
| 156 |
+
mel = np.concatenate(
|
| 157 |
+
(
|
| 158 |
+
mel,
|
| 159 |
+
np.zeros(
|
| 160 |
+
(self.config.n_mels, target - mel.shape[1]), dtype=np.float32
|
| 161 |
+
),
|
| 162 |
+
),
|
| 163 |
+
axis=-1,
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
return mel[np.newaxis, ...]
|
| 167 |
+
|
| 168 |
+
def run_encoder(
|
| 169 |
+
self,
|
| 170 |
+
mel: np.ndarray,
|
| 171 |
+
) -> List[np.ndarray]:
|
| 172 |
+
cross_kv = self.encoder.run(
|
| 173 |
+
None,
|
| 174 |
+
{
|
| 175 |
+
self.encoder.get_inputs()[0].name: mel,
|
| 176 |
+
},
|
| 177 |
+
)
|
| 178 |
+
return cross_kv
|
| 179 |
|
| 180 |
+
def run_decoder(self, inputs: List[np.ndarray]) -> List[np.ndarray]:
|
| 181 |
+
feed = {
|
| 182 |
+
self.decoder.get_inputs()[i].name: inputs[i] for i in range(len(inputs))
|
| 183 |
+
}
|
|
|
|
| 184 |
|
| 185 |
+
out = self.decoder.run(
|
| 186 |
+
None,
|
| 187 |
+
feed,
|
| 188 |
+
)
|
| 189 |
+
return out
|
| 190 |
+
|
| 191 |
+
def get_self_cache(self) -> List[np.ndarray]:
|
| 192 |
+
self_cache = []
|
| 193 |
+
batch_size = 1
|
| 194 |
+
for i in range(self.config.n_text_layer):
|
| 195 |
+
k = np.zeros(
|
| 196 |
+
(batch_size, self.config.n_text_ctx, self.config.n_text_state),
|
| 197 |
+
dtype=np.float32,
|
| 198 |
+
)
|
| 199 |
+
v = np.zeros(
|
| 200 |
+
(batch_size, self.config.n_text_ctx, self.config.n_text_state),
|
| 201 |
+
dtype=np.float32,
|
| 202 |
+
)
|
| 203 |
+
self_cache.extend([k, v])
|
| 204 |
+
return self_cache
|
| 205 |
+
|
| 206 |
+
def causal_mask_1d(self, n: int, L: int):
|
| 207 |
+
"""
|
| 208 |
+
Returns a 1-D int mask of shape (L,) with:
|
| 209 |
+
0 -> allowed
|
| 210 |
+
1 -> masked (will be converted to -inf later)
|
| 211 |
+
"""
|
| 212 |
+
mask = np.ones((L,), dtype=np.int32)
|
| 213 |
+
if n > 0:
|
| 214 |
+
mask[:n] = 0
|
| 215 |
+
return mask
|
| 216 |
|
| 217 |
def run(self, audio: Union[str, np.ndarray]) -> str:
|
| 218 |
if isinstance(audio, str):
|
|
|
|
| 220 |
|
| 221 |
mel = self.compute_feature(audio)
|
| 222 |
|
| 223 |
+
cross_kv = self.run_encoder(mel)
|
| 224 |
+
|
| 225 |
+
self_kv = self.get_self_cache()
|
| 226 |
+
|
| 227 |
+
offset = np.array([0], dtype=np.int32)
|
| 228 |
+
for t in self.config.sot_sequence:
|
| 229 |
+
token = np.array([[t]], dtype=np.int32) # sot
|
| 230 |
+
mask = self.causal_mask_1d(offset.item(), self.config.n_text_ctx)
|
| 231 |
+
|
| 232 |
+
out = self.run_decoder([token] + self_kv + cross_kv + [offset, mask])
|
| 233 |
+
|
| 234 |
+
for i in range(1, len(out)):
|
| 235 |
+
self_kv[i - 1][:, offset.item() : offset.item() + 1, :] = out[i]
|
| 236 |
+
|
| 237 |
+
offset += 1
|
| 238 |
+
|
| 239 |
+
idx = out[0][0, 0].argmax()
|
| 240 |
+
|
| 241 |
+
eot = self.config.eot
|
| 242 |
+
|
| 243 |
+
ans = []
|
| 244 |
+
|
| 245 |
+
while idx != eot and offset.item() < 100:
|
| 246 |
+
ans.append(idx)
|
| 247 |
+
token = np.array([[idx]], dtype=np.int32)
|
| 248 |
+
|
| 249 |
+
mask = self.causal_mask_1d(offset.item(), self.config.n_text_ctx)
|
| 250 |
+
|
| 251 |
+
out = self.run_decoder([token] + self_kv + cross_kv + [offset, mask])
|
| 252 |
+
|
| 253 |
+
for i in range(1, len(out)):
|
| 254 |
+
self_kv[i - 1][:, offset.item() : offset.item() + 1, :] = out[i]
|
| 255 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 256 |
offset += 1
|
| 257 |
+
idx = out[0][0, 0].argmax()
|
| 258 |
+
|
| 259 |
+
# print(ans)
|
| 260 |
+
|
| 261 |
+
s = b""
|
| 262 |
+
for i in ans:
|
| 263 |
+
if i in self.id2token:
|
| 264 |
+
s += base64.b64decode(self.id2token[i])
|
| 265 |
+
|
| 266 |
+
text = s.decode().strip()
|
| 267 |
|
| 268 |
if self.language == "zh":
|
| 269 |
try:
|
| 270 |
+
sim_zh = zhconv.convert(text, "zh-hans")
|
| 271 |
return sim_zh
|
| 272 |
except:
|
| 273 |
return text
|
|
|
|
|
|
|
| 274 |
|
| 275 |
+
return text
|