| import torch | |
| import torch.nn as nn | |
| import torch.optim as optim | |
| from torch.utils.data import Dataset, DataLoader | |
| class TranslationDataset(Dataset): | |
| def __init__(self, file_path, src_tokenizer, tgt_tokenizer, max_len): | |
| self.data = [] | |
| self.src_tokenizer = src_tokenizer | |
| self.tgt_tokenizer = tgt_tokenizer | |
| self.max_len = max_len | |
| with open(file_path, 'r', encoding='utf-8') as f: | |
| for line in f: | |
| src, tgt = line.strip().split('\t') | |
| self.data.append((self.encode(src, self.src_tokenizer), self.encode(tgt, self.tgt_tokenizer))) | |
| def encode(self, sentence, tokenizer): | |
| tokens = sentence.split() | |
| ids = [tokenizer.word2int.get(token, tokenizer.word2int["<UNK>"]) for token in tokens] | |
| ids = [tokenizer.word2int["<BOS>"]] + ids[:self.max_len - 2] + [tokenizer.word2int["<EOS>"]] | |
| ids += [tokenizer.word2int["<PAD>"]] * (self.max_len - len(ids)) | |
| return ids | |
| def __len__(self): | |
| return len(self.data) | |
| def __getitem__(self, idx): | |
| src, tgt = self.data[idx] | |
| return torch.tensor(src, dtype=torch.long), torch.tensor(tgt, dtype=torch.long) |