import os # Force single-threaded BLAS for reproducible tests (user can override). os.environ.setdefault("OMP_NUM_THREADS", "1") os.environ.setdefault("MKL_NUM_THREADS", "1") xp_backend = "numpy" _rng = None try: import cupy as xp # CuPy provides a NumPy-compatible API xp_backend = "cupy" _rng = xp.random.Generator(xp.random.Philox(0)) except Exception: import numpy as xp xp_backend = "numpy" _rng = xp.random.Generator(xp.random.PCG64(0)) def set_seed(seed: int): global _rng if xp_backend == "cupy": _rng = xp.random.Generator(xp.random.Philox(int(seed))) else: _rng = xp.random.Generator(xp.random.PCG64(int(seed))) def backend_name() -> str: return xp_backend def as_xp(a): return xp.asarray(a) def to_cpu(a): try: import cupy as cp if isinstance(a, cp.ndarray): return xp.asnumpy(a) except Exception: pass return a def randn(shape, dtype): # Deterministic normal sampler using the current _rng # CuPy: xp.random.Generator.normal exists if hasattr(_rng, "normal"): return _rng.normal(loc=0.0, scale=1.0, size=shape).astype(getattr(xp, dtype)) # Fallback return xp.asarray(xp.random.randn(*shape), dtype=getattr(xp, dtype))