File size: 3,336 Bytes
1eae316
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
## Developer: inkbytefo
## Modified: 2025-11-22

import torch
from torch.utils.data import Dataset, DataLoader
import os
import zipfile
import urllib.request
import numpy as np

class Enwik8Dataset(Dataset):
    """
    Dataset for enwik8 (Hutter Prize).
    Downloads and loads the first 100MB of Wikipedia XML dump.
    """
    URL = "http://mattmahoney.net/dc/enwik8.zip"
    FILE_NAME = "enwik8"
    
    def __init__(self, data_dir: str, seq_len: int = 1024, split: str = 'train'):
        self.seq_len = seq_len
        self.data_dir = data_dir
        self.file_path = os.path.join(data_dir, self.FILE_NAME)
        
        if not os.path.exists(self.file_path):
            self._download_and_extract()
            
        # Load data into memory (100MB is small)
        with open(self.file_path, 'rb') as f:
            data = np.frombuffer(f.read(), dtype=np.uint8)
            
        # Split: 90MB Train, 5MB Val, 5MB Test
        n = len(data)
        tr_split = int(n * 0.9)
        val_split = int(n * 0.95)
        
        if split == 'train':
            self.data = data[:tr_split]
        elif split == 'val':
            self.data = data[tr_split:val_split]
        else:
            self.data = data[val_split:]
            
        self.data = torch.from_numpy(self.data.copy()).long() # Copy to avoid negative stride issues if any
        
    def _download_and_extract(self):
        print(f"Downloading {self.URL}...")
        zip_path = os.path.join(self.data_dir, "enwik8.zip")
        urllib.request.urlretrieve(self.URL, zip_path)
        
        print("Extracting...")
        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            zip_ref.extractall(self.data_dir)
            
    def __len__(self):
        # Number of possible sequences
        return len(self.data) - self.seq_len - 1

    def __getitem__(self, idx):
        # Random sampling is better for generalization in this context, 
        # but standard Dataset uses index. 
        # We can just return the slice at idx.
        
        # Input: [idx : idx + seq_len]
        # Target: [idx + 1 : idx + seq_len + 1] (Standard next token)
        # But our model expects:
        # Input: [idx : idx + seq_len]
        # Target: [idx + patch_size : idx + seq_len + patch_size] (Next Patch)
        
        # Wait, train.py handles the shifting for Next Patch.
        # So we just return a chunk of length SEQ_LEN + PATCH_SIZE?
        # Or just return SEQ_LEN and let train.py handle it?
        
        # train.py expects (seq, _) and does:
        # x = seq[:, :split_idx]
        # y = seq[:, PATCH_SIZE:]
        
        # So we need to provide a sequence of length SEQ_LEN (which includes the target part).
        # Actually, if we want x to be 1024, we need 1024 + patch_size bytes?
        # train.py: split_idx = seq.size(1) - PATCH_SIZE
        # So if seq is 1024, x is 1020, y is 1020.
        
        # Let's return exactly what's needed.
        chunk = self.data[idx : idx + self.seq_len]
        return chunk, chunk # Dummy target, train.py splits it

def get_enwik8_dataloader(data_dir: str, batch_size: int = 32, seq_len: int = 1024, split: str = 'train'):
    dataset = Enwik8Dataset(data_dir, seq_len, split)
    return DataLoader(dataset, batch_size=batch_size, shuffle=(split=='train'), num_workers=0)