| from keras.src import backend | |
| def is_in_jax_tracing_scope(x=None): | |
| if backend.backend() == "jax": | |
| if x is None: | |
| x = backend.numpy.ones(()) | |
| for c in x.__class__.__mro__: | |
| if c.__name__ == "Tracer" and c.__module__.startswith("jax"): | |
| return True | |
| return False | |