|
|
import torch |
|
|
import torch.nn as nn |
|
|
import numpy as np |
|
|
from torch.nn.utils.rnn import pad_sequence |
|
|
from torch.utils.data import DataLoader, Dataset |
|
|
from collections import Counter |
|
|
from itertools import chain |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
seq_length = 10 |
|
|
|
|
|
batch_size = 32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TextDataset(Dataset): |
|
|
def __init__(self, text, vocab=None, seq_length=seq_length): |
|
|
|
|
|
self.tokens = text.split() |
|
|
|
|
|
|
|
|
if vocab: |
|
|
self.vocab = vocab |
|
|
else: |
|
|
|
|
|
self.vocab = {'<pad>': 0, '<eos>': 1} |
|
|
token_counts = Counter(self.tokens) |
|
|
for token, _ in token_counts.items(): |
|
|
self.vocab[token] = len(self.vocab) |
|
|
|
|
|
|
|
|
self.index2token = {index: token for token, index in self.vocab.items()} |
|
|
|
|
|
|
|
|
self.indexed_tokens = [self.vocab[token] for token in self.tokens] |
|
|
|
|
|
|
|
|
self.seq_length = seq_length |
|
|
|
|
|
def __len__(self): |
|
|
|
|
|
return len(self.indexed_tokens) // self.seq_length |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
|
|
|
start_idx = idx * self.seq_length |
|
|
end_idx = start_idx + self.seq_length + 1 |
|
|
sequence = self.indexed_tokens[start_idx:end_idx] |
|
|
|
|
|
return torch.tensor(sequence, dtype=torch.long) |
|
|
|
|
|
|
|
|
with open('tiny-shakespeare.txt', 'r', encoding='utf-8') as file: |
|
|
text = file.read() |
|
|
|
|
|
|
|
|
dataset = TextDataset(text) |
|
|
|
|
|
vocab_size = len(dataset.vocab) |
|
|
train_data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True) |
|
|
|
|
|
|
|
|
def create_padding_mask(seq): |
|
|
return (seq == dataset.vocab['<pad>']).transpose(0, 1) |
|
|
|
|
|
|
|
|
|