Simo76's picture
Upload 9 files
493de78 verified
import os
# Force single-threaded BLAS for reproducible tests (user can override).
os.environ.setdefault("OMP_NUM_THREADS", "1")
os.environ.setdefault("MKL_NUM_THREADS", "1")
xp_backend = "numpy"
_rng = None
try:
import cupy as xp # CuPy provides a NumPy-compatible API
xp_backend = "cupy"
_rng = xp.random.Generator(xp.random.Philox(0))
except Exception:
import numpy as xp
xp_backend = "numpy"
_rng = xp.random.Generator(xp.random.PCG64(0))
def set_seed(seed: int):
global _rng
if xp_backend == "cupy":
_rng = xp.random.Generator(xp.random.Philox(int(seed)))
else:
_rng = xp.random.Generator(xp.random.PCG64(int(seed)))
def backend_name() -> str:
return xp_backend
def as_xp(a):
return xp.asarray(a)
def to_cpu(a):
try:
import cupy as cp
if isinstance(a, cp.ndarray):
return xp.asnumpy(a)
except Exception:
pass
return a
def randn(shape, dtype):
# Deterministic normal sampler using the current _rng
# CuPy: xp.random.Generator.normal exists
if hasattr(_rng, "normal"):
return _rng.normal(loc=0.0, scale=1.0, size=shape).astype(getattr(xp, dtype))
# Fallback
return xp.asarray(xp.random.randn(*shape), dtype=getattr(xp, dtype))