""" Phase 2-A Data: Vision (Moving MNIST) + Audio (WavJEPA) + Text (TinyStories) Reuses pre-computed features from Phase 1 and Phase 1-B. """ import os import torch import numpy as np from torch.utils.data import Dataset, DataLoader class VisionDataset(Dataset): """Moving MNIST patches (from Phase 1).""" def __init__(self, path="../phase1/mnist_test_seq.npy", n_frames=5, patch_size=16, mask_ratio=0.5): data = np.load(path) self.data = torch.FloatTensor(data) / 255.0 self.n_frames = n_frames self.patch_size = patch_size self.mask_ratio = mask_ratio self.total_patches = n_frames * (64 // patch_size) ** 2 def __len__(self): return self.data.shape[1] def __getitem__(self, idx): frames = self.data[:self.n_frames, idx] patches = self._to_patches(frames) n_mask = int(self.total_patches * self.mask_ratio) perm = torch.randperm(self.total_patches) mask_idx = perm[:n_mask].sort().values return { "all_patches": patches, "target_patches": patches[mask_idx], "mask_idx": mask_idx, } def _to_patches(self, frames): p = self.patch_size patches = frames.unfold(1, p, p).unfold(2, p, p) return patches.contiguous().view(-1, p * p) class AudioDataset(Dataset): """Pre-computed WavJEPA features (from Phase 1-B).""" def __init__(self, path="../phase1b/audio_features/audio_features.npy", mask_ratio=0.5): self.features = np.load(path, mmap_mode='r') self.n_tokens = self.features.shape[1] self.feat_dim = self.features.shape[2] self.mask_ratio = mask_ratio def __len__(self): return self.features.shape[0] def __getitem__(self, idx): features = torch.FloatTensor(np.array(self.features[idx])) n_mask = int(self.n_tokens * self.mask_ratio) perm = torch.randperm(self.n_tokens) mask_idx = perm[:n_mask].sort().values return { "all_features": features, "target_features": features[mask_idx], "mask_idx": mask_idx, } class TextDataset(Dataset): """TinyStories (reused).""" def __init__(self, path="../phase1/TinyStoriesV2-GPT4-train.txt", seq_len=128, vocab_size=10000): import tiktoken self.seq_len = seq_len self.enc = tiktoken.get_encoding("gpt2") with open(path, "r", encoding="utf-8") as f: text = f.read() tokens = self.enc.encode(text) tokens = [t for t in tokens if t < vocab_size] self.tokens = torch.LongTensor(tokens) print(f"TinyStories: {len(self.tokens):,} tokens") def __len__(self): return max(0, len(self.tokens) - self.seq_len - 1) def __getitem__(self, idx): x = self.tokens[idx:idx + self.seq_len] y = self.tokens[idx + 1:idx + self.seq_len + 1] return x, y def collate_vision(batch): return { "all_patches": torch.stack([b["all_patches"] for b in batch]), "target_patches": torch.stack([b["target_patches"] for b in batch]), "mask_idx": torch.stack([b["mask_idx"] for b in batch]), } def collate_audio(batch): return { "all_features": torch.stack([b["all_features"] for b in batch]), "target_features": torch.stack([b["target_features"] for b in batch]), "mask_idx": torch.stack([b["mask_idx"] for b in batch]), } class TriModalDataLoader: """Yields (vision, audio, text) triples. Shortest dataset cycles.""" def __init__(self, v_ds, a_ds, t_ds, batch_size=32): self.v_loader = DataLoader(v_ds, batch_size=batch_size, shuffle=True, collate_fn=collate_vision, drop_last=True, pin_memory=True) self.a_loader = DataLoader(a_ds, batch_size=batch_size, shuffle=True, collate_fn=collate_audio, drop_last=True, pin_memory=True) self.t_loader = DataLoader(t_ds, batch_size=batch_size, shuffle=True, drop_last=True, pin_memory=True) # Iterate based on smallest dataset self.n_batches = min(len(self.v_loader), len(self.a_loader)) def __iter__(self): v_iter = iter(self.v_loader) a_iter = iter(self.a_loader) t_iter = iter(self.t_loader) for _ in range(self.n_batches): try: v_batch = next(v_iter) except StopIteration: v_iter = iter(self.v_loader) v_batch = next(v_iter) try: a_batch = next(a_iter) except StopIteration: a_iter = iter(self.a_loader) a_batch = next(a_iter) try: t_batch = next(t_iter) except StopIteration: t_iter = iter(self.t_loader) t_batch = next(t_iter) yield v_batch, a_batch, t_batch def __len__(self): return self.n_batches if __name__ == "__main__": from model import CONFIG # Quick test with synthetic class FakeVision(Dataset): def __len__(self): return 100 def __getitem__(self, i): p = torch.randn(80, 256) m = torch.arange(40) return {"all_patches": p, "target_patches": p[:40], "mask_idx": m} class FakeAudio(Dataset): def __len__(self): return 100 def __getitem__(self, i): f = torch.randn(200, 768) m = torch.arange(100) return {"all_features": f, "target_features": f[:100], "mask_idx": m} class FakeText(Dataset): def __len__(self): return 100 def __getitem__(self, i): return torch.randint(0, 10000, (128,)), torch.randint(0, 10000, (128,)) loader = TriModalDataLoader(FakeVision(), FakeAudio(), FakeText(), batch_size=4) for v, a, t in loader: print(f"V: {v['all_patches'].shape}, A: {a['all_features'].shape}, T: {t[0].shape}") break print("DataLoader OK")