2264K's picture
Upload data.py with huggingface_hub
bbed392 verified
"""
Phase 2-A Data: Vision (Moving MNIST) + Audio (WavJEPA) + Text (TinyStories)
Reuses pre-computed features from Phase 1 and Phase 1-B.
"""
import os
import torch
import numpy as np
from torch.utils.data import Dataset, DataLoader
class VisionDataset(Dataset):
"""Moving MNIST patches (from Phase 1)."""
def __init__(self, path="../phase1/mnist_test_seq.npy", n_frames=5,
patch_size=16, mask_ratio=0.5):
data = np.load(path)
self.data = torch.FloatTensor(data) / 255.0
self.n_frames = n_frames
self.patch_size = patch_size
self.mask_ratio = mask_ratio
self.total_patches = n_frames * (64 // patch_size) ** 2
def __len__(self):
return self.data.shape[1]
def __getitem__(self, idx):
frames = self.data[:self.n_frames, idx]
patches = self._to_patches(frames)
n_mask = int(self.total_patches * self.mask_ratio)
perm = torch.randperm(self.total_patches)
mask_idx = perm[:n_mask].sort().values
return {
"all_patches": patches,
"target_patches": patches[mask_idx],
"mask_idx": mask_idx,
}
def _to_patches(self, frames):
p = self.patch_size
patches = frames.unfold(1, p, p).unfold(2, p, p)
return patches.contiguous().view(-1, p * p)
class AudioDataset(Dataset):
"""Pre-computed WavJEPA features (from Phase 1-B)."""
def __init__(self, path="../phase1b/audio_features/audio_features.npy", mask_ratio=0.5):
self.features = np.load(path, mmap_mode='r')
self.n_tokens = self.features.shape[1]
self.feat_dim = self.features.shape[2]
self.mask_ratio = mask_ratio
def __len__(self):
return self.features.shape[0]
def __getitem__(self, idx):
features = torch.FloatTensor(np.array(self.features[idx]))
n_mask = int(self.n_tokens * self.mask_ratio)
perm = torch.randperm(self.n_tokens)
mask_idx = perm[:n_mask].sort().values
return {
"all_features": features,
"target_features": features[mask_idx],
"mask_idx": mask_idx,
}
class TextDataset(Dataset):
"""TinyStories (reused)."""
def __init__(self, path="../phase1/TinyStoriesV2-GPT4-train.txt",
seq_len=128, vocab_size=10000):
import tiktoken
self.seq_len = seq_len
self.enc = tiktoken.get_encoding("gpt2")
with open(path, "r", encoding="utf-8") as f:
text = f.read()
tokens = self.enc.encode(text)
tokens = [t for t in tokens if t < vocab_size]
self.tokens = torch.LongTensor(tokens)
print(f"TinyStories: {len(self.tokens):,} tokens")
def __len__(self):
return max(0, len(self.tokens) - self.seq_len - 1)
def __getitem__(self, idx):
x = self.tokens[idx:idx + self.seq_len]
y = self.tokens[idx + 1:idx + self.seq_len + 1]
return x, y
def collate_vision(batch):
return {
"all_patches": torch.stack([b["all_patches"] for b in batch]),
"target_patches": torch.stack([b["target_patches"] for b in batch]),
"mask_idx": torch.stack([b["mask_idx"] for b in batch]),
}
def collate_audio(batch):
return {
"all_features": torch.stack([b["all_features"] for b in batch]),
"target_features": torch.stack([b["target_features"] for b in batch]),
"mask_idx": torch.stack([b["mask_idx"] for b in batch]),
}
class TriModalDataLoader:
"""Yields (vision, audio, text) triples. Shortest dataset cycles."""
def __init__(self, v_ds, a_ds, t_ds, batch_size=32):
self.v_loader = DataLoader(v_ds, batch_size=batch_size, shuffle=True,
collate_fn=collate_vision, drop_last=True, pin_memory=True)
self.a_loader = DataLoader(a_ds, batch_size=batch_size, shuffle=True,
collate_fn=collate_audio, drop_last=True, pin_memory=True)
self.t_loader = DataLoader(t_ds, batch_size=batch_size, shuffle=True,
drop_last=True, pin_memory=True)
# Iterate based on smallest dataset
self.n_batches = min(len(self.v_loader), len(self.a_loader))
def __iter__(self):
v_iter = iter(self.v_loader)
a_iter = iter(self.a_loader)
t_iter = iter(self.t_loader)
for _ in range(self.n_batches):
try:
v_batch = next(v_iter)
except StopIteration:
v_iter = iter(self.v_loader)
v_batch = next(v_iter)
try:
a_batch = next(a_iter)
except StopIteration:
a_iter = iter(self.a_loader)
a_batch = next(a_iter)
try:
t_batch = next(t_iter)
except StopIteration:
t_iter = iter(self.t_loader)
t_batch = next(t_iter)
yield v_batch, a_batch, t_batch
def __len__(self):
return self.n_batches
if __name__ == "__main__":
from model import CONFIG
# Quick test with synthetic
class FakeVision(Dataset):
def __len__(self): return 100
def __getitem__(self, i):
p = torch.randn(80, 256)
m = torch.arange(40)
return {"all_patches": p, "target_patches": p[:40], "mask_idx": m}
class FakeAudio(Dataset):
def __len__(self): return 100
def __getitem__(self, i):
f = torch.randn(200, 768)
m = torch.arange(100)
return {"all_features": f, "target_features": f[:100], "mask_idx": m}
class FakeText(Dataset):
def __len__(self): return 100
def __getitem__(self, i):
return torch.randint(0, 10000, (128,)), torch.randint(0, 10000, (128,))
loader = TriModalDataLoader(FakeVision(), FakeAudio(), FakeText(), batch_size=4)
for v, a, t in loader:
print(f"V: {v['all_patches'].shape}, A: {a['all_features'].shape}, T: {t[0].shape}")
break
print("DataLoader OK")