| """Minimal stub for jax to avoid import errors during static checks. | |
| This is a lightweight shim; it does NOT implement real JAX functionality. | |
| """ | |
| from typing import Any | |
| def numpy(): | |
| import numpy as _np | |
| return _np | |
| def device_put(x: Any, device: Any = None) -> Any: | |
| return x | |
| class lax: | |
| def add(a, b): | |
| return a + b | |