Spaces:
Runtime error
Runtime error
File size: 6,748 Bytes
54fa103 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 | """Pre-generate and disk-cache all PDE training datasets.
Run once before starting experiments:
uv run prefetch_data.py
This eliminates per-experiment data generation overhead. Each benchmark
is generated in a separate thread. Total time: dominated by ns_hre_2d
(10-20 min). All others finish in under 5 min combined.
After this script completes, every train.py subprocess will load data
from disk in < 5s instead of regenerating from scratch.
"""
import os
import sys
import time
import threading
import numpy as np
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor, as_completed
# ββ Locate cache dir (mirrors prepare.py) βββββββββββββββββββββββββββββββββββββ
CACHE_DIR = os.path.join(os.path.expanduser("~"), ".cache", "sciml_autoresearch")
os.makedirs(CACHE_DIR, exist_ok=True)
_print_lock = threading.Lock()
def log(bm: str, msg: str):
with _print_lock:
print(f" [{bm}] {msg}", flush=True)
# ββ Per-source prefetch functions βββββββββββββββββββββββββββββββββββββββββββββ
def prefetch_ext(benchmark: str) -> str:
"""Cache train+val for benchmarks_ext benchmarks (kdv, wave, darcy, ns_2d)."""
from data.benchmarks_ext import (
_get_ext_train, _load_or_gen_ext_val,
_get_ext_val_cache, _get_ext_train_cache_path, N_TRAIN,
)
results = []
# Val
val_path = _get_ext_val_cache(benchmark)
if os.path.exists(val_path):
log(benchmark, f"val already cached ({Path(val_path).stat().st_size // 1024}KB)")
else:
log(benchmark, "generating val setβ¦")
t0 = time.time()
_load_or_gen_ext_val(benchmark)
log(benchmark, f"val cached in {time.time()-t0:.0f}s β {Path(val_path).name}")
results.append("val")
# Train
train_path = _get_ext_train_cache_path(benchmark)
if os.path.exists(train_path):
log(benchmark, f"train already cached ({Path(train_path).stat().st_size // 1024 // 1024}MB)")
else:
log(benchmark, f"generating train set ({N_TRAIN} samples)β¦")
t0 = time.time()
_get_ext_train(benchmark)
log(benchmark, f"train cached in {time.time()-t0:.0f}s β {Path(train_path).name}")
results.append("train")
return f"{benchmark}: {', '.join(results)} β"
def prefetch_sim(benchmark: str) -> str:
"""Cache train+val for simulation benchmarks (euler_1d, swe_2d, allen_cahn_2d, ns_hre_2d)."""
from data.simulations import _load_or_generate, _cache_path
results = []
for split in ("val", "train"):
path = _cache_path(benchmark, split)
if os.path.exists(path):
size_mb = Path(path).stat().st_size // 1024 // 1024
log(benchmark, f"{split} already cached ({size_mb}MB)")
else:
log(benchmark, f"generating {split} setβ¦")
t0 = time.time()
_load_or_generate(benchmark, split)
log(benchmark, f"{split} cached in {time.time()-t0:.0f}s β {Path(path).name}")
results.append(split)
return f"{benchmark}: {', '.join(results)} β"
def prefetch_burgers() -> str:
"""Cache burgers_1d train+val to disk (prepare.py only does in-memory caching)."""
from data.prepare import (
_generate_dataset, _load_or_gen_val, N_TRAIN, N_VAL, TRAIN_SEED, CACHE_DIR, GRID_SIZE
)
cache_train = os.path.join(CACHE_DIR, f"burgers_1d_train_N{GRID_SIZE}.npz")
# Val (prepare.py already caches this, just warm it)
log("burgers_1d", "warming val cacheβ¦")
_load_or_gen_val("burgers_1d")
log("burgers_1d", "val ready")
# Train
if os.path.exists(cache_train):
size_mb = Path(cache_train).stat().st_size // 1024 // 1024
log("burgers_1d", f"train already cached ({size_mb}MB)")
else:
log("burgers_1d", f"generating train set ({N_TRAIN} samples)β¦")
t0 = time.time()
inputs, targets = _generate_dataset("burgers_1d", N_TRAIN, TRAIN_SEED)
np.savez(cache_train, inputs=inputs, targets=targets)
log("burgers_1d", f"train cached in {time.time()-t0:.0f}s β {Path(cache_train).name}")
return "burgers_1d: val, train β"
# ββ Main ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
TASKS = [
# (label, fn, args)
("burgers_1d", prefetch_burgers, ()),
("kdv_1d", prefetch_ext, ("kdv_1d",)),
("wave_1d", prefetch_ext, ("wave_1d",)),
("darcy_2d", prefetch_ext, ("darcy_2d",)),
("ns_2d", prefetch_ext, ("ns_2d",)),
("ns_hre_2d", prefetch_ext, ("ns_hre_2d",)),
("swe_2d", prefetch_ext, ("swe_2d",)),
("allen_cahn_2d", prefetch_ext, ("allen_cahn_2d",)),
("mhd_2d", prefetch_ext, ("mhd_2d",)),
("euler_1d", prefetch_sim, ("euler_1d",)),
]
# ns_hre_2d takes 10-20 min and is CPU-bound; run it alone so it gets full CPU
SERIAL_TASKS = {"ns_hre_2d"}
PARALLEL_TASKS = [t for t in TASKS if t[0] not in SERIAL_TASKS]
SLOW_TASKS = [t for t in TASKS if t[0] in SERIAL_TASKS]
if __name__ == "__main__":
skip_slow = "--skip-slow" in sys.argv
print(f"\nPrefetching all PDE datasets β {CACHE_DIR}")
print(f"Parallel: {[t[0] for t in PARALLEL_TASKS]}")
if skip_slow:
print(f"Skipping slow: {[t[0] for t in SLOW_TASKS]} (pass without --skip-slow to include)")
else:
print(f"Serial (slow): {[t[0] for t in SLOW_TASKS]}")
print()
t_total = time.time()
results = []
# Run fast benchmarks in parallel
with ThreadPoolExecutor(max_workers=len(PARALLEL_TASKS)) as pool:
futures = {pool.submit(fn, *args): label for label, fn, args in PARALLEL_TASKS}
for fut in as_completed(futures):
try:
results.append(fut.result())
except Exception as e:
label = futures[fut]
results.append(f"{label}: ERROR β {e}")
# Run slow benchmarks serially (give them full CPU)
if not skip_slow:
for label, fn, args in SLOW_TASKS:
try:
results.append(fn(*args))
except Exception as e:
results.append(f"{label}: ERROR β {e}")
print(f"\n{'β'*60}")
print(f"Done in {time.time()-t_total:.0f}s\n")
for r in results:
print(f" {r}")
print(f"\nAll datasets cached. Experiments will now load data in <5s.")
|