inoryQwQ commited on
Commit
ce28028
·
1 Parent(s): 4e849a4

Update models, remove decoder_main

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. models-ax630c/{base-decoder-loop.axmodel → base/base-decoder.axmodel} +2 -2
  2. models-ax630c/{base-encoder.axmodel → base/base-encoder.axmodel} +2 -2
  3. models-ax630c/{base-tokens.txt → base/base-tokens.txt} +0 -0
  4. models-ax630c/base/base_config.json +30 -0
  5. models-ax630c/{base-decoder-main.axmodel → small/small-decoder.axmodel} +2 -2
  6. models-ax630c/small/small-tokens.txt +0 -0
  7. models-ax630c/small/small_config.json +30 -0
  8. models-ax630c/{base-positional_embedding.bin → tiny/tiny-decoder.axmodel} +2 -2
  9. models-ax650/small/small-positional_embedding.bin → models-ax630c/tiny/tiny-encoder.axmodel +2 -2
  10. models-ax630c/tiny/tiny-tokens.txt +0 -0
  11. models-ax630c/tiny/tiny_config.json +30 -0
  12. models-ax650/base/base-decoder-loop.axmodel +0 -3
  13. models-ax650/base/base-decoder-main.axmodel +0 -3
  14. models-ax650/base/base-decoder.axmodel +3 -0
  15. models-ax650/base/base-encoder.axmodel +2 -2
  16. models-ax650/small/small-decoder-loop.axmodel +0 -3
  17. models-ax650/small/small-decoder-main.axmodel +0 -3
  18. models-ax650/small/small-decoder.axmodel +3 -0
  19. models-ax650/small/small-encoder.axmodel +2 -2
  20. models-ax650/tiny/tiny-decoder-loop.axmodel +0 -3
  21. models-ax650/tiny/tiny-decoder-main.axmodel +0 -3
  22. models-ax650/tiny/tiny-decoder.axmodel +3 -0
  23. models-ax650/tiny/tiny-encoder.axmodel +2 -2
  24. models-ax650/tiny/tiny-positional_embedding.bin +0 -3
  25. models-ax650/turbo/turbo-decoder-loop.axmodel +0 -3
  26. models-ax650/turbo/turbo-decoder-main.axmodel +0 -3
  27. models-ax650/turbo/turbo-decoder.axmodel +3 -0
  28. models-ax650/turbo/turbo-encoder.axmodel +2 -2
  29. models-ax650/turbo/turbo-positional_embedding.bin +0 -3
  30. models-onnx/base/base-decoder-loop.onnx +0 -3
  31. models-onnx/base/base-decoder-main.onnx +0 -3
  32. models-onnx/base/base-decoder.onnx +3 -0
  33. models-onnx/base/base-encoder.onnx +2 -2
  34. models-onnx/base/base-positional_embedding.bin +0 -3
  35. models-onnx/small/small-decoder.onnx +3 -0
  36. models-onnx/small/small-encoder.onnx +3 -0
  37. models-onnx/small/small-positional_embedding.bin +0 -3
  38. models-onnx/tiny/tiny-decoder-loop.onnx +0 -3
  39. models-onnx/tiny/tiny-decoder-main.onnx +0 -3
  40. models-onnx/tiny/tiny-decoder.onnx +3 -0
  41. models-onnx/tiny/tiny-encoder.onnx +2 -2
  42. models-onnx/tiny/tiny-positional_embedding.bin +0 -3
  43. models-onnx/turbo/turbo-decoder.onnx +3 -0
  44. models-ax650/base/base-positional_embedding.bin → models-onnx/turbo/turbo-encoder.onnx +2 -2
  45. models-onnx/turbo/turbo-tokens.txt +0 -0
  46. python/assets/multilingual.tiktoken +0 -0
  47. python/languages.py +1 -1
  48. python/main.py +39 -19
  49. python/test_wer.py +96 -61
  50. python/whisper.py +179 -128
models-ax630c/{base-decoder-loop.axmodel → base/base-decoder.axmodel} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:4b12160aaa1ca31248a32ce05713fd72e273b16444389853c1f52990cf5130eb
3
- size 130364397
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:51e595565f0121eb4dc9ee14172cb8a111a56bf280a927660fdee5fbffa9d52e
3
+ size 184323085
models-ax630c/{base-encoder.axmodel → base/base-encoder.axmodel} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:b9f89ed5bbe31bcf98aa0e479ced1699b39816db2d3e2e2ff84c6e887af2b79b
3
- size 56024079
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:00ac3d4d0aa81f3910d4aa9c777e81fbf3b4bc22f26a9d9ac38f236392261603
3
+ size 56706622
models-ax630c/{base-tokens.txt → base/base-tokens.txt} RENAMED
File without changes
models-ax630c/base/base_config.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "whisper-base",
3
+ "version": "1",
4
+ "maintainer": "k2-fsa",
5
+ "n_mels": 80,
6
+ "n_audio_ctx": 1500,
7
+ "n_audio_state": 512,
8
+ "n_audio_head": 8,
9
+ "n_audio_layer": 6,
10
+ "n_vocab": 51865,
11
+ "n_text_ctx": 448,
12
+ "n_text_state": 512,
13
+ "n_text_head": 8,
14
+ "n_text_layer": 6,
15
+ "sot_sequence": "50258,50259,50359",
16
+ "all_language_tokens": "50346,50356,50292,50319,50325,50330,50327,50328,50341,50331,50307,50353,50279,50320,50322,50326,50340,50267,50269,50297,50301,50316,50323,50287,50276,50342,50304,50277,50311,50350,50286,50278,50290,50309,50284,50268,50300,50321,50272,50291,50281,50266,50357,50333,50293,50299,50294,50337,50271,50263,50296,50264,50343,50315,50314,50270,50352,50317,50349,50348,50283,50265,50308,50305,50336,50261,50335,50262,50345,50344,50351,50310,50329,50332,50289,50274,50302,50259,50324,50282,50285,50313,50280,50334,50260,50303,50312,50318,50295,50273,50338,50298,50347,50288,50354,50355,50275,50306,50339",
17
+ "all_language_codes": "my,jw,bg,gl,yo,be,af,oc,tk,tg,et,ln,he,mr,si,so,ps,pt,pl,cy,lv,kk,km,ta,hi,nn,az,fi,is,as,hu,vi,ur,br,ro,tr,fa,pa,ar,hr,el,ja,su,gu,lt,te,la,uz,nl,ru,ml,ko,mt,bs,mn,ca,haw,sq,mg,tl,cs,fr,mk,sl,lo,de,yi,es,lb,sa,tt,eu,ka,sd,th,it,bn,en,sn,ms,da,ne,uk,am,zh,sr,hy,sw,mi,sv,fo,sk,bo,no,ha,ba,id,kn,ht",
18
+ "sot": 50258,
19
+ "sot_index": 0,
20
+ "eot": 50257,
21
+ "blank_id": 220,
22
+ "is_multilingual": 1,
23
+ "no_speech": 50362,
24
+ "non_speech_tokens": "1,2,7,8,9,10,14,25,26,27,28,29,31,58,59,60,61,62,63,90,91,92,93,359,503,522,542,873,893,902,918,922,931,1350,1853,1982,2460,2627,3246,3253,3268,3536,3846,3961,4183,4667,6585,6647,7273,9061,9383,10428,10929,11938,12033,12331,12562,13793,14157,14635,15265,15618,16553,16604,18362,18956,20075,21675,22520,26130,26161,26435,28279,29464,31650,32302,32470,36865,42863,47425,49870,50254",
25
+ "transcribe": 50359,
26
+ "translate": 50358,
27
+ "sot_prev": 50361,
28
+ "sot_lm": 50360,
29
+ "no_timestamps": 50363
30
+ }
models-ax630c/{base-decoder-main.axmodel → small/small-decoder.axmodel} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:940f273d111e3aee53cdb692a384a29556981aa146afbb2f558f6aac262c0621
3
- size 135675471
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9d13b4c3693f72e280162a1fd78c52ffc8ecc9318d7d2bd56db9810945c88f1b
3
+ size 345275786
models-ax630c/small/small-tokens.txt ADDED
The diff for this file is too large to render. See raw diff
 
models-ax630c/small/small_config.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "whisper-small",
3
+ "version": "1",
4
+ "maintainer": "k2-fsa",
5
+ "n_mels": 80,
6
+ "n_audio_ctx": 1500,
7
+ "n_audio_state": 768,
8
+ "n_audio_head": 12,
9
+ "n_audio_layer": 12,
10
+ "n_vocab": 51865,
11
+ "n_text_ctx": 448,
12
+ "n_text_state": 768,
13
+ "n_text_head": 12,
14
+ "n_text_layer": 12,
15
+ "sot_sequence": "50258,50259,50359",
16
+ "all_language_tokens": "50349,50276,50334,50346,50335,50313,50353,50284,50280,50268,50266,50267,50330,50350,50259,50311,50336,50340,50288,50302,50291,50279,50299,50351,50270,50301,50314,50295,50331,50285,50303,50326,50342,50355,50348,50341,50354,50321,50272,50269,50357,50333,50283,50309,50271,50324,50323,50290,50327,50298,50319,50356,50282,50332,50275,50263,50294,50305,50293,50317,50338,50287,50292,50316,50343,50289,50260,50328,50312,50344,50325,50304,50339,50320,50308,50274,50262,50345,50278,50296,50337,50310,50329,50318,50347,50265,50307,50264,50352,50315,50273,50277,50300,50261,50286,50306,50322,50281,50297",
17
+ "all_language_codes": "mg,hi,am,my,yi,ne,ln,ro,uk,tr,ja,pt,be,as,en,is,lo,ps,no,bn,hr,he,te,tt,ca,lv,mn,mi,tg,da,sr,so,nn,ba,tl,tk,ha,pa,ar,pl,su,gu,cs,br,nl,sn,km,ur,af,sk,gl,jw,ms,sd,id,ru,la,sl,lt,sq,fo,ta,bg,kk,mt,th,zh,oc,hy,sa,yo,az,ht,mr,mk,it,es,lb,vi,ml,uz,eu,ka,sw,bo,fr,et,ko,haw,bs,sv,fi,fa,de,hu,kn,si,el,cy",
18
+ "sot": 50258,
19
+ "sot_index": 0,
20
+ "eot": 50257,
21
+ "blank_id": 220,
22
+ "is_multilingual": 1,
23
+ "no_speech": 50362,
24
+ "non_speech_tokens": "1,2,7,8,9,10,14,25,26,27,28,29,31,58,59,60,61,62,63,90,91,92,93,359,503,522,542,873,893,902,918,922,931,1350,1853,1982,2460,2627,3246,3253,3268,3536,3846,3961,4183,4667,6585,6647,7273,9061,9383,10428,10929,11938,12033,12331,12562,13793,14157,14635,15265,15618,16553,16604,18362,18956,20075,21675,22520,26130,26161,26435,28279,29464,31650,32302,32470,36865,42863,47425,49870,50254",
25
+ "transcribe": 50359,
26
+ "translate": 50358,
27
+ "sot_prev": 50361,
28
+ "sot_lm": 50360,
29
+ "no_timestamps": 50363
30
+ }
models-ax630c/{base-positional_embedding.bin → tiny/tiny-decoder.axmodel} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:88fa1cdbf2b06f86b0ecb7be0fccfc39e906502986572b8cf5319c250e857169
3
- size 917504
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:64a611a2575597fed3e705d9faf941df0b33d58bc10a733fa31f5e937fd58ec4
3
+ size 129647157
models-ax650/small/small-positional_embedding.bin → models-ax630c/tiny/tiny-encoder.axmodel RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:c10bc44f2bd94bdf1b7aa03581309fa536132b3fe79bfe22c9a6934a42cd8b58
3
- size 1376256
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0528e2d7e317668e43a5641695f93d0c80d9902e28f0e9f2fdef76470855efe7
3
+ size 26853722
models-ax630c/tiny/tiny-tokens.txt ADDED
The diff for this file is too large to render. See raw diff
 
models-ax630c/tiny/tiny_config.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "model_type": "whisper-tiny",
3
+ "version": "1",
4
+ "maintainer": "k2-fsa",
5
+ "n_mels": 80,
6
+ "n_audio_ctx": 1500,
7
+ "n_audio_state": 384,
8
+ "n_audio_head": 6,
9
+ "n_audio_layer": 4,
10
+ "n_vocab": 51865,
11
+ "n_text_ctx": 448,
12
+ "n_text_state": 384,
13
+ "n_text_head": 6,
14
+ "n_text_layer": 4,
15
+ "sot_sequence": "50258,50259,50359",
16
+ "all_language_tokens": "50286,50307,50345,50299,50290,50265,50275,50294,50309,50262,50318,50331,50282,50349,50270,50326,50279,50272,50355,50287,50269,50315,50278,50283,50266,50260,50297,50348,50346,50296,50277,50334,50342,50313,50288,50322,50325,50259,50302,50332,50338,50344,50261,50330,50304,50357,50314,50340,50291,50352,50320,50271,50316,50336,50323,50293,50263,50308,50284,50273,50267,50312,50321,50328,50285,50298,50301,50327,50354,50303,50356,50351,50295,50339,50292,50319,50264,50310,50276,50335,50311,50341,50350,50268,50289,50281,50324,50333,50317,50343,50305,50274,50306,50353,50300,50347,50329,50337,50280",
17
+ "all_language_codes": "hu,et,lb,te,ur,fr,id,la,br,es,sw,tg,ms,mg,ca,so,he,ar,ba,ta,pl,bs,vi,cs,ja,zh,cy,tl,my,ml,fi,am,nn,ne,no,si,yo,en,bn,sd,fo,sa,de,be,az,su,mn,ps,hr,haw,mr,nl,kk,lo,km,lt,ru,mk,ro,sv,pt,hy,pa,oc,da,sk,lv,af,ha,sr,jw,tt,mi,ht,bg,gl,ko,eu,hi,yi,is,tk,as,tr,th,el,sn,gu,sq,mt,sl,it,kn,ln,fa,bo,ka,uz,uk",
18
+ "sot": 50258,
19
+ "sot_index": 0,
20
+ "eot": 50257,
21
+ "blank_id": 220,
22
+ "is_multilingual": 1,
23
+ "no_speech": 50362,
24
+ "non_speech_tokens": "1,2,7,8,9,10,14,25,26,27,28,29,31,58,59,60,61,62,63,90,91,92,93,359,503,522,542,873,893,902,918,922,931,1350,1853,1982,2460,2627,3246,3253,3268,3536,3846,3961,4183,4667,6585,6647,7273,9061,9383,10428,10929,11938,12033,12331,12562,13793,14157,14635,15265,15618,16553,16604,18362,18956,20075,21675,22520,26130,26161,26435,28279,29464,31650,32302,32470,36865,42863,47425,49870,50254",
25
+ "transcribe": 50359,
26
+ "translate": 50358,
27
+ "sot_prev": 50361,
28
+ "sot_lm": 50360,
29
+ "no_timestamps": 50363
30
+ }
models-ax650/base/base-decoder-loop.axmodel DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:43fbc9a1672eabd705bb68fdfd6b0837c4d3bceec5e07c80cc829cf47417e11d
3
- size 183531172
 
 
 
 
models-ax650/base/base-decoder-main.axmodel DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:7d9167ae1c52ed1cb318fb27e45734c2ceb0560444078693bf831ce02f2c0331
3
- size 183985586
 
 
 
 
models-ax650/base/base-decoder.axmodel ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a5212e05f8c0d2d4b2a319eaf87eb93253a9bb476dee8fa3d8a85f3137b61045
3
+ size 184444593
models-ax650/base/base-encoder.axmodel CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:76899fc4d232fd6f458d3597e8a67a04719715e971bc82e286679014a929f5b6
3
- size 33082024
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c07d03194566f292d26cf1f0e104d5740e1a249a0dc92b23c5b15b9e96496c24
3
+ size 33132600
models-ax650/small/small-decoder-loop.axmodel DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:b472a0f3539d17fece09e92bf6cd69ebf391928a6050896bbf86b558a25def22
3
- size 269002567
 
 
 
 
models-ax650/small/small-decoder-main.axmodel DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:f3bfc577f60c35192d8ce8cc24f9ca4aa84af72756ba11af9d178d337cb7eb1c
3
- size 285531695
 
 
 
 
models-ax650/small/small-decoder.axmodel ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d4ed1e62122ed495624efa0275a0d4cb2450996b6f1a5e1d9ea9d48026d1bb66
3
+ size 350609498
models-ax650/small/small-encoder.axmodel CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:7b3bc8db9762f9b2dfe78bffbc8070fb877b2572c5288253573e49a8c7b37948
3
- size 139705612
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6e4e8462e3b6ac3ea9465d61560bc42e1398a06ea640d9ff5cdb82636ab73d47
3
+ size 136275980
models-ax650/tiny/tiny-decoder-loop.axmodel DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:139aa1429a6f439b7a5d1e5f481cb761c673afd4718a25968ad979fccfdfecaf
3
- size 128541899
 
 
 
 
models-ax650/tiny/tiny-decoder-main.axmodel DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:8c1d08ec9309a26103a955cb432e0ddc25476da706a6a0a94108d225a48385aa
3
- size 128909975
 
 
 
 
models-ax650/tiny/tiny-decoder.axmodel ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ecd77406ae1c04883beb7517f3e9f0ca0c4b35e640467997f126efb539b96306
3
+ size 129267343
models-ax650/tiny/tiny-encoder.axmodel CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:1e38085752ab08b9eec10b89cb1374e1a57ede680cc0340292e8d7261399acae
3
- size 14102295
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:698670eb7410eeb2a84e26cee8918a780c97d1232a9dc7f946f61949105299a9
3
+ size 14085412
models-ax650/tiny/tiny-positional_embedding.bin DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:c13450ae630323a0bdd39b1226f92a7ac251131a909c7efdb7d2f5516736eb83
3
- size 688128
 
 
 
 
models-ax650/turbo/turbo-decoder-loop.axmodel DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:1b50daea5f4776006bf0e635ca5224f118649a9b3c37b2b821ae4c321db096ec
3
- size 499257709
 
 
 
 
models-ax650/turbo/turbo-decoder-main.axmodel DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:56d34930691311002e401d8243abbae632593b57406e8ae780cd02c9076a783d
3
- size 500341239
 
 
 
 
models-ax650/turbo/turbo-decoder.axmodel ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3d5a52c2c1f2c08fdb3b664201a79b42e3439184ffcc5f19e151eeb4a88cadd8
3
+ size 501696186
models-ax650/turbo/turbo-encoder.axmodel CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:d8962c9189a43f6fc0ea080df1bdcc43ca916b9337d16d0c6a7b30894c14e3ee
3
- size 892467653
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b2cbdf3941b8d739318148505c380167a41e7d710d1a3cfda1a257c2c7fe428f
3
+ size 893420089
models-ax650/turbo/turbo-positional_embedding.bin DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:a94790ac6719da6e134835255274ee7fc6066ad5e0f08a0f747c1e1cf6407dc3
3
- size 2293760
 
 
 
 
models-onnx/base/base-decoder-loop.onnx DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:1616a829b7d3d643616633551204b8d0f008fb7a7dc38919eda2e8c6c6ed9714
3
- size 194571088
 
 
 
 
models-onnx/base/base-decoder-main.onnx DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:1096b83590016bdbe74c66c7ccad1c0120abd6d37214560b1dfe4cd886a0e683
3
- size 205485892
 
 
 
 
models-onnx/base/base-decoder.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a2ba37d199ec7c1794facba8efc84a484b0224f6650e362fda7d2db75827023a
3
+ size 195497242
models-onnx/base/base-encoder.onnx CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:dd4b51bd569e9b2b2d83a8ed56f3618811f0c593aa95c010069df675027b5f2b
3
- size 95026988
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1c75e7cca22000432ec4f8f50726299aa20db34a9c154646e8ac10c0ddb4699b
3
+ size 95025778
models-onnx/base/base-positional_embedding.bin DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:88fa1cdbf2b06f86b0ecb7be0fccfc39e906502986572b8cf5319c250e857169
3
- size 917504
 
 
 
 
models-onnx/small/small-decoder.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f696ce159cffc13416ef624b91dae341de9d3f7e720cfc832000fe987f6d50b4
3
+ size 557821592
models-onnx/small/small-encoder.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f7dce3226784147e42021007dede9c2c144d33487711a1950b5935b9c4829f1b
3
+ size 409408370
models-onnx/small/small-positional_embedding.bin DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:c10bc44f2bd94bdf1b7aa03581309fa536132b3fe79bfe22c9a6934a42cd8b58
3
- size 1376256
 
 
 
 
models-onnx/tiny/tiny-decoder-loop.onnx DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:5cbb3533939e2dfdf567b27762b12cf0956b7d7982bfb915228d24789f483058
3
- size 112843354
 
 
 
 
models-onnx/tiny/tiny-decoder-main.onnx DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:59ced1cf4e9a6f2aef0a2457f64f846e5682033abb4b894ba7680a60c792ad73
3
- size 118301861
 
 
 
 
models-onnx/tiny/tiny-decoder.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0d1a16a5a70e9a5940b68559eb5142c7c53b8f595f2f5e4a01d7e168acca6eb1
3
+ size 113537271
models-onnx/tiny/tiny-encoder.onnx CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:a8030a6d1f3615b8a5e000995fee88357768c7dbaad05a79f853a4040c97087b
3
- size 37606186
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:324e93c5ddc2e922273ebeb16bba8c453aa009c698e2262103ac5ce21df1c3ed
3
+ size 37605342
models-onnx/tiny/tiny-positional_embedding.bin DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:c13450ae630323a0bdd39b1226f92a7ac251131a909c7efdb7d2f5516736eb83
3
- size 688128
 
 
 
 
models-onnx/turbo/turbo-decoder.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:75b7d28ee244aa1d39d02d98514cd951e7697f6fb742c7f84f808b8aab9b1d2a
3
+ size 635240242
models-ax650/base/base-positional_embedding.bin → models-onnx/turbo/turbo-encoder.onnx RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:88fa1cdbf2b06f86b0ecb7be0fccfc39e906502986572b8cf5319c250e857169
3
- size 917504
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:883b1a459251b53f7431bdb81d2a573d0965c42be73179cbdeb0041154fb0a7d
3
+ size 389433
models-onnx/turbo/turbo-tokens.txt ADDED
The diff for this file is too large to render. See raw diff
 
python/assets/multilingual.tiktoken ADDED
The diff for this file is too large to render. See raw diff
 
python/languages.py CHANGED
@@ -99,4 +99,4 @@ WHISPER_LANGUAGES = {
99
  "jw": "javanese",
100
  "su": "sundanese",
101
  "yue": "cantonese",
102
- }
 
99
  "jw": "javanese",
100
  "su": "sundanese",
101
  "yue": "cantonese",
102
+ }
python/main.py CHANGED
@@ -6,29 +6,49 @@ import time
6
 
7
  def get_args():
8
  parser = argparse.ArgumentParser(
9
- prog="whisper",
10
- description="Run Whisper on input audio file"
11
  )
12
  parser.add_argument("--wav", "-w", type=str, required=True, help="Input audio file")
13
- parser.add_argument("--model_type", "-t", type=str, choices=["tiny", "base", "small", "large", "large-v3", "turbo"], required=True, help="model type, only support tiny, base and small currently")
14
- parser.add_argument("--model_path", "-p", type=str, required=False, default="../models/models-ax650", help="model path for *.axmodel, tokens.txt, positional_embedding.bin")
15
- parser.add_argument("--language", "-l", type=str, required=False, default="zh", help="Target language, support en, zh, ja, and others. See languages.py for more options.")
16
- parser.add_argument("--task", type=str, required=False, choices=["translate", "transcribe"], default="transcribe")
17
- parser.add_argument("--print_rtf", action="store_true", help="Print Real-Time Factor")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  return parser.parse_args()
19
 
20
 
21
- def print_args(args):
22
- print(f"wav: {args.wav}")
23
- print(f"model_type: {args.model_type}")
24
- print(f"model_path: {args.model_path}")
25
- print(f"language: {args.language}")
26
- print(f"task: {args.task}")
27
-
28
-
29
  def main():
30
  args = get_args()
31
- print_args(args)
32
 
33
  # Check wav existence
34
  wav_path = args.wav
@@ -36,19 +56,19 @@ def main():
36
 
37
  model = Whisper(args.model_type, args.model_path, args.language, args.task)
38
 
39
-
40
-
41
- print("\n预测结果:")
42
  start = time.time()
43
  print(model.run(wav_path))
44
  end = time.time()
45
 
46
  if args.print_rtf:
47
  import librosa
 
48
  samples, sr = librosa.load(wav_path, sr=16000)
49
  duration = len(samples) / sr
50
  process_time = end - start
51
  print(f"RTF: {process_time / duration}")
52
 
 
53
  if __name__ == "__main__":
54
  main()
 
6
 
7
  def get_args():
8
  parser = argparse.ArgumentParser(
9
+ prog="whisper", description="Run Whisper on input audio file"
 
10
  )
11
  parser.add_argument("--wav", "-w", type=str, required=True, help="Input audio file")
12
+ parser.add_argument(
13
+ "--model_type",
14
+ "-t",
15
+ type=str,
16
+ choices=["tiny", "base", "small", "large", "large-v3", "turbo"],
17
+ required=True,
18
+ help="model type, only support tiny, base and small currently",
19
+ )
20
+ parser.add_argument(
21
+ "--model_path",
22
+ "-p",
23
+ type=str,
24
+ required=False,
25
+ default="../models-ax650",
26
+ help="model path for *.axmodel, tokens.txt, positional_embedding.bin",
27
+ )
28
+ parser.add_argument(
29
+ "--language",
30
+ "-l",
31
+ type=str,
32
+ required=False,
33
+ default="zh",
34
+ help="Target language, support en, zh, ja, and others. See languages.py for more options.",
35
+ )
36
+ parser.add_argument(
37
+ "--task",
38
+ type=str,
39
+ required=False,
40
+ choices=["translate", "transcribe"],
41
+ default="transcribe",
42
+ )
43
+ parser.add_argument(
44
+ "--print_rtf", action="store_true", help="Print Real-Time Factor"
45
+ )
46
  return parser.parse_args()
47
 
48
 
 
 
 
 
 
 
 
 
49
  def main():
50
  args = get_args()
51
+ print(vars(args))
52
 
53
  # Check wav existence
54
  wav_path = args.wav
 
56
 
57
  model = Whisper(args.model_type, args.model_path, args.language, args.task)
58
 
59
+ print("ASR result:")
 
 
60
  start = time.time()
61
  print(model.run(wav_path))
62
  end = time.time()
63
 
64
  if args.print_rtf:
65
  import librosa
66
+
67
  samples, sr = librosa.load(wav_path, sr=16000)
68
  duration = len(samples) / sr
69
  process_time = end - start
70
  print(f"RTF: {process_time / duration}")
71
 
72
+
73
  if __name__ == "__main__":
74
  main()
python/test_wer.py CHANGED
@@ -10,35 +10,35 @@ def setup_logging():
10
  # 获取脚本所在目录
11
  script_dir = os.path.dirname(os.path.abspath(__file__))
12
  log_file = os.path.join(script_dir, "test_wer.log")
13
-
14
  # 配置日志格式
15
- log_format = '%(asctime)s - %(levelname)s - %(message)s'
16
- date_format = '%Y-%m-%d %H:%M:%S'
17
-
18
  # 创建logger
19
  logger = logging.getLogger()
20
  logger.setLevel(logging.INFO)
21
-
22
  # 清除现有的handler
23
  for handler in logger.handlers[:]:
24
  logger.removeHandler(handler)
25
-
26
  # 创建文件handler
27
- file_handler = logging.FileHandler(log_file, mode='a', encoding='utf-8')
28
  file_handler.setLevel(logging.INFO)
29
  file_formatter = logging.Formatter(log_format, date_format)
30
  file_handler.setFormatter(file_formatter)
31
-
32
  # 创建控制台handler
33
  console_handler = logging.StreamHandler()
34
  console_handler.setLevel(logging.INFO)
35
  console_formatter = logging.Formatter(log_format, date_format)
36
  console_handler.setFormatter(console_formatter)
37
-
38
  # 添加handler到logger
39
  logger.addHandler(file_handler)
40
  logger.addHandler(console_handler)
41
-
42
  return logger
43
 
44
 
@@ -46,21 +46,21 @@ class AIShellDataset:
46
  def __init__(self, gt_path: str):
47
  """
48
  初始化数据集
49
-
50
  Args:
51
  json_path: voice.json文件的路径
52
  """
53
  self.gt_path = gt_path
54
  self.dataset_dir = os.path.dirname(gt_path)
55
  self.voice_dir = os.path.join(self.dataset_dir, "aishell_S0764")
56
-
57
  # 检查必要文件和文件夹是否存在
58
  assert os.path.exists(gt_path), f"gt文件不存在: {gt_path}"
59
  assert os.path.exists(self.voice_dir), f"aishell_S0764文件夹不存在: {self.voice_dir}"
60
-
61
  # 加载数据
62
  self.data = []
63
- with open(gt_path, 'r', encoding='utf-8') as f:
64
  for line in f:
65
  line = line.strip()
66
  audio_path, gt = line.split(" ")
@@ -70,50 +70,50 @@ class AIShellDataset:
70
  # 使用logging而不是print
71
  logger = logging.getLogger()
72
  logger.info(f"加载了 {len(self.data)} 条数据")
73
-
74
  def __iter__(self):
75
  """返回迭代器"""
76
  self.index = 0
77
  return self
78
-
79
  def __next__(self):
80
  """返回下一个数据项"""
81
  if self.index >= len(self.data):
82
  raise StopIteration
83
-
84
  item = self.data[self.index]
85
  audio_path = item["audio_path"]
86
  ground_truth = item["gt"]
87
-
88
  self.index += 1
89
  return audio_path, ground_truth
90
-
91
  def __len__(self):
92
  """返回数据集大小"""
93
  return len(self.data)
94
-
95
 
96
  class CommonVoiceDataset:
97
  """Common Voice数据集解析器"""
98
-
99
  def __init__(self, tsv_path: str):
100
  """
101
  初始化数据集
102
-
103
  Args:
104
  json_path: voice.json文件的路径
105
  """
106
  self.tsv_path = tsv_path
107
  self.dataset_dir = os.path.dirname(tsv_path)
108
  self.voice_dir = os.path.join(self.dataset_dir, "clips")
109
-
110
  # 检查必要文件和文件夹是否存在
111
  assert os.path.exists(tsv_path), f"{tsv_path}文件不存在: {tsv_path}"
112
  assert os.path.exists(self.voice_dir), f"voice文件夹不存在: {self.voice_dir}"
113
-
114
  # 加载JSON数据
115
  self.data = []
116
- with open(tsv_path, 'r', encoding='utf-8') as f:
117
  f.readline()
118
  for line in f:
119
  line = line.strip()
@@ -122,43 +122,77 @@ class CommonVoiceDataset:
122
  gt = splits[2]
123
  audio_path = os.path.join(self.voice_dir, audio_path)
124
  self.data.append({"audio_path": audio_path, "gt": gt})
125
-
126
  # 使用logging而不是print
127
  logger = logging.getLogger()
128
  logger.info(f"加载了 {len(self.data)} 条数据")
129
-
130
  def __iter__(self):
131
  """返回迭代器"""
132
  self.index = 0
133
  return self
134
-
135
  def __next__(self):
136
  """返回下一个数据项"""
137
  if self.index >= len(self.data):
138
  raise StopIteration
139
-
140
  item = self.data[self.index]
141
  audio_path = item["audio_path"]
142
  ground_truth = item["gt"]
143
-
144
  self.index += 1
145
  return audio_path, ground_truth
146
-
147
  def __len__(self):
148
  """返回数据集大小"""
149
  return len(self.data)
150
 
 
151
  def get_args():
152
- parser = argparse.ArgumentParser(
153
- prog="whisper",
154
- description="Test WER on dataset"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
155
  )
156
- parser.add_argument("--dataset", "-d", type=str, required=True, choices=["aishell", "common_voice"], help="Test dataset")
157
- parser.add_argument("--gt_path", "-g", type=str, required=True, help="Test dataset ground truth file")
158
- parser.add_argument("--max_num", type=int, default=-1, required=False, help="Maximum test data num")
159
- parser.add_argument("--model_type", "-t", type=str, choices=["tiny", "base", "small", "large", "large-v3", "turbo"], required=True, help="model type, only support tiny, base and small currently")
160
- parser.add_argument("--model_path", "-p", type=str, required=False, default="../models/models-ax650", help="model path for *.axmodel, tokens.txt, positional_embedding.bin")
161
- parser.add_argument("--language", "-l", type=str, required=False, default="zh", help="Target language, support en, zh, ja, and others. See languages.py for more options.")
162
  return parser.parse_args()
163
 
164
 
@@ -173,42 +207,42 @@ def print_args(args):
173
 
174
 
175
  def min_distance(word1: str, word2: str) -> int:
176
-
177
  row = len(word1) + 1
178
  column = len(word2) + 1
179
-
180
- cache = [ [0]*column for i in range(row) ]
181
-
182
  for i in range(row):
183
  for j in range(column):
184
-
185
- if i ==0 and j ==0:
186
  cache[i][j] = 0
187
- elif i == 0 and j!=0:
188
  cache[i][j] = j
189
- elif j == 0 and i!=0:
190
  cache[i][j] = i
191
  else:
192
- if word1[i-1] == word2[j-1]:
193
- cache[i][j] = cache[i-1][j-1]
194
  else:
195
- replace = cache[i-1][j-1] + 1
196
- insert = cache[i][j-1] + 1
197
- remove = cache[i-1][j] + 1
198
-
199
  cache[i][j] = min(replace, insert, remove)
200
-
201
- return cache[row-1][column-1]
202
 
203
 
204
  def remove_punctuation(text):
205
  # 定义正则表达式模式,匹配所有标点符号
206
  # 这个模式包括常见的标点符号和中文标点
207
- pattern = r'[^\w\s]|_'
208
-
209
  # 使用sub方法将所有匹配的标点符号替换为空字符串
210
- cleaned_text = re.sub(pattern, '', text)
211
-
212
  return cleaned_text
213
 
214
 
@@ -254,7 +288,7 @@ def main():
254
 
255
  hyp.append(hypothesis)
256
  references.append(reference)
257
-
258
  line_content = f"({n+1}/{max_data_num}) {os.path.basename(audio_path)} gt: {reference} predict: {hypothesis} WER: {character_error_rate}%"
259
  wer_file.write(line_content + "\n")
260
  logger.info(line_content)
@@ -268,5 +302,6 @@ def main():
268
  wer_file.write(f"Total WER: {total_character_error_rate}%")
269
  wer_file.close()
270
 
 
271
  if __name__ == "__main__":
272
  main()
 
10
  # 获取脚本所在目录
11
  script_dir = os.path.dirname(os.path.abspath(__file__))
12
  log_file = os.path.join(script_dir, "test_wer.log")
13
+
14
  # 配置日志格式
15
+ log_format = "%(asctime)s - %(levelname)s - %(message)s"
16
+ date_format = "%Y-%m-%d %H:%M:%S"
17
+
18
  # 创建logger
19
  logger = logging.getLogger()
20
  logger.setLevel(logging.INFO)
21
+
22
  # 清除现有的handler
23
  for handler in logger.handlers[:]:
24
  logger.removeHandler(handler)
25
+
26
  # 创建文件handler
27
+ file_handler = logging.FileHandler(log_file, mode="a", encoding="utf-8")
28
  file_handler.setLevel(logging.INFO)
29
  file_formatter = logging.Formatter(log_format, date_format)
30
  file_handler.setFormatter(file_formatter)
31
+
32
  # 创建控制台handler
33
  console_handler = logging.StreamHandler()
34
  console_handler.setLevel(logging.INFO)
35
  console_formatter = logging.Formatter(log_format, date_format)
36
  console_handler.setFormatter(console_formatter)
37
+
38
  # 添加handler到logger
39
  logger.addHandler(file_handler)
40
  logger.addHandler(console_handler)
41
+
42
  return logger
43
 
44
 
 
46
  def __init__(self, gt_path: str):
47
  """
48
  初始化数据集
49
+
50
  Args:
51
  json_path: voice.json文件的路径
52
  """
53
  self.gt_path = gt_path
54
  self.dataset_dir = os.path.dirname(gt_path)
55
  self.voice_dir = os.path.join(self.dataset_dir, "aishell_S0764")
56
+
57
  # 检查必要文件和文件夹是否存在
58
  assert os.path.exists(gt_path), f"gt文件不存在: {gt_path}"
59
  assert os.path.exists(self.voice_dir), f"aishell_S0764文件夹不存在: {self.voice_dir}"
60
+
61
  # 加载数据
62
  self.data = []
63
+ with open(gt_path, "r", encoding="utf-8") as f:
64
  for line in f:
65
  line = line.strip()
66
  audio_path, gt = line.split(" ")
 
70
  # 使用logging而不是print
71
  logger = logging.getLogger()
72
  logger.info(f"加载了 {len(self.data)} 条数据")
73
+
74
  def __iter__(self):
75
  """返回迭代器"""
76
  self.index = 0
77
  return self
78
+
79
  def __next__(self):
80
  """返回下一个数据项"""
81
  if self.index >= len(self.data):
82
  raise StopIteration
83
+
84
  item = self.data[self.index]
85
  audio_path = item["audio_path"]
86
  ground_truth = item["gt"]
87
+
88
  self.index += 1
89
  return audio_path, ground_truth
90
+
91
  def __len__(self):
92
  """返回数据集大小"""
93
  return len(self.data)
94
+
95
 
96
  class CommonVoiceDataset:
97
  """Common Voice数据集解析器"""
98
+
99
  def __init__(self, tsv_path: str):
100
  """
101
  初始化数据集
102
+
103
  Args:
104
  json_path: voice.json文件的路径
105
  """
106
  self.tsv_path = tsv_path
107
  self.dataset_dir = os.path.dirname(tsv_path)
108
  self.voice_dir = os.path.join(self.dataset_dir, "clips")
109
+
110
  # 检查必要文件和文件夹是否存在
111
  assert os.path.exists(tsv_path), f"{tsv_path}文件不存在: {tsv_path}"
112
  assert os.path.exists(self.voice_dir), f"voice文件夹不存在: {self.voice_dir}"
113
+
114
  # 加载JSON数据
115
  self.data = []
116
+ with open(tsv_path, "r", encoding="utf-8") as f:
117
  f.readline()
118
  for line in f:
119
  line = line.strip()
 
122
  gt = splits[2]
123
  audio_path = os.path.join(self.voice_dir, audio_path)
124
  self.data.append({"audio_path": audio_path, "gt": gt})
125
+
126
  # 使用logging而不是print
127
  logger = logging.getLogger()
128
  logger.info(f"加载了 {len(self.data)} 条数据")
129
+
130
  def __iter__(self):
131
  """返回迭代器"""
132
  self.index = 0
133
  return self
134
+
135
  def __next__(self):
136
  """返回下一个数据项"""
137
  if self.index >= len(self.data):
138
  raise StopIteration
139
+
140
  item = self.data[self.index]
141
  audio_path = item["audio_path"]
142
  ground_truth = item["gt"]
143
+
144
  self.index += 1
145
  return audio_path, ground_truth
146
+
147
  def __len__(self):
148
  """返回数据集大小"""
149
  return len(self.data)
150
 
151
+
152
  def get_args():
153
+ parser = argparse.ArgumentParser(prog="whisper", description="Test WER on dataset")
154
+ parser.add_argument(
155
+ "--dataset",
156
+ "-d",
157
+ type=str,
158
+ required=True,
159
+ choices=["aishell", "common_voice"],
160
+ help="Test dataset",
161
+ )
162
+ parser.add_argument(
163
+ "--gt_path",
164
+ "-g",
165
+ type=str,
166
+ required=True,
167
+ help="Test dataset ground truth file",
168
+ )
169
+ parser.add_argument(
170
+ "--max_num", type=int, default=-1, required=False, help="Maximum test data num"
171
+ )
172
+ parser.add_argument(
173
+ "--model_type",
174
+ "-t",
175
+ type=str,
176
+ choices=["tiny", "base", "small", "large", "large-v3", "turbo"],
177
+ required=True,
178
+ help="model type, only support tiny, base and small currently",
179
+ )
180
+ parser.add_argument(
181
+ "--model_path",
182
+ "-p",
183
+ type=str,
184
+ required=False,
185
+ default="../models/models-ax650",
186
+ help="model path for *.axmodel, tokens.txt, positional_embedding.bin",
187
+ )
188
+ parser.add_argument(
189
+ "--language",
190
+ "-l",
191
+ type=str,
192
+ required=False,
193
+ default="zh",
194
+ help="Target language, support en, zh, ja, and others. See languages.py for more options.",
195
  )
 
 
 
 
 
 
196
  return parser.parse_args()
197
 
198
 
 
207
 
208
 
209
  def min_distance(word1: str, word2: str) -> int:
210
+
211
  row = len(word1) + 1
212
  column = len(word2) + 1
213
+
214
+ cache = [[0] * column for i in range(row)]
215
+
216
  for i in range(row):
217
  for j in range(column):
218
+
219
+ if i == 0 and j == 0:
220
  cache[i][j] = 0
221
+ elif i == 0 and j != 0:
222
  cache[i][j] = j
223
+ elif j == 0 and i != 0:
224
  cache[i][j] = i
225
  else:
226
+ if word1[i - 1] == word2[j - 1]:
227
+ cache[i][j] = cache[i - 1][j - 1]
228
  else:
229
+ replace = cache[i - 1][j - 1] + 1
230
+ insert = cache[i][j - 1] + 1
231
+ remove = cache[i - 1][j] + 1
232
+
233
  cache[i][j] = min(replace, insert, remove)
234
+
235
+ return cache[row - 1][column - 1]
236
 
237
 
238
  def remove_punctuation(text):
239
  # 定义正则表达式模式,匹配所有标点符号
240
  # 这个模式包括常见的标点符号和中文标点
241
+ pattern = r"[^\w\s]|_"
242
+
243
  # 使用sub方法将所有匹配的标点符号替换为空字符串
244
+ cleaned_text = re.sub(pattern, "", text)
245
+
246
  return cleaned_text
247
 
248
 
 
288
 
289
  hyp.append(hypothesis)
290
  references.append(reference)
291
+
292
  line_content = f"({n+1}/{max_data_num}) {os.path.basename(audio_path)} gt: {reference} predict: {hypothesis} WER: {character_error_rate}%"
293
  wer_file.write(line_content + "\n")
294
  logger.info(line_content)
 
302
  wer_file.write(f"Total WER: {total_character_error_rate}%")
303
  wer_file.close()
304
 
305
+
306
  if __name__ == "__main__":
307
  main()
python/whisper.py CHANGED
@@ -1,4 +1,4 @@
1
- import axengine as axe
2
  import numpy as np
3
  import librosa
4
  import os
@@ -9,27 +9,27 @@ from dataclasses import dataclass
9
  import zhconv
10
 
11
 
12
- NEG_INF = float("-inf")
13
-
14
  @dataclass
15
  class WhisperConfig:
16
- n_mels : int = 0
17
- sample_rate : int = 0
18
- n_fft : int = 0
19
- hop_length : int = 0
20
-
21
- sot : int = 0
22
- eot : int = 0
23
- blank_id : int = 0
24
- no_timestamps : int = 0
25
- no_speech : int = 0
26
- translate : int = 0
27
- transcribe : int = 0
28
- n_vocab : int = 0
29
- n_text_ctx : int = 0
30
- n_text_state : int = 0
31
-
32
- sot_sequence : np.ndarray = field(default_factory=lambda: np.array([0,0,0,0], dtype=np.int32))
 
 
33
 
34
 
35
  class Whisper:
@@ -38,35 +38,41 @@ class Whisper:
38
 
39
  self.language = language
40
  self.task = task
41
- self.encoder, self.decoder_main, self.decoder_loop, self.pe, self.tokenizer, model_config = \
42
- self.load_model(model_type, model_path, language, task)
 
43
  self.config = self.load_config(model_config)
44
 
45
-
46
  def load_model(self, model_type, model_path, language, task):
47
  encoder_path = f"{model_type}/{model_type}-encoder.axmodel"
48
- decoder_main_path = f"{model_type}/{model_type}-decoder-main.axmodel"
49
- decoder_loop_path = f"{model_type}/{model_type}-decoder-loop.axmodel"
50
- pe_path = f"{model_type}/{model_type}-positional_embedding.bin"
51
  model_config_file = f"{model_type}/{model_type}_config.json"
 
52
 
53
- required_files = [os.path.join(model_path, i) for i in (encoder_path, decoder_main_path, decoder_loop_path, pe_path, model_config_file)]
 
 
 
54
  # Check file existence
55
  for i, file_path in enumerate(required_files):
56
  assert os.path.exists(file_path), f"{file_path} NOT exist"
57
 
58
  # Load encoder
59
- encoder = axe.InferenceSession(required_files[0], providers=['AxEngineExecutionProvider'])
 
 
60
  # Load decoder main
61
- decoder_main = axe.InferenceSession(required_files[1], providers=['AxEngineExecutionProvider'])
62
- # Load decoder loop
63
- decoder_loop = axe.InferenceSession(required_files[2], providers=['AxEngineExecutionProvider'])
64
- # Load position embedding
65
- pe = np.fromfile(required_files[3], dtype=np.float32)
66
  # Load tokens
67
- model_config = json.load(open(required_files[4], "r"))
68
- model_config["all_language_tokens"] = [int(i) for i in model_config["all_language_tokens"].split(",")]
69
- model_config["all_language_codes"] = [i for i in model_config["all_language_codes"].split(",")]
 
 
 
 
70
  tokenizer = get_tokenizer(
71
  model_config["is_multilingual"],
72
  num_languages=len(model_config["all_language_codes"]),
@@ -74,8 +80,9 @@ class Whisper:
74
  task=task,
75
  )
76
 
77
- return encoder, decoder_main, decoder_loop, pe, tokenizer, model_config
78
-
 
79
 
80
  def load_config(self, model_config):
81
  config = WhisperConfig
@@ -94,34 +101,46 @@ class Whisper:
94
  config.n_vocab = model_config["n_vocab"]
95
  config.n_text_ctx = model_config["n_text_ctx"]
96
  config.n_text_state = model_config["n_text_state"]
 
97
 
98
- lang_token = model_config["all_language_tokens"][model_config["all_language_codes"].index(self.language)]
99
- task_token = config.transcribe if self.task == "transcribe" else config.translate
100
- config.sot_sequence = np.array([config.sot, lang_token, task_token, config.no_timestamps], dtype=np.int32)
 
 
 
 
 
 
101
 
102
  return config
103
-
 
 
 
 
 
 
 
104
 
105
  def load_audio(self, audio: str):
106
  data, sample_rate = librosa.load(audio, sr=self.config.sample_rate)
107
  samples = np.ascontiguousarray(data)
108
  return samples, sample_rate
109
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
- def compute_feature(self, audio: np.ndarray, padding = 480000):
112
- if padding > 0:
113
- audio = np.concatenate((audio, np.zeros((padding,), dtype=np.float32)), axis=-1)
114
-
115
- mel = librosa.feature.melspectrogram(y=audio,
116
- sr=self.config.sample_rate,
117
- n_fft=self.config.n_fft,
118
- hop_length=self.config.hop_length,
119
- window="hann",
120
- center=True,
121
- pad_mode="reflect",
122
- power=2.0,
123
- n_mels=self.config.n_mels)
124
-
125
  log_spec = np.log10(np.maximum(mel, 1e-10))
126
  log_spec = np.maximum(log_spec, log_spec.max() - 8.0)
127
  mel = (log_spec + 4.0) / 4.0
@@ -129,31 +148,71 @@ class Whisper:
129
  target = 3000
130
  if mel.shape[1] > target:
131
  # -50 so that there are some zero tail paddings.
132
- mel = mel[:, : target]
133
  mel[:, -50:] = 0
134
 
135
  # We don't need to pad it to 30 seconds now!
136
  if mel.shape[1] < target:
137
- mel = np.concatenate((mel, np.zeros((self.config.n_mels, target - mel.shape[1]), dtype=np.float32)), axis=-1)
138
-
139
- return mel
140
-
141
-
142
- def supress_tokens(self, logits, is_initial):
143
- if is_initial:
144
- logits[self.config.eot] = NEG_INF
145
- logits[self.config.blank_id] = NEG_INF
146
-
147
- logits[self.config.no_timestamps] = NEG_INF
148
- logits[self.config.sot] = NEG_INF
149
- logits[self.config.no_speech] = NEG_INF
 
 
 
 
 
 
 
 
 
 
150
 
151
- if self.task == "transcribe":
152
- logits[self.config.translate] = NEG_INF
153
- else:
154
- logits[self.config.transcribe] = NEG_INF
155
- return logits
156
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
 
158
  def run(self, audio: Union[str, np.ndarray]) -> str:
159
  if isinstance(audio, str):
@@ -161,64 +220,56 @@ class Whisper:
161
 
162
  mel = self.compute_feature(audio)
163
 
164
- # Run encoder
165
- x = self.encoder.run(None, input_feed={"mel": mel[None, ...]})
166
- n_layer_cross_k, n_layer_cross_v = x
167
-
168
- # Run decoder_main
169
- x = self.decoder_main.run(None, input_feed={
170
- "tokens": self.config.sot_sequence[None, ...],
171
- "n_layer_cross_k": n_layer_cross_k,
172
- "n_layer_cross_v": n_layer_cross_v
173
- })
174
- logits, n_layer_self_k_cache, n_layer_self_v_cache = x
175
-
176
- # Decode token
177
- logits = logits[0, -1, :]
178
- logits = self.supress_tokens(logits, is_initial=True)
179
- # logits.tofile("logits.bin")
180
- max_token_id = np.argmax(logits)
181
- output_tokens = []
182
-
183
- # Position embedding offset
184
- offset = self.config.sot_sequence.shape[0]
185
-
186
- # Autoregressively run decoder until token meets EOT
187
- for i in range(self.config.n_text_ctx - self.config.sot_sequence.shape[0]):
188
- if max_token_id >= self.config.eot:
189
- break
190
-
191
- output_tokens.append(max_token_id)
192
-
193
- mask = np.zeros((self.config.n_text_ctx,), dtype=np.float32)
194
- mask[: self.config.n_text_ctx - offset - 1] = NEG_INF
195
-
196
- # Run decoder_loop
197
- x = self.decoder_loop.run(None, input_feed={
198
- "tokens": np.array([[output_tokens[-1]]], dtype=np.int32),
199
- "in_n_layer_self_k_cache": n_layer_self_k_cache,
200
- "in_n_layer_self_v_cache": n_layer_self_v_cache,
201
- "n_layer_cross_k": n_layer_cross_k,
202
- "n_layer_cross_v": n_layer_cross_v,
203
- "positional_embedding": self.pe[offset * self.config.n_text_state : (offset + 1) * self.config.n_text_state][None, ...],
204
- "mask": mask
205
- })
206
- logits, n_layer_self_k_cache, n_layer_self_v_cache = x
207
-
208
- # Decode token
209
  offset += 1
210
- logits = self.supress_tokens(logits.flatten(), is_initial=False)
211
- max_token_id = np.argmax(logits)
212
-
213
- text = self.tokenizer.decode(output_tokens)
 
 
 
 
 
 
214
 
215
  if self.language == "zh":
216
  try:
217
- sim_zh = zhconv.convert(text, 'zh-hans')
218
  return sim_zh
219
  except:
220
  return text
221
-
222
- return text
223
 
224
-
 
1
+ import axengine as axe
2
  import numpy as np
3
  import librosa
4
  import os
 
9
  import zhconv
10
 
11
 
 
 
12
  @dataclass
13
  class WhisperConfig:
14
+ n_mels: int = 0
15
+ sample_rate: int = 0
16
+ n_fft: int = 0
17
+ hop_length: int = 0
18
+
19
+ sot: int = 0
20
+ eot: int = 0
21
+ blank_id: int = 0
22
+ no_timestamps: int = 0
23
+ no_speech: int = 0
24
+ translate: int = 0
25
+ transcribe: int = 0
26
+ n_vocab: int = 0
27
+ n_text_ctx: int = 0
28
+ n_text_state: int = 0
29
+
30
+ sot_sequence: np.ndarray = field(
31
+ default_factory=lambda: np.array([0, 0, 0, 0], dtype=np.int32)
32
+ )
33
 
34
 
35
  class Whisper:
 
38
 
39
  self.language = language
40
  self.task = task
41
+ self.encoder, self.decoder, self.tokenizer, model_config = self.load_model(
42
+ model_type, model_path, language, task
43
+ )
44
  self.config = self.load_config(model_config)
45
 
 
46
  def load_model(self, model_type, model_path, language, task):
47
  encoder_path = f"{model_type}/{model_type}-encoder.axmodel"
48
+ decoder_path = f"{model_type}/{model_type}-decoder.axmodel"
 
 
49
  model_config_file = f"{model_type}/{model_type}_config.json"
50
+ token_file = f"{model_type}/{model_type}-tokens.txt"
51
 
52
+ required_files = [
53
+ os.path.join(model_path, i)
54
+ for i in (encoder_path, decoder_path, model_config_file, token_file)
55
+ ]
56
  # Check file existence
57
  for i, file_path in enumerate(required_files):
58
  assert os.path.exists(file_path), f"{file_path} NOT exist"
59
 
60
  # Load encoder
61
+ encoder = axe.InferenceSession(
62
+ required_files[0], providers=["AxEngineExecutionProvider"]
63
+ )
64
  # Load decoder main
65
+ decoder = axe.InferenceSession(
66
+ required_files[1], providers=["AxEngineExecutionProvider"]
67
+ )
 
 
68
  # Load tokens
69
+ model_config = json.load(open(required_files[2], "r"))
70
+ model_config["all_language_tokens"] = [
71
+ int(i) for i in model_config["all_language_tokens"].split(",")
72
+ ]
73
+ model_config["all_language_codes"] = [
74
+ i for i in model_config["all_language_codes"].split(",")
75
+ ]
76
  tokenizer = get_tokenizer(
77
  model_config["is_multilingual"],
78
  num_languages=len(model_config["all_language_codes"]),
 
80
  task=task,
81
  )
82
 
83
+ self.id2token = self.load_tokens(required_files[3])
84
+
85
+ return encoder, decoder, tokenizer, model_config
86
 
87
  def load_config(self, model_config):
88
  config = WhisperConfig
 
101
  config.n_vocab = model_config["n_vocab"]
102
  config.n_text_ctx = model_config["n_text_ctx"]
103
  config.n_text_state = model_config["n_text_state"]
104
+ config.n_text_layer = model_config["n_text_layer"]
105
 
106
+ lang_token = model_config["all_language_tokens"][
107
+ model_config["all_language_codes"].index(self.language)
108
+ ]
109
+ task_token = (
110
+ config.transcribe if self.task == "transcribe" else config.translate
111
+ )
112
+ config.sot_sequence = np.array(
113
+ [config.sot, lang_token, task_token, config.no_timestamps], dtype=np.int32
114
+ )
115
 
116
  return config
117
+
118
+ def load_tokens(self, filename):
119
+ tokens = dict()
120
+ with open(filename, "r") as f:
121
+ for line in f:
122
+ t, i = line.split()
123
+ tokens[int(i)] = t
124
+ return tokens
125
 
126
  def load_audio(self, audio: str):
127
  data, sample_rate = librosa.load(audio, sr=self.config.sample_rate)
128
  samples = np.ascontiguousarray(data)
129
  return samples, sample_rate
130
 
131
+ def compute_feature(self, audio: np.ndarray):
132
+ mel = librosa.feature.melspectrogram(
133
+ y=audio,
134
+ sr=self.config.sample_rate,
135
+ n_fft=self.config.n_fft,
136
+ hop_length=self.config.hop_length,
137
+ window="hann",
138
+ center=True,
139
+ pad_mode="reflect",
140
+ power=2.0,
141
+ n_mels=self.config.n_mels,
142
+ )
143
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
  log_spec = np.log10(np.maximum(mel, 1e-10))
145
  log_spec = np.maximum(log_spec, log_spec.max() - 8.0)
146
  mel = (log_spec + 4.0) / 4.0
 
148
  target = 3000
149
  if mel.shape[1] > target:
150
  # -50 so that there are some zero tail paddings.
151
+ mel = mel[:, :target]
152
  mel[:, -50:] = 0
153
 
154
  # We don't need to pad it to 30 seconds now!
155
  if mel.shape[1] < target:
156
+ mel = np.concatenate(
157
+ (
158
+ mel,
159
+ np.zeros(
160
+ (self.config.n_mels, target - mel.shape[1]), dtype=np.float32
161
+ ),
162
+ ),
163
+ axis=-1,
164
+ )
165
+
166
+ return mel[np.newaxis, ...]
167
+
168
+ def run_encoder(
169
+ self,
170
+ mel: np.ndarray,
171
+ ) -> List[np.ndarray]:
172
+ cross_kv = self.encoder.run(
173
+ None,
174
+ {
175
+ self.encoder.get_inputs()[0].name: mel,
176
+ },
177
+ )
178
+ return cross_kv
179
 
180
+ def run_decoder(self, inputs: List[np.ndarray]) -> List[np.ndarray]:
181
+ feed = {
182
+ self.decoder.get_inputs()[i].name: inputs[i] for i in range(len(inputs))
183
+ }
 
184
 
185
+ out = self.decoder.run(
186
+ None,
187
+ feed,
188
+ )
189
+ return out
190
+
191
+ def get_self_cache(self) -> List[np.ndarray]:
192
+ self_cache = []
193
+ batch_size = 1
194
+ for i in range(self.config.n_text_layer):
195
+ k = np.zeros(
196
+ (batch_size, self.config.n_text_ctx, self.config.n_text_state),
197
+ dtype=np.float32,
198
+ )
199
+ v = np.zeros(
200
+ (batch_size, self.config.n_text_ctx, self.config.n_text_state),
201
+ dtype=np.float32,
202
+ )
203
+ self_cache.extend([k, v])
204
+ return self_cache
205
+
206
+ def causal_mask_1d(self, n: int, L: int):
207
+ """
208
+ Returns a 1-D int mask of shape (L,) with:
209
+ 0 -> allowed
210
+ 1 -> masked (will be converted to -inf later)
211
+ """
212
+ mask = np.ones((L,), dtype=np.int32)
213
+ if n > 0:
214
+ mask[:n] = 0
215
+ return mask
216
 
217
  def run(self, audio: Union[str, np.ndarray]) -> str:
218
  if isinstance(audio, str):
 
220
 
221
  mel = self.compute_feature(audio)
222
 
223
+ cross_kv = self.run_encoder(mel)
224
+
225
+ self_kv = self.get_self_cache()
226
+
227
+ offset = np.array([0], dtype=np.int32)
228
+ for t in self.config.sot_sequence:
229
+ token = np.array([[t]], dtype=np.int32) # sot
230
+ mask = self.causal_mask_1d(offset.item(), self.config.n_text_ctx)
231
+
232
+ out = self.run_decoder([token] + self_kv + cross_kv + [offset, mask])
233
+
234
+ for i in range(1, len(out)):
235
+ self_kv[i - 1][:, offset.item() : offset.item() + 1, :] = out[i]
236
+
237
+ offset += 1
238
+
239
+ idx = out[0][0, 0].argmax()
240
+
241
+ eot = self.config.eot
242
+
243
+ ans = []
244
+
245
+ while idx != eot and offset.item() < 100:
246
+ ans.append(idx)
247
+ token = np.array([[idx]], dtype=np.int32)
248
+
249
+ mask = self.causal_mask_1d(offset.item(), self.config.n_text_ctx)
250
+
251
+ out = self.run_decoder([token] + self_kv + cross_kv + [offset, mask])
252
+
253
+ for i in range(1, len(out)):
254
+ self_kv[i - 1][:, offset.item() : offset.item() + 1, :] = out[i]
255
+
 
 
 
 
 
 
 
 
 
 
 
 
256
  offset += 1
257
+ idx = out[0][0, 0].argmax()
258
+
259
+ # print(ans)
260
+
261
+ s = b""
262
+ for i in ans:
263
+ if i in self.id2token:
264
+ s += base64.b64decode(self.id2token[i])
265
+
266
+ text = s.decode().strip()
267
 
268
  if self.language == "zh":
269
  try:
270
+ sim_zh = zhconv.convert(text, "zh-hans")
271
  return sim_zh
272
  except:
273
  return text
 
 
274
 
275
+ return text