"""Quickstart: Gaussian likelihood for one Gaia photometric measurement. This is a compact reference script for collaborators loading the released isochrone emulator from Hugging Face. It follows the same pattern as the main README examples: 1. download a bundle with ``Emulator.from_pretrained(...)``; 2. freeze a JAX callable with ``make_frozen_apply(jit=False)``; 3. explicitly normalize physical inputs into the bundle's canonical space; 4. explicitly denormalize canonical outputs back to physical magnitudes; 5. evaluate a simple diagonal Gaussian log likelihood in one outer jit. The input vector is ``[log10_age_yr, eep, feh]``. The measurement vector is absolute ``[G_mag, BP_mag, RP_mag]``. If your data are apparent magnitudes, include distance modulus, extinction, calibration offsets, or other nuisance terms in your own likelihood around the emulator prediction. Related examples in the source repository: - examples/basic/02_load_bundle_predict.py - examples/basic/04_use_bundle_in_map_fit.py """ from __future__ import annotations import jax import jax.numpy as jnp import numpy as np from astro_emulators_toolkit import Emulator, denormalize_tree, normalize_tree REPO_ID = "RozanskiT/isochrones-mlp" REVISION = None CACHE_DIR = ".emuspec_cache" OUTPUT_LEAF = "magnitudes" OUTPUT_CHANNELS = ("G_mag", "BP_mag", "RP_mag") # One trial isochrone point: log10(age [yr]), EEP, [Fe/H]. THETA_PHYSICAL = np.asarray([9.4, 300.0, 0.0], dtype=np.float32) # One example absolute photometric measurement in G, BP, RP. # Replace these with your own absolute magnitudes and uncertainties. OBSERVED_MAGNITUDES = np.asarray([6.94, 7.52, 6.21], dtype=np.float32) OBSERVED_SIGMA_MAG = np.asarray([0.03, 0.03, 0.03], dtype=np.float32) def main() -> None: emu = Emulator.from_pretrained( REPO_ID, revision=REVISION, cache_dir=CACHE_DIR, verbose=True, ) apply_magnitudes = emu.make_frozen_apply(jit=False) ref_inputs = emu.reference_scaling_inputs ref_outputs = emu.reference_scaling_outputs if ref_inputs is None or ref_outputs is None: raise ValueError( "This likelihood example requires reference_scaling_inputs and " "reference_scaling_outputs in the bundle metadata." ) y_obs = jnp.asarray(OBSERVED_MAGNITUDES, dtype=jnp.float32) y_err = jnp.asarray(OBSERVED_SIGMA_MAG, dtype=jnp.float32) def predict_magnitudes(theta): """Predict physical magnitudes; jit the outer objective, not this helper.""" x_physical = {"parameters": theta[None, :]} x_scaled = normalize_tree( x_physical, ref_inputs["min_tree"], ref_inputs["max_tree"], ) y_scaled = apply_magnitudes(x_scaled) y_physical = denormalize_tree( y_scaled, ref_outputs["min_tree"], ref_outputs["max_tree"], ) return y_physical[OUTPUT_LEAF][0] @jax.jit def evaluate_likelihood(theta): y_model = predict_magnitudes(theta) resid = (y_obs - y_model) / y_err log_norm = jnp.sum(jnp.log(2.0 * jnp.pi * y_err**2)) log_likelihood = -0.5 * (jnp.sum(resid**2) + log_norm) return y_model, log_likelihood theta = jnp.asarray(THETA_PHYSICAL, dtype=jnp.float32) model_magnitudes_jax, logp_jax = evaluate_likelihood(theta) model_magnitudes = np.asarray(jax.block_until_ready(model_magnitudes_jax)) logp = float(jax.block_until_ready(logp_jax)) print("theta_physical [age, eep, feh]:", THETA_PHYSICAL.tolist()) print("model absolute magnitudes:") for name, value in zip(OUTPUT_CHANNELS, model_magnitudes, strict=True): print(f" {name}: {value:.6f}") print("observed absolute magnitudes:") for name, value in zip(OUTPUT_CHANNELS, OBSERVED_MAGNITUDES, strict=True): print(f" {name}: {value:.6f}") print("sigma_mag:", OBSERVED_SIGMA_MAG.tolist()) print("log_likelihood:", f"{logp:.6f}") if __name__ == "__main__": main()