TimurHromek commited on
Commit
c093feb
·
verified ·
1 Parent(s): 32d5e19

Upload 10 files

Browse files
LICENSE.txt ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright 2025 Timur Hromek
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
checkpoints_translation/checkpoint_epoch_1_valloss_4.8841.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c2debe09caf3aa76cf341d4985eb301077ee5b790fc4d5bc8348b45b93797316
3
+ size 260563794
checkpoints_translation/checkpoint_epoch_2_valloss_4.3551.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:97472d527a0abd2777adea5d272d8c98d2a1db30e7f3de809e639d91fd31b35a
3
+ size 260563794
checkpoints_translation/checkpoint_epoch_3_valloss_4.1226.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a9dee8be3fb660152565e108f09ac660afa0defaef8b5c50a43218bad39777cc
3
+ size 260563794
checkpoints_translation/checkpoint_epoch_4_valloss_nan.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bf50ebd062d5abf0ce49936d7fbe129c271807b70dc6a9db5a5c305b2f668606
3
+ size 260562255
checkpoints_translation/checkpoint_epoch_5_valloss_nan.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3e77b63efc8ba8ca6dc99b25e7ed28e2796bed38f5f67ee5572bc4aa897b9b12
3
+ size 260562255
opus_en_zh_tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
translate_train.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import logging
4
+ from pathlib import Path
5
+ from datetime import datetime
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from torch.nn.utils.rnn import pad_sequence
10
+ from torch.utils.data import DataLoader, Dataset
11
+
12
+ from datasets import load_dataset
13
+ from tokenizers import Tokenizer
14
+ from tokenizers.models import BPE
15
+ from tokenizers.trainers import BpeTrainer
16
+ from tokenizers.pre_tokenizers import Whitespace
17
+ from tqdm import tqdm
18
+
19
+ # --- Configuration ---
20
+ CONFIG = {
21
+ "SRC_LANG": "en",
22
+ "TGT_LANG": "zh",
23
+ "TOKENIZER_FILE": "opus_en_zh_tokenizer.json",
24
+ "MAX_SEQ_LEN": 128,
25
+ "VOCAB_SIZE": 32000,
26
+ "DIM": 256,
27
+ "ENCODER_LAYERS": 4,
28
+ "DECODER_LAYERS": 4,
29
+ "N_HEADS": 8,
30
+ "FF_DIM": 512,
31
+ "DROPOUT": 0.1,
32
+ "BATCH_SIZE": 64,
33
+ "LEARNING_RATE": 5e-4,
34
+ "NUM_EPOCHS": 5,
35
+ "CHECKPOINT_DIR": "checkpoints_translation",
36
+ }
37
+
38
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
39
+
40
+ # --- Tokenizer Manager ---
41
+ class TokenizerManager:
42
+ # ... (No changes needed in this class)
43
+ def __init__(self, config):
44
+ self.config = config
45
+ self.tokenizer_path = Path(self.config["TOKENIZER_FILE"])
46
+ self.special_tokens = ["<unk>", "<pad>", "<s>", "</s>"]
47
+ def get_text_iterator(self):
48
+ dataset = load_dataset(f"Helsinki-NLP/opus-100", f"{self.config['SRC_LANG']}-{self.config['TGT_LANG']}", split="train", streaming=True)
49
+ for item in dataset: yield item['translation'][self.config['SRC_LANG']]; yield item['translation'][self.config['TGT_LANG']]
50
+ def train_tokenizer(self):
51
+ logging.info("Training a new tokenizer...")
52
+ tokenizer = Tokenizer(BPE(unk_token="<unk>")); tokenizer.pre_tokenizer = Whitespace()
53
+ trainer = BpeTrainer(vocab_size=self.config["VOCAB_SIZE"], special_tokens=self.special_tokens)
54
+ tokenizer.train_from_iterator(self.get_text_iterator(), trainer=trainer)
55
+ tokenizer.save(str(self.tokenizer_path)); logging.info(f"Tokenizer trained and saved to {self.tokenizer_path}")
56
+ return tokenizer
57
+ def get_tokenizer(self):
58
+ if not self.tokenizer_path.exists(): return self.train_tokenizer()
59
+ logging.info(f"Loading existing tokenizer from {self.tokenizer_path}")
60
+ return Tokenizer.from_file(str(self.tokenizer_path))
61
+
62
+ # --- Dataset and Dataloader ---
63
+ class OpusDataset(Dataset):
64
+ # ... (No changes needed in this class)
65
+ def __init__(self, tokenizer, config, split="train"):
66
+ self.tokenizer = tokenizer; self.config = config
67
+ dataset = load_dataset(f"Helsinki-NLP/opus-100", f"{config['SRC_LANG']}-{config['TGT_LANG']}", split=split)
68
+ self.pairs = [item['translation'] for item in dataset]
69
+ self.src_lang, self.tgt_lang, self.max_len = config["SRC_LANG"], config["TGT_LANG"], config["MAX_SEQ_LEN"]
70
+ self.bos_id, self.eos_id, self.pad_id = tokenizer.token_to_id("<s>"), tokenizer.token_to_id("</s>"), tokenizer.token_to_id("<pad>")
71
+ def __len__(self): return len(self.pairs)
72
+ def __getitem__(self, idx):
73
+ pair = self.pairs[idx]
74
+ src_text, tgt_text = pair[self.src_lang], pair[self.tgt_lang]
75
+ src_tokens = [self.bos_id] + self.tokenizer.encode(src_text).ids + [self.eos_id]
76
+ tgt_tokens = [self.bos_id] + self.tokenizer.encode(tgt_text).ids + [self.eos_id]
77
+ return {"src": torch.tensor(src_tokens[:self.max_len], dtype=torch.long), "tgt": torch.tensor(tgt_tokens[:self.max_len], dtype=torch.long)}
78
+
79
+ class PadCollate:
80
+ # ... (No changes needed in this class)
81
+ def __init__(self, pad_id): self.pad_id = pad_id
82
+ def __call__(self, batch):
83
+ src_batch, tgt_batch = [item["src"] for item in batch], [item["tgt"] for item in batch]
84
+ src_padded = pad_sequence(src_batch, batch_first=True, padding_value=self.pad_id)
85
+ tgt_padded = pad_sequence(tgt_batch, batch_first=True, padding_value=self.pad_id)
86
+ return {"src": src_padded, "tgt": tgt_padded}
87
+
88
+ # --- Model Architecture ---
89
+ class PositionalEncoding(nn.Module):
90
+
91
+ def __init__(self, dim, dropout, max_len=5000):
92
+ super().__init__(); self.dropout = nn.Dropout(p=dropout)
93
+ position = torch.arange(max_len).unsqueeze(1); div_term = torch.exp(torch.arange(0, dim, 2) * (-math.log(10000.0) / dim))
94
+ pe = torch.zeros(max_len, 1, dim); pe[:, 0, 0::2] = torch.sin(position * div_term); pe[:, 0, 1::2] = torch.cos(position * div_term)
95
+ self.register_buffer('pe', pe)
96
+ def forward(self, x): x = x + self.pe[:x.size(0)]; return self.dropout(x)
97
+
98
+ class TranslationTransformer(nn.Module):
99
+
100
+ def __init__(self, vocab_size, dim, n_heads, encoder_layers, decoder_layers, ff_dim, dropout, max_len):
101
+ super().__init__()
102
+ self.embedding = nn.Embedding(vocab_size, dim); self.pos_encoder = PositionalEncoding(dim, dropout, max_len)
103
+ self.transformer = nn.Transformer(d_model=dim, nhead=n_heads, num_encoder_layers=encoder_layers, num_decoder_layers=decoder_layers, dim_feedforward=ff_dim, dropout=dropout, batch_first=True)
104
+ self.generator = nn.Linear(dim, vocab_size)
105
+ def _generate_mask(self, src, tgt, pad_id):
106
+ tgt_mask = self.transformer.generate_square_subsequent_mask(tgt.shape[1], device=tgt.device)
107
+ src_padding_mask, tgt_padding_mask = (src == pad_id), (tgt == pad_id)
108
+ return tgt_mask, src_padding_mask, tgt_padding_mask
109
+ def forward(self, src, tgt, pad_id):
110
+ src_emb = self.pos_encoder((self.embedding(src) * math.sqrt(CONFIG["DIM"])).permute(1, 0, 2)).permute(1, 0, 2)
111
+ tgt_emb = self.pos_encoder((self.embedding(tgt) * math.sqrt(CONFIG["DIM"])).permute(1, 0, 2)).permute(1, 0, 2)
112
+ tgt_mask, src_padding_mask, tgt_padding_mask = self._generate_mask(src, tgt, pad_id)
113
+ output = self.transformer(src_emb, tgt_emb, tgt_mask=tgt_mask, src_key_padding_mask=src_padding_mask, tgt_key_padding_mask=tgt_padding_mask, memory_key_padding_mask=src_padding_mask)
114
+ return self.generator(output)
115
+
116
+ # --- Trainer ---
117
+ class Trainer:
118
+ def __init__(self, model, tokenizer, config):
119
+ self.model = model
120
+ self.tokenizer = tokenizer
121
+ self.config = config
122
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
123
+ self.model.to(self.device)
124
+ self.optimizer = torch.optim.AdamW(model.parameters(), lr=config["LEARNING_RATE"])
125
+ self.pad_id = tokenizer.token_to_id("<pad>")
126
+ self.criterion = nn.CrossEntropyLoss(ignore_index=self.pad_id)
127
+ self.scaler = torch.cuda.amp.GradScaler(enabled=(self.device.type == 'cuda'))
128
+ self.checkpoint_dir = Path(config["CHECKPOINT_DIR"])
129
+ self.checkpoint_dir.mkdir(exist_ok=True)
130
+
131
+ def train_epoch(self, dataloader):
132
+ self.model.train()
133
+ total_loss = 0
134
+ progress_bar = tqdm(dataloader, desc=f"Epoch {self.current_epoch+1}/{self.config['NUM_EPOCHS']} Training")
135
+ for batch in progress_bar:
136
+ src, tgt = batch["src"].to(self.device), batch["tgt"].to(self.device)
137
+ tgt_input, tgt_output = tgt[:, :-1], tgt[:, 1:]
138
+ self.optimizer.zero_grad(set_to_none=True)
139
+ with torch.amp.autocast(device_type=self.device.type, enabled=(self.device.type == 'cuda')):
140
+ logits = self.model(src, tgt_input, self.pad_id)
141
+ loss = self.criterion(logits.view(-1, logits.size(-1)), tgt_output.reshape(-1))
142
+ self.scaler.scale(loss).backward()
143
+ self.scaler.step(self.optimizer)
144
+ self.scaler.update()
145
+ total_loss += loss.item()
146
+ progress_bar.set_postfix({"loss": f"{loss.item():.4f}"})
147
+ return total_loss / len(dataloader)
148
+
149
+ # <<< NEW METHOD: For validation and testing >>>
150
+ def evaluate(self, dataloader, description="Evaluating"):
151
+ self.model.eval()
152
+ total_loss = 0
153
+ with torch.no_grad():
154
+ progress_bar = tqdm(dataloader, desc=description)
155
+ for batch in progress_bar:
156
+ src, tgt = batch["src"].to(self.device), batch["tgt"].to(self.device)
157
+ tgt_input, tgt_output = tgt[:, :-1], tgt[:, 1:]
158
+ logits = self.model(src, tgt_input, self.pad_id)
159
+ loss = self.criterion(logits.view(-1, logits.size(-1)), tgt_output.reshape(-1))
160
+ total_loss += loss.item()
161
+ progress_bar.set_postfix({"loss": f"{loss.item():.4f}"})
162
+ return total_loss / len(dataloader)
163
+
164
+ def save_checkpoint(self, epoch, val_loss):
165
+ filename = f"checkpoint_epoch_{epoch+1}_valloss_{val_loss:.4f}.pt"
166
+ path = self.checkpoint_dir / filename
167
+ torch.save({'epoch': epoch, 'model_state_dict': self.model.state_dict(),
168
+ 'optimizer_state_dict': self.optimizer.state_dict(), 'loss': val_loss}, path)
169
+ logging.info(f"Checkpoint saved to {path}")
170
+
171
+
172
+ def train(self, train_loader, val_loader):
173
+ for epoch in range(self.config["NUM_EPOCHS"]):
174
+ self.current_epoch = epoch
175
+ logging.info(f"--- Starting Epoch {epoch + 1}/{self.config['NUM_EPOCHS']} ---")
176
+ train_loss = self.train_epoch(train_loader)
177
+ val_loss = self.evaluate(val_loader, description=f"Epoch {epoch+1}/{self.config['NUM_EPOCHS']} Validation")
178
+ logging.info(f"Epoch {epoch+1} -> Train Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}")
179
+ self.save_checkpoint(epoch, val_loss)
180
+ self.translate("This is a test to see how the model is learning.")
181
+
182
+ def translate(self, src_sentence: str):
183
+ self.model.eval()
184
+ src_tokens = [self.tokenizer.token_to_id("<s>")] + self.tokenizer.encode(src_sentence).ids + [self.tokenizer.token_to_id("</s>")]
185
+ src = torch.tensor(src_tokens, dtype=torch.long).unsqueeze(0).to(self.device)
186
+ tgt_tokens = [self.tokenizer.token_to_id("<s>")]
187
+ with torch.no_grad():
188
+ for _ in range(self.config["MAX_SEQ_LEN"]):
189
+ tgt_input = torch.tensor(tgt_tokens, dtype=torch.long).unsqueeze(0).to(self.device)
190
+ logits = self.model(src, tgt_input, self.pad_id)
191
+ next_token_id = logits[:, -1, :].argmax(dim=-1).item()
192
+ tgt_tokens.append(next_token_id)
193
+ if next_token_id == self.tokenizer.token_to_id("</s>"): break
194
+ translated_text = self.tokenizer.decode(tgt_tokens, skip_special_tokens=True)
195
+ logging.info(f"Source: '{src_sentence}'")
196
+ logging.info(f"Translated: '{translated_text}'")
197
+
198
+ def main():
199
+ # Implemented a cuda check to see if my drivers are turning schizo again or not.
200
+ print("-" * 50)
201
+ print("CUDA Health Check:")
202
+ if torch.cuda.is_available():
203
+ print(f"✅ CUDA is available.")
204
+ print(f" PyTorch Version: {torch.__version__}")
205
+ print(f" CUDA Version PyTorch was built with: {torch.version.cuda}")
206
+ print(f" Number of GPUs: {torch.cuda.device_count()}")
207
+ print(f" Current GPU Name: {torch.cuda.get_device_name(0)}")
208
+ else:
209
+ print(f"❌ CUDA is NOT available.")
210
+ print(f" PyTorch will run on CPU, which will be very slow.")
211
+ print(f" ACTION: Ensure you have installed PyTorch with CUDA support. See https://pytorch.org/get-started/locally/")
212
+ print("-" * 50)
213
+
214
+ tokenizer_manager = TokenizerManager(CONFIG)
215
+ tokenizer = tokenizer_manager.get_tokenizer()
216
+ CONFIG["VOCAB_SIZE"] = tokenizer.get_vocab_size()
217
+
218
+ logging.info("Loading and preparing datasets...")
219
+ train_dataset = OpusDataset(tokenizer, CONFIG, split="train")
220
+ val_dataset = OpusDataset(tokenizer, CONFIG, split="validation")
221
+ test_dataset = OpusDataset(tokenizer, CONFIG, split="test")
222
+ logging.info(f"Dataset sizes -> Train: {len(train_dataset)}, Validation: {len(val_dataset)}, Test: {len(test_dataset)}")
223
+
224
+ pad_id = tokenizer.token_to_id("<pad>")
225
+ collate_fn = PadCollate(pad_id)
226
+ num_workers = 0 if os.name == 'nt' else os.cpu_count() // 2
227
+
228
+ train_loader = DataLoader(train_dataset, batch_size=CONFIG["BATCH_SIZE"], shuffle=True, collate_fn=collate_fn, num_workers=num_workers, pin_memory=torch.cuda.is_available())
229
+ val_loader = DataLoader(val_dataset, batch_size=CONFIG["BATCH_SIZE"], shuffle=False, collate_fn=collate_fn, num_workers=num_workers, pin_memory=torch.cuda.is_available())
230
+ test_loader = DataLoader(test_dataset, batch_size=CONFIG["BATCH_SIZE"], shuffle=False, collate_fn=collate_fn, num_workers=num_workers, pin_memory=torch.cuda.is_available())
231
+
232
+ model = TranslationTransformer(vocab_size=CONFIG["VOCAB_SIZE"], dim=CONFIG["DIM"], n_heads=CONFIG["N_HEADS"],
233
+ encoder_layers=CONFIG["ENCODER_LAYERS"], decoder_layers=CONFIG["DECODER_LAYERS"],
234
+ ff_dim=CONFIG["FF_DIM"], dropout=CONFIG["DROPOUT"], max_len=CONFIG["MAX_SEQ_LEN"])
235
+
236
+ total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
237
+ logging.info(f"Model initialized. Total trainable parameters: {total_params:,}")
238
+
239
+ trainer = Trainer(model, tokenizer, CONFIG)
240
+
241
+
242
+ trainer.train(train_loader, val_loader)
243
+
244
+ # NEW TESTS, NOT AS SHITTY AS BEFORE
245
+ logging.info("\n--- Training Complete. Evaluating on Test Set... ---")
246
+ test_loss = trainer.evaluate(test_loader, description="Final Test Evaluation")
247
+ logging.info(f"Final Test Loss: {test_loss:.4f}")
248
+
249
+ logging.info("\n--- Final Translation Examples ---")
250
+ trainer.translate("The European Economic Area was created in 1994.")
251
+ trainer.translate("What is your name?")
252
+ trainer.translate("This technology is changing the world.")
253
+
254
+ if __name__ == "__main__":
255
+ main()
translator_loader.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from pathlib import Path
4
+ import math
5
+ import logging
6
+ import re
7
+
8
+ # --- Setup ---
9
+ # Configure logging to be minimal for inference
10
+ logging.basicConfig(level=logging.INFO, format='%(message)s')
11
+
12
+ # --- Configuration (Must match the training script) ---
13
+ CONFIG = {
14
+ "SRC_LANG": "en",
15
+ "TGT_LANG": "zh",
16
+ "TOKENIZER_FILE": "opus_en_zh_tokenizer.json",
17
+ "MAX_SEQ_LEN": 128,
18
+ "DIM": 256,
19
+ "ENCODER_LAYERS": 4,
20
+ "DECODER_LAYERS": 4,
21
+ "N_HEADS": 8,
22
+ "FF_DIM": 512,
23
+ "DROPOUT": 0.1,
24
+ "CHECKPOINT_DIR": "checkpoints_translation",
25
+ }
26
+
27
+
28
+ class PositionalEncoding(nn.Module):
29
+ def __init__(self, dim, dropout, max_len=5000):
30
+ super().__init__()
31
+ self.dropout = nn.Dropout(p=dropout)
32
+ position = torch.arange(max_len).unsqueeze(1)
33
+ div_term = torch.exp(torch.arange(0, dim, 2) * (-math.log(10000.0) / dim))
34
+ pe = torch.zeros(max_len, 1, dim)
35
+ pe[:, 0, 0::2] = torch.sin(position * div_term)
36
+ pe[:, 0, 1::2] = torch.cos(position * div_term)
37
+ self.register_buffer('pe', pe)
38
+
39
+ def forward(self, x):
40
+ x = x + self.pe[:x.size(0)]
41
+ return self.dropout(x)
42
+
43
+ class TranslationTransformer(nn.Module):
44
+ def __init__(self, vocab_size, dim, n_heads, encoder_layers, decoder_layers, ff_dim, dropout, max_len):
45
+ super().__init__()
46
+ self.embedding = nn.Embedding(vocab_size, dim)
47
+ self.pos_encoder = PositionalEncoding(dim, dropout, max_len)
48
+ self.transformer = nn.Transformer(
49
+ d_model=dim, nhead=n_heads, num_encoder_layers=encoder_layers,
50
+ num_decoder_layers=decoder_layers, dim_feedforward=ff_dim,
51
+ dropout=dropout, batch_first=True
52
+ )
53
+ self.generator = nn.Linear(dim, vocab_size)
54
+
55
+ def _generate_mask(self, src, tgt, pad_id):
56
+ tgt_mask = self.transformer.generate_square_subsequent_mask(tgt.shape[1], device=tgt.device)
57
+ src_padding_mask = (src == pad_id)
58
+ tgt_padding_mask = (tgt == pad_id)
59
+ return tgt_mask, src_padding_mask, tgt_padding_mask
60
+
61
+ def forward(self, src, tgt, pad_id):
62
+ src_emb = self.pos_encoder((self.embedding(src) * math.sqrt(CONFIG["DIM"])).permute(1, 0, 2)).permute(1, 0, 2)
63
+ tgt_emb = self.pos_encoder((self.embedding(tgt) * math.sqrt(CONFIG["DIM"])).permute(1, 0, 2)).permute(1, 0, 2)
64
+ tgt_mask, src_padding_mask, tgt_padding_mask = self._generate_mask(src, tgt, pad_id)
65
+ output = self.transformer(
66
+ src_emb, tgt_emb, tgt_mask=tgt_mask, src_key_padding_mask=src_padding_mask,
67
+ tgt_key_padding_mask=tgt_padding_mask, memory_key_padding_mask=src_padding_mask
68
+ )
69
+ return self.generator(output)
70
+
71
+ # We need to import the Tokenizer class to load the tokenizer file
72
+ from tokenizers import Tokenizer
73
+
74
+ class Translator:
75
+ def __init__(self, config):
76
+ self.config = config
77
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
78
+ logging.info(f"Using device: {self.device}")
79
+
80
+ # Load the trained tokenizer
81
+ tokenizer_path = Path(self.config["TOKENIZER_FILE"])
82
+ if not tokenizer_path.exists():
83
+ raise FileNotFoundError(f"Tokenizer file not found at {tokenizer_path}. Please run the training script first.")
84
+ self.tokenizer = Tokenizer.from_file(str(tokenizer_path))
85
+
86
+ # Get special token IDs
87
+ self.bos_id = self.tokenizer.token_to_id("<s>")
88
+ self.eos_id = self.tokenizer.token_to_id("</s>")
89
+ self.pad_id = self.tokenizer.token_to_id("<pad>")
90
+
91
+ # Initialize the model structure
92
+ self.model = TranslationTransformer(
93
+ vocab_size=self.tokenizer.get_vocab_size(),
94
+ dim=self.config["DIM"], n_heads=self.config["N_HEADS"],
95
+ encoder_layers=self.config["ENCODER_LAYERS"], decoder_layers=self.config["DECODER_LAYERS"],
96
+ ff_dim=self.config["FF_DIM"], dropout=self.config["DROPOUT"], max_len=self.config["MAX_SEQ_LEN"]
97
+ )
98
+ self.model.to(self.device)
99
+
100
+ def load_best_checkpoint(self):
101
+ """Finds and loads the checkpoint with the lowest validation loss."""
102
+ checkpoint_dir = Path(self.config["CHECKPOINT_DIR"])
103
+ if not checkpoint_dir.exists():
104
+ raise FileNotFoundError(f"Checkpoint directory not found at {checkpoint_dir}.")
105
+
106
+ best_loss = float('inf')
107
+ best_checkpoint_path = None
108
+
109
+ for chk_path in checkpoint_dir.glob("*.pt"):
110
+ # Use regex to find the validation loss in the filename
111
+ match = re.search(r'valloss_([\d.]+)\.pt', chk_path.name)
112
+ if match:
113
+ val_loss = float(match.group(1))
114
+ if val_loss < best_loss:
115
+ best_loss = val_loss
116
+ best_checkpoint_path = chk_path
117
+
118
+ if best_checkpoint_path is None:
119
+ raise FileNotFoundError(f"No valid checkpoints found in {checkpoint_dir}. Checkpoint names must be like '...valloss_x.xxxx.pt'.")
120
+
121
+ logging.info(f"Loading best model from: {best_checkpoint_path} (Validation Loss: {best_loss:.4f})")
122
+ checkpoint = torch.load(best_checkpoint_path, map_location=self.device)
123
+ self.model.load_state_dict(checkpoint['model_state_dict'])
124
+
125
+ # Set the model to evaluation mode. This is crucial!
126
+ # It disables layers like Dropout for consistent inference.
127
+ self.model.eval()
128
+
129
+ def translate(self, src_sentence: str):
130
+ """Translates a single English sentence to Chinese using greedy decoding."""
131
+ if not src_sentence.strip():
132
+ return ""
133
+
134
+ # Prepare the input
135
+ src_tokens = [self.bos_id] + self.tokenizer.encode(src_sentence).ids + [self.eos_id]
136
+ src = torch.tensor(src_tokens, dtype=torch.long).unsqueeze(0).to(self.device)
137
+
138
+ # Start decoding
139
+ tgt_tokens = [self.bos_id]
140
+
141
+ with torch.no_grad(): # Disable gradient calculation for efficiency
142
+ for _ in range(self.config["MAX_SEQ_LEN"]):
143
+ tgt_input = torch.tensor(tgt_tokens, dtype=torch.long).unsqueeze(0).to(self.device)
144
+
145
+ # Get model predictions
146
+ logits = self.model(src, tgt_input, self.pad_id)
147
+
148
+ # Get the most likely next token (greedy decoding)
149
+ next_token_id = logits[:, -1, :].argmax(dim=-1).item()
150
+ tgt_tokens.append(next_token_id)
151
+
152
+ # Stop if the end-of-sentence token is generated
153
+ if next_token_id == self.eos_id:
154
+ break
155
+
156
+ # Decode the generated token IDs back to a string
157
+ translated_text = self.tokenizer.decode(tgt_tokens, skip_special_tokens=True)
158
+ return translated_text
159
+
160
+ def interactive_session():
161
+ """Runs the main interactive translation loop."""
162
+ try:
163
+ translator = Translator(CONFIG)
164
+ translator.load_best_checkpoint()
165
+ except FileNotFoundError as e:
166
+ logging.error(f"Error initializing translator: {e}")
167
+ logging.error("Please make sure you have run the training script and have a valid tokenizer and checkpoint file.")
168
+ return
169
+
170
+ print("\n--- ZHEN - 1 Translator ---")
171
+ print("Type an English sentence and press Enter.")
172
+ print("Type 'quit' or 'exit' to close the program.")
173
+
174
+ while True:
175
+ try:
176
+ source_text = input("\nEnglish > ")
177
+ if source_text.lower() in ['quit', 'exit', 'q']:
178
+ print("Exiting translator. Goodbye!")
179
+ break
180
+
181
+ if not source_text:
182
+ continue
183
+
184
+ translated_text = translator.translate(source_text)
185
+ print(f"Chinese < {translated_text}")
186
+
187
+ except KeyboardInterrupt:
188
+ print("\nExiting translator. Goodbye!")
189
+ break
190
+ except Exception as e:
191
+ logging.error(f"An unexpected error occurred: {e}")
192
+
193
+
194
+ if __name__ == "__main__":
195
+ interactive_session()
verify_cuda.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ print(f"PyTorch version: {torch.__version__}")
4
+ print(f"CUDA available: {torch.cuda.is_available()}")
5
+
6
+ if torch.cuda.is_available():
7
+ print(f"CUDA version PyTorch was built with: {torch.version.cuda}")
8
+ print(f"Number of GPUs: {torch.cuda.device_count()}")
9
+ print(f"Current device: {torch.cuda.current_device()}")
10
+ print(f"Device name: {torch.cuda.get_device_name(torch.cuda.current_device())}")
11
+ else:
12
+ print("\n❌ PyTorch cannot find CUDA.")
13
+ print(" Follow the 'Foolproof Plan' to fix your environment.")