File size: 6,748 Bytes
54fa103
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
"""Pre-generate and disk-cache all PDE training datasets.

Run once before starting experiments:
    uv run prefetch_data.py

This eliminates per-experiment data generation overhead. Each benchmark
is generated in a separate thread. Total time: dominated by ns_hre_2d
(10-20 min). All others finish in under 5 min combined.

After this script completes, every train.py subprocess will load data
from disk in < 5s instead of regenerating from scratch.
"""

import os
import sys
import time
import threading
import numpy as np
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor, as_completed

# ── Locate cache dir (mirrors prepare.py) ─────────────────────────────────────
CACHE_DIR = os.path.join(os.path.expanduser("~"), ".cache", "sciml_autoresearch")
os.makedirs(CACHE_DIR, exist_ok=True)

_print_lock = threading.Lock()

def log(bm: str, msg: str):
    with _print_lock:
        print(f"  [{bm}] {msg}", flush=True)


# ── Per-source prefetch functions ─────────────────────────────────────────────

def prefetch_ext(benchmark: str) -> str:
    """Cache train+val for benchmarks_ext benchmarks (kdv, wave, darcy, ns_2d)."""
    from data.benchmarks_ext import (
        _get_ext_train, _load_or_gen_ext_val,
        _get_ext_val_cache, _get_ext_train_cache_path, N_TRAIN,
    )
    results = []

    # Val
    val_path = _get_ext_val_cache(benchmark)
    if os.path.exists(val_path):
        log(benchmark, f"val already cached ({Path(val_path).stat().st_size // 1024}KB)")
    else:
        log(benchmark, "generating val set…")
        t0 = time.time()
        _load_or_gen_ext_val(benchmark)
        log(benchmark, f"val cached in {time.time()-t0:.0f}s β†’ {Path(val_path).name}")
    results.append("val")

    # Train
    train_path = _get_ext_train_cache_path(benchmark)
    if os.path.exists(train_path):
        log(benchmark, f"train already cached ({Path(train_path).stat().st_size // 1024 // 1024}MB)")
    else:
        log(benchmark, f"generating train set ({N_TRAIN} samples)…")
        t0 = time.time()
        _get_ext_train(benchmark)
        log(benchmark, f"train cached in {time.time()-t0:.0f}s β†’ {Path(train_path).name}")
    results.append("train")

    return f"{benchmark}: {', '.join(results)} βœ“"


def prefetch_sim(benchmark: str) -> str:
    """Cache train+val for simulation benchmarks (euler_1d, swe_2d, allen_cahn_2d, ns_hre_2d)."""
    from data.simulations import _load_or_generate, _cache_path

    results = []
    for split in ("val", "train"):
        path = _cache_path(benchmark, split)
        if os.path.exists(path):
            size_mb = Path(path).stat().st_size // 1024 // 1024
            log(benchmark, f"{split} already cached ({size_mb}MB)")
        else:
            log(benchmark, f"generating {split} set…")
            t0 = time.time()
            _load_or_generate(benchmark, split)
            log(benchmark, f"{split} cached in {time.time()-t0:.0f}s β†’ {Path(path).name}")
        results.append(split)

    return f"{benchmark}: {', '.join(results)} βœ“"


def prefetch_burgers() -> str:
    """Cache burgers_1d train+val to disk (prepare.py only does in-memory caching)."""
    from data.prepare import (
        _generate_dataset, _load_or_gen_val, N_TRAIN, N_VAL, TRAIN_SEED, CACHE_DIR, GRID_SIZE
    )

    cache_train = os.path.join(CACHE_DIR, f"burgers_1d_train_N{GRID_SIZE}.npz")

    # Val (prepare.py already caches this, just warm it)
    log("burgers_1d", "warming val cache…")
    _load_or_gen_val("burgers_1d")
    log("burgers_1d", "val ready")

    # Train
    if os.path.exists(cache_train):
        size_mb = Path(cache_train).stat().st_size // 1024 // 1024
        log("burgers_1d", f"train already cached ({size_mb}MB)")
    else:
        log("burgers_1d", f"generating train set ({N_TRAIN} samples)…")
        t0 = time.time()
        inputs, targets = _generate_dataset("burgers_1d", N_TRAIN, TRAIN_SEED)
        np.savez(cache_train, inputs=inputs, targets=targets)
        log("burgers_1d", f"train cached in {time.time()-t0:.0f}s β†’ {Path(cache_train).name}")

    return "burgers_1d: val, train βœ“"




# ── Main ──────────────────────────────────────────────────────────────────────

TASKS = [
    # (label, fn, args)
    ("burgers_1d",    prefetch_burgers,        ()),
    ("kdv_1d",        prefetch_ext,            ("kdv_1d",)),
    ("wave_1d",       prefetch_ext,            ("wave_1d",)),
    ("darcy_2d",  prefetch_ext,            ("darcy_2d",)),
    ("ns_2d",         prefetch_ext,            ("ns_2d",)),
    ("ns_hre_2d",     prefetch_ext,            ("ns_hre_2d",)),
    ("swe_2d",        prefetch_ext,            ("swe_2d",)),
    ("allen_cahn_2d", prefetch_ext,            ("allen_cahn_2d",)),
    ("mhd_2d",        prefetch_ext,            ("mhd_2d",)),
    ("euler_1d",      prefetch_sim,            ("euler_1d",)),
]

# ns_hre_2d takes 10-20 min and is CPU-bound; run it alone so it gets full CPU
SERIAL_TASKS  = {"ns_hre_2d"}
PARALLEL_TASKS = [t for t in TASKS if t[0] not in SERIAL_TASKS]
SLOW_TASKS     = [t for t in TASKS if t[0] in SERIAL_TASKS]

if __name__ == "__main__":
    skip_slow = "--skip-slow" in sys.argv
    print(f"\nPrefetching all PDE datasets β†’ {CACHE_DIR}")
    print(f"Parallel: {[t[0] for t in PARALLEL_TASKS]}")
    if skip_slow:
        print(f"Skipping slow: {[t[0] for t in SLOW_TASKS]} (pass without --skip-slow to include)")
    else:
        print(f"Serial (slow): {[t[0] for t in SLOW_TASKS]}")
    print()

    t_total = time.time()
    results = []

    # Run fast benchmarks in parallel
    with ThreadPoolExecutor(max_workers=len(PARALLEL_TASKS)) as pool:
        futures = {pool.submit(fn, *args): label for label, fn, args in PARALLEL_TASKS}
        for fut in as_completed(futures):
            try:
                results.append(fut.result())
            except Exception as e:
                label = futures[fut]
                results.append(f"{label}: ERROR β€” {e}")

    # Run slow benchmarks serially (give them full CPU)
    if not skip_slow:
        for label, fn, args in SLOW_TASKS:
            try:
                results.append(fn(*args))
            except Exception as e:
                results.append(f"{label}: ERROR β€” {e}")

    print(f"\n{'━'*60}")
    print(f"Done in {time.time()-t_total:.0f}s\n")
    for r in results:
        print(f"  {r}")
    print(f"\nAll datasets cached. Experiments will now load data in <5s.")