File size: 9,262 Bytes
8cb233a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
"""
Data download and tokenization pipeline for H4 Polytopic Attention experiments.

Supports multiple datasets with automatic download and caching:
  - synthetic: Fibonacci-structured phrases (no download needed)
  - shakespeare: Tiny Shakespeare (~1MB character-level text)
  - tinystories: TinyStories from HuggingFace (real children's stories)

All datasets return the same interface:
    (train_data, val_data, vocab_size, stoi, itos)
"""

import os
import sys
import json
import torch
import urllib.request

DATA_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), '..', 'data')

DATASETS = {
    'synthetic': {
        'source': 'synthetic',
        'description': 'Fibonacci-structured phrases (built-in)',
    },
    'shakespeare': {
        'source': 'url',
        'url': 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt',
        'filename': 'shakespeare.txt',
        'description': 'Tiny Shakespeare (~1MB, character-level)',
    },
    'tinystories': {
        'source': 'huggingface',
        'path': 'roneneldan/TinyStories',
        'split': 'train',
        'val_split': 'validation',
        'filename': 'tinystories.txt',
        'val_filename': 'tinystories_val.txt',
        'description': 'TinyStories (HuggingFace, real children\'s stories)',
        # Fallback URL if HF datasets library is not installed
        'fallback_url': None,  # Too large for raw URL fallback
    },
}


def _ensure_data_dir():
    """Create data/ directory if it doesn't exist."""
    os.makedirs(DATA_DIR, exist_ok=True)


def _download_url(url, filepath):
    """Download a file from URL using urllib (stdlib)."""
    print(f"Downloading {url} ...")
    try:
        urllib.request.urlretrieve(url, filepath)
        print(f"  Saved to {filepath} ({os.path.getsize(filepath)} bytes)")
        return True
    except Exception as e:
        print(f"  Download failed: {e}")
        return False


def _generate_synthetic_text():
    """Generate synthetic text with Fibonacci-structured repetitions."""
    base_phrases = [
        "the golden ratio appears in nature ",
        "fibonacci numbers grow exponentially ",
        "symmetry underlies all of physics ",
        "the icosahedron has twenty faces ",
        "phi equals one plus one over phi ",
        "geometry is the language of space ",
        "five fold symmetry cannot tile a plane ",
        "the dodecahedron has twelve faces ",
    ]
    text = ""
    a, b = 1, 1
    for _ in range(200):
        phrase = base_phrases[a % len(base_phrases)]
        text += phrase * (b % 3 + 1)
        a, b = b, a + b
    return text


def _load_shakespeare():
    """Download and return Tiny Shakespeare text."""
    _ensure_data_dir()
    cfg = DATASETS['shakespeare']
    filepath = os.path.join(DATA_DIR, cfg['filename'])

    if not os.path.exists(filepath):
        if not _download_url(cfg['url'], filepath):
            print("Shakespeare download failed, falling back to synthetic data.")
            return None

    with open(filepath, 'r', encoding='utf-8') as f:
        text = f.read()
    print(f"Loaded Shakespeare: {len(text):,} chars")
    return text


def _load_tinystories():
    """Load TinyStories from HuggingFace datasets or cached files."""
    _ensure_data_dir()
    cfg = DATASETS['tinystories']
    train_path = os.path.join(DATA_DIR, cfg['filename'])
    val_path = os.path.join(DATA_DIR, cfg['val_filename'])

    # Check cache first
    if os.path.exists(train_path) and os.path.exists(val_path):
        with open(train_path, 'r', encoding='utf-8') as f:
            train_text = f.read()
        with open(val_path, 'r', encoding='utf-8') as f:
            val_text = f.read()
        print(f"Loaded TinyStories from cache: train={len(train_text):,} chars, val={len(val_text):,} chars")
        return train_text, val_text

    # Try HuggingFace datasets library
    try:
        from datasets import load_dataset as hf_load_dataset
        print("Loading TinyStories from HuggingFace (this may take a while)...")
        ds = hf_load_dataset(cfg['path'])

        # Extract text — TinyStories has a 'text' field
        # Limit to first 5M chars for manageability on CPU
        MAX_CHARS = 5_000_000
        train_text = ""
        for item in ds[cfg['split']]:
            train_text += item['text'] + "\n"
            if len(train_text) >= MAX_CHARS:
                train_text = train_text[:MAX_CHARS]
                break

        val_text = ""
        for item in ds[cfg['val_split']]:
            val_text += item['text'] + "\n"
            if len(val_text) >= MAX_CHARS // 10:
                val_text = val_text[:MAX_CHARS // 10]
                break

        # Cache to disk
        with open(train_path, 'w', encoding='utf-8') as f:
            f.write(train_text)
        with open(val_path, 'w', encoding='utf-8') as f:
            f.write(val_text)

        print(f"TinyStories loaded and cached: train={len(train_text):,} chars, val={len(val_text):,} chars")
        return train_text, val_text

    except ImportError:
        print("HuggingFace 'datasets' library not installed.")
        print("Install with: pip install datasets")
        print("Falling back to synthetic data.")
        return None
    except Exception as e:
        print(f"Failed to load TinyStories: {e}")
        print("Falling back to synthetic data.")
        return None


def prepare_char_dataset(text, val_text=None):
    """Prepare character-level dataset from text.

    Returns:
        (train_data, val_data, vocab_size, stoi, itos)
    """
    if val_text is not None:
        # Pre-split data: build vocab from both
        all_text = text + val_text
    else:
        all_text = text

    chars = sorted(list(set(all_text)))
    vocab_size = len(chars)
    stoi = {ch: i for i, ch in enumerate(chars)}
    itos = {i: ch for ch, i in stoi.items()}

    if val_text is not None:
        train_data = torch.tensor([stoi[c] for c in text], dtype=torch.long)
        val_data = torch.tensor([stoi[c] for c in val_text], dtype=torch.long)
    else:
        data = torch.tensor([stoi[c] for c in text], dtype=torch.long)
        n = int(0.9 * len(data))
        train_data = data[:n]
        val_data = data[n:]

    return train_data, val_data, vocab_size, stoi, itos


def load_dataset(name='shakespeare'):
    """Load a dataset by name. Returns raw text (or tuple for pre-split datasets).

    For use with train_cpu.py's load_text_data() replacement.

    Args:
        name: 'synthetic', 'shakespeare', or 'tinystories'

    Returns:
        text (str) for single-text datasets, or
        (train_text, val_text) for pre-split datasets, or
        None on failure (caller should fall back to synthetic)
    """
    if name == 'synthetic':
        return _generate_synthetic_text()
    elif name == 'shakespeare':
        return _load_shakespeare()
    elif name == 'tinystories':
        return _load_tinystories()
    else:
        print(f"Unknown dataset: {name}. Available: {list(DATASETS.keys())}")
        return None


def load_and_prepare(name='shakespeare'):
    """Full pipeline: download, tokenize, return ready-to-train tensors.

    Returns:
        (train_data, val_data, vocab_size, stoi, itos)
    """
    result = load_dataset(name)

    if result is None:
        # Fall back to synthetic
        print("Using synthetic fallback data.")
        text = _generate_synthetic_text()
        return prepare_char_dataset(text)

    if isinstance(result, tuple):
        # Pre-split dataset (e.g., TinyStories)
        train_text, val_text = result
        return prepare_char_dataset(train_text, val_text)
    else:
        # Single text, will be split 90/10
        return prepare_char_dataset(result)


def list_datasets():
    """Print available datasets."""
    print("Available datasets:")
    for name, cfg in DATASETS.items():
        cached = ""
        if cfg['source'] == 'url':
            path = os.path.join(DATA_DIR, cfg.get('filename', ''))
            if os.path.exists(path):
                cached = f" [cached: {os.path.getsize(path):,} bytes]"
        elif cfg['source'] == 'huggingface':
            path = os.path.join(DATA_DIR, cfg.get('filename', ''))
            if os.path.exists(path):
                cached = f" [cached: {os.path.getsize(path):,} bytes]"
        print(f"  {name:15s}{cfg['description']}{cached}")


if __name__ == '__main__':
    import argparse
    parser = argparse.ArgumentParser(description='Prepare datasets for H4 experiments')
    parser.add_argument('dataset', nargs='?', default='shakespeare',
                        choices=list(DATASETS.keys()),
                        help='Dataset to prepare (default: shakespeare)')
    parser.add_argument('--list', action='store_true', help='List available datasets')
    args = parser.parse_args()

    if args.list:
        list_datasets()
        sys.exit(0)

    train_data, val_data, vocab_size, stoi, itos = load_and_prepare(args.dataset)
    print(f"\nDataset: {args.dataset}")
    print(f"Vocab size: {vocab_size}")
    print(f"Train tokens: {len(train_data):,}")
    print(f"Val tokens: {len(val_data):,}")
    print(f"Sample chars: {''.join(itos[i] for i in train_data[:80].tolist())}")