| """
|
| Phase 2-A Data: Vision (Moving MNIST) + Audio (WavJEPA) + Text (TinyStories)
|
| Reuses pre-computed features from Phase 1 and Phase 1-B.
|
| """
|
|
|
| import os
|
| import torch
|
| import numpy as np
|
| from torch.utils.data import Dataset, DataLoader
|
|
|
|
|
| class VisionDataset(Dataset):
|
| """Moving MNIST patches (from Phase 1)."""
|
|
|
| def __init__(self, path="../phase1/mnist_test_seq.npy", n_frames=5,
|
| patch_size=16, mask_ratio=0.5):
|
| data = np.load(path)
|
| self.data = torch.FloatTensor(data) / 255.0
|
| self.n_frames = n_frames
|
| self.patch_size = patch_size
|
| self.mask_ratio = mask_ratio
|
| self.total_patches = n_frames * (64 // patch_size) ** 2
|
|
|
| def __len__(self):
|
| return self.data.shape[1]
|
|
|
| def __getitem__(self, idx):
|
| frames = self.data[:self.n_frames, idx]
|
| patches = self._to_patches(frames)
|
| n_mask = int(self.total_patches * self.mask_ratio)
|
| perm = torch.randperm(self.total_patches)
|
| mask_idx = perm[:n_mask].sort().values
|
|
|
| return {
|
| "all_patches": patches,
|
| "target_patches": patches[mask_idx],
|
| "mask_idx": mask_idx,
|
| }
|
|
|
| def _to_patches(self, frames):
|
| p = self.patch_size
|
| patches = frames.unfold(1, p, p).unfold(2, p, p)
|
| return patches.contiguous().view(-1, p * p)
|
|
|
|
|
| class AudioDataset(Dataset):
|
| """Pre-computed WavJEPA features (from Phase 1-B)."""
|
|
|
| def __init__(self, path="../phase1b/audio_features/audio_features.npy", mask_ratio=0.5):
|
| self.features = np.load(path, mmap_mode='r')
|
| self.n_tokens = self.features.shape[1]
|
| self.feat_dim = self.features.shape[2]
|
| self.mask_ratio = mask_ratio
|
|
|
| def __len__(self):
|
| return self.features.shape[0]
|
|
|
| def __getitem__(self, idx):
|
| features = torch.FloatTensor(np.array(self.features[idx]))
|
| n_mask = int(self.n_tokens * self.mask_ratio)
|
| perm = torch.randperm(self.n_tokens)
|
| mask_idx = perm[:n_mask].sort().values
|
|
|
| return {
|
| "all_features": features,
|
| "target_features": features[mask_idx],
|
| "mask_idx": mask_idx,
|
| }
|
|
|
|
|
| class TextDataset(Dataset):
|
| """TinyStories (reused)."""
|
|
|
| def __init__(self, path="../phase1/TinyStoriesV2-GPT4-train.txt",
|
| seq_len=128, vocab_size=10000):
|
| import tiktoken
|
| self.seq_len = seq_len
|
| self.enc = tiktoken.get_encoding("gpt2")
|
|
|
| with open(path, "r", encoding="utf-8") as f:
|
| text = f.read()
|
| tokens = self.enc.encode(text)
|
| tokens = [t for t in tokens if t < vocab_size]
|
| self.tokens = torch.LongTensor(tokens)
|
| print(f"TinyStories: {len(self.tokens):,} tokens")
|
|
|
| def __len__(self):
|
| return max(0, len(self.tokens) - self.seq_len - 1)
|
|
|
| def __getitem__(self, idx):
|
| x = self.tokens[idx:idx + self.seq_len]
|
| y = self.tokens[idx + 1:idx + self.seq_len + 1]
|
| return x, y
|
|
|
|
|
| def collate_vision(batch):
|
| return {
|
| "all_patches": torch.stack([b["all_patches"] for b in batch]),
|
| "target_patches": torch.stack([b["target_patches"] for b in batch]),
|
| "mask_idx": torch.stack([b["mask_idx"] for b in batch]),
|
| }
|
|
|
|
|
| def collate_audio(batch):
|
| return {
|
| "all_features": torch.stack([b["all_features"] for b in batch]),
|
| "target_features": torch.stack([b["target_features"] for b in batch]),
|
| "mask_idx": torch.stack([b["mask_idx"] for b in batch]),
|
| }
|
|
|
|
|
| class TriModalDataLoader:
|
| """Yields (vision, audio, text) triples. Shortest dataset cycles."""
|
|
|
| def __init__(self, v_ds, a_ds, t_ds, batch_size=32):
|
| self.v_loader = DataLoader(v_ds, batch_size=batch_size, shuffle=True,
|
| collate_fn=collate_vision, drop_last=True, pin_memory=True)
|
| self.a_loader = DataLoader(a_ds, batch_size=batch_size, shuffle=True,
|
| collate_fn=collate_audio, drop_last=True, pin_memory=True)
|
| self.t_loader = DataLoader(t_ds, batch_size=batch_size, shuffle=True,
|
| drop_last=True, pin_memory=True)
|
|
|
| self.n_batches = min(len(self.v_loader), len(self.a_loader))
|
|
|
| def __iter__(self):
|
| v_iter = iter(self.v_loader)
|
| a_iter = iter(self.a_loader)
|
| t_iter = iter(self.t_loader)
|
|
|
| for _ in range(self.n_batches):
|
| try:
|
| v_batch = next(v_iter)
|
| except StopIteration:
|
| v_iter = iter(self.v_loader)
|
| v_batch = next(v_iter)
|
| try:
|
| a_batch = next(a_iter)
|
| except StopIteration:
|
| a_iter = iter(self.a_loader)
|
| a_batch = next(a_iter)
|
| try:
|
| t_batch = next(t_iter)
|
| except StopIteration:
|
| t_iter = iter(self.t_loader)
|
| t_batch = next(t_iter)
|
| yield v_batch, a_batch, t_batch
|
|
|
| def __len__(self):
|
| return self.n_batches
|
|
|
|
|
| if __name__ == "__main__":
|
| from model import CONFIG
|
|
|
| class FakeVision(Dataset):
|
| def __len__(self): return 100
|
| def __getitem__(self, i):
|
| p = torch.randn(80, 256)
|
| m = torch.arange(40)
|
| return {"all_patches": p, "target_patches": p[:40], "mask_idx": m}
|
|
|
| class FakeAudio(Dataset):
|
| def __len__(self): return 100
|
| def __getitem__(self, i):
|
| f = torch.randn(200, 768)
|
| m = torch.arange(100)
|
| return {"all_features": f, "target_features": f[:100], "mask_idx": m}
|
|
|
| class FakeText(Dataset):
|
| def __len__(self): return 100
|
| def __getitem__(self, i):
|
| return torch.randint(0, 10000, (128,)), torch.randint(0, 10000, (128,))
|
|
|
| loader = TriModalDataLoader(FakeVision(), FakeAudio(), FakeText(), batch_size=4)
|
| for v, a, t in loader:
|
| print(f"V: {v['all_patches'].shape}, A: {a['all_features'].shape}, T: {t[0].shape}")
|
| break
|
| print("DataLoader OK")
|
|
|