|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
import os |
|
|
import zipfile |
|
|
import urllib.request |
|
|
import numpy as np |
|
|
|
|
|
class Enwik8Dataset(Dataset): |
|
|
""" |
|
|
Dataset for enwik8 (Hutter Prize). |
|
|
Downloads and loads the first 100MB of Wikipedia XML dump. |
|
|
""" |
|
|
URL = "http://mattmahoney.net/dc/enwik8.zip" |
|
|
FILE_NAME = "enwik8" |
|
|
|
|
|
def __init__(self, data_dir: str, seq_len: int = 1024, split: str = 'train'): |
|
|
self.seq_len = seq_len |
|
|
self.data_dir = data_dir |
|
|
self.file_path = os.path.join(data_dir, self.FILE_NAME) |
|
|
|
|
|
if not os.path.exists(self.file_path): |
|
|
self._download_and_extract() |
|
|
|
|
|
|
|
|
with open(self.file_path, 'rb') as f: |
|
|
data = np.frombuffer(f.read(), dtype=np.uint8) |
|
|
|
|
|
|
|
|
n = len(data) |
|
|
tr_split = int(n * 0.9) |
|
|
val_split = int(n * 0.95) |
|
|
|
|
|
if split == 'train': |
|
|
self.data = data[:tr_split] |
|
|
elif split == 'val': |
|
|
self.data = data[tr_split:val_split] |
|
|
else: |
|
|
self.data = data[val_split:] |
|
|
|
|
|
self.data = torch.from_numpy(self.data.copy()).long() |
|
|
|
|
|
def _download_and_extract(self): |
|
|
print(f"Downloading {self.URL}...") |
|
|
zip_path = os.path.join(self.data_dir, "enwik8.zip") |
|
|
urllib.request.urlretrieve(self.URL, zip_path) |
|
|
|
|
|
print("Extracting...") |
|
|
with zipfile.ZipFile(zip_path, 'r') as zip_ref: |
|
|
zip_ref.extractall(self.data_dir) |
|
|
|
|
|
def __len__(self): |
|
|
|
|
|
return len(self.data) - self.seq_len - 1 |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
chunk = self.data[idx : idx + self.seq_len] |
|
|
return chunk, chunk |
|
|
|
|
|
def get_enwik8_dataloader(data_dir: str, batch_size: int = 32, seq_len: int = 1024, split: str = 'train'): |
|
|
dataset = Enwik8Dataset(data_dir, seq_len, split) |
|
|
return DataLoader(dataset, batch_size=batch_size, shuffle=(split=='train'), num_workers=0) |
|
|
|