Spaces:
Sleeping
Sleeping
| import numpy as np | |
| import torch | |
| class DataLoader: | |
| def __init__(self, data_path): | |
| self.data_path = data_path | |
| self.batch_size = None | |
| self.block_size = None | |
| self.data = None | |
| self.train_data = None | |
| self.val_data = None | |
| def load_data(self, block_size=128, split=0.8, batch_size=64, device='cpu'): | |
| with open(self.data_path, 'r') as f: | |
| data = f.read() | |
| self.block_size = block_size | |
| self.batch_size = batch_size | |
| self.device = device | |
| self.data = data | |
| def __len__(self): | |
| return int(np.ceil(len(self.data) / self.batch_size)) | |
| def __getitem__(self, index): | |
| indexes = self.indexes[index * | |
| self.batch_size:(index + 1) * self.batch_size] | |
| batch = [self.data[i] for i in indexes] | |
| batch = np.array(batch) | |
| return batch | |
| def get_batch(self, split, device='cpu'): | |
| if self.data is None: | |
| raise ValueError('Data not loaded') | |
| data = self.train_data if split == 'train' else self.val_data | |
| ix = torch.randint(len(data) - self.block_size, (self.batch_size,)) | |
| x = torch.stack([data[i:i+self.block_size] for i in ix]) | |
| y = torch.stack([data[i+1:i+self.block_size+1] for i in ix]) | |
| x, y = x.to(device), y.to(device) | |
| return x, y | |
| class Encoder: | |
| def __init__(self, data, type='char'): | |
| self.data = data | |
| self.type = type | |
| self.vocab_size = None | |
| if type == 'char': | |
| self.chars = sorted(list(set(data))) | |
| self.stoi = {ch: i for i, ch in enumerate(self.chars)} | |
| self.itos = {i: ch for i, ch in enumerate(self.chars)} | |
| self.vocab_size = len(self.chars) | |
| elif type == 'word': | |
| self.words = data.split() | |
| self.stoi = {word: i for i, word in enumerate(self.words)} | |
| self.itos = {i: word for i, word in enumerate(self.words)} | |
| self.vocab_size = len(self.words) | |
| else: | |
| raise ValueError('Type must be either "char" or "word"') | |
| def encode(self, string: str): | |
| if self.type == 'char': | |
| return torch.tensor([self.stoi[c] for c in string]) | |
| elif self.type == 'word': | |
| return torch.tensor([self.stoi[w] for w in string.split()]) | |
| else: | |
| raise ValueError('Type must be either "char" or "word"') | |
| def decode(self, ids: list): | |
| if self.type == 'char': | |
| return ''.join([self.itos[i] for i in ids]) | |
| elif self.type == 'word': | |
| return ' '.join([self.itos[i] for i in ids]) | |
| else: | |
| raise ValueError('Type must be either "char" or "word"') | |