agiformer / src /data /real_data.py
tefoteknik's picture
Upload src/data/real_data.py with huggingface_hub
1eae316 verified
## Developer: inkbytefo
## Modified: 2025-11-22
import torch
from torch.utils.data import Dataset, DataLoader
import os
import zipfile
import urllib.request
import numpy as np
class Enwik8Dataset(Dataset):
"""
Dataset for enwik8 (Hutter Prize).
Downloads and loads the first 100MB of Wikipedia XML dump.
"""
URL = "http://mattmahoney.net/dc/enwik8.zip"
FILE_NAME = "enwik8"
def __init__(self, data_dir: str, seq_len: int = 1024, split: str = 'train'):
self.seq_len = seq_len
self.data_dir = data_dir
self.file_path = os.path.join(data_dir, self.FILE_NAME)
if not os.path.exists(self.file_path):
self._download_and_extract()
# Load data into memory (100MB is small)
with open(self.file_path, 'rb') as f:
data = np.frombuffer(f.read(), dtype=np.uint8)
# Split: 90MB Train, 5MB Val, 5MB Test
n = len(data)
tr_split = int(n * 0.9)
val_split = int(n * 0.95)
if split == 'train':
self.data = data[:tr_split]
elif split == 'val':
self.data = data[tr_split:val_split]
else:
self.data = data[val_split:]
self.data = torch.from_numpy(self.data.copy()).long() # Copy to avoid negative stride issues if any
def _download_and_extract(self):
print(f"Downloading {self.URL}...")
zip_path = os.path.join(self.data_dir, "enwik8.zip")
urllib.request.urlretrieve(self.URL, zip_path)
print("Extracting...")
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
zip_ref.extractall(self.data_dir)
def __len__(self):
# Number of possible sequences
return len(self.data) - self.seq_len - 1
def __getitem__(self, idx):
# Random sampling is better for generalization in this context,
# but standard Dataset uses index.
# We can just return the slice at idx.
# Input: [idx : idx + seq_len]
# Target: [idx + 1 : idx + seq_len + 1] (Standard next token)
# But our model expects:
# Input: [idx : idx + seq_len]
# Target: [idx + patch_size : idx + seq_len + patch_size] (Next Patch)
# Wait, train.py handles the shifting for Next Patch.
# So we just return a chunk of length SEQ_LEN + PATCH_SIZE?
# Or just return SEQ_LEN and let train.py handle it?
# train.py expects (seq, _) and does:
# x = seq[:, :split_idx]
# y = seq[:, PATCH_SIZE:]
# So we need to provide a sequence of length SEQ_LEN (which includes the target part).
# Actually, if we want x to be 1024, we need 1024 + patch_size bytes?
# train.py: split_idx = seq.size(1) - PATCH_SIZE
# So if seq is 1024, x is 1020, y is 1020.
# Let's return exactly what's needed.
chunk = self.data[idx : idx + self.seq_len]
return chunk, chunk # Dummy target, train.py splits it
def get_enwik8_dataloader(data_dir: str, batch_size: int = 32, seq_len: int = 1024, split: str = 'train'):
dataset = Enwik8Dataset(data_dir, seq_len, split)
return DataLoader(dataset, batch_size=batch_size, shuffle=(split=='train'), num_workers=0)