from keras.src import backend def is_in_jax_tracing_scope(x=None): if backend.backend() == "jax": if x is None: x = backend.numpy.ones(()) for c in x.__class__.__mro__: if c.__name__ == "Tracer" and c.__module__.startswith("jax"): return True return False