Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
8948419
1
Parent(s):
8c4ce15
add fireredasr correctly
Browse files- fireredasr +0 -1
- fireredasr/LICENSE +201 -0
- fireredasr/README.md +160 -0
- fireredasr/examples/fireredasr +1 -0
- fireredasr/examples/inference_fireredasr_aed.sh +33 -0
- fireredasr/examples/inference_fireredasr_llm.sh +32 -0
- fireredasr/examples/pretrained_models +1 -0
- fireredasr/examples/wav/IT0011W0001.wav +0 -0
- fireredasr/examples/wav/TEST_NET_Y0000000000_-KTKHdZ2fb8_S00000.wav +0 -0
- fireredasr/examples/wav/text +4 -0
- fireredasr/examples/wav/wav.scp +4 -0
- fireredasr/fireredasr/.speech2text.py.swp +0 -0
- fireredasr/fireredasr/data/asr_feat.py +107 -0
- fireredasr/fireredasr/data/token_dict.py +59 -0
- fireredasr/fireredasr/models/.fireredasr.py.swp +0 -0
- fireredasr/fireredasr/models/fireredasr.py +125 -0
- fireredasr/fireredasr/models/fireredasr_aed.py +35 -0
- fireredasr/fireredasr/models/fireredasr_llm.py +272 -0
- fireredasr/fireredasr/models/module/adapter.py +30 -0
- fireredasr/fireredasr/models/module/conformer_encoder.py +322 -0
- fireredasr/fireredasr/models/module/transformer_decoder.py +299 -0
- fireredasr/fireredasr/speech2text.py +105 -0
- fireredasr/fireredasr/tokenizer/aed_tokenizer.py +67 -0
- fireredasr/fireredasr/tokenizer/llm_tokenizer.py +105 -0
- fireredasr/fireredasr/utils/param.py +13 -0
- fireredasr/fireredasr/utils/wer.py +303 -0
- fireredasr/pretrained_models/README.md +1 -0
- fireredasr/requirements.txt +8 -0
fireredasr
DELETED
|
@@ -1 +0,0 @@
|
|
| 1 |
-
Subproject commit 1eadb81b66eca948cd492bc0aeedd786333c049d
|
|
|
|
|
|
fireredasr/LICENSE
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright [yyyy] [name of copyright owner]
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
fireredasr/README.md
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<div align="center">
|
| 2 |
+
<h1>FireRedASR: Open-Source Industrial-Grade
|
| 3 |
+
<br>
|
| 4 |
+
Automatic Speech Recognition Models</h1>
|
| 5 |
+
|
| 6 |
+
</div>
|
| 7 |
+
|
| 8 |
+
[[Paper]](https://arxiv.org/pdf/2501.14350)
|
| 9 |
+
[[Model]](https://huggingface.co/fireredteam)
|
| 10 |
+
[[Blog]](https://fireredteam.github.io/demos/firered_asr/)
|
| 11 |
+
|
| 12 |
+
FireRedASR is a family of open-source industrial-grade automatic speech recognition (ASR) models supporting Mandarin, Chinese dialects and English, achieving a new state-of-the-art (SOTA) on public Mandarin ASR benchmarks, while also offering outstanding singing lyrics recognition capability.
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
## 🔥 News
|
| 16 |
+
- [2025/02/17] We release [FireRedASR-LLM-L](https://huggingface.co/fireredteam/FireRedASR-LLM-L/tree/main) model weights.
|
| 17 |
+
- [2025/01/24] We release [technical report](https://arxiv.org/pdf/2501.14350), [blog](https://fireredteam.github.io/demos/firered_asr/), and [FireRedASR-AED-L](https://huggingface.co/fireredteam/FireRedASR-AED-L/tree/main) model weights.
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
## Method
|
| 21 |
+
|
| 22 |
+
FireRedASR is designed to meet diverse requirements in superior performance and optimal efficiency across various applications. It comprises two variants:
|
| 23 |
+
- FireRedASR-LLM: Designed to achieve state-of-the-art (SOTA) performance and to enable seamless end-to-end speech interaction. It adopts an Encoder-Adapter-LLM framework leveraging large language model (LLM) capabilities.
|
| 24 |
+
- FireRedASR-AED: Designed to balance high performance and computational efficiency and to serve as an effective speech representation module in LLM-based speech models. It utilizes an Attention-based Encoder-Decoder (AED) architecture.
|
| 25 |
+
|
| 26 |
+

|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
## Evaluation
|
| 30 |
+
Results are reported in Character Error Rate (CER%) for Chinese and Word Error Rate (WER%) for English.
|
| 31 |
+
|
| 32 |
+
### Evaluation on Public Mandarin ASR Benchmarks
|
| 33 |
+
| Model | #Params | aishell1 | aishell2 | ws\_net | ws\_meeting | Average-4 |
|
| 34 |
+
|:----------------:|:-------:|:--------:|:--------:|:--------:|:-----------:|:---------:|
|
| 35 |
+
| FireRedASR-LLM | 8.3B | 0.76 | 2.15 | 4.60 | 4.67 | 3.05 |
|
| 36 |
+
| FireRedASR-AED | 1.1B | 0.55 | 2.52 | 4.88 | 4.76 | 3.18 |
|
| 37 |
+
| Seed-ASR | 12B+ | 0.68 | 2.27 | 4.66 | 5.69 | 3.33 |
|
| 38 |
+
| Qwen-Audio | 8.4B | 1.30 | 3.10 | 9.50 | 10.87 | 6.19 |
|
| 39 |
+
| SenseVoice-L | 1.6B | 2.09 | 3.04 | 6.01 | 6.73 | 4.47 |
|
| 40 |
+
| Whisper-Large-v3 | 1.6B | 5.14 | 4.96 | 10.48 | 18.87 | 9.86 |
|
| 41 |
+
| Paraformer-Large | 0.2B | 1.68 | 2.85 | 6.74 | 6.97 | 4.56 |
|
| 42 |
+
|
| 43 |
+
`ws` means WenetSpeech.
|
| 44 |
+
|
| 45 |
+
### Evaluation on Public Chinese Dialect and English ASR Benchmarks
|
| 46 |
+
|Test Set | KeSpeech | LibriSpeech test-clean | LibriSpeech test-other |
|
| 47 |
+
| :------------:| :------: | :--------------------: | :----------------------:|
|
| 48 |
+
|FireRedASR-LLM | 3.56 | 1.73 | 3.67 |
|
| 49 |
+
|FireRedASR-AED | 4.48 | 1.93 | 4.44 |
|
| 50 |
+
|Previous SOTA Results | 6.70 | 1.82 | 3.50 |
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
## Usage
|
| 54 |
+
Download model files from [huggingface](https://huggingface.co/fireredteam) and place them in the folder `pretrained_models`.
|
| 55 |
+
|
| 56 |
+
If you want to use `FireRedASR-LLM-L`, you also need to download [Qwen2-7B-Instruct](https://huggingface.co/Qwen/Qwen2-7B-Instruct) and place it in the folder `pretrained_models`. Then, go to folder `FireRedASR-LLM-L` and run `$ ln -s ../Qwen2-7B-Instruct`
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
### Setup
|
| 60 |
+
Create a Python environment and install dependencies
|
| 61 |
+
```bash
|
| 62 |
+
$ git clone https://github.com/FireRedTeam/FireRedASR.git
|
| 63 |
+
$ conda create --name fireredasr python=3.10
|
| 64 |
+
$ pip install -r requirements.txt
|
| 65 |
+
```
|
| 66 |
+
|
| 67 |
+
Set up Linux PATH and PYTHONPATH
|
| 68 |
+
```
|
| 69 |
+
$ export PATH=$PWD/fireredasr/:$PWD/fireredasr/utils/:$PATH
|
| 70 |
+
$ export PYTHONPATH=$PWD/:$PYTHONPATH
|
| 71 |
+
```
|
| 72 |
+
|
| 73 |
+
Convert audio to 16kHz 16-bit PCM format
|
| 74 |
+
```
|
| 75 |
+
ffmpeg -i input_audio -ar 16000 -ac 1 -acodec pcm_s16le -f wav output.wav
|
| 76 |
+
```
|
| 77 |
+
|
| 78 |
+
### Quick Start
|
| 79 |
+
```bash
|
| 80 |
+
$ cd examples
|
| 81 |
+
$ bash inference_fireredasr_aed.sh
|
| 82 |
+
$ bash inference_fireredasr_llm.sh
|
| 83 |
+
```
|
| 84 |
+
|
| 85 |
+
### Command-line Usage
|
| 86 |
+
```bash
|
| 87 |
+
$ speech2text.py --help
|
| 88 |
+
$ speech2text.py --wav_path examples/wav/BAC009S0764W0121.wav --asr_type "aed" --model_dir pretrained_models/FireRedASR-AED-L
|
| 89 |
+
$ speech2text.py --wav_path examples/wav/BAC009S0764W0121.wav --asr_type "llm" --model_dir pretrained_models/FireRedASR-LLM-L
|
| 90 |
+
```
|
| 91 |
+
|
| 92 |
+
### Python Usage
|
| 93 |
+
```python
|
| 94 |
+
from fireredasr.models.fireredasr import FireRedAsr
|
| 95 |
+
|
| 96 |
+
batch_uttid = ["BAC009S0764W0121"]
|
| 97 |
+
batch_wav_path = ["examples/wav/BAC009S0764W0121.wav"]
|
| 98 |
+
|
| 99 |
+
# FireRedASR-AED
|
| 100 |
+
model = FireRedAsr.from_pretrained("aed", "pretrained_models/FireRedASR-AED-L")
|
| 101 |
+
results = model.transcribe(
|
| 102 |
+
batch_uttid,
|
| 103 |
+
batch_wav_path,
|
| 104 |
+
{
|
| 105 |
+
"use_gpu": 1,
|
| 106 |
+
"beam_size": 3,
|
| 107 |
+
"nbest": 1,
|
| 108 |
+
"decode_max_len": 0,
|
| 109 |
+
"softmax_smoothing": 1.25,
|
| 110 |
+
"aed_length_penalty": 0.6,
|
| 111 |
+
"eos_penalty": 1.0
|
| 112 |
+
}
|
| 113 |
+
)
|
| 114 |
+
print(results)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
# FireRedASR-LLM
|
| 118 |
+
model = FireRedAsr.from_pretrained("llm", "pretrained_models/FireRedASR-LLM-L")
|
| 119 |
+
results = model.transcribe(
|
| 120 |
+
batch_uttid,
|
| 121 |
+
batch_wav_path,
|
| 122 |
+
{
|
| 123 |
+
"use_gpu": 1,
|
| 124 |
+
"beam_size": 3,
|
| 125 |
+
"decode_max_len": 0,
|
| 126 |
+
"decode_min_len": 0,
|
| 127 |
+
"repetition_penalty": 3.0,
|
| 128 |
+
"llm_length_penalty": 1.0,
|
| 129 |
+
"temperature": 1.0
|
| 130 |
+
}
|
| 131 |
+
)
|
| 132 |
+
print(results)
|
| 133 |
+
```
|
| 134 |
+
|
| 135 |
+
## Usage Tips
|
| 136 |
+
### Batch Beam Search
|
| 137 |
+
- When performing batch beam search with FireRedASR-LLM, please ensure that the input lengths of the utterances are similar. If there are significant differences in utterance lengths, shorter utterances may experience repetition issues. You can either sort your dataset by length or set `batch_size` to 1 to avoid the repetition issue.
|
| 138 |
+
|
| 139 |
+
### Input Length Limitations
|
| 140 |
+
- FireRedASR-AED supports audio input up to 60s. Input longer than 60s may cause hallucination issues, and input exceeding 200s will trigger positional encoding errors.
|
| 141 |
+
- FireRedASR-LLM supports audio input up to 30s. The behavior for longer input is currently unknown.
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
## Acknowledgements
|
| 145 |
+
Thanks to the following open-source works:
|
| 146 |
+
- [Qwen2-7B-Instruct](https://huggingface.co/Qwen/Qwen2-7B-Instruct)
|
| 147 |
+
- [icefall/ASR_LLM](https://github.com/k2-fsa/icefall/tree/master/egs/speech_llm/ASR_LLM)
|
| 148 |
+
- [WeNet](https://github.com/wenet-e2e/wenet)
|
| 149 |
+
- [Speech-Transformer](https://github.com/kaituoxu/Speech-Transformer)
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
## Citation
|
| 153 |
+
```bibtex
|
| 154 |
+
@article{xu2025fireredasr,
|
| 155 |
+
title={FireRedASR: Open-Source Industrial-Grade Mandarin Speech Recognition Models from Encoder-Decoder to LLM Integration},
|
| 156 |
+
author={Xu, Kai-Tuo and Xie, Feng-Long and Tang, Xu and Hu, Yao},
|
| 157 |
+
journal={arXiv preprint arXiv:2501.14350},
|
| 158 |
+
year={2025}
|
| 159 |
+
}
|
| 160 |
+
```
|
fireredasr/examples/fireredasr
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
../fireredasr
|
fireredasr/examples/inference_fireredasr_aed.sh
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
export PATH=$PWD/fireredasr/:$PWD/fireredasr/utils/:$PATH
|
| 4 |
+
export PYTHONPATH=$PWD/:$PYTHONPATH
|
| 5 |
+
|
| 6 |
+
# model_dir includes model.pth.tar, cmvn.ark, dict.txt
|
| 7 |
+
model_dir=$PWD/pretrained_models/FireRedASR-AED-L
|
| 8 |
+
|
| 9 |
+
# Support several input format
|
| 10 |
+
wavs="--wav_path wav/BAC009S0764W0121.wav"
|
| 11 |
+
wavs="--wav_paths wav/BAC009S0764W0121.wav wav/IT0011W0001.wav wav/TEST_NET_Y0000000000_-KTKHdZ2fb8_S00000.wav wav/TEST_MEETING_T0000000001_S00000.wav"
|
| 12 |
+
wavs="--wav_dir wav/"
|
| 13 |
+
wavs="--wav_scp wav/wav.scp"
|
| 14 |
+
|
| 15 |
+
out="out/aed-l-asr.txt"
|
| 16 |
+
|
| 17 |
+
decode_args="
|
| 18 |
+
--batch_size 2 --beam_size 3 --nbest 1
|
| 19 |
+
--decode_max_len 0 --softmax_smoothing 1.25 --aed_length_penalty 0.6
|
| 20 |
+
--eos_penalty 1.0
|
| 21 |
+
"
|
| 22 |
+
|
| 23 |
+
mkdir -p $(dirname $out)
|
| 24 |
+
set -x
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
CUDA_VISIBLE_DEVICES=0 \
|
| 28 |
+
speech2text.py --asr_type "aed" --model_dir $model_dir $decode_args $wavs --output $out
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
ref="wav/text"
|
| 32 |
+
wer.py --print_sentence_wer 1 --do_tn 0 --rm_special 0 --ref $ref --hyp $out > $out.wer 2>&1
|
| 33 |
+
tail -n8 $out.wer
|
fireredasr/examples/inference_fireredasr_llm.sh
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
|
| 3 |
+
export PATH=$PWD/fireredasr/:$PWD/fireredasr/utils/:$PATH
|
| 4 |
+
export PYTHONPATH=$PWD/:$PYTHONPATH
|
| 5 |
+
|
| 6 |
+
# model_dir includes model.pth.tar, asr_encoder.pth.tar, cmvn.ark, Qwen2-7B-Instruct
|
| 7 |
+
model_dir=$PWD/pretrained_models/FireRedASR-LLM-L
|
| 8 |
+
|
| 9 |
+
# Support several input format
|
| 10 |
+
wavs="--wav_path wav/BAC009S0764W0121.wav"
|
| 11 |
+
wavs="--wav_paths wav/BAC009S0764W0121.wav wav/IT0011W0001.wav wav/TEST_NET_Y0000000000_-KTKHdZ2fb8_S00000.wav wav/TEST_MEETING_T0000000001_S00000.wav"
|
| 12 |
+
wavs="--wav_dir wav/"
|
| 13 |
+
wavs="--wav_scp wav/wav.scp"
|
| 14 |
+
|
| 15 |
+
out="out/llm-l-asr.txt"
|
| 16 |
+
|
| 17 |
+
decode_args="
|
| 18 |
+
--batch_size 1 --beam_size 3 --decode_max_len 0 --decode_min_len 0
|
| 19 |
+
--repetition_penalty 3.0 --llm_length_penalty 1.0 --temperature 1.0
|
| 20 |
+
"
|
| 21 |
+
|
| 22 |
+
mkdir -p $(dirname $out)
|
| 23 |
+
set -x
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
CUDA_VISIBLE_DEVICES=0 \
|
| 27 |
+
speech2text.py --asr_type "llm" --model_dir $model_dir $decode_args $wavs --output $out
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
ref="wav/text"
|
| 31 |
+
wer.py --print_sentence_wer 1 --do_tn 0 --rm_special 1 --ref $ref --hyp $out > $out.wer 2>&1
|
| 32 |
+
tail -n8 $out.wer
|
fireredasr/examples/pretrained_models
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
../pretrained_models
|
fireredasr/examples/wav/IT0011W0001.wav
ADDED
|
Binary file (63.8 kB). View file
|
|
|
fireredasr/examples/wav/TEST_NET_Y0000000000_-KTKHdZ2fb8_S00000.wav
ADDED
|
Binary file (57.6 kB). View file
|
|
|
fireredasr/examples/wav/text
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
BAC009S0764W0121 甚至 出现 交易 几乎 停滞 的 情况
|
| 2 |
+
IT0011W0001 换一首歌
|
| 3 |
+
TEST_NET_Y0000000000_-KTKHdZ2fb8_S00000 我有的时候说不清楚你们知道吗
|
| 4 |
+
TEST_MEETING_T0000000001_S00000 好首先说一下刚才这个经理说完的这个销售问题咱再说一下咱们的商场问题首先咱们商场上半年业这个先各部门儿汇报一下就是业绩
|
fireredasr/examples/wav/wav.scp
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
BAC009S0764W0121 wav/BAC009S0764W0121.wav
|
| 2 |
+
IT0011W0001 wav/IT0011W0001.wav
|
| 3 |
+
TEST_NET_Y0000000000_-KTKHdZ2fb8_S00000 wav/TEST_NET_Y0000000000_-KTKHdZ2fb8_S00000.wav
|
| 4 |
+
TEST_MEETING_T0000000001_S00000 wav/TEST_MEETING_T0000000001_S00000.wav
|
fireredasr/fireredasr/.speech2text.py.swp
ADDED
|
Binary file (12.3 kB). View file
|
|
|
fireredasr/fireredasr/data/asr_feat.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import os
|
| 3 |
+
|
| 4 |
+
import kaldiio
|
| 5 |
+
import kaldi_native_fbank as knf
|
| 6 |
+
import numpy as np
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class ASRFeatExtractor:
|
| 11 |
+
def __init__(self, kaldi_cmvn_file):
|
| 12 |
+
self.cmvn = CMVN(kaldi_cmvn_file) if kaldi_cmvn_file != "" else None
|
| 13 |
+
self.fbank = KaldifeatFbank(num_mel_bins=80, frame_length=25,
|
| 14 |
+
frame_shift=10, dither=0.0)
|
| 15 |
+
|
| 16 |
+
def __call__(self, wav_paths):
|
| 17 |
+
feats = []
|
| 18 |
+
durs = []
|
| 19 |
+
for wav_path in wav_paths:
|
| 20 |
+
sample_rate, wav_np = kaldiio.load_mat(wav_path)
|
| 21 |
+
dur = wav_np.shape[0] / sample_rate
|
| 22 |
+
fbank = self.fbank((sample_rate, wav_np))
|
| 23 |
+
if self.cmvn is not None:
|
| 24 |
+
fbank = self.cmvn(fbank)
|
| 25 |
+
fbank = torch.from_numpy(fbank).float()
|
| 26 |
+
feats.append(fbank)
|
| 27 |
+
durs.append(dur)
|
| 28 |
+
lengths = torch.tensor([feat.size(0) for feat in feats]).long()
|
| 29 |
+
feats_pad = self.pad_feat(feats, 0.0)
|
| 30 |
+
return feats_pad, lengths, durs
|
| 31 |
+
|
| 32 |
+
def pad_feat(self, xs, pad_value):
|
| 33 |
+
# type: (List[Tensor], int) -> Tensor
|
| 34 |
+
n_batch = len(xs)
|
| 35 |
+
max_len = max([xs[i].size(0) for i in range(n_batch)])
|
| 36 |
+
pad = torch.ones(n_batch, max_len, *xs[0].size()[1:]).to(xs[0].device).to(xs[0].dtype).fill_(pad_value)
|
| 37 |
+
for i in range(n_batch):
|
| 38 |
+
pad[i, :xs[i].size(0)] = xs[i]
|
| 39 |
+
return pad
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class CMVN:
|
| 45 |
+
def __init__(self, kaldi_cmvn_file):
|
| 46 |
+
self.dim, self.means, self.inverse_std_variences = \
|
| 47 |
+
self.read_kaldi_cmvn(kaldi_cmvn_file)
|
| 48 |
+
|
| 49 |
+
def __call__(self, x, is_train=False):
|
| 50 |
+
assert x.shape[-1] == self.dim, "CMVN dim mismatch"
|
| 51 |
+
out = x - self.means
|
| 52 |
+
out = out * self.inverse_std_variences
|
| 53 |
+
return out
|
| 54 |
+
|
| 55 |
+
def read_kaldi_cmvn(self, kaldi_cmvn_file):
|
| 56 |
+
assert os.path.exists(kaldi_cmvn_file)
|
| 57 |
+
stats = kaldiio.load_mat(kaldi_cmvn_file)
|
| 58 |
+
assert stats.shape[0] == 2
|
| 59 |
+
dim = stats.shape[-1] - 1
|
| 60 |
+
count = stats[0, dim]
|
| 61 |
+
assert count >= 1
|
| 62 |
+
floor = 1e-20
|
| 63 |
+
means = []
|
| 64 |
+
inverse_std_variences = []
|
| 65 |
+
for d in range(dim):
|
| 66 |
+
mean = stats[0, d] / count
|
| 67 |
+
means.append(mean.item())
|
| 68 |
+
varience = (stats[1, d] / count) - mean*mean
|
| 69 |
+
if varience < floor:
|
| 70 |
+
varience = floor
|
| 71 |
+
istd = 1.0 / math.sqrt(varience)
|
| 72 |
+
inverse_std_variences.append(istd)
|
| 73 |
+
return dim, np.array(means), np.array(inverse_std_variences)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class KaldifeatFbank:
|
| 78 |
+
def __init__(self, num_mel_bins=80, frame_length=25, frame_shift=10,
|
| 79 |
+
dither=1.0):
|
| 80 |
+
self.dither = dither
|
| 81 |
+
opts = knf.FbankOptions()
|
| 82 |
+
opts.frame_opts.dither = dither
|
| 83 |
+
opts.mel_opts.num_bins = num_mel_bins
|
| 84 |
+
opts.frame_opts.snip_edges = True
|
| 85 |
+
opts.mel_opts.debug_mel = False
|
| 86 |
+
self.opts = opts
|
| 87 |
+
|
| 88 |
+
def __call__(self, wav, is_train=False):
|
| 89 |
+
if type(wav) is str:
|
| 90 |
+
sample_rate, wav_np = kaldiio.load_mat(wav)
|
| 91 |
+
elif type(wav) in [tuple, list] and len(wav) == 2:
|
| 92 |
+
sample_rate, wav_np = wav
|
| 93 |
+
assert len(wav_np.shape) == 1
|
| 94 |
+
|
| 95 |
+
dither = self.dither if is_train else 0.0
|
| 96 |
+
self.opts.frame_opts.dither = dither
|
| 97 |
+
fbank = knf.OnlineFbank(self.opts)
|
| 98 |
+
|
| 99 |
+
fbank.accept_waveform(sample_rate, wav_np.tolist())
|
| 100 |
+
feat = []
|
| 101 |
+
for i in range(fbank.num_frames_ready):
|
| 102 |
+
feat.append(fbank.get_frame(i))
|
| 103 |
+
if len(feat) == 0:
|
| 104 |
+
print("Check data, len(feat) == 0", wav, flush=True)
|
| 105 |
+
return np.zeros((0, self.opts.mel_opts.num_bins))
|
| 106 |
+
feat = np.vstack(feat)
|
| 107 |
+
return feat
|
fireredasr/fireredasr/data/token_dict.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class TokenDict:
|
| 5 |
+
def __init__(self, dict_path, unk=""):
|
| 6 |
+
assert dict_path != ""
|
| 7 |
+
self.id2word, self.word2id = self.read_dict(dict_path)
|
| 8 |
+
self.unk = unk
|
| 9 |
+
assert unk == "" or unk in self.word2id
|
| 10 |
+
self.unkid = self.word2id[unk] if unk else -1
|
| 11 |
+
|
| 12 |
+
def get(self, key, default):
|
| 13 |
+
if type(default) == str:
|
| 14 |
+
default = self.word2id[default]
|
| 15 |
+
return self.word2id.get(key, default)
|
| 16 |
+
|
| 17 |
+
def __getitem__(self, key):
|
| 18 |
+
if type(key) == str:
|
| 19 |
+
if self.unk:
|
| 20 |
+
return self.word2id.get(key, self.word2id[self.unk])
|
| 21 |
+
else:
|
| 22 |
+
return self.word2id[key]
|
| 23 |
+
elif type(key) == int:
|
| 24 |
+
return self.id2word[key]
|
| 25 |
+
else:
|
| 26 |
+
raise TypeError("Key should be str or int")
|
| 27 |
+
|
| 28 |
+
def __len__(self):
|
| 29 |
+
return len(self.id2word)
|
| 30 |
+
|
| 31 |
+
def __contains__(self, query):
|
| 32 |
+
if type(query) == str:
|
| 33 |
+
return query in self.word2id
|
| 34 |
+
elif type(query) == int:
|
| 35 |
+
return query in self.id2word
|
| 36 |
+
else:
|
| 37 |
+
raise TypeError("query should be str or int")
|
| 38 |
+
|
| 39 |
+
def read_dict(self, dict_path):
|
| 40 |
+
id2word, word2id = [], {}
|
| 41 |
+
with open(dict_path, encoding='utf8') as f:
|
| 42 |
+
for i, line in enumerate(f):
|
| 43 |
+
tokens = line.strip().split()
|
| 44 |
+
if len(tokens) >= 2:
|
| 45 |
+
word, index = tokens[0], int(tokens[1])
|
| 46 |
+
elif len(tokens) == 1:
|
| 47 |
+
word, index = tokens[0], i
|
| 48 |
+
else: # empty line or space
|
| 49 |
+
logging.info(f"Find empty line or space '{line.strip()}' in {dict_path}:L{i}, set to ' '")
|
| 50 |
+
word, index = " ", i
|
| 51 |
+
assert len(id2word) == index
|
| 52 |
+
assert len(word2id) == index
|
| 53 |
+
if word == "<space>":
|
| 54 |
+
logging.info(f"NOTE: Find <space> in {dict_path}:L{i} and convert it to ' '")
|
| 55 |
+
word = " "
|
| 56 |
+
word2id[word] = index
|
| 57 |
+
id2word.append(word)
|
| 58 |
+
assert len(id2word) == len(word2id)
|
| 59 |
+
return id2word, word2id
|
fireredasr/fireredasr/models/.fireredasr.py.swp
ADDED
|
Binary file (16.4 kB). View file
|
|
|
fireredasr/fireredasr/models/fireredasr.py
ADDED
|
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import time
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
|
| 6 |
+
from fireredasr.data.asr_feat import ASRFeatExtractor
|
| 7 |
+
from fireredasr.models.fireredasr_aed import FireRedAsrAed
|
| 8 |
+
from fireredasr.models.fireredasr_llm import FireRedAsrLlm
|
| 9 |
+
from fireredasr.tokenizer.aed_tokenizer import ChineseCharEnglishSpmTokenizer
|
| 10 |
+
from fireredasr.tokenizer.llm_tokenizer import LlmTokenizerWrapper
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class FireRedAsr:
|
| 14 |
+
@classmethod
|
| 15 |
+
def from_pretrained(cls, asr_type, model_dir):
|
| 16 |
+
assert asr_type in ["aed", "llm"]
|
| 17 |
+
|
| 18 |
+
cmvn_path = os.path.join(model_dir, "cmvn.ark")
|
| 19 |
+
feat_extractor = ASRFeatExtractor(cmvn_path)
|
| 20 |
+
|
| 21 |
+
if asr_type == "aed":
|
| 22 |
+
model_path = os.path.join(model_dir, "model.pth.tar")
|
| 23 |
+
dict_path =os.path.join(model_dir, "dict.txt")
|
| 24 |
+
spm_model = os.path.join(model_dir, "train_bpe1000.model")
|
| 25 |
+
model = load_fireredasr_aed_model(model_path)
|
| 26 |
+
tokenizer = ChineseCharEnglishSpmTokenizer(dict_path, spm_model)
|
| 27 |
+
elif asr_type == "llm":
|
| 28 |
+
model_path = os.path.join(model_dir, "model.pth.tar")
|
| 29 |
+
encoder_path = os.path.join(model_dir, "asr_encoder.pth.tar")
|
| 30 |
+
llm_dir = os.path.join(model_dir, "Qwen2-7B-Instruct")
|
| 31 |
+
model, tokenizer = load_firered_llm_model_and_tokenizer(
|
| 32 |
+
model_path, encoder_path, llm_dir)
|
| 33 |
+
model.eval()
|
| 34 |
+
return cls(asr_type, feat_extractor, model, tokenizer)
|
| 35 |
+
|
| 36 |
+
def __init__(self, asr_type, feat_extractor, model, tokenizer):
|
| 37 |
+
self.asr_type = asr_type
|
| 38 |
+
self.feat_extractor = feat_extractor
|
| 39 |
+
self.model = model
|
| 40 |
+
self.tokenizer = tokenizer
|
| 41 |
+
|
| 42 |
+
@torch.no_grad()
|
| 43 |
+
def transcribe(self, batch_uttid, batch_wav_path, args={}):
|
| 44 |
+
feats, lengths, durs = self.feat_extractor(batch_wav_path)
|
| 45 |
+
total_dur = sum(durs)
|
| 46 |
+
if args.get("use_gpu", False):
|
| 47 |
+
feats, lengths = feats.cuda(), lengths.cuda()
|
| 48 |
+
self.model.cuda()
|
| 49 |
+
else:
|
| 50 |
+
self.model.cpu()
|
| 51 |
+
|
| 52 |
+
if self.asr_type == "aed":
|
| 53 |
+
start_time = time.time()
|
| 54 |
+
|
| 55 |
+
hyps = self.model.transcribe(
|
| 56 |
+
feats, lengths,
|
| 57 |
+
args.get("beam_size", 1),
|
| 58 |
+
args.get("nbest", 1),
|
| 59 |
+
args.get("decode_max_len", 0),
|
| 60 |
+
args.get("softmax_smoothing", 1.0),
|
| 61 |
+
args.get("aed_length_penalty", 0.0),
|
| 62 |
+
args.get("eos_penalty", 1.0)
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
elapsed = time.time() - start_time
|
| 66 |
+
rtf= elapsed / total_dur if total_dur > 0 else 0
|
| 67 |
+
|
| 68 |
+
results = []
|
| 69 |
+
for uttid, wav, hyp in zip(batch_uttid, batch_wav_path, hyps):
|
| 70 |
+
hyp = hyp[0] # only return 1-best
|
| 71 |
+
hyp_ids = [int(id) for id in hyp["yseq"].cpu()]
|
| 72 |
+
text = self.tokenizer.detokenize(hyp_ids)
|
| 73 |
+
results.append({"uttid": uttid, "text": text, "wav": wav,
|
| 74 |
+
"rtf": f"{rtf:.4f}"})
|
| 75 |
+
return results
|
| 76 |
+
|
| 77 |
+
elif self.asr_type == "llm":
|
| 78 |
+
input_ids, attention_mask, _, _ = \
|
| 79 |
+
LlmTokenizerWrapper.preprocess_texts(
|
| 80 |
+
origin_texts=[""]*feats.size(0), tokenizer=self.tokenizer,
|
| 81 |
+
max_len=128, decode=True)
|
| 82 |
+
if args.get("use_gpu", False):
|
| 83 |
+
input_ids = input_ids.cuda()
|
| 84 |
+
attention_mask = attention_mask.cuda()
|
| 85 |
+
start_time = time.time()
|
| 86 |
+
|
| 87 |
+
generated_ids = self.model.transcribe(
|
| 88 |
+
feats, lengths, input_ids, attention_mask,
|
| 89 |
+
args.get("beam_size", 1),
|
| 90 |
+
args.get("decode_max_len", 0),
|
| 91 |
+
args.get("decode_min_len", 0),
|
| 92 |
+
args.get("repetition_penalty", 1.0),
|
| 93 |
+
args.get("llm_length_penalty", 0.0),
|
| 94 |
+
args.get("temperature", 1.0)
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
elapsed = time.time() - start_time
|
| 98 |
+
rtf= elapsed / total_dur if total_dur > 0 else 0
|
| 99 |
+
texts = self.tokenizer.batch_decode(generated_ids,
|
| 100 |
+
skip_special_tokens=True)
|
| 101 |
+
results = []
|
| 102 |
+
for uttid, wav, text in zip(batch_uttid, batch_wav_path, texts):
|
| 103 |
+
results.append({"uttid": uttid, "text": text, "wav": wav,
|
| 104 |
+
"rtf": f"{rtf:.4f}"})
|
| 105 |
+
return results
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def load_fireredasr_aed_model(model_path):
|
| 110 |
+
package = torch.load(model_path, map_location=lambda storage, loc: storage)
|
| 111 |
+
print("model args:", package["args"])
|
| 112 |
+
model = FireRedAsrAed.from_args(package["args"])
|
| 113 |
+
model.load_state_dict(package["model_state_dict"], strict=True)
|
| 114 |
+
return model
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def load_firered_llm_model_and_tokenizer(model_path, encoder_path, llm_dir):
|
| 118 |
+
package = torch.load(model_path, map_location=lambda storage, loc: storage)
|
| 119 |
+
package["args"].encoder_path = encoder_path
|
| 120 |
+
package["args"].llm_dir = llm_dir
|
| 121 |
+
print("model args:", package["args"])
|
| 122 |
+
model = FireRedAsrLlm.from_args(package["args"])
|
| 123 |
+
model.load_state_dict(package["model_state_dict"], strict=False)
|
| 124 |
+
tokenizer = LlmTokenizerWrapper.build_llm_tokenizer(llm_dir)
|
| 125 |
+
return model, tokenizer
|
fireredasr/fireredasr/models/fireredasr_aed.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
from fireredasr.models.module.conformer_encoder import ConformerEncoder
|
| 4 |
+
from fireredasr.models.module.transformer_decoder import TransformerDecoder
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class FireRedAsrAed(torch.nn.Module):
|
| 8 |
+
@classmethod
|
| 9 |
+
def from_args(cls, args):
|
| 10 |
+
return cls(args)
|
| 11 |
+
|
| 12 |
+
def __init__(self, args):
|
| 13 |
+
super().__init__()
|
| 14 |
+
self.sos_id = args.sos_id
|
| 15 |
+
self.eos_id = args.eos_id
|
| 16 |
+
|
| 17 |
+
self.encoder = ConformerEncoder(
|
| 18 |
+
args.idim, args.n_layers_enc, args.n_head, args.d_model,
|
| 19 |
+
args.residual_dropout, args.dropout_rate,
|
| 20 |
+
args.kernel_size, args.pe_maxlen)
|
| 21 |
+
|
| 22 |
+
self.decoder = TransformerDecoder(
|
| 23 |
+
args.sos_id, args.eos_id, args.pad_id, args.odim,
|
| 24 |
+
args.n_layers_dec, args.n_head, args.d_model,
|
| 25 |
+
args.residual_dropout, args.pe_maxlen)
|
| 26 |
+
|
| 27 |
+
def transcribe(self, padded_input, input_lengths,
|
| 28 |
+
beam_size=1, nbest=1, decode_max_len=0,
|
| 29 |
+
softmax_smoothing=1.0, length_penalty=0.0, eos_penalty=1.0):
|
| 30 |
+
enc_outputs, _, enc_mask = self.encoder(padded_input, input_lengths)
|
| 31 |
+
nbest_hyps = self.decoder.batch_beam_search(
|
| 32 |
+
enc_outputs, enc_mask,
|
| 33 |
+
beam_size, nbest, decode_max_len,
|
| 34 |
+
softmax_smoothing, length_penalty, eos_penalty)
|
| 35 |
+
return nbest_hyps
|
fireredasr/fireredasr/models/fireredasr_llm.py
ADDED
|
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os
|
| 3 |
+
import random
|
| 4 |
+
import re
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
from transformers import AutoModelForCausalLM
|
| 9 |
+
|
| 10 |
+
from fireredasr.models.fireredasr_aed import FireRedAsrAed
|
| 11 |
+
from fireredasr.models.module.adapter import Adapter
|
| 12 |
+
from fireredasr.tokenizer.llm_tokenizer import DEFAULT_SPEECH_TOKEN, IGNORE_TOKEN_ID
|
| 13 |
+
from fireredasr.tokenizer.llm_tokenizer import LlmTokenizerWrapper
|
| 14 |
+
from fireredasr.utils.param import count_model_parameters
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class FireRedAsrLlm(nn.Module):
|
| 18 |
+
@classmethod
|
| 19 |
+
def load_encoder(cls, model_path):
|
| 20 |
+
assert os.path.exists(model_path)
|
| 21 |
+
package = torch.load(model_path, map_location=lambda storage, loc: storage)
|
| 22 |
+
model = FireRedAsrAed.from_args(package["args"])
|
| 23 |
+
if "model_state_dict" in package:
|
| 24 |
+
model.load_state_dict(package["model_state_dict"], strict=False)
|
| 25 |
+
encoder = model.encoder
|
| 26 |
+
encoder_dim = encoder.odim
|
| 27 |
+
return encoder, encoder_dim
|
| 28 |
+
|
| 29 |
+
@classmethod
|
| 30 |
+
def from_args(cls, args):
|
| 31 |
+
logging.info(args)
|
| 32 |
+
logging.info("Build FireRedAsrLlm")
|
| 33 |
+
# Build Speech Encoder
|
| 34 |
+
encoder, encoder_dim = cls.load_encoder(args.encoder_path)
|
| 35 |
+
count_model_parameters(encoder)
|
| 36 |
+
if args.freeze_encoder:
|
| 37 |
+
logging.info(f"Frezee encoder")
|
| 38 |
+
for name, param in encoder.named_parameters():
|
| 39 |
+
param.requires_grad = False
|
| 40 |
+
encoder.eval()
|
| 41 |
+
|
| 42 |
+
if args.use_flash_attn:
|
| 43 |
+
attn_implementation = "flash_attention_2"
|
| 44 |
+
if args.use_fp16:
|
| 45 |
+
torch_dtype = torch.float16
|
| 46 |
+
else:
|
| 47 |
+
torch_dtype = torch.float32
|
| 48 |
+
else:
|
| 49 |
+
attn_implementation = "eager"
|
| 50 |
+
if args.use_fp16:
|
| 51 |
+
torch_dtype = torch.float16
|
| 52 |
+
else:
|
| 53 |
+
torch_dtype = torch.float32
|
| 54 |
+
|
| 55 |
+
# Build LLM
|
| 56 |
+
llm = AutoModelForCausalLM.from_pretrained(
|
| 57 |
+
args.llm_dir,
|
| 58 |
+
attn_implementation=attn_implementation,
|
| 59 |
+
torch_dtype=torch_dtype,
|
| 60 |
+
)
|
| 61 |
+
count_model_parameters(llm)
|
| 62 |
+
|
| 63 |
+
# LLM Freeze or LoRA
|
| 64 |
+
llm_dim = llm.config.hidden_size
|
| 65 |
+
if args.freeze_llm:
|
| 66 |
+
logging.info(f"Frezee LLM")
|
| 67 |
+
for name, param in llm.named_parameters():
|
| 68 |
+
param.requires_grad = False
|
| 69 |
+
llm.eval()
|
| 70 |
+
else:
|
| 71 |
+
if args.use_lora:
|
| 72 |
+
from peft import LoraConfig, get_peft_model
|
| 73 |
+
lora_config = LoraConfig(
|
| 74 |
+
r=64,
|
| 75 |
+
lora_alpha=16,
|
| 76 |
+
target_modules=[
|
| 77 |
+
"q_proj",
|
| 78 |
+
"k_proj",
|
| 79 |
+
"v_proj",
|
| 80 |
+
"o_proj",
|
| 81 |
+
"up_proj",
|
| 82 |
+
"gate_proj",
|
| 83 |
+
"down_proj",
|
| 84 |
+
],
|
| 85 |
+
lora_dropout=0.05,
|
| 86 |
+
task_type="CAUSAL_LM",
|
| 87 |
+
)
|
| 88 |
+
llm = get_peft_model(llm, lora_config)
|
| 89 |
+
llm.print_trainable_parameters()
|
| 90 |
+
|
| 91 |
+
tokenizer = LlmTokenizerWrapper.build_llm_tokenizer(args.llm_dir)
|
| 92 |
+
assert tokenizer.pad_token_id == tokenizer.convert_tokens_to_ids("<|endoftext|>")
|
| 93 |
+
llm.config.pad_token_id = tokenizer.pad_token_id
|
| 94 |
+
llm.config.bos_token_id = tokenizer.convert_tokens_to_ids("<|im_start|>")
|
| 95 |
+
llm.config.eos_token_id = tokenizer.convert_tokens_to_ids("<|im_end|>")
|
| 96 |
+
llm.config.default_speech_token_id = tokenizer.convert_tokens_to_ids(
|
| 97 |
+
DEFAULT_SPEECH_TOKEN
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
# Build projector
|
| 101 |
+
encoder_projector = Adapter(
|
| 102 |
+
encoder_dim, llm_dim, args.encoder_downsample_rate)
|
| 103 |
+
count_model_parameters(encoder_projector)
|
| 104 |
+
|
| 105 |
+
return cls(encoder, llm, encoder_projector,
|
| 106 |
+
args.freeze_encoder, args.freeze_llm)
|
| 107 |
+
|
| 108 |
+
def __init__(self, encoder, llm, encoder_projector,
|
| 109 |
+
freeze_encoder, freeze_llm):
|
| 110 |
+
super().__init__()
|
| 111 |
+
self.encoder = encoder
|
| 112 |
+
self.llm = llm
|
| 113 |
+
self.encoder_projector = encoder_projector
|
| 114 |
+
# args
|
| 115 |
+
self.freeze_encoder = freeze_encoder
|
| 116 |
+
self.freeze_llm = freeze_llm
|
| 117 |
+
self.llm_config = llm.config
|
| 118 |
+
|
| 119 |
+
def transcribe(self, padded_feat, feat_lengths, padded_input_ids, attention_mask,
|
| 120 |
+
beam_size=1, decode_max_len=0, decode_min_len=0,
|
| 121 |
+
repetition_penalty=1.0, llm_length_penalty=1.0, temperature=1.0):
|
| 122 |
+
encoder_outs, enc_lengths, enc_mask = self.encoder(padded_feat, feat_lengths)
|
| 123 |
+
speech_features, speech_lens = self.encoder_projector(encoder_outs, enc_lengths)
|
| 124 |
+
inputs_embeds = self.llm.get_input_embeddings()(padded_input_ids)
|
| 125 |
+
|
| 126 |
+
inputs_embeds, attention_mask, _ = \
|
| 127 |
+
self._merge_input_ids_with_speech_features(
|
| 128 |
+
speech_features.to(inputs_embeds.dtype), inputs_embeds, padded_input_ids, attention_mask,
|
| 129 |
+
speech_lens=speech_lens
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
max_new_tokens = speech_features.size(1) if decode_max_len < 1 else decode_max_len
|
| 133 |
+
max_new_tokens = max(1, max_new_tokens)
|
| 134 |
+
|
| 135 |
+
generated_ids = self.llm.generate(
|
| 136 |
+
inputs_embeds=inputs_embeds,
|
| 137 |
+
max_new_tokens=max_new_tokens,
|
| 138 |
+
num_beams=beam_size,
|
| 139 |
+
do_sample=False,
|
| 140 |
+
min_length=decode_min_len,
|
| 141 |
+
top_p=1.0,
|
| 142 |
+
repetition_penalty=repetition_penalty,
|
| 143 |
+
length_penalty=llm_length_penalty,
|
| 144 |
+
temperature=temperature,
|
| 145 |
+
bos_token_id=self.llm.config.bos_token_id,
|
| 146 |
+
eos_token_id=self.llm.config.eos_token_id,
|
| 147 |
+
pad_token_id=self.llm.config.pad_token_id,
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
return generated_ids
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def _merge_input_ids_with_speech_features(
|
| 154 |
+
self, speech_features, inputs_embeds, input_ids, attention_mask, labels=None,
|
| 155 |
+
speech_lens=None
|
| 156 |
+
):
|
| 157 |
+
"""
|
| 158 |
+
Modified from: https://github.com/k2-fsa/icefall/blob/master/egs/speech_llm/ASR_LLM/whisper_llm_zh/model.py
|
| 159 |
+
"""
|
| 160 |
+
speech_lens = None
|
| 161 |
+
num_speechs, speech_len, embed_dim = speech_features.shape
|
| 162 |
+
batch_size, sequence_length = input_ids.shape
|
| 163 |
+
left_padding = not torch.sum(
|
| 164 |
+
input_ids[:, -1] == torch.tensor(self.llm.config.pad_token_id)
|
| 165 |
+
)
|
| 166 |
+
# 1. Create a mask to know where special speech tokens are
|
| 167 |
+
special_speech_token_mask = input_ids == self.llm.config.default_speech_token_id
|
| 168 |
+
num_special_speech_tokens = torch.sum(special_speech_token_mask, dim=-1)
|
| 169 |
+
# Compute the maximum embed dimension
|
| 170 |
+
max_embed_dim = (
|
| 171 |
+
num_special_speech_tokens.max() * (speech_len - 1)
|
| 172 |
+
) + sequence_length
|
| 173 |
+
batch_indices, non_speech_indices = torch.where(
|
| 174 |
+
input_ids != self.llm.config.default_speech_token_id
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
# 2. Compute the positions where text should be written
|
| 178 |
+
# Calculate new positions for text tokens in merged speech-text sequence.
|
| 179 |
+
# `special_speech_token_mask` identifies speech tokens. Each speech token will be replaced by `nb_text_tokens_per_speechs - 1` text tokens.
|
| 180 |
+
# `torch.cumsum` computes how each speech token shifts subsequent text token positions.
|
| 181 |
+
# - 1 to adjust for zero-based indexing, as `cumsum` inherently increases indices by one.
|
| 182 |
+
new_token_positions = (
|
| 183 |
+
torch.cumsum((special_speech_token_mask * (speech_len - 1) + 1), -1) - 1
|
| 184 |
+
) # (N,U)
|
| 185 |
+
nb_speech_pad = max_embed_dim - 1 - new_token_positions[:, -1]
|
| 186 |
+
if left_padding:
|
| 187 |
+
new_token_positions += nb_speech_pad[:, None] # offset for left padding
|
| 188 |
+
text_to_overwrite = new_token_positions[batch_indices, non_speech_indices]
|
| 189 |
+
|
| 190 |
+
# 3. Create the full embedding, already padded to the maximum position
|
| 191 |
+
final_embedding = torch.zeros(
|
| 192 |
+
batch_size,
|
| 193 |
+
max_embed_dim,
|
| 194 |
+
embed_dim,
|
| 195 |
+
dtype=inputs_embeds.dtype,
|
| 196 |
+
device=inputs_embeds.device,
|
| 197 |
+
)
|
| 198 |
+
final_attention_mask = torch.zeros(
|
| 199 |
+
batch_size,
|
| 200 |
+
max_embed_dim,
|
| 201 |
+
dtype=attention_mask.dtype,
|
| 202 |
+
device=inputs_embeds.device,
|
| 203 |
+
)
|
| 204 |
+
if labels is not None:
|
| 205 |
+
final_labels = torch.full(
|
| 206 |
+
(batch_size, max_embed_dim),
|
| 207 |
+
IGNORE_TOKEN_ID,
|
| 208 |
+
dtype=input_ids.dtype,
|
| 209 |
+
device=input_ids.device,
|
| 210 |
+
)
|
| 211 |
+
# In case the Vision model or the Language model has been offloaded to CPU, we need to manually
|
| 212 |
+
# set the corresponding tensors into their correct target device.
|
| 213 |
+
target_device = inputs_embeds.device
|
| 214 |
+
batch_indices, non_speech_indices, text_to_overwrite = (
|
| 215 |
+
batch_indices.to(target_device),
|
| 216 |
+
non_speech_indices.to(target_device),
|
| 217 |
+
text_to_overwrite.to(target_device),
|
| 218 |
+
)
|
| 219 |
+
attention_mask = attention_mask.to(target_device)
|
| 220 |
+
|
| 221 |
+
# 4. Fill the embeddings based on the mask. If we have ["hey" "<speech>", "how", "are"]
|
| 222 |
+
# we need to index copy on [0, 577, 578, 579] for the text and [1:576] for the speech features
|
| 223 |
+
final_embedding[batch_indices, text_to_overwrite] = inputs_embeds[
|
| 224 |
+
batch_indices, non_speech_indices
|
| 225 |
+
]
|
| 226 |
+
final_attention_mask[batch_indices, text_to_overwrite] = attention_mask[
|
| 227 |
+
batch_indices, non_speech_indices
|
| 228 |
+
]
|
| 229 |
+
if labels is not None:
|
| 230 |
+
final_labels[batch_indices, text_to_overwrite] = labels[
|
| 231 |
+
batch_indices, non_speech_indices
|
| 232 |
+
]
|
| 233 |
+
|
| 234 |
+
# 5. Fill the embeddings corresponding to the speechs. Anything that is not `text_positions` needs filling (#29835)
|
| 235 |
+
speech_to_overwrite = torch.full(
|
| 236 |
+
(batch_size, max_embed_dim),
|
| 237 |
+
True,
|
| 238 |
+
dtype=torch.bool,
|
| 239 |
+
device=inputs_embeds.device,
|
| 240 |
+
)
|
| 241 |
+
speech_to_overwrite[batch_indices, text_to_overwrite] = False
|
| 242 |
+
if speech_lens is not None:
|
| 243 |
+
speech_pad_position = speech_to_overwrite.cumsum(-1) <= speech_lens[:, None]
|
| 244 |
+
speech_to_overwrite &= speech_to_overwrite.cumsum(-1) - 1 >= nb_speech_pad[
|
| 245 |
+
:, None
|
| 246 |
+
].to(target_device)
|
| 247 |
+
|
| 248 |
+
if speech_to_overwrite.sum() != speech_features.shape[:-1].numel():
|
| 249 |
+
raise ValueError(
|
| 250 |
+
f"The input provided to the model are wrong. The number of speech tokens is {torch.sum(special_speech_token_mask)} while"
|
| 251 |
+
f" the number of speech given to the model is {num_speechs}. This prevents correct indexing and breaks batch generation."
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
final_embedding[speech_to_overwrite] = (
|
| 255 |
+
speech_features.contiguous().reshape(-1, embed_dim).to(target_device)
|
| 256 |
+
)
|
| 257 |
+
if speech_lens is not None:
|
| 258 |
+
speech_to_overwrite &= speech_pad_position
|
| 259 |
+
final_attention_mask |= speech_to_overwrite
|
| 260 |
+
|
| 261 |
+
# 6. Mask out the embedding at padding positions, as we later use the past_key_value value to determine the non-attended tokens.
|
| 262 |
+
batch_indices, pad_indices = torch.where(
|
| 263 |
+
input_ids == self.llm.config.pad_token_id
|
| 264 |
+
)
|
| 265 |
+
indices_to_mask = new_token_positions[batch_indices, pad_indices]
|
| 266 |
+
|
| 267 |
+
final_embedding[batch_indices, indices_to_mask] = 0
|
| 268 |
+
|
| 269 |
+
if labels is None:
|
| 270 |
+
final_labels = None
|
| 271 |
+
|
| 272 |
+
return final_embedding, final_attention_mask, final_labels #, position_ids
|
fireredasr/fireredasr/models/module/adapter.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class Adapter(nn.Module):
|
| 6 |
+
def __init__(self, encoder_dim, llm_dim, downsample_rate=2):
|
| 7 |
+
super().__init__()
|
| 8 |
+
self.ds = downsample_rate
|
| 9 |
+
self.linear1 = nn.Linear(encoder_dim * downsample_rate, llm_dim)
|
| 10 |
+
self.relu = nn.ReLU()
|
| 11 |
+
self.linear2 = nn.Linear(llm_dim, llm_dim)
|
| 12 |
+
|
| 13 |
+
def forward(self, x, x_lens):
|
| 14 |
+
batch_size, seq_len, feat_dim = x.size()
|
| 15 |
+
num_frames_to_discard = seq_len % self.ds
|
| 16 |
+
if num_frames_to_discard > 0:
|
| 17 |
+
x = x[:, :-num_frames_to_discard, :]
|
| 18 |
+
seq_len = x.size(1)
|
| 19 |
+
|
| 20 |
+
x = x.contiguous()
|
| 21 |
+
x = x.view(
|
| 22 |
+
batch_size, seq_len // self.ds, feat_dim * self.ds
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
x = self.linear1(x)
|
| 26 |
+
x = self.relu(x)
|
| 27 |
+
x = self.linear2(x)
|
| 28 |
+
|
| 29 |
+
new_x_lens = torch.clamp(x_lens, max=seq_len) // self.ds
|
| 30 |
+
return x, new_x_lens
|
fireredasr/fireredasr/models/module/conformer_encoder.py
ADDED
|
@@ -0,0 +1,322 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class ConformerEncoder(nn.Module):
|
| 7 |
+
def __init__(self, idim, n_layers, n_head, d_model,
|
| 8 |
+
residual_dropout=0.1, dropout_rate=0.1, kernel_size=33,
|
| 9 |
+
pe_maxlen=5000):
|
| 10 |
+
super().__init__()
|
| 11 |
+
self.odim = d_model
|
| 12 |
+
|
| 13 |
+
self.input_preprocessor = Conv2dSubsampling(idim, d_model)
|
| 14 |
+
self.positional_encoding = RelPositionalEncoding(d_model)
|
| 15 |
+
self.dropout = nn.Dropout(residual_dropout)
|
| 16 |
+
|
| 17 |
+
self.layer_stack = nn.ModuleList()
|
| 18 |
+
for l in range(n_layers):
|
| 19 |
+
block = RelPosEmbConformerBlock(d_model, n_head,
|
| 20 |
+
residual_dropout,
|
| 21 |
+
dropout_rate, kernel_size)
|
| 22 |
+
self.layer_stack.append(block)
|
| 23 |
+
|
| 24 |
+
def forward(self, padded_input, input_lengths, pad=True):
|
| 25 |
+
if pad:
|
| 26 |
+
padded_input = F.pad(padded_input,
|
| 27 |
+
(0, 0, 0, self.input_preprocessor.context - 1), 'constant', 0.0)
|
| 28 |
+
src_mask = self.padding_position_is_0(padded_input, input_lengths)
|
| 29 |
+
|
| 30 |
+
embed_output, input_lengths, src_mask = self.input_preprocessor(padded_input, src_mask)
|
| 31 |
+
enc_output = self.dropout(embed_output)
|
| 32 |
+
|
| 33 |
+
pos_emb = self.dropout(self.positional_encoding(embed_output))
|
| 34 |
+
|
| 35 |
+
enc_outputs = []
|
| 36 |
+
for enc_layer in self.layer_stack:
|
| 37 |
+
enc_output = enc_layer(enc_output, pos_emb, slf_attn_mask=src_mask,
|
| 38 |
+
pad_mask=src_mask)
|
| 39 |
+
enc_outputs.append(enc_output)
|
| 40 |
+
|
| 41 |
+
return enc_output, input_lengths, src_mask
|
| 42 |
+
|
| 43 |
+
def padding_position_is_0(self, padded_input, input_lengths):
|
| 44 |
+
N, T = padded_input.size()[:2]
|
| 45 |
+
mask = torch.ones((N, T)).to(padded_input.device)
|
| 46 |
+
for i in range(N):
|
| 47 |
+
mask[i, input_lengths[i]:] = 0
|
| 48 |
+
mask = mask.unsqueeze(dim=1)
|
| 49 |
+
return mask.to(torch.uint8)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class RelPosEmbConformerBlock(nn.Module):
|
| 53 |
+
def __init__(self, d_model, n_head,
|
| 54 |
+
residual_dropout=0.1,
|
| 55 |
+
dropout_rate=0.1, kernel_size=33):
|
| 56 |
+
super().__init__()
|
| 57 |
+
self.ffn1 = ConformerFeedForward(d_model, dropout_rate)
|
| 58 |
+
self.mhsa = RelPosMultiHeadAttention(n_head, d_model,
|
| 59 |
+
residual_dropout)
|
| 60 |
+
self.conv = ConformerConvolution(d_model, kernel_size,
|
| 61 |
+
dropout_rate)
|
| 62 |
+
self.ffn2 = ConformerFeedForward(d_model, dropout_rate)
|
| 63 |
+
self.layer_norm = nn.LayerNorm(d_model)
|
| 64 |
+
|
| 65 |
+
def forward(self, x, pos_emb, slf_attn_mask=None, pad_mask=None):
|
| 66 |
+
out = 0.5 * x + 0.5 * self.ffn1(x)
|
| 67 |
+
out = self.mhsa(out, out, out, pos_emb, mask=slf_attn_mask)[0]
|
| 68 |
+
out = self.conv(out, pad_mask)
|
| 69 |
+
out = 0.5 * out + 0.5 * self.ffn2(out)
|
| 70 |
+
out = self.layer_norm(out)
|
| 71 |
+
return out
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
class Swish(nn.Module):
|
| 75 |
+
def forward(self, x):
|
| 76 |
+
return x * torch.sigmoid(x)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class Conv2dSubsampling(nn.Module):
|
| 80 |
+
def __init__(self, idim, d_model, out_channels=32):
|
| 81 |
+
super().__init__()
|
| 82 |
+
self.conv = nn.Sequential(
|
| 83 |
+
nn.Conv2d(1, out_channels, 3, 2),
|
| 84 |
+
nn.ReLU(),
|
| 85 |
+
nn.Conv2d(out_channels, out_channels, 3, 2),
|
| 86 |
+
nn.ReLU(),
|
| 87 |
+
)
|
| 88 |
+
subsample_idim = ((idim - 1) // 2 - 1) // 2
|
| 89 |
+
self.out = nn.Linear(out_channels * subsample_idim, d_model)
|
| 90 |
+
|
| 91 |
+
self.subsampling = 4
|
| 92 |
+
left_context = right_context = 3 # both exclude currect frame
|
| 93 |
+
self.context = left_context + 1 + right_context # 7
|
| 94 |
+
|
| 95 |
+
def forward(self, x, x_mask):
|
| 96 |
+
x = x.unsqueeze(1)
|
| 97 |
+
x = self.conv(x)
|
| 98 |
+
N, C, T, D = x.size()
|
| 99 |
+
x = self.out(x.transpose(1, 2).contiguous().view(N, T, C * D))
|
| 100 |
+
mask = x_mask[:, :, :-2:2][:, :, :-2:2]
|
| 101 |
+
input_lengths = mask[:, -1, :].sum(dim=-1)
|
| 102 |
+
return x, input_lengths, mask
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
class RelPositionalEncoding(torch.nn.Module):
|
| 106 |
+
def __init__(self, d_model, max_len=5000):
|
| 107 |
+
super().__init__()
|
| 108 |
+
pe_positive = torch.zeros(max_len, d_model, requires_grad=False)
|
| 109 |
+
pe_negative = torch.zeros(max_len, d_model, requires_grad=False)
|
| 110 |
+
position = torch.arange(0, max_len).unsqueeze(1).float()
|
| 111 |
+
div_term = torch.exp(torch.arange(0, d_model, 2).float() *
|
| 112 |
+
-(torch.log(torch.tensor(10000.0)).item()/d_model))
|
| 113 |
+
pe_positive[:, 0::2] = torch.sin(position * div_term)
|
| 114 |
+
pe_positive[:, 1::2] = torch.cos(position * div_term)
|
| 115 |
+
pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
|
| 116 |
+
pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
|
| 117 |
+
|
| 118 |
+
pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
|
| 119 |
+
pe_negative = pe_negative[1:].unsqueeze(0)
|
| 120 |
+
pe = torch.cat([pe_positive, pe_negative], dim=1)
|
| 121 |
+
self.register_buffer('pe', pe)
|
| 122 |
+
|
| 123 |
+
def forward(self, x):
|
| 124 |
+
# Tmax = 2 * max_len - 1
|
| 125 |
+
Tmax, T = self.pe.size(1), x.size(1)
|
| 126 |
+
pos_emb = self.pe[:, Tmax // 2 - T + 1 : Tmax // 2 + T].clone().detach()
|
| 127 |
+
return pos_emb
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
class ConformerFeedForward(nn.Module):
|
| 131 |
+
def __init__(self, d_model, dropout_rate=0.1):
|
| 132 |
+
super().__init__()
|
| 133 |
+
pre_layer_norm = nn.LayerNorm(d_model)
|
| 134 |
+
linear_expand = nn.Linear(d_model, d_model*4)
|
| 135 |
+
nonlinear = Swish()
|
| 136 |
+
dropout_pre = nn.Dropout(dropout_rate)
|
| 137 |
+
linear_project = nn.Linear(d_model*4, d_model)
|
| 138 |
+
dropout_post = nn.Dropout(dropout_rate)
|
| 139 |
+
self.net = nn.Sequential(pre_layer_norm,
|
| 140 |
+
linear_expand,
|
| 141 |
+
nonlinear,
|
| 142 |
+
dropout_pre,
|
| 143 |
+
linear_project,
|
| 144 |
+
dropout_post)
|
| 145 |
+
|
| 146 |
+
def forward(self, x):
|
| 147 |
+
residual = x
|
| 148 |
+
output = self.net(x)
|
| 149 |
+
output = output + residual
|
| 150 |
+
return output
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
class ConformerConvolution(nn.Module):
|
| 154 |
+
def __init__(self, d_model, kernel_size=33, dropout_rate=0.1):
|
| 155 |
+
super().__init__()
|
| 156 |
+
assert kernel_size % 2 == 1
|
| 157 |
+
self.pre_layer_norm = nn.LayerNorm(d_model)
|
| 158 |
+
self.pointwise_conv1 = nn.Conv1d(d_model, d_model*4, kernel_size=1, bias=False)
|
| 159 |
+
self.glu = F.glu
|
| 160 |
+
self.padding = (kernel_size - 1) // 2
|
| 161 |
+
self.depthwise_conv = nn.Conv1d(d_model*2, d_model*2,
|
| 162 |
+
kernel_size, stride=1,
|
| 163 |
+
padding=self.padding,
|
| 164 |
+
groups=d_model*2, bias=False)
|
| 165 |
+
self.batch_norm = nn.LayerNorm(d_model*2)
|
| 166 |
+
self.swish = Swish()
|
| 167 |
+
self.pointwise_conv2 = nn.Conv1d(d_model*2, d_model, kernel_size=1, bias=False)
|
| 168 |
+
self.dropout = nn.Dropout(dropout_rate)
|
| 169 |
+
|
| 170 |
+
def forward(self, x, mask=None):
|
| 171 |
+
residual = x
|
| 172 |
+
out = self.pre_layer_norm(x)
|
| 173 |
+
out = out.transpose(1, 2)
|
| 174 |
+
if mask is not None:
|
| 175 |
+
out.masked_fill_(mask.ne(1), 0.0)
|
| 176 |
+
out = self.pointwise_conv1(out)
|
| 177 |
+
out = F.glu(out, dim=1)
|
| 178 |
+
out = self.depthwise_conv(out)
|
| 179 |
+
|
| 180 |
+
out = out.transpose(1, 2)
|
| 181 |
+
out = self.swish(self.batch_norm(out))
|
| 182 |
+
out = out.transpose(1, 2)
|
| 183 |
+
|
| 184 |
+
out = self.dropout(self.pointwise_conv2(out))
|
| 185 |
+
if mask is not None:
|
| 186 |
+
out.masked_fill_(mask.ne(1), 0.0)
|
| 187 |
+
out = out.transpose(1, 2)
|
| 188 |
+
return out + residual
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
class EncoderMultiHeadAttention(nn.Module):
|
| 192 |
+
def __init__(self, n_head, d_model,
|
| 193 |
+
residual_dropout=0.1):
|
| 194 |
+
super().__init__()
|
| 195 |
+
assert d_model % n_head == 0
|
| 196 |
+
self.n_head = n_head
|
| 197 |
+
self.d_k = d_model // n_head
|
| 198 |
+
self.d_v = self.d_k
|
| 199 |
+
|
| 200 |
+
self.w_qs = nn.Linear(d_model, n_head * self.d_k, bias=False)
|
| 201 |
+
self.w_ks = nn.Linear(d_model, n_head * self.d_k, bias=False)
|
| 202 |
+
self.w_vs = nn.Linear(d_model, n_head * self.d_v, bias=False)
|
| 203 |
+
|
| 204 |
+
self.layer_norm_q = nn.LayerNorm(d_model)
|
| 205 |
+
self.layer_norm_k = nn.LayerNorm(d_model)
|
| 206 |
+
self.layer_norm_v = nn.LayerNorm(d_model)
|
| 207 |
+
|
| 208 |
+
self.attention = ScaledDotProductAttention(temperature=self.d_k ** 0.5)
|
| 209 |
+
self.fc = nn.Linear(n_head * self.d_v, d_model, bias=False)
|
| 210 |
+
self.dropout = nn.Dropout(residual_dropout)
|
| 211 |
+
|
| 212 |
+
def forward(self, q, k, v, mask=None):
|
| 213 |
+
sz_b, len_q = q.size(0), q.size(1)
|
| 214 |
+
|
| 215 |
+
residual = q
|
| 216 |
+
q, k, v = self.forward_qkv(q, k, v)
|
| 217 |
+
|
| 218 |
+
output, attn = self.attention(q, k, v, mask=mask)
|
| 219 |
+
|
| 220 |
+
output = self.forward_output(output, residual, sz_b, len_q)
|
| 221 |
+
return output, attn
|
| 222 |
+
|
| 223 |
+
def forward_qkv(self, q, k, v):
|
| 224 |
+
d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
|
| 225 |
+
sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1)
|
| 226 |
+
|
| 227 |
+
q = self.layer_norm_q(q)
|
| 228 |
+
k = self.layer_norm_k(k)
|
| 229 |
+
v = self.layer_norm_v(v)
|
| 230 |
+
|
| 231 |
+
q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
|
| 232 |
+
k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
|
| 233 |
+
v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)
|
| 234 |
+
q = q.transpose(1, 2)
|
| 235 |
+
k = k.transpose(1, 2)
|
| 236 |
+
v = v.transpose(1, 2)
|
| 237 |
+
return q, k, v
|
| 238 |
+
|
| 239 |
+
def forward_output(self, output, residual, sz_b, len_q):
|
| 240 |
+
output = output.transpose(1, 2).contiguous().view(sz_b, len_q, -1)
|
| 241 |
+
fc_out = self.fc(output)
|
| 242 |
+
output = self.dropout(fc_out)
|
| 243 |
+
output = output + residual
|
| 244 |
+
return output
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
class ScaledDotProductAttention(nn.Module):
|
| 248 |
+
def __init__(self, temperature):
|
| 249 |
+
super().__init__()
|
| 250 |
+
self.temperature = temperature
|
| 251 |
+
self.dropout = nn.Dropout(0.0)
|
| 252 |
+
self.INF = float('inf')
|
| 253 |
+
|
| 254 |
+
def forward(self, q, k, v, mask=None):
|
| 255 |
+
attn = torch.matmul(q, k.transpose(2, 3)) / self.temperature
|
| 256 |
+
output, attn = self.forward_attention(attn, v, mask)
|
| 257 |
+
return output, attn
|
| 258 |
+
|
| 259 |
+
def forward_attention(self, attn, v, mask=None):
|
| 260 |
+
if mask is not None:
|
| 261 |
+
mask = mask.unsqueeze(1)
|
| 262 |
+
mask = mask.eq(0)
|
| 263 |
+
attn = attn.masked_fill(mask, -self.INF)
|
| 264 |
+
attn = torch.softmax(attn, dim=-1).masked_fill(mask, 0.0)
|
| 265 |
+
else:
|
| 266 |
+
attn = torch.softmax(attn, dim=-1)
|
| 267 |
+
|
| 268 |
+
d_attn = self.dropout(attn)
|
| 269 |
+
output = torch.matmul(d_attn, v)
|
| 270 |
+
|
| 271 |
+
return output, attn
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
class RelPosMultiHeadAttention(EncoderMultiHeadAttention):
|
| 275 |
+
def __init__(self, n_head, d_model,
|
| 276 |
+
residual_dropout=0.1):
|
| 277 |
+
super().__init__(n_head, d_model,
|
| 278 |
+
residual_dropout)
|
| 279 |
+
d_k = d_model // n_head
|
| 280 |
+
self.scale = 1.0 / (d_k ** 0.5)
|
| 281 |
+
self.linear_pos = nn.Linear(d_model, n_head * d_k, bias=False)
|
| 282 |
+
self.pos_bias_u = nn.Parameter(torch.FloatTensor(n_head, d_k))
|
| 283 |
+
self.pos_bias_v = nn.Parameter(torch.FloatTensor(n_head, d_k))
|
| 284 |
+
torch.nn.init.xavier_uniform_(self.pos_bias_u)
|
| 285 |
+
torch.nn.init.xavier_uniform_(self.pos_bias_v)
|
| 286 |
+
|
| 287 |
+
def _rel_shift(self, x):
|
| 288 |
+
N, H, T1, T2 = x.size()
|
| 289 |
+
zero_pad = torch.zeros((N, H, T1, 1), device=x.device, dtype=x.dtype)
|
| 290 |
+
x_padded = torch.cat([zero_pad, x], dim=-1)
|
| 291 |
+
|
| 292 |
+
x_padded = x_padded.view(N, H, T2 + 1, T1)
|
| 293 |
+
x = x_padded[:, :, 1:].view_as(x)
|
| 294 |
+
x = x[:, :, :, : x.size(-1) // 2 + 1]
|
| 295 |
+
return x
|
| 296 |
+
|
| 297 |
+
def forward(self, q, k, v, pos_emb, mask=None):
|
| 298 |
+
sz_b, len_q = q.size(0), q.size(1)
|
| 299 |
+
|
| 300 |
+
residual = q
|
| 301 |
+
q, k, v = self.forward_qkv(q, k, v)
|
| 302 |
+
|
| 303 |
+
q = q.transpose(1, 2)
|
| 304 |
+
n_batch_pos = pos_emb.size(0)
|
| 305 |
+
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.n_head, self.d_k)
|
| 306 |
+
p = p.transpose(1, 2)
|
| 307 |
+
|
| 308 |
+
q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
|
| 309 |
+
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
|
| 310 |
+
|
| 311 |
+
matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
|
| 312 |
+
|
| 313 |
+
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
|
| 314 |
+
matrix_bd = self._rel_shift(matrix_bd)
|
| 315 |
+
|
| 316 |
+
attn_scores = matrix_ac + matrix_bd
|
| 317 |
+
attn_scores.mul_(self.scale)
|
| 318 |
+
|
| 319 |
+
output, attn = self.attention.forward_attention(attn_scores, v, mask=mask)
|
| 320 |
+
|
| 321 |
+
output = self.forward_output(output, residual, sz_b, len_q)
|
| 322 |
+
return output, attn
|
fireredasr/fireredasr/models/module/transformer_decoder.py
ADDED
|
@@ -0,0 +1,299 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List, Optional, Dict
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
from torch import Tensor
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class TransformerDecoder(nn.Module):
|
| 10 |
+
def __init__(
|
| 11 |
+
self, sos_id, eos_id, pad_id, odim,
|
| 12 |
+
n_layers, n_head, d_model,
|
| 13 |
+
residual_dropout=0.1, pe_maxlen=5000):
|
| 14 |
+
super().__init__()
|
| 15 |
+
self.INF = 1e10
|
| 16 |
+
# parameters
|
| 17 |
+
self.pad_id = pad_id
|
| 18 |
+
self.sos_id = sos_id
|
| 19 |
+
self.eos_id = eos_id
|
| 20 |
+
self.n_layers = n_layers
|
| 21 |
+
|
| 22 |
+
# Components
|
| 23 |
+
self.tgt_word_emb = nn.Embedding(odim, d_model, padding_idx=self.pad_id)
|
| 24 |
+
self.positional_encoding = PositionalEncoding(d_model, max_len=pe_maxlen)
|
| 25 |
+
self.dropout = nn.Dropout(residual_dropout)
|
| 26 |
+
|
| 27 |
+
self.layer_stack = nn.ModuleList()
|
| 28 |
+
for l in range(n_layers):
|
| 29 |
+
block = DecoderLayer(d_model, n_head, residual_dropout)
|
| 30 |
+
self.layer_stack.append(block)
|
| 31 |
+
|
| 32 |
+
self.tgt_word_prj = nn.Linear(d_model, odim, bias=False)
|
| 33 |
+
self.layer_norm_out = nn.LayerNorm(d_model)
|
| 34 |
+
|
| 35 |
+
self.tgt_word_prj.weight = self.tgt_word_emb.weight
|
| 36 |
+
self.scale = (d_model ** 0.5)
|
| 37 |
+
|
| 38 |
+
def batch_beam_search(self, encoder_outputs, src_masks,
|
| 39 |
+
beam_size=1, nbest=1, decode_max_len=0,
|
| 40 |
+
softmax_smoothing=1.0, length_penalty=0.0, eos_penalty=1.0):
|
| 41 |
+
B = beam_size
|
| 42 |
+
N, Ti, H = encoder_outputs.size()
|
| 43 |
+
device = encoder_outputs.device
|
| 44 |
+
maxlen = decode_max_len if decode_max_len > 0 else Ti
|
| 45 |
+
assert eos_penalty > 0.0 and eos_penalty <= 1.0
|
| 46 |
+
|
| 47 |
+
# Init
|
| 48 |
+
encoder_outputs = encoder_outputs.unsqueeze(1).repeat(1, B, 1, 1).view(N*B, Ti, H)
|
| 49 |
+
src_mask = src_masks.unsqueeze(1).repeat(1, B, 1, 1).view(N*B, -1, Ti)
|
| 50 |
+
ys = torch.ones(N*B, 1).fill_(self.sos_id).long().to(device)
|
| 51 |
+
caches: List[Optional[Tensor]] = []
|
| 52 |
+
for _ in range(self.n_layers):
|
| 53 |
+
caches.append(None)
|
| 54 |
+
scores = torch.tensor([0.0] + [-self.INF]*(B-1)).float().to(device)
|
| 55 |
+
scores = scores.repeat(N).view(N*B, 1)
|
| 56 |
+
is_finished = torch.zeros_like(scores)
|
| 57 |
+
|
| 58 |
+
# Autoregressive Prediction
|
| 59 |
+
for t in range(maxlen):
|
| 60 |
+
tgt_mask = self.ignored_target_position_is_0(ys, self.pad_id)
|
| 61 |
+
|
| 62 |
+
dec_output = self.dropout(
|
| 63 |
+
self.tgt_word_emb(ys) * self.scale +
|
| 64 |
+
self.positional_encoding(ys))
|
| 65 |
+
|
| 66 |
+
i = 0
|
| 67 |
+
for dec_layer in self.layer_stack:
|
| 68 |
+
dec_output = dec_layer.forward(
|
| 69 |
+
dec_output, encoder_outputs,
|
| 70 |
+
tgt_mask, src_mask,
|
| 71 |
+
cache=caches[i])
|
| 72 |
+
caches[i] = dec_output
|
| 73 |
+
i += 1
|
| 74 |
+
|
| 75 |
+
dec_output = self.layer_norm_out(dec_output)
|
| 76 |
+
|
| 77 |
+
t_logit = self.tgt_word_prj(dec_output[:, -1])
|
| 78 |
+
t_scores = F.log_softmax(t_logit / softmax_smoothing, dim=-1)
|
| 79 |
+
|
| 80 |
+
if eos_penalty != 1.0:
|
| 81 |
+
t_scores[:, self.eos_id] *= eos_penalty
|
| 82 |
+
|
| 83 |
+
t_topB_scores, t_topB_ys = torch.topk(t_scores, k=B, dim=1)
|
| 84 |
+
t_topB_scores = self.set_finished_beam_score_to_zero(t_topB_scores, is_finished)
|
| 85 |
+
t_topB_ys = self.set_finished_beam_y_to_eos(t_topB_ys, is_finished)
|
| 86 |
+
|
| 87 |
+
# Accumulated
|
| 88 |
+
scores = scores + t_topB_scores
|
| 89 |
+
|
| 90 |
+
# Pruning
|
| 91 |
+
scores = scores.view(N, B*B)
|
| 92 |
+
scores, topB_score_ids = torch.topk(scores, k=B, dim=1)
|
| 93 |
+
scores = scores.view(-1, 1)
|
| 94 |
+
|
| 95 |
+
topB_row_number_in_each_B_rows_of_ys = torch.div(topB_score_ids, B).view(N*B)
|
| 96 |
+
stride = B * torch.arange(N).view(N, 1).repeat(1, B).view(N*B).to(device)
|
| 97 |
+
topB_row_number_in_ys = topB_row_number_in_each_B_rows_of_ys.long() + stride.long()
|
| 98 |
+
|
| 99 |
+
# Update ys
|
| 100 |
+
ys = ys[topB_row_number_in_ys]
|
| 101 |
+
t_ys = torch.gather(t_topB_ys.view(N, B*B), dim=1, index=topB_score_ids).view(N*B, 1)
|
| 102 |
+
ys = torch.cat((ys, t_ys), dim=1)
|
| 103 |
+
|
| 104 |
+
# Update caches
|
| 105 |
+
new_caches: List[Optional[Tensor]] = []
|
| 106 |
+
for cache in caches:
|
| 107 |
+
if cache is not None:
|
| 108 |
+
new_caches.append(cache[topB_row_number_in_ys])
|
| 109 |
+
caches = new_caches
|
| 110 |
+
|
| 111 |
+
# Update finished state
|
| 112 |
+
is_finished = t_ys.eq(self.eos_id)
|
| 113 |
+
if is_finished.sum().item() == N*B:
|
| 114 |
+
break
|
| 115 |
+
|
| 116 |
+
# Length penalty (follow GNMT)
|
| 117 |
+
scores = scores.view(N, B)
|
| 118 |
+
ys = ys.view(N, B, -1)
|
| 119 |
+
ys_lengths = self.get_ys_lengths(ys)
|
| 120 |
+
if length_penalty > 0.0:
|
| 121 |
+
penalty = torch.pow((5+ys_lengths.float())/(5.0+1), length_penalty)
|
| 122 |
+
scores /= penalty
|
| 123 |
+
nbest_scores, nbest_ids = torch.topk(scores, k=int(nbest), dim=1)
|
| 124 |
+
nbest_scores = -1.0 * nbest_scores
|
| 125 |
+
index = nbest_ids + B * torch.arange(N).view(N, 1).to(device).long()
|
| 126 |
+
nbest_ys = ys.view(N*B, -1)[index.view(-1)]
|
| 127 |
+
nbest_ys = nbest_ys.view(N, nbest_ids.size(1), -1)
|
| 128 |
+
nbest_ys_lengths = ys_lengths.view(N*B)[index.view(-1)].view(N, -1)
|
| 129 |
+
|
| 130 |
+
# result
|
| 131 |
+
nbest_hyps: List[List[Dict[str, Tensor]]] = []
|
| 132 |
+
for n in range(N):
|
| 133 |
+
n_nbest_hyps: List[Dict[str, Tensor]] = []
|
| 134 |
+
for i, score in enumerate(nbest_scores[n]):
|
| 135 |
+
new_hyp = {
|
| 136 |
+
"yseq": nbest_ys[n, i, 1:nbest_ys_lengths[n, i]]
|
| 137 |
+
}
|
| 138 |
+
n_nbest_hyps.append(new_hyp)
|
| 139 |
+
nbest_hyps.append(n_nbest_hyps)
|
| 140 |
+
return nbest_hyps
|
| 141 |
+
|
| 142 |
+
def ignored_target_position_is_0(self, padded_targets, ignore_id):
|
| 143 |
+
mask = torch.ne(padded_targets, ignore_id)
|
| 144 |
+
mask = mask.unsqueeze(dim=1)
|
| 145 |
+
T = padded_targets.size(-1)
|
| 146 |
+
upper_tri_0_mask = self.upper_triangular_is_0(T).unsqueeze(0).to(mask.dtype)
|
| 147 |
+
upper_tri_0_mask = upper_tri_0_mask.to(mask.dtype).to(mask.device)
|
| 148 |
+
return mask.to(torch.uint8) & upper_tri_0_mask.to(torch.uint8)
|
| 149 |
+
|
| 150 |
+
def upper_triangular_is_0(self, size):
|
| 151 |
+
ones = torch.ones(size, size)
|
| 152 |
+
tri_left_ones = torch.tril(ones)
|
| 153 |
+
return tri_left_ones.to(torch.uint8)
|
| 154 |
+
|
| 155 |
+
def set_finished_beam_score_to_zero(self, scores, is_finished):
|
| 156 |
+
NB, B = scores.size()
|
| 157 |
+
is_finished = is_finished.float()
|
| 158 |
+
mask_score = torch.tensor([0.0] + [-self.INF]*(B-1)).float().to(scores.device)
|
| 159 |
+
mask_score = mask_score.view(1, B).repeat(NB, 1)
|
| 160 |
+
return scores * (1 - is_finished) + mask_score * is_finished
|
| 161 |
+
|
| 162 |
+
def set_finished_beam_y_to_eos(self, ys, is_finished):
|
| 163 |
+
is_finished = is_finished.long()
|
| 164 |
+
return ys * (1 - is_finished) + self.eos_id * is_finished
|
| 165 |
+
|
| 166 |
+
def get_ys_lengths(self, ys):
|
| 167 |
+
N, B, Tmax = ys.size()
|
| 168 |
+
ys_lengths = torch.sum(torch.ne(ys, self.eos_id), dim=-1)
|
| 169 |
+
return ys_lengths.int()
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
class DecoderLayer(nn.Module):
|
| 174 |
+
def __init__(self, d_model, n_head, dropout):
|
| 175 |
+
super().__init__()
|
| 176 |
+
self.self_attn_norm = nn.LayerNorm(d_model)
|
| 177 |
+
self.self_attn = DecoderMultiHeadAttention(d_model, n_head, dropout)
|
| 178 |
+
|
| 179 |
+
self.cross_attn_norm = nn.LayerNorm(d_model)
|
| 180 |
+
self.cross_attn = DecoderMultiHeadAttention(d_model, n_head, dropout)
|
| 181 |
+
|
| 182 |
+
self.mlp_norm = nn.LayerNorm(d_model)
|
| 183 |
+
self.mlp = PositionwiseFeedForward(d_model, d_model*4, dropout)
|
| 184 |
+
|
| 185 |
+
def forward(self, dec_input, enc_output, self_attn_mask, cross_attn_mask,
|
| 186 |
+
cache=None):
|
| 187 |
+
x = dec_input
|
| 188 |
+
residual = x
|
| 189 |
+
x = self.self_attn_norm(x)
|
| 190 |
+
if cache is not None:
|
| 191 |
+
xq = x[:, -1:, :]
|
| 192 |
+
residual = residual[:, -1:, :]
|
| 193 |
+
self_attn_mask = self_attn_mask[:, -1:, :]
|
| 194 |
+
else:
|
| 195 |
+
xq = x
|
| 196 |
+
x = self.self_attn(xq, x, x, mask=self_attn_mask)
|
| 197 |
+
x = residual + x
|
| 198 |
+
|
| 199 |
+
residual = x
|
| 200 |
+
x = self.cross_attn_norm(x)
|
| 201 |
+
x = self.cross_attn(x, enc_output, enc_output, mask=cross_attn_mask)
|
| 202 |
+
x = residual + x
|
| 203 |
+
|
| 204 |
+
residual = x
|
| 205 |
+
x = self.mlp_norm(x)
|
| 206 |
+
x = residual + self.mlp(x)
|
| 207 |
+
|
| 208 |
+
if cache is not None:
|
| 209 |
+
x = torch.cat([cache, x], dim=1)
|
| 210 |
+
|
| 211 |
+
return x
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
class DecoderMultiHeadAttention(nn.Module):
|
| 215 |
+
def __init__(self, d_model, n_head, dropout=0.1):
|
| 216 |
+
super().__init__()
|
| 217 |
+
self.d_model = d_model
|
| 218 |
+
self.n_head = n_head
|
| 219 |
+
self.d_k = d_model // n_head
|
| 220 |
+
|
| 221 |
+
self.w_qs = nn.Linear(d_model, n_head * self.d_k)
|
| 222 |
+
self.w_ks = nn.Linear(d_model, n_head * self.d_k, bias=False)
|
| 223 |
+
self.w_vs = nn.Linear(d_model, n_head * self.d_k)
|
| 224 |
+
|
| 225 |
+
self.attention = DecoderScaledDotProductAttention(
|
| 226 |
+
temperature=self.d_k ** 0.5)
|
| 227 |
+
self.fc = nn.Linear(n_head * self.d_k, d_model)
|
| 228 |
+
self.dropout = nn.Dropout(dropout)
|
| 229 |
+
|
| 230 |
+
def forward(self, q, k, v, mask=None):
|
| 231 |
+
bs = q.size(0)
|
| 232 |
+
|
| 233 |
+
q = self.w_qs(q).view(bs, -1, self.n_head, self.d_k)
|
| 234 |
+
k = self.w_ks(k).view(bs, -1, self.n_head, self.d_k)
|
| 235 |
+
v = self.w_vs(v).view(bs, -1, self.n_head, self.d_k)
|
| 236 |
+
q = q.transpose(1, 2)
|
| 237 |
+
k = k.transpose(1, 2)
|
| 238 |
+
v = v.transpose(1, 2)
|
| 239 |
+
|
| 240 |
+
if mask is not None:
|
| 241 |
+
mask = mask.unsqueeze(1)
|
| 242 |
+
|
| 243 |
+
output = self.attention(q, k, v, mask=mask)
|
| 244 |
+
|
| 245 |
+
output = output.transpose(1, 2).contiguous().view(bs, -1, self.d_model)
|
| 246 |
+
output = self.fc(output)
|
| 247 |
+
output = self.dropout(output)
|
| 248 |
+
|
| 249 |
+
return output
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
class DecoderScaledDotProductAttention(nn.Module):
|
| 253 |
+
def __init__(self, temperature):
|
| 254 |
+
super().__init__()
|
| 255 |
+
self.temperature = temperature
|
| 256 |
+
self.INF = float("inf")
|
| 257 |
+
|
| 258 |
+
def forward(self, q, k, v, mask=None):
|
| 259 |
+
attn = torch.matmul(q, k.transpose(2, 3)) / self.temperature
|
| 260 |
+
if mask is not None:
|
| 261 |
+
mask = mask.eq(0)
|
| 262 |
+
attn = attn.masked_fill(mask, -self.INF)
|
| 263 |
+
attn = torch.softmax(attn, dim=-1).masked_fill(mask, 0.0)
|
| 264 |
+
else:
|
| 265 |
+
attn = torch.softmax(attn, dim=-1)
|
| 266 |
+
output = torch.matmul(attn, v)
|
| 267 |
+
return output
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
class PositionwiseFeedForward(nn.Module):
|
| 271 |
+
def __init__(self, d_model, d_ff, dropout=0.1):
|
| 272 |
+
super().__init__()
|
| 273 |
+
self.w_1 = nn.Linear(d_model, d_ff)
|
| 274 |
+
self.act = nn.GELU()
|
| 275 |
+
self.w_2 = nn.Linear(d_ff, d_model)
|
| 276 |
+
self.dropout = nn.Dropout(dropout)
|
| 277 |
+
|
| 278 |
+
def forward(self, x):
|
| 279 |
+
output = self.w_2(self.act(self.w_1(x)))
|
| 280 |
+
output = self.dropout(output)
|
| 281 |
+
return output
|
| 282 |
+
|
| 283 |
+
|
| 284 |
+
class PositionalEncoding(nn.Module):
|
| 285 |
+
def __init__(self, d_model, max_len=5000):
|
| 286 |
+
super().__init__()
|
| 287 |
+
assert d_model % 2 == 0
|
| 288 |
+
pe = torch.zeros(max_len, d_model, requires_grad=False)
|
| 289 |
+
position = torch.arange(0, max_len).unsqueeze(1).float()
|
| 290 |
+
div_term = torch.exp(torch.arange(0, d_model, 2).float() *
|
| 291 |
+
-(torch.log(torch.tensor(10000.0)).item()/d_model))
|
| 292 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
| 293 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
| 294 |
+
pe = pe.unsqueeze(0)
|
| 295 |
+
self.register_buffer('pe', pe)
|
| 296 |
+
|
| 297 |
+
def forward(self, x):
|
| 298 |
+
length = x.size(1)
|
| 299 |
+
return self.pe[:, :length].clone().detach()
|
fireredasr/fireredasr/speech2text.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import glob
|
| 5 |
+
import os
|
| 6 |
+
import sys
|
| 7 |
+
|
| 8 |
+
from fireredasr.models.fireredasr import FireRedAsr
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
parser = argparse.ArgumentParser()
|
| 12 |
+
parser.add_argument('--asr_type', type=str, required=True, choices=["aed", "llm"])
|
| 13 |
+
parser.add_argument('--model_dir', type=str, required=True)
|
| 14 |
+
|
| 15 |
+
# Input / Output
|
| 16 |
+
parser.add_argument("--wav_path", type=str)
|
| 17 |
+
parser.add_argument("--wav_paths", type=str, nargs="*")
|
| 18 |
+
parser.add_argument("--wav_dir", type=str)
|
| 19 |
+
parser.add_argument("--wav_scp", type=str)
|
| 20 |
+
parser.add_argument("--output", type=str)
|
| 21 |
+
|
| 22 |
+
# Decode Options
|
| 23 |
+
parser.add_argument('--use_gpu', type=int, default=1)
|
| 24 |
+
parser.add_argument("--batch_size", type=int, default=1)
|
| 25 |
+
parser.add_argument("--beam_size", type=int, default=1)
|
| 26 |
+
parser.add_argument("--decode_max_len", type=int, default=0)
|
| 27 |
+
# FireRedASR-AED
|
| 28 |
+
parser.add_argument("--nbest", type=int, default=1)
|
| 29 |
+
parser.add_argument("--softmax_smoothing", type=float, default=1.0)
|
| 30 |
+
parser.add_argument("--aed_length_penalty", type=float, default=0.0)
|
| 31 |
+
parser.add_argument("--eos_penalty", type=float, default=1.0)
|
| 32 |
+
# FireRedASR-LLM
|
| 33 |
+
parser.add_argument("--decode_min_len", type=int, default=0)
|
| 34 |
+
parser.add_argument("--repetition_penalty", type=float, default=1.0)
|
| 35 |
+
parser.add_argument("--llm_length_penalty", type=float, default=0.0)
|
| 36 |
+
parser.add_argument("--temperature", type=float, default=1.0)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def main(args):
|
| 40 |
+
wavs = get_wav_info(args)
|
| 41 |
+
fout = open(args.output, "w") if args.output else None
|
| 42 |
+
|
| 43 |
+
model = FireRedAsr.from_pretrained(args.asr_type, args.model_dir)
|
| 44 |
+
|
| 45 |
+
batch_uttid = []
|
| 46 |
+
batch_wav_path = []
|
| 47 |
+
for i, wav in enumerate(wavs):
|
| 48 |
+
uttid, wav_path = wav
|
| 49 |
+
batch_uttid.append(uttid)
|
| 50 |
+
batch_wav_path.append(wav_path)
|
| 51 |
+
if len(batch_wav_path) < args.batch_size and i != len(wavs) - 1:
|
| 52 |
+
continue
|
| 53 |
+
|
| 54 |
+
results = model.transcribe(
|
| 55 |
+
batch_uttid,
|
| 56 |
+
batch_wav_path,
|
| 57 |
+
{
|
| 58 |
+
"use_gpu": args.use_gpu,
|
| 59 |
+
"beam_size": args.beam_size,
|
| 60 |
+
"nbest": args.nbest,
|
| 61 |
+
"decode_max_len": args.decode_max_len,
|
| 62 |
+
"softmax_smoothing": args.softmax_smoothing,
|
| 63 |
+
"aed_length_penalty": args.aed_length_penalty,
|
| 64 |
+
"eos_penalty": args.eos_penalty,
|
| 65 |
+
"decode_min_len": args.decode_min_len,
|
| 66 |
+
"repetition_penalty": args.repetition_penalty,
|
| 67 |
+
"llm_length_penalty": args.llm_length_penalty,
|
| 68 |
+
"temperature": args.temperature
|
| 69 |
+
}
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
for result in results:
|
| 73 |
+
print(result)
|
| 74 |
+
if fout is not None:
|
| 75 |
+
fout.write(f"{result['uttid']}\t{result['text']}\n")
|
| 76 |
+
|
| 77 |
+
batch_uttid = []
|
| 78 |
+
batch_wav_path = []
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def get_wav_info(args):
|
| 82 |
+
"""
|
| 83 |
+
Returns:
|
| 84 |
+
wavs: list of (uttid, wav_path)
|
| 85 |
+
"""
|
| 86 |
+
base = lambda p: os.path.basename(p).replace(".wav", "")
|
| 87 |
+
if args.wav_path:
|
| 88 |
+
wavs = [(base(args.wav_path), args.wav_path)]
|
| 89 |
+
elif args.wav_paths and len(args.wav_paths) >= 1:
|
| 90 |
+
wavs = [(base(p), p) for p in sorted(args.wav_paths)]
|
| 91 |
+
elif args.wav_scp:
|
| 92 |
+
wavs = [line.strip().split() for line in open(args.wav_scp)]
|
| 93 |
+
elif args.wav_dir:
|
| 94 |
+
wavs = glob.glob(f"{args.wav_dir}/**/*.wav", recursive=True)
|
| 95 |
+
wavs = [(base(p), p) for p in sorted(wavs)]
|
| 96 |
+
else:
|
| 97 |
+
raise ValueError("Please provide valid wav info")
|
| 98 |
+
print(f"#wavs={len(wavs)}")
|
| 99 |
+
return wavs
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
if __name__ == "__main__":
|
| 103 |
+
args = parser.parse_args()
|
| 104 |
+
print(args)
|
| 105 |
+
main(args)
|
fireredasr/fireredasr/tokenizer/aed_tokenizer.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import re
|
| 3 |
+
|
| 4 |
+
import sentencepiece as spm
|
| 5 |
+
|
| 6 |
+
from fireredasr.data.token_dict import TokenDict
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class ChineseCharEnglishSpmTokenizer:
|
| 10 |
+
"""
|
| 11 |
+
- One Chinese char is a token.
|
| 12 |
+
- Split English word into SPM and one piece is a token.
|
| 13 |
+
- Ignore ' ' between Chinese char
|
| 14 |
+
- Replace ' ' between English word with "▁" by spm_model
|
| 15 |
+
- Need to put SPM piece into dict file
|
| 16 |
+
- If not set spm_model, will use English char and <space>
|
| 17 |
+
"""
|
| 18 |
+
SPM_SPACE = "▁"
|
| 19 |
+
|
| 20 |
+
def __init__(self, dict_path, spm_model, unk="<unk>", space="<space>"):
|
| 21 |
+
self.dict = TokenDict(dict_path, unk=unk)
|
| 22 |
+
self.space = space
|
| 23 |
+
if spm_model:
|
| 24 |
+
self.sp = spm.SentencePieceProcessor()
|
| 25 |
+
self.sp.Load(spm_model)
|
| 26 |
+
else:
|
| 27 |
+
self.sp = None
|
| 28 |
+
print("[WRAN] Not set spm_model, will use English char")
|
| 29 |
+
print("[WARN] Please check how to deal with ' '(space)")
|
| 30 |
+
if self.space not in self.dict:
|
| 31 |
+
print("Please add <space> to your dict, or it will be <unk>")
|
| 32 |
+
|
| 33 |
+
def tokenize(self, text, replace_punc=True):
|
| 34 |
+
#if text == "":
|
| 35 |
+
# logging.info(f"empty text")
|
| 36 |
+
text = text.upper()
|
| 37 |
+
tokens = []
|
| 38 |
+
if replace_punc:
|
| 39 |
+
text = re.sub("[,。?!,\.?!]", " ", text)
|
| 40 |
+
pattern = re.compile(r'([\u3400-\u4dbf\u4e00-\u9fff])')
|
| 41 |
+
parts = pattern.split(text.strip())
|
| 42 |
+
parts = [p for p in parts if len(p.strip()) > 0]
|
| 43 |
+
for part in parts:
|
| 44 |
+
if pattern.fullmatch(part) is not None:
|
| 45 |
+
tokens.append(part)
|
| 46 |
+
else:
|
| 47 |
+
if self.sp:
|
| 48 |
+
for piece in self.sp.EncodeAsPieces(part.strip()):
|
| 49 |
+
tokens.append(piece)
|
| 50 |
+
else:
|
| 51 |
+
for char in part.strip():
|
| 52 |
+
tokens.append(char if char != " " else self.space)
|
| 53 |
+
tokens_id = []
|
| 54 |
+
for token in tokens:
|
| 55 |
+
tokens_id.append(self.dict.get(token, self.dict.unk))
|
| 56 |
+
return tokens, tokens_id
|
| 57 |
+
|
| 58 |
+
def detokenize(self, inputs, join_symbol="", replace_spm_space=True):
|
| 59 |
+
"""inputs is ids or tokens, do not need self.sp"""
|
| 60 |
+
if len(inputs) > 0 and type(inputs[0]) == int:
|
| 61 |
+
tokens = [self.dict[id] for id in inputs]
|
| 62 |
+
else:
|
| 63 |
+
tokens = inputs
|
| 64 |
+
s = f"{join_symbol}".join(tokens)
|
| 65 |
+
if replace_spm_space:
|
| 66 |
+
s = s.replace(self.SPM_SPACE, ' ').strip()
|
| 67 |
+
return s
|
fireredasr/fireredasr/tokenizer/llm_tokenizer.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from transformers import AutoTokenizer
|
| 5 |
+
from transformers.trainer_pt_utils import LabelSmoother
|
| 6 |
+
|
| 7 |
+
DEFAULT_SPEECH_TOKEN = "<speech>"
|
| 8 |
+
IGNORE_TOKEN_ID = LabelSmoother.ignore_index
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class LlmTokenizerWrapper:
|
| 12 |
+
@classmethod
|
| 13 |
+
def build_llm_tokenizer(cls, llm_path, use_flash_attn=False):
|
| 14 |
+
tokenizer = AutoTokenizer.from_pretrained(llm_path)
|
| 15 |
+
if use_flash_attn:
|
| 16 |
+
tokenizer.padding_side = "left"
|
| 17 |
+
else:
|
| 18 |
+
tokenizer.padding_side = "right"
|
| 19 |
+
special_tokens_dict = {"additional_special_tokens": [DEFAULT_SPEECH_TOKEN]}
|
| 20 |
+
tokenizer.add_special_tokens(special_tokens_dict)
|
| 21 |
+
return tokenizer
|
| 22 |
+
|
| 23 |
+
@classmethod
|
| 24 |
+
def clean_text(cls, origin_text):
|
| 25 |
+
"""remove punc, remove space between Chinese and keep space between English"""
|
| 26 |
+
# remove punc
|
| 27 |
+
text = re.sub("[,。?!,\.!?《》()\·“”、\\/]", "", origin_text)
|
| 28 |
+
# merge space
|
| 29 |
+
text = re.sub("\s+", " ", text)
|
| 30 |
+
|
| 31 |
+
# remove space between Chinese and keep space between English
|
| 32 |
+
pattern = re.compile(r'([\u3400-\u4dbf\u4e00-\u9fff])') # Chinese
|
| 33 |
+
parts = pattern.split(text.strip())
|
| 34 |
+
parts = [p for p in parts if len(p.strip()) > 0]
|
| 35 |
+
text = "".join(parts)
|
| 36 |
+
text = text.strip()
|
| 37 |
+
|
| 38 |
+
text = text.lower()
|
| 39 |
+
return text
|
| 40 |
+
|
| 41 |
+
@classmethod
|
| 42 |
+
def preprocess_texts(cls, origin_texts, tokenizer, max_len, decode=False):
|
| 43 |
+
messages = []
|
| 44 |
+
clean_texts = []
|
| 45 |
+
for i, origin_text in enumerate(origin_texts):
|
| 46 |
+
text = cls.clean_text(origin_text)
|
| 47 |
+
clean_texts.append(text)
|
| 48 |
+
text = text if not decode else ""
|
| 49 |
+
message = [
|
| 50 |
+
{"role": "user", "content": f"{DEFAULT_SPEECH_TOKEN}请转写音频为文字"},
|
| 51 |
+
{"role": "assistant", "content": text},
|
| 52 |
+
]
|
| 53 |
+
messages.append(message)
|
| 54 |
+
|
| 55 |
+
texts = []
|
| 56 |
+
if not decode:
|
| 57 |
+
TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{ '<|im_end|>'}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}"
|
| 58 |
+
else:
|
| 59 |
+
TEMPLATE = "{% for message in messages %}{{'<|im_start|>' + message['role'] + '\n' + message['content']}}{% if loop.last %}{{''}}{% else %}{{ '<|im_end|>\n' }}{% endif %}{% endfor %}"
|
| 60 |
+
for i, msg in enumerate(messages):
|
| 61 |
+
texts.append(
|
| 62 |
+
tokenizer.apply_chat_template(
|
| 63 |
+
msg,
|
| 64 |
+
tokenize=True,
|
| 65 |
+
chat_template=TEMPLATE,
|
| 66 |
+
add_generation_prompt=False,
|
| 67 |
+
padding="longest",
|
| 68 |
+
max_length=max_len,
|
| 69 |
+
truncation=True,
|
| 70 |
+
)
|
| 71 |
+
)
|
| 72 |
+
|
| 73 |
+
# Padding texts
|
| 74 |
+
max_len_texts = max([len(text) for text in texts])
|
| 75 |
+
if tokenizer.padding_side == "right":
|
| 76 |
+
texts = [
|
| 77 |
+
text + [tokenizer.pad_token_id] * (max_len_texts - len(text))
|
| 78 |
+
for text in texts
|
| 79 |
+
]
|
| 80 |
+
else:
|
| 81 |
+
texts = [
|
| 82 |
+
[tokenizer.pad_token_id] * (max_len_texts - len(text)) + text
|
| 83 |
+
for text in texts
|
| 84 |
+
]
|
| 85 |
+
input_ids = torch.tensor(texts, dtype=torch.int)
|
| 86 |
+
|
| 87 |
+
target_ids = input_ids.clone()
|
| 88 |
+
target_ids[target_ids == tokenizer.pad_token_id] = IGNORE_TOKEN_ID
|
| 89 |
+
|
| 90 |
+
# first get the indices of the tokens
|
| 91 |
+
mask_prompt = True
|
| 92 |
+
if mask_prompt:
|
| 93 |
+
mask_indices = torch.where(
|
| 94 |
+
input_ids == tokenizer.convert_tokens_to_ids("assistant")
|
| 95 |
+
)
|
| 96 |
+
for i in range(mask_indices[0].size(0)):
|
| 97 |
+
row = mask_indices[0][i]
|
| 98 |
+
col = mask_indices[1][i]
|
| 99 |
+
target_ids[row, : col + 2] = IGNORE_TOKEN_ID
|
| 100 |
+
|
| 101 |
+
attention_mask = input_ids.ne(tokenizer.pad_token_id)
|
| 102 |
+
|
| 103 |
+
target_ids = target_ids.type(torch.LongTensor)
|
| 104 |
+
input_ids = input_ids.type(torch.LongTensor)
|
| 105 |
+
return input_ids, attention_mask, target_ids, clean_texts
|
fireredasr/fireredasr/utils/param.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def count_model_parameters(model):
|
| 7 |
+
if not isinstance(model, torch.nn.Module):
|
| 8 |
+
return 0, 0
|
| 9 |
+
name = f"{model.__class__.__name__} {model.__class__}"
|
| 10 |
+
num = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 11 |
+
size = num * 4.0 / 1024.0 / 1024.0 # float32, MB
|
| 12 |
+
logging.info(f"#param of {name} is {num} = {size:.1f} MB (float32)")
|
| 13 |
+
return num, size
|
fireredasr/fireredasr/utils/wer.py
ADDED
|
@@ -0,0 +1,303 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import re
|
| 5 |
+
from collections import OrderedDict
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
parser = argparse.ArgumentParser()
|
| 9 |
+
parser.add_argument("--ref", type=str, required=True)
|
| 10 |
+
parser.add_argument("--hyp", type=str, required=True)
|
| 11 |
+
parser.add_argument("--print_sentence_wer", type=int, default=0)
|
| 12 |
+
parser.add_argument("--do_tn", type=int, default=0, help="simple tn by cn2an")
|
| 13 |
+
parser.add_argument("--rm_special", type=int, default=0, help="remove <\|.*?\|>")
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def main(args):
|
| 17 |
+
uttid2refs = read_uttid2tokens(args.ref, args.do_tn, args.rm_special)
|
| 18 |
+
uttid2hyps = read_uttid2tokens(args.hyp, args.do_tn, args.rm_special)
|
| 19 |
+
uttid2wer_info, wer_stat, en_dig_stat = compute_uttid2wer_info(
|
| 20 |
+
uttid2refs, uttid2hyps, args.print_sentence_wer)
|
| 21 |
+
wer_stat.print()
|
| 22 |
+
en_dig_stat.print()
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def read_uttid2tokens(filename, do_tn=False, rm_special=False):
|
| 26 |
+
print(f">>> Read uttid to tokens: {filename}", flush=True)
|
| 27 |
+
uttid2tokens = OrderedDict()
|
| 28 |
+
uttid2text = read_uttid2text(filename, do_tn, rm_special)
|
| 29 |
+
for uttid, text in uttid2text.items():
|
| 30 |
+
tokens = text2tokens(text)
|
| 31 |
+
uttid2tokens[uttid] = tokens
|
| 32 |
+
return uttid2tokens
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def read_uttid2text(filename, do_tn=False, rm_special=False):
|
| 36 |
+
uttid2text = OrderedDict()
|
| 37 |
+
with open(filename, "r", encoding="utf8") as fin:
|
| 38 |
+
for i, line in enumerate(fin):
|
| 39 |
+
cols = line.split()
|
| 40 |
+
if len(cols) == 0:
|
| 41 |
+
print("[WARN] empty line, continue", i, flush=True)
|
| 42 |
+
continue
|
| 43 |
+
assert cols[0] not in uttid2text, f"repeated uttid: {line}"
|
| 44 |
+
if len(cols) == 1:
|
| 45 |
+
uttid2text[cols[0]] = ""
|
| 46 |
+
continue
|
| 47 |
+
txt = " ".join(cols[1:])
|
| 48 |
+
if rm_special:
|
| 49 |
+
txt = " ".join([t for t in re.split("<\|.*?\|>", txt) if t.strip() != ""])
|
| 50 |
+
if do_tn:
|
| 51 |
+
import cn2an
|
| 52 |
+
txt = cn2an.transform(txt, "an2cn")
|
| 53 |
+
uttid2text[cols[0]] = txt
|
| 54 |
+
return uttid2text
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def text2tokens(text):
|
| 58 |
+
PUNCTUATIONS = ",。?!,\.?!"#$%&'()*+-/:;<=>@[\]^_`{|}~⦅⦆「」、 、〃〈〉《》「」『』【】〔〕〖〗〘〙〚〛〜〝〞〟〰〾〿–—‘’‛“”„‟…‧﹏﹑﹔·。\":" + "()\[\]{}/;`|=+"
|
| 59 |
+
if text == "":
|
| 60 |
+
return []
|
| 61 |
+
tokens = []
|
| 62 |
+
|
| 63 |
+
text = re.sub("<unk>", "", text)
|
| 64 |
+
text = re.sub(r"[%s]+" % PUNCTUATIONS, " ", text)
|
| 65 |
+
|
| 66 |
+
pattern = re.compile(r'([\u4e00-\u9fff])')
|
| 67 |
+
parts = pattern.split(text.strip().upper())
|
| 68 |
+
parts = [p for p in parts if len(p.strip()) > 0]
|
| 69 |
+
for part in parts:
|
| 70 |
+
if pattern.fullmatch(part) is not None:
|
| 71 |
+
tokens.append(part)
|
| 72 |
+
else:
|
| 73 |
+
for word in part.strip().split():
|
| 74 |
+
tokens.append(word)
|
| 75 |
+
return tokens
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def compute_uttid2wer_info(refs, hyps, print_sentence_wer=False):
|
| 79 |
+
print(f">>> Compute uttid to wer info", flush=True)
|
| 80 |
+
|
| 81 |
+
uttid2wer_info = OrderedDict()
|
| 82 |
+
wer_stat = WerStats()
|
| 83 |
+
en_dig_stat = EnDigStats()
|
| 84 |
+
|
| 85 |
+
for uttid, ref in refs.items():
|
| 86 |
+
if uttid not in hyps:
|
| 87 |
+
print(f"[WARN] No hyp for {uttid}", flush=True)
|
| 88 |
+
continue
|
| 89 |
+
hyp = hyps[uttid]
|
| 90 |
+
|
| 91 |
+
if len(hyp) - len(ref) >= 8:
|
| 92 |
+
print(f"[BidLengthDiff]: {uttid} {len(ref)} {len(hyp)}#{' '.join(ref)}#{' '.join(hyp)}")
|
| 93 |
+
#continue
|
| 94 |
+
|
| 95 |
+
wer_info = compute_one_wer_info(ref, hyp)
|
| 96 |
+
uttid2wer_info[uttid] = wer_info
|
| 97 |
+
ns = count_english_ditgit(ref, hyp, wer_info)
|
| 98 |
+
wer_stat.add(wer_info)
|
| 99 |
+
en_dig_stat.add(*ns)
|
| 100 |
+
if print_sentence_wer:
|
| 101 |
+
print(f"{uttid} {wer_info}")
|
| 102 |
+
|
| 103 |
+
return uttid2wer_info, wer_stat, en_dig_stat
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
COST_SUB = 3
|
| 107 |
+
COST_DEL = 3
|
| 108 |
+
COST_INS = 3
|
| 109 |
+
|
| 110 |
+
ALIGN_CRT = 0
|
| 111 |
+
ALIGN_SUB = 1
|
| 112 |
+
ALIGN_DEL = 2
|
| 113 |
+
ALIGN_INS = 3
|
| 114 |
+
ALIGN_END = 4
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def compute_one_wer_info(ref, hyp):
|
| 118 |
+
"""Impl minimum edit distance and backtrace.
|
| 119 |
+
Args:
|
| 120 |
+
ref, hyp: List[str]
|
| 121 |
+
Returns:
|
| 122 |
+
WerInfo
|
| 123 |
+
"""
|
| 124 |
+
ref_len = len(ref)
|
| 125 |
+
hyp_len = len(hyp)
|
| 126 |
+
|
| 127 |
+
class _DpPoint:
|
| 128 |
+
def __init__(self, cost, align):
|
| 129 |
+
self.cost = cost
|
| 130 |
+
self.align = align
|
| 131 |
+
|
| 132 |
+
dp = []
|
| 133 |
+
for i in range(0, ref_len + 1):
|
| 134 |
+
dp.append([])
|
| 135 |
+
for j in range(0, hyp_len + 1):
|
| 136 |
+
dp[-1].append(_DpPoint(i * j, ALIGN_CRT))
|
| 137 |
+
|
| 138 |
+
# Initialize
|
| 139 |
+
for i in range(1, hyp_len + 1):
|
| 140 |
+
dp[0][i].cost = dp[0][i - 1].cost + COST_INS;
|
| 141 |
+
dp[0][i].align = ALIGN_INS
|
| 142 |
+
for i in range(1, ref_len + 1):
|
| 143 |
+
dp[i][0].cost = dp[i - 1][0].cost + COST_DEL
|
| 144 |
+
dp[i][0].align = ALIGN_DEL
|
| 145 |
+
|
| 146 |
+
# DP
|
| 147 |
+
for i in range(1, ref_len + 1):
|
| 148 |
+
for j in range(1, hyp_len + 1):
|
| 149 |
+
min_cost = 0
|
| 150 |
+
min_align = ALIGN_CRT
|
| 151 |
+
if hyp[j - 1] == ref[i - 1]:
|
| 152 |
+
min_cost = dp[i - 1][j - 1].cost
|
| 153 |
+
min_align = ALIGN_CRT
|
| 154 |
+
else:
|
| 155 |
+
min_cost = dp[i - 1][j - 1].cost + COST_SUB
|
| 156 |
+
min_align = ALIGN_SUB
|
| 157 |
+
|
| 158 |
+
del_cost = dp[i - 1][j].cost + COST_DEL
|
| 159 |
+
if del_cost < min_cost:
|
| 160 |
+
min_cost = del_cost
|
| 161 |
+
min_align = ALIGN_DEL
|
| 162 |
+
|
| 163 |
+
ins_cost = dp[i][j - 1].cost + COST_INS
|
| 164 |
+
if ins_cost < min_cost:
|
| 165 |
+
min_cost = ins_cost
|
| 166 |
+
min_align = ALIGN_INS
|
| 167 |
+
|
| 168 |
+
dp[i][j].cost = min_cost
|
| 169 |
+
dp[i][j].align = min_align
|
| 170 |
+
|
| 171 |
+
# Backtrace
|
| 172 |
+
crt = sub = ins = det = 0
|
| 173 |
+
i = ref_len
|
| 174 |
+
j = hyp_len
|
| 175 |
+
align = []
|
| 176 |
+
while i > 0 or j > 0:
|
| 177 |
+
if dp[i][j].align == ALIGN_CRT:
|
| 178 |
+
align.append((i, j, ALIGN_CRT))
|
| 179 |
+
i -= 1
|
| 180 |
+
j -= 1
|
| 181 |
+
crt += 1
|
| 182 |
+
elif dp[i][j].align == ALIGN_SUB:
|
| 183 |
+
align.append((i, j, ALIGN_SUB))
|
| 184 |
+
i -= 1
|
| 185 |
+
j -= 1
|
| 186 |
+
sub += 1
|
| 187 |
+
elif dp[i][j].align == ALIGN_DEL:
|
| 188 |
+
align.append((i, j, ALIGN_DEL))
|
| 189 |
+
i -= 1
|
| 190 |
+
det += 1
|
| 191 |
+
elif dp[i][j].align == ALIGN_INS:
|
| 192 |
+
align.append((i, j, ALIGN_INS))
|
| 193 |
+
j -= 1
|
| 194 |
+
ins += 1
|
| 195 |
+
|
| 196 |
+
err = sub + det + ins
|
| 197 |
+
align.reverse()
|
| 198 |
+
wer_info = WerInfo(ref_len, err, crt, sub, det, ins, align)
|
| 199 |
+
return wer_info
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
class WerInfo:
|
| 204 |
+
def __init__(self, ref, err, crt, sub, dele, ins, ali):
|
| 205 |
+
self.r = ref
|
| 206 |
+
self.e = err
|
| 207 |
+
self.c = crt
|
| 208 |
+
self.s = sub
|
| 209 |
+
self.d = dele
|
| 210 |
+
self.i = ins
|
| 211 |
+
self.ali = ali
|
| 212 |
+
r = max(self.r, 1)
|
| 213 |
+
self.wer = 100.0 * (self.s + self.d + self.i) / r
|
| 214 |
+
|
| 215 |
+
def __repr__(self):
|
| 216 |
+
s = f"wer {self.wer:.2f} ref {self.r:2d} sub {self.s:2d} del {self.d:2d} ins {self.i:2d}"
|
| 217 |
+
return s
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
class WerStats:
|
| 221 |
+
def __init__(self):
|
| 222 |
+
self.infos = []
|
| 223 |
+
|
| 224 |
+
def add(self, wer_info):
|
| 225 |
+
self.infos.append(wer_info)
|
| 226 |
+
|
| 227 |
+
def print(self):
|
| 228 |
+
r = sum(info.r for info in self.infos)
|
| 229 |
+
if r <= 0:
|
| 230 |
+
print(f"REF len is {r}, check")
|
| 231 |
+
r = 1
|
| 232 |
+
s = sum(info.s for info in self.infos)
|
| 233 |
+
d = sum(info.d for info in self.infos)
|
| 234 |
+
i = sum(info.i for info in self.infos)
|
| 235 |
+
se = 100.0 * s / r
|
| 236 |
+
de = 100.0 * d / r
|
| 237 |
+
ie = 100.0 * i / r
|
| 238 |
+
wer = 100.0 * (s + d + i) / r
|
| 239 |
+
sen = max(len(self.infos), 1)
|
| 240 |
+
errsen = sum(info.e > 0 for info in self.infos)
|
| 241 |
+
ser = 100.0 * errsen / sen
|
| 242 |
+
print("-"*80)
|
| 243 |
+
print(f"ref{r:6d} sub{s:6d} del{d:6d} ins{i:6d}")
|
| 244 |
+
print(f"WER{wer:6.2f} sub{se:6.2f} del{de:6.2f} ins{ie:6.2f}")
|
| 245 |
+
print(f"SER{ser:6.2f} = {errsen} / {sen}")
|
| 246 |
+
print("-"*80)
|
| 247 |
+
|
| 248 |
+
|
| 249 |
+
class EnDigStats:
|
| 250 |
+
def __init__(self):
|
| 251 |
+
self.n_en_word = 0
|
| 252 |
+
self.n_en_correct = 0
|
| 253 |
+
self.n_dig_word = 0
|
| 254 |
+
self.n_dig_correct = 0
|
| 255 |
+
|
| 256 |
+
def add(self, n_en_word, n_en_correct, n_dig_word, n_dig_correct):
|
| 257 |
+
self.n_en_word += n_en_word
|
| 258 |
+
self.n_en_correct += n_en_correct
|
| 259 |
+
self.n_dig_word += n_dig_word
|
| 260 |
+
self.n_dig_correct += n_dig_correct
|
| 261 |
+
|
| 262 |
+
def print(self):
|
| 263 |
+
print(f"English #word={self.n_en_word}, #correct={self.n_en_correct}\n"
|
| 264 |
+
f"Digit #word={self.n_dig_word}, #correct={self.n_dig_correct}")
|
| 265 |
+
print("-"*80)
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
def count_english_ditgit(ref, hyp, wer_info):
|
| 270 |
+
patt_en = "[a-zA-Z\.\-\']+"
|
| 271 |
+
patt_dig = "[0-9]+"
|
| 272 |
+
patt_cjk = re.compile(r'([\u4e00-\u9fff])')
|
| 273 |
+
n_en_word = 0
|
| 274 |
+
n_en_correct = 0
|
| 275 |
+
n_dig_word = 0
|
| 276 |
+
n_dig_correct = 0
|
| 277 |
+
ali = wer_info.ali
|
| 278 |
+
for i, token in enumerate(ref):
|
| 279 |
+
if re.match(patt_en, token):
|
| 280 |
+
n_en_word += 1
|
| 281 |
+
for y in ali:
|
| 282 |
+
if y[0] == i+1 and y[2] == ALIGN_CRT:
|
| 283 |
+
j = y[1] - 1
|
| 284 |
+
n_en_correct += 1
|
| 285 |
+
break
|
| 286 |
+
if re.match(patt_dig, token):
|
| 287 |
+
n_dig_word += 1
|
| 288 |
+
for y in ali:
|
| 289 |
+
if y[0] == i+1 and y[2] == ALIGN_CRT:
|
| 290 |
+
j = y[1] - 1
|
| 291 |
+
n_dig_correct += 1
|
| 292 |
+
break
|
| 293 |
+
if not re.match(patt_cjk, token) and not re.match(patt_en, token) \
|
| 294 |
+
and not re.match(patt_dig, token):
|
| 295 |
+
print("[WiredChar]:", token)
|
| 296 |
+
return n_en_word, n_en_correct, n_dig_word, n_dig_correct
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
if __name__ == "__main__":
|
| 301 |
+
args = parser.parse_args()
|
| 302 |
+
print(args, flush=True)
|
| 303 |
+
main(args)
|
fireredasr/pretrained_models/README.md
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
Put pretrained models here.
|
fireredasr/requirements.txt
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
cn2an>=0.5.23
|
| 2 |
+
kaldiio>=2.18.0
|
| 3 |
+
kaldi_native_fbank>=1.15
|
| 4 |
+
numpy>=1.26.1
|
| 5 |
+
peft>=0.13.2
|
| 6 |
+
sentencepiece
|
| 7 |
+
torch>=2.0.0
|
| 8 |
+
transformers>=4.46.3
|