File size: 1,274 Bytes
493de78
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import os
# Force single-threaded BLAS for reproducible tests (user can override).
os.environ.setdefault("OMP_NUM_THREADS", "1")
os.environ.setdefault("MKL_NUM_THREADS", "1")

xp_backend = "numpy"
_rng = None

try:
    import cupy as xp  # CuPy provides a NumPy-compatible API
    xp_backend = "cupy"
    _rng = xp.random.Generator(xp.random.Philox(0))
except Exception:
    import numpy as xp
    xp_backend = "numpy"
    _rng = xp.random.Generator(xp.random.PCG64(0))

def set_seed(seed: int):
    global _rng
    if xp_backend == "cupy":
        _rng = xp.random.Generator(xp.random.Philox(int(seed)))
    else:
        _rng = xp.random.Generator(xp.random.PCG64(int(seed)))

def backend_name() -> str:
    return xp_backend

def as_xp(a):
    return xp.asarray(a)

def to_cpu(a):
    try:
        import cupy as cp
        if isinstance(a, cp.ndarray):
            return xp.asnumpy(a)
    except Exception:
        pass
    return a

def randn(shape, dtype):
    # Deterministic normal sampler using the current _rng
    # CuPy: xp.random.Generator.normal exists
    if hasattr(_rng, "normal"):
        return _rng.normal(loc=0.0, scale=1.0, size=shape).astype(getattr(xp, dtype))
    # Fallback
    return xp.asarray(xp.random.randn(*shape), dtype=getattr(xp, dtype))