Upload 10 files
Browse files- LICENSE.txt +201 -0
- checkpoints_translation/checkpoint_epoch_1_valloss_4.8841.pt +3 -0
- checkpoints_translation/checkpoint_epoch_2_valloss_4.3551.pt +3 -0
- checkpoints_translation/checkpoint_epoch_3_valloss_4.1226.pt +3 -0
- checkpoints_translation/checkpoint_epoch_4_valloss_nan.pt +3 -0
- checkpoints_translation/checkpoint_epoch_5_valloss_nan.pt +3 -0
- opus_en_zh_tokenizer.json +0 -0
- translate_train.py +255 -0
- translator_loader.py +195 -0
- verify_cuda.py +13 -0
LICENSE.txt
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright 2025 Timur Hromek
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
checkpoints_translation/checkpoint_epoch_1_valloss_4.8841.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:c2debe09caf3aa76cf341d4985eb301077ee5b790fc4d5bc8348b45b93797316
|
| 3 |
+
size 260563794
|
checkpoints_translation/checkpoint_epoch_2_valloss_4.3551.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:97472d527a0abd2777adea5d272d8c98d2a1db30e7f3de809e639d91fd31b35a
|
| 3 |
+
size 260563794
|
checkpoints_translation/checkpoint_epoch_3_valloss_4.1226.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a9dee8be3fb660152565e108f09ac660afa0defaef8b5c50a43218bad39777cc
|
| 3 |
+
size 260563794
|
checkpoints_translation/checkpoint_epoch_4_valloss_nan.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:bf50ebd062d5abf0ce49936d7fbe129c271807b70dc6a9db5a5c305b2f668606
|
| 3 |
+
size 260562255
|
checkpoints_translation/checkpoint_epoch_5_valloss_nan.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3e77b63efc8ba8ca6dc99b25e7ed28e2796bed38f5f67ee5572bc4aa897b9b12
|
| 3 |
+
size 260562255
|
opus_en_zh_tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
translate_train.py
ADDED
|
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import math
|
| 3 |
+
import logging
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from datetime import datetime
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
import torch.nn as nn
|
| 9 |
+
from torch.nn.utils.rnn import pad_sequence
|
| 10 |
+
from torch.utils.data import DataLoader, Dataset
|
| 11 |
+
|
| 12 |
+
from datasets import load_dataset
|
| 13 |
+
from tokenizers import Tokenizer
|
| 14 |
+
from tokenizers.models import BPE
|
| 15 |
+
from tokenizers.trainers import BpeTrainer
|
| 16 |
+
from tokenizers.pre_tokenizers import Whitespace
|
| 17 |
+
from tqdm import tqdm
|
| 18 |
+
|
| 19 |
+
# --- Configuration ---
|
| 20 |
+
CONFIG = {
|
| 21 |
+
"SRC_LANG": "en",
|
| 22 |
+
"TGT_LANG": "zh",
|
| 23 |
+
"TOKENIZER_FILE": "opus_en_zh_tokenizer.json",
|
| 24 |
+
"MAX_SEQ_LEN": 128,
|
| 25 |
+
"VOCAB_SIZE": 32000,
|
| 26 |
+
"DIM": 256,
|
| 27 |
+
"ENCODER_LAYERS": 4,
|
| 28 |
+
"DECODER_LAYERS": 4,
|
| 29 |
+
"N_HEADS": 8,
|
| 30 |
+
"FF_DIM": 512,
|
| 31 |
+
"DROPOUT": 0.1,
|
| 32 |
+
"BATCH_SIZE": 64,
|
| 33 |
+
"LEARNING_RATE": 5e-4,
|
| 34 |
+
"NUM_EPOCHS": 5,
|
| 35 |
+
"CHECKPOINT_DIR": "checkpoints_translation",
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 39 |
+
|
| 40 |
+
# --- Tokenizer Manager ---
|
| 41 |
+
class TokenizerManager:
|
| 42 |
+
# ... (No changes needed in this class)
|
| 43 |
+
def __init__(self, config):
|
| 44 |
+
self.config = config
|
| 45 |
+
self.tokenizer_path = Path(self.config["TOKENIZER_FILE"])
|
| 46 |
+
self.special_tokens = ["<unk>", "<pad>", "<s>", "</s>"]
|
| 47 |
+
def get_text_iterator(self):
|
| 48 |
+
dataset = load_dataset(f"Helsinki-NLP/opus-100", f"{self.config['SRC_LANG']}-{self.config['TGT_LANG']}", split="train", streaming=True)
|
| 49 |
+
for item in dataset: yield item['translation'][self.config['SRC_LANG']]; yield item['translation'][self.config['TGT_LANG']]
|
| 50 |
+
def train_tokenizer(self):
|
| 51 |
+
logging.info("Training a new tokenizer...")
|
| 52 |
+
tokenizer = Tokenizer(BPE(unk_token="<unk>")); tokenizer.pre_tokenizer = Whitespace()
|
| 53 |
+
trainer = BpeTrainer(vocab_size=self.config["VOCAB_SIZE"], special_tokens=self.special_tokens)
|
| 54 |
+
tokenizer.train_from_iterator(self.get_text_iterator(), trainer=trainer)
|
| 55 |
+
tokenizer.save(str(self.tokenizer_path)); logging.info(f"Tokenizer trained and saved to {self.tokenizer_path}")
|
| 56 |
+
return tokenizer
|
| 57 |
+
def get_tokenizer(self):
|
| 58 |
+
if not self.tokenizer_path.exists(): return self.train_tokenizer()
|
| 59 |
+
logging.info(f"Loading existing tokenizer from {self.tokenizer_path}")
|
| 60 |
+
return Tokenizer.from_file(str(self.tokenizer_path))
|
| 61 |
+
|
| 62 |
+
# --- Dataset and Dataloader ---
|
| 63 |
+
class OpusDataset(Dataset):
|
| 64 |
+
# ... (No changes needed in this class)
|
| 65 |
+
def __init__(self, tokenizer, config, split="train"):
|
| 66 |
+
self.tokenizer = tokenizer; self.config = config
|
| 67 |
+
dataset = load_dataset(f"Helsinki-NLP/opus-100", f"{config['SRC_LANG']}-{config['TGT_LANG']}", split=split)
|
| 68 |
+
self.pairs = [item['translation'] for item in dataset]
|
| 69 |
+
self.src_lang, self.tgt_lang, self.max_len = config["SRC_LANG"], config["TGT_LANG"], config["MAX_SEQ_LEN"]
|
| 70 |
+
self.bos_id, self.eos_id, self.pad_id = tokenizer.token_to_id("<s>"), tokenizer.token_to_id("</s>"), tokenizer.token_to_id("<pad>")
|
| 71 |
+
def __len__(self): return len(self.pairs)
|
| 72 |
+
def __getitem__(self, idx):
|
| 73 |
+
pair = self.pairs[idx]
|
| 74 |
+
src_text, tgt_text = pair[self.src_lang], pair[self.tgt_lang]
|
| 75 |
+
src_tokens = [self.bos_id] + self.tokenizer.encode(src_text).ids + [self.eos_id]
|
| 76 |
+
tgt_tokens = [self.bos_id] + self.tokenizer.encode(tgt_text).ids + [self.eos_id]
|
| 77 |
+
return {"src": torch.tensor(src_tokens[:self.max_len], dtype=torch.long), "tgt": torch.tensor(tgt_tokens[:self.max_len], dtype=torch.long)}
|
| 78 |
+
|
| 79 |
+
class PadCollate:
|
| 80 |
+
# ... (No changes needed in this class)
|
| 81 |
+
def __init__(self, pad_id): self.pad_id = pad_id
|
| 82 |
+
def __call__(self, batch):
|
| 83 |
+
src_batch, tgt_batch = [item["src"] for item in batch], [item["tgt"] for item in batch]
|
| 84 |
+
src_padded = pad_sequence(src_batch, batch_first=True, padding_value=self.pad_id)
|
| 85 |
+
tgt_padded = pad_sequence(tgt_batch, batch_first=True, padding_value=self.pad_id)
|
| 86 |
+
return {"src": src_padded, "tgt": tgt_padded}
|
| 87 |
+
|
| 88 |
+
# --- Model Architecture ---
|
| 89 |
+
class PositionalEncoding(nn.Module):
|
| 90 |
+
|
| 91 |
+
def __init__(self, dim, dropout, max_len=5000):
|
| 92 |
+
super().__init__(); self.dropout = nn.Dropout(p=dropout)
|
| 93 |
+
position = torch.arange(max_len).unsqueeze(1); div_term = torch.exp(torch.arange(0, dim, 2) * (-math.log(10000.0) / dim))
|
| 94 |
+
pe = torch.zeros(max_len, 1, dim); pe[:, 0, 0::2] = torch.sin(position * div_term); pe[:, 0, 1::2] = torch.cos(position * div_term)
|
| 95 |
+
self.register_buffer('pe', pe)
|
| 96 |
+
def forward(self, x): x = x + self.pe[:x.size(0)]; return self.dropout(x)
|
| 97 |
+
|
| 98 |
+
class TranslationTransformer(nn.Module):
|
| 99 |
+
|
| 100 |
+
def __init__(self, vocab_size, dim, n_heads, encoder_layers, decoder_layers, ff_dim, dropout, max_len):
|
| 101 |
+
super().__init__()
|
| 102 |
+
self.embedding = nn.Embedding(vocab_size, dim); self.pos_encoder = PositionalEncoding(dim, dropout, max_len)
|
| 103 |
+
self.transformer = nn.Transformer(d_model=dim, nhead=n_heads, num_encoder_layers=encoder_layers, num_decoder_layers=decoder_layers, dim_feedforward=ff_dim, dropout=dropout, batch_first=True)
|
| 104 |
+
self.generator = nn.Linear(dim, vocab_size)
|
| 105 |
+
def _generate_mask(self, src, tgt, pad_id):
|
| 106 |
+
tgt_mask = self.transformer.generate_square_subsequent_mask(tgt.shape[1], device=tgt.device)
|
| 107 |
+
src_padding_mask, tgt_padding_mask = (src == pad_id), (tgt == pad_id)
|
| 108 |
+
return tgt_mask, src_padding_mask, tgt_padding_mask
|
| 109 |
+
def forward(self, src, tgt, pad_id):
|
| 110 |
+
src_emb = self.pos_encoder((self.embedding(src) * math.sqrt(CONFIG["DIM"])).permute(1, 0, 2)).permute(1, 0, 2)
|
| 111 |
+
tgt_emb = self.pos_encoder((self.embedding(tgt) * math.sqrt(CONFIG["DIM"])).permute(1, 0, 2)).permute(1, 0, 2)
|
| 112 |
+
tgt_mask, src_padding_mask, tgt_padding_mask = self._generate_mask(src, tgt, pad_id)
|
| 113 |
+
output = self.transformer(src_emb, tgt_emb, tgt_mask=tgt_mask, src_key_padding_mask=src_padding_mask, tgt_key_padding_mask=tgt_padding_mask, memory_key_padding_mask=src_padding_mask)
|
| 114 |
+
return self.generator(output)
|
| 115 |
+
|
| 116 |
+
# --- Trainer ---
|
| 117 |
+
class Trainer:
|
| 118 |
+
def __init__(self, model, tokenizer, config):
|
| 119 |
+
self.model = model
|
| 120 |
+
self.tokenizer = tokenizer
|
| 121 |
+
self.config = config
|
| 122 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 123 |
+
self.model.to(self.device)
|
| 124 |
+
self.optimizer = torch.optim.AdamW(model.parameters(), lr=config["LEARNING_RATE"])
|
| 125 |
+
self.pad_id = tokenizer.token_to_id("<pad>")
|
| 126 |
+
self.criterion = nn.CrossEntropyLoss(ignore_index=self.pad_id)
|
| 127 |
+
self.scaler = torch.cuda.amp.GradScaler(enabled=(self.device.type == 'cuda'))
|
| 128 |
+
self.checkpoint_dir = Path(config["CHECKPOINT_DIR"])
|
| 129 |
+
self.checkpoint_dir.mkdir(exist_ok=True)
|
| 130 |
+
|
| 131 |
+
def train_epoch(self, dataloader):
|
| 132 |
+
self.model.train()
|
| 133 |
+
total_loss = 0
|
| 134 |
+
progress_bar = tqdm(dataloader, desc=f"Epoch {self.current_epoch+1}/{self.config['NUM_EPOCHS']} Training")
|
| 135 |
+
for batch in progress_bar:
|
| 136 |
+
src, tgt = batch["src"].to(self.device), batch["tgt"].to(self.device)
|
| 137 |
+
tgt_input, tgt_output = tgt[:, :-1], tgt[:, 1:]
|
| 138 |
+
self.optimizer.zero_grad(set_to_none=True)
|
| 139 |
+
with torch.amp.autocast(device_type=self.device.type, enabled=(self.device.type == 'cuda')):
|
| 140 |
+
logits = self.model(src, tgt_input, self.pad_id)
|
| 141 |
+
loss = self.criterion(logits.view(-1, logits.size(-1)), tgt_output.reshape(-1))
|
| 142 |
+
self.scaler.scale(loss).backward()
|
| 143 |
+
self.scaler.step(self.optimizer)
|
| 144 |
+
self.scaler.update()
|
| 145 |
+
total_loss += loss.item()
|
| 146 |
+
progress_bar.set_postfix({"loss": f"{loss.item():.4f}"})
|
| 147 |
+
return total_loss / len(dataloader)
|
| 148 |
+
|
| 149 |
+
# <<< NEW METHOD: For validation and testing >>>
|
| 150 |
+
def evaluate(self, dataloader, description="Evaluating"):
|
| 151 |
+
self.model.eval()
|
| 152 |
+
total_loss = 0
|
| 153 |
+
with torch.no_grad():
|
| 154 |
+
progress_bar = tqdm(dataloader, desc=description)
|
| 155 |
+
for batch in progress_bar:
|
| 156 |
+
src, tgt = batch["src"].to(self.device), batch["tgt"].to(self.device)
|
| 157 |
+
tgt_input, tgt_output = tgt[:, :-1], tgt[:, 1:]
|
| 158 |
+
logits = self.model(src, tgt_input, self.pad_id)
|
| 159 |
+
loss = self.criterion(logits.view(-1, logits.size(-1)), tgt_output.reshape(-1))
|
| 160 |
+
total_loss += loss.item()
|
| 161 |
+
progress_bar.set_postfix({"loss": f"{loss.item():.4f}"})
|
| 162 |
+
return total_loss / len(dataloader)
|
| 163 |
+
|
| 164 |
+
def save_checkpoint(self, epoch, val_loss):
|
| 165 |
+
filename = f"checkpoint_epoch_{epoch+1}_valloss_{val_loss:.4f}.pt"
|
| 166 |
+
path = self.checkpoint_dir / filename
|
| 167 |
+
torch.save({'epoch': epoch, 'model_state_dict': self.model.state_dict(),
|
| 168 |
+
'optimizer_state_dict': self.optimizer.state_dict(), 'loss': val_loss}, path)
|
| 169 |
+
logging.info(f"Checkpoint saved to {path}")
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def train(self, train_loader, val_loader):
|
| 173 |
+
for epoch in range(self.config["NUM_EPOCHS"]):
|
| 174 |
+
self.current_epoch = epoch
|
| 175 |
+
logging.info(f"--- Starting Epoch {epoch + 1}/{self.config['NUM_EPOCHS']} ---")
|
| 176 |
+
train_loss = self.train_epoch(train_loader)
|
| 177 |
+
val_loss = self.evaluate(val_loader, description=f"Epoch {epoch+1}/{self.config['NUM_EPOCHS']} Validation")
|
| 178 |
+
logging.info(f"Epoch {epoch+1} -> Train Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}")
|
| 179 |
+
self.save_checkpoint(epoch, val_loss)
|
| 180 |
+
self.translate("This is a test to see how the model is learning.")
|
| 181 |
+
|
| 182 |
+
def translate(self, src_sentence: str):
|
| 183 |
+
self.model.eval()
|
| 184 |
+
src_tokens = [self.tokenizer.token_to_id("<s>")] + self.tokenizer.encode(src_sentence).ids + [self.tokenizer.token_to_id("</s>")]
|
| 185 |
+
src = torch.tensor(src_tokens, dtype=torch.long).unsqueeze(0).to(self.device)
|
| 186 |
+
tgt_tokens = [self.tokenizer.token_to_id("<s>")]
|
| 187 |
+
with torch.no_grad():
|
| 188 |
+
for _ in range(self.config["MAX_SEQ_LEN"]):
|
| 189 |
+
tgt_input = torch.tensor(tgt_tokens, dtype=torch.long).unsqueeze(0).to(self.device)
|
| 190 |
+
logits = self.model(src, tgt_input, self.pad_id)
|
| 191 |
+
next_token_id = logits[:, -1, :].argmax(dim=-1).item()
|
| 192 |
+
tgt_tokens.append(next_token_id)
|
| 193 |
+
if next_token_id == self.tokenizer.token_to_id("</s>"): break
|
| 194 |
+
translated_text = self.tokenizer.decode(tgt_tokens, skip_special_tokens=True)
|
| 195 |
+
logging.info(f"Source: '{src_sentence}'")
|
| 196 |
+
logging.info(f"Translated: '{translated_text}'")
|
| 197 |
+
|
| 198 |
+
def main():
|
| 199 |
+
# Implemented a cuda check to see if my drivers are turning schizo again or not.
|
| 200 |
+
print("-" * 50)
|
| 201 |
+
print("CUDA Health Check:")
|
| 202 |
+
if torch.cuda.is_available():
|
| 203 |
+
print(f"✅ CUDA is available.")
|
| 204 |
+
print(f" PyTorch Version: {torch.__version__}")
|
| 205 |
+
print(f" CUDA Version PyTorch was built with: {torch.version.cuda}")
|
| 206 |
+
print(f" Number of GPUs: {torch.cuda.device_count()}")
|
| 207 |
+
print(f" Current GPU Name: {torch.cuda.get_device_name(0)}")
|
| 208 |
+
else:
|
| 209 |
+
print(f"❌ CUDA is NOT available.")
|
| 210 |
+
print(f" PyTorch will run on CPU, which will be very slow.")
|
| 211 |
+
print(f" ACTION: Ensure you have installed PyTorch with CUDA support. See https://pytorch.org/get-started/locally/")
|
| 212 |
+
print("-" * 50)
|
| 213 |
+
|
| 214 |
+
tokenizer_manager = TokenizerManager(CONFIG)
|
| 215 |
+
tokenizer = tokenizer_manager.get_tokenizer()
|
| 216 |
+
CONFIG["VOCAB_SIZE"] = tokenizer.get_vocab_size()
|
| 217 |
+
|
| 218 |
+
logging.info("Loading and preparing datasets...")
|
| 219 |
+
train_dataset = OpusDataset(tokenizer, CONFIG, split="train")
|
| 220 |
+
val_dataset = OpusDataset(tokenizer, CONFIG, split="validation")
|
| 221 |
+
test_dataset = OpusDataset(tokenizer, CONFIG, split="test")
|
| 222 |
+
logging.info(f"Dataset sizes -> Train: {len(train_dataset)}, Validation: {len(val_dataset)}, Test: {len(test_dataset)}")
|
| 223 |
+
|
| 224 |
+
pad_id = tokenizer.token_to_id("<pad>")
|
| 225 |
+
collate_fn = PadCollate(pad_id)
|
| 226 |
+
num_workers = 0 if os.name == 'nt' else os.cpu_count() // 2
|
| 227 |
+
|
| 228 |
+
train_loader = DataLoader(train_dataset, batch_size=CONFIG["BATCH_SIZE"], shuffle=True, collate_fn=collate_fn, num_workers=num_workers, pin_memory=torch.cuda.is_available())
|
| 229 |
+
val_loader = DataLoader(val_dataset, batch_size=CONFIG["BATCH_SIZE"], shuffle=False, collate_fn=collate_fn, num_workers=num_workers, pin_memory=torch.cuda.is_available())
|
| 230 |
+
test_loader = DataLoader(test_dataset, batch_size=CONFIG["BATCH_SIZE"], shuffle=False, collate_fn=collate_fn, num_workers=num_workers, pin_memory=torch.cuda.is_available())
|
| 231 |
+
|
| 232 |
+
model = TranslationTransformer(vocab_size=CONFIG["VOCAB_SIZE"], dim=CONFIG["DIM"], n_heads=CONFIG["N_HEADS"],
|
| 233 |
+
encoder_layers=CONFIG["ENCODER_LAYERS"], decoder_layers=CONFIG["DECODER_LAYERS"],
|
| 234 |
+
ff_dim=CONFIG["FF_DIM"], dropout=CONFIG["DROPOUT"], max_len=CONFIG["MAX_SEQ_LEN"])
|
| 235 |
+
|
| 236 |
+
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 237 |
+
logging.info(f"Model initialized. Total trainable parameters: {total_params:,}")
|
| 238 |
+
|
| 239 |
+
trainer = Trainer(model, tokenizer, CONFIG)
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
trainer.train(train_loader, val_loader)
|
| 243 |
+
|
| 244 |
+
# NEW TESTS, NOT AS SHITTY AS BEFORE
|
| 245 |
+
logging.info("\n--- Training Complete. Evaluating on Test Set... ---")
|
| 246 |
+
test_loss = trainer.evaluate(test_loader, description="Final Test Evaluation")
|
| 247 |
+
logging.info(f"Final Test Loss: {test_loss:.4f}")
|
| 248 |
+
|
| 249 |
+
logging.info("\n--- Final Translation Examples ---")
|
| 250 |
+
trainer.translate("The European Economic Area was created in 1994.")
|
| 251 |
+
trainer.translate("What is your name?")
|
| 252 |
+
trainer.translate("This technology is changing the world.")
|
| 253 |
+
|
| 254 |
+
if __name__ == "__main__":
|
| 255 |
+
main()
|
translator_loader.py
ADDED
|
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
import math
|
| 5 |
+
import logging
|
| 6 |
+
import re
|
| 7 |
+
|
| 8 |
+
# --- Setup ---
|
| 9 |
+
# Configure logging to be minimal for inference
|
| 10 |
+
logging.basicConfig(level=logging.INFO, format='%(message)s')
|
| 11 |
+
|
| 12 |
+
# --- Configuration (Must match the training script) ---
|
| 13 |
+
CONFIG = {
|
| 14 |
+
"SRC_LANG": "en",
|
| 15 |
+
"TGT_LANG": "zh",
|
| 16 |
+
"TOKENIZER_FILE": "opus_en_zh_tokenizer.json",
|
| 17 |
+
"MAX_SEQ_LEN": 128,
|
| 18 |
+
"DIM": 256,
|
| 19 |
+
"ENCODER_LAYERS": 4,
|
| 20 |
+
"DECODER_LAYERS": 4,
|
| 21 |
+
"N_HEADS": 8,
|
| 22 |
+
"FF_DIM": 512,
|
| 23 |
+
"DROPOUT": 0.1,
|
| 24 |
+
"CHECKPOINT_DIR": "checkpoints_translation",
|
| 25 |
+
}
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class PositionalEncoding(nn.Module):
|
| 29 |
+
def __init__(self, dim, dropout, max_len=5000):
|
| 30 |
+
super().__init__()
|
| 31 |
+
self.dropout = nn.Dropout(p=dropout)
|
| 32 |
+
position = torch.arange(max_len).unsqueeze(1)
|
| 33 |
+
div_term = torch.exp(torch.arange(0, dim, 2) * (-math.log(10000.0) / dim))
|
| 34 |
+
pe = torch.zeros(max_len, 1, dim)
|
| 35 |
+
pe[:, 0, 0::2] = torch.sin(position * div_term)
|
| 36 |
+
pe[:, 0, 1::2] = torch.cos(position * div_term)
|
| 37 |
+
self.register_buffer('pe', pe)
|
| 38 |
+
|
| 39 |
+
def forward(self, x):
|
| 40 |
+
x = x + self.pe[:x.size(0)]
|
| 41 |
+
return self.dropout(x)
|
| 42 |
+
|
| 43 |
+
class TranslationTransformer(nn.Module):
|
| 44 |
+
def __init__(self, vocab_size, dim, n_heads, encoder_layers, decoder_layers, ff_dim, dropout, max_len):
|
| 45 |
+
super().__init__()
|
| 46 |
+
self.embedding = nn.Embedding(vocab_size, dim)
|
| 47 |
+
self.pos_encoder = PositionalEncoding(dim, dropout, max_len)
|
| 48 |
+
self.transformer = nn.Transformer(
|
| 49 |
+
d_model=dim, nhead=n_heads, num_encoder_layers=encoder_layers,
|
| 50 |
+
num_decoder_layers=decoder_layers, dim_feedforward=ff_dim,
|
| 51 |
+
dropout=dropout, batch_first=True
|
| 52 |
+
)
|
| 53 |
+
self.generator = nn.Linear(dim, vocab_size)
|
| 54 |
+
|
| 55 |
+
def _generate_mask(self, src, tgt, pad_id):
|
| 56 |
+
tgt_mask = self.transformer.generate_square_subsequent_mask(tgt.shape[1], device=tgt.device)
|
| 57 |
+
src_padding_mask = (src == pad_id)
|
| 58 |
+
tgt_padding_mask = (tgt == pad_id)
|
| 59 |
+
return tgt_mask, src_padding_mask, tgt_padding_mask
|
| 60 |
+
|
| 61 |
+
def forward(self, src, tgt, pad_id):
|
| 62 |
+
src_emb = self.pos_encoder((self.embedding(src) * math.sqrt(CONFIG["DIM"])).permute(1, 0, 2)).permute(1, 0, 2)
|
| 63 |
+
tgt_emb = self.pos_encoder((self.embedding(tgt) * math.sqrt(CONFIG["DIM"])).permute(1, 0, 2)).permute(1, 0, 2)
|
| 64 |
+
tgt_mask, src_padding_mask, tgt_padding_mask = self._generate_mask(src, tgt, pad_id)
|
| 65 |
+
output = self.transformer(
|
| 66 |
+
src_emb, tgt_emb, tgt_mask=tgt_mask, src_key_padding_mask=src_padding_mask,
|
| 67 |
+
tgt_key_padding_mask=tgt_padding_mask, memory_key_padding_mask=src_padding_mask
|
| 68 |
+
)
|
| 69 |
+
return self.generator(output)
|
| 70 |
+
|
| 71 |
+
# We need to import the Tokenizer class to load the tokenizer file
|
| 72 |
+
from tokenizers import Tokenizer
|
| 73 |
+
|
| 74 |
+
class Translator:
|
| 75 |
+
def __init__(self, config):
|
| 76 |
+
self.config = config
|
| 77 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 78 |
+
logging.info(f"Using device: {self.device}")
|
| 79 |
+
|
| 80 |
+
# Load the trained tokenizer
|
| 81 |
+
tokenizer_path = Path(self.config["TOKENIZER_FILE"])
|
| 82 |
+
if not tokenizer_path.exists():
|
| 83 |
+
raise FileNotFoundError(f"Tokenizer file not found at {tokenizer_path}. Please run the training script first.")
|
| 84 |
+
self.tokenizer = Tokenizer.from_file(str(tokenizer_path))
|
| 85 |
+
|
| 86 |
+
# Get special token IDs
|
| 87 |
+
self.bos_id = self.tokenizer.token_to_id("<s>")
|
| 88 |
+
self.eos_id = self.tokenizer.token_to_id("</s>")
|
| 89 |
+
self.pad_id = self.tokenizer.token_to_id("<pad>")
|
| 90 |
+
|
| 91 |
+
# Initialize the model structure
|
| 92 |
+
self.model = TranslationTransformer(
|
| 93 |
+
vocab_size=self.tokenizer.get_vocab_size(),
|
| 94 |
+
dim=self.config["DIM"], n_heads=self.config["N_HEADS"],
|
| 95 |
+
encoder_layers=self.config["ENCODER_LAYERS"], decoder_layers=self.config["DECODER_LAYERS"],
|
| 96 |
+
ff_dim=self.config["FF_DIM"], dropout=self.config["DROPOUT"], max_len=self.config["MAX_SEQ_LEN"]
|
| 97 |
+
)
|
| 98 |
+
self.model.to(self.device)
|
| 99 |
+
|
| 100 |
+
def load_best_checkpoint(self):
|
| 101 |
+
"""Finds and loads the checkpoint with the lowest validation loss."""
|
| 102 |
+
checkpoint_dir = Path(self.config["CHECKPOINT_DIR"])
|
| 103 |
+
if not checkpoint_dir.exists():
|
| 104 |
+
raise FileNotFoundError(f"Checkpoint directory not found at {checkpoint_dir}.")
|
| 105 |
+
|
| 106 |
+
best_loss = float('inf')
|
| 107 |
+
best_checkpoint_path = None
|
| 108 |
+
|
| 109 |
+
for chk_path in checkpoint_dir.glob("*.pt"):
|
| 110 |
+
# Use regex to find the validation loss in the filename
|
| 111 |
+
match = re.search(r'valloss_([\d.]+)\.pt', chk_path.name)
|
| 112 |
+
if match:
|
| 113 |
+
val_loss = float(match.group(1))
|
| 114 |
+
if val_loss < best_loss:
|
| 115 |
+
best_loss = val_loss
|
| 116 |
+
best_checkpoint_path = chk_path
|
| 117 |
+
|
| 118 |
+
if best_checkpoint_path is None:
|
| 119 |
+
raise FileNotFoundError(f"No valid checkpoints found in {checkpoint_dir}. Checkpoint names must be like '...valloss_x.xxxx.pt'.")
|
| 120 |
+
|
| 121 |
+
logging.info(f"Loading best model from: {best_checkpoint_path} (Validation Loss: {best_loss:.4f})")
|
| 122 |
+
checkpoint = torch.load(best_checkpoint_path, map_location=self.device)
|
| 123 |
+
self.model.load_state_dict(checkpoint['model_state_dict'])
|
| 124 |
+
|
| 125 |
+
# Set the model to evaluation mode. This is crucial!
|
| 126 |
+
# It disables layers like Dropout for consistent inference.
|
| 127 |
+
self.model.eval()
|
| 128 |
+
|
| 129 |
+
def translate(self, src_sentence: str):
|
| 130 |
+
"""Translates a single English sentence to Chinese using greedy decoding."""
|
| 131 |
+
if not src_sentence.strip():
|
| 132 |
+
return ""
|
| 133 |
+
|
| 134 |
+
# Prepare the input
|
| 135 |
+
src_tokens = [self.bos_id] + self.tokenizer.encode(src_sentence).ids + [self.eos_id]
|
| 136 |
+
src = torch.tensor(src_tokens, dtype=torch.long).unsqueeze(0).to(self.device)
|
| 137 |
+
|
| 138 |
+
# Start decoding
|
| 139 |
+
tgt_tokens = [self.bos_id]
|
| 140 |
+
|
| 141 |
+
with torch.no_grad(): # Disable gradient calculation for efficiency
|
| 142 |
+
for _ in range(self.config["MAX_SEQ_LEN"]):
|
| 143 |
+
tgt_input = torch.tensor(tgt_tokens, dtype=torch.long).unsqueeze(0).to(self.device)
|
| 144 |
+
|
| 145 |
+
# Get model predictions
|
| 146 |
+
logits = self.model(src, tgt_input, self.pad_id)
|
| 147 |
+
|
| 148 |
+
# Get the most likely next token (greedy decoding)
|
| 149 |
+
next_token_id = logits[:, -1, :].argmax(dim=-1).item()
|
| 150 |
+
tgt_tokens.append(next_token_id)
|
| 151 |
+
|
| 152 |
+
# Stop if the end-of-sentence token is generated
|
| 153 |
+
if next_token_id == self.eos_id:
|
| 154 |
+
break
|
| 155 |
+
|
| 156 |
+
# Decode the generated token IDs back to a string
|
| 157 |
+
translated_text = self.tokenizer.decode(tgt_tokens, skip_special_tokens=True)
|
| 158 |
+
return translated_text
|
| 159 |
+
|
| 160 |
+
def interactive_session():
|
| 161 |
+
"""Runs the main interactive translation loop."""
|
| 162 |
+
try:
|
| 163 |
+
translator = Translator(CONFIG)
|
| 164 |
+
translator.load_best_checkpoint()
|
| 165 |
+
except FileNotFoundError as e:
|
| 166 |
+
logging.error(f"Error initializing translator: {e}")
|
| 167 |
+
logging.error("Please make sure you have run the training script and have a valid tokenizer and checkpoint file.")
|
| 168 |
+
return
|
| 169 |
+
|
| 170 |
+
print("\n--- ZHEN - 1 Translator ---")
|
| 171 |
+
print("Type an English sentence and press Enter.")
|
| 172 |
+
print("Type 'quit' or 'exit' to close the program.")
|
| 173 |
+
|
| 174 |
+
while True:
|
| 175 |
+
try:
|
| 176 |
+
source_text = input("\nEnglish > ")
|
| 177 |
+
if source_text.lower() in ['quit', 'exit', 'q']:
|
| 178 |
+
print("Exiting translator. Goodbye!")
|
| 179 |
+
break
|
| 180 |
+
|
| 181 |
+
if not source_text:
|
| 182 |
+
continue
|
| 183 |
+
|
| 184 |
+
translated_text = translator.translate(source_text)
|
| 185 |
+
print(f"Chinese < {translated_text}")
|
| 186 |
+
|
| 187 |
+
except KeyboardInterrupt:
|
| 188 |
+
print("\nExiting translator. Goodbye!")
|
| 189 |
+
break
|
| 190 |
+
except Exception as e:
|
| 191 |
+
logging.error(f"An unexpected error occurred: {e}")
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
if __name__ == "__main__":
|
| 195 |
+
interactive_session()
|
verify_cuda.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
print(f"PyTorch version: {torch.__version__}")
|
| 4 |
+
print(f"CUDA available: {torch.cuda.is_available()}")
|
| 5 |
+
|
| 6 |
+
if torch.cuda.is_available():
|
| 7 |
+
print(f"CUDA version PyTorch was built with: {torch.version.cuda}")
|
| 8 |
+
print(f"Number of GPUs: {torch.cuda.device_count()}")
|
| 9 |
+
print(f"Current device: {torch.cuda.current_device()}")
|
| 10 |
+
print(f"Device name: {torch.cuda.get_device_name(torch.cuda.current_device())}")
|
| 11 |
+
else:
|
| 12 |
+
print("\n❌ PyTorch cannot find CUDA.")
|
| 13 |
+
print(" Follow the 'Foolproof Plan' to fix your environment.")
|