isochrones-mlp / reference_likelihood.py
RozanskiT's picture
Replace repo contents
9f104bf verified
"""Quickstart: Gaussian likelihood for one Gaia photometric measurement.
This is a compact reference script for collaborators loading the released
isochrone emulator from Hugging Face. It follows the same pattern as the main
README examples:
1. download a bundle with ``Emulator.from_pretrained(...)``;
2. freeze a JAX callable with ``make_frozen_apply(jit=False)``;
3. explicitly normalize physical inputs into the bundle's canonical space;
4. explicitly denormalize canonical outputs back to physical magnitudes;
5. evaluate a simple diagonal Gaussian log likelihood in one outer jit.
The input vector is ``[log10_age_yr, eep, feh]``. The measurement vector is
absolute ``[G_mag, BP_mag, RP_mag]``. If your data are apparent magnitudes,
include distance modulus, extinction, calibration offsets, or other nuisance
terms in your own likelihood around the emulator prediction.
Related examples in the source repository:
- examples/basic/02_load_bundle_predict.py
- examples/basic/04_use_bundle_in_map_fit.py
"""
from __future__ import annotations
import jax
import jax.numpy as jnp
import numpy as np
from astro_emulators_toolkit import Emulator, denormalize_tree, normalize_tree
REPO_ID = "RozanskiT/isochrones-mlp"
REVISION = None
CACHE_DIR = ".emuspec_cache"
OUTPUT_LEAF = "magnitudes"
OUTPUT_CHANNELS = ("G_mag", "BP_mag", "RP_mag")
# One trial isochrone point: log10(age [yr]), EEP, [Fe/H].
THETA_PHYSICAL = np.asarray([9.4, 300.0, 0.0], dtype=np.float32)
# One example absolute photometric measurement in G, BP, RP.
# Replace these with your own absolute magnitudes and uncertainties.
OBSERVED_MAGNITUDES = np.asarray([6.94, 7.52, 6.21], dtype=np.float32)
OBSERVED_SIGMA_MAG = np.asarray([0.03, 0.03, 0.03], dtype=np.float32)
def main() -> None:
emu = Emulator.from_pretrained(
REPO_ID,
revision=REVISION,
cache_dir=CACHE_DIR,
verbose=True,
)
apply_magnitudes = emu.make_frozen_apply(jit=False)
ref_inputs = emu.reference_scaling_inputs
ref_outputs = emu.reference_scaling_outputs
if ref_inputs is None or ref_outputs is None:
raise ValueError(
"This likelihood example requires reference_scaling_inputs and "
"reference_scaling_outputs in the bundle metadata."
)
y_obs = jnp.asarray(OBSERVED_MAGNITUDES, dtype=jnp.float32)
y_err = jnp.asarray(OBSERVED_SIGMA_MAG, dtype=jnp.float32)
def predict_magnitudes(theta):
"""Predict physical magnitudes; jit the outer objective, not this helper."""
x_physical = {"parameters": theta[None, :]}
x_scaled = normalize_tree(
x_physical,
ref_inputs["min_tree"],
ref_inputs["max_tree"],
)
y_scaled = apply_magnitudes(x_scaled)
y_physical = denormalize_tree(
y_scaled,
ref_outputs["min_tree"],
ref_outputs["max_tree"],
)
return y_physical[OUTPUT_LEAF][0]
@jax.jit
def evaluate_likelihood(theta):
y_model = predict_magnitudes(theta)
resid = (y_obs - y_model) / y_err
log_norm = jnp.sum(jnp.log(2.0 * jnp.pi * y_err**2))
log_likelihood = -0.5 * (jnp.sum(resid**2) + log_norm)
return y_model, log_likelihood
theta = jnp.asarray(THETA_PHYSICAL, dtype=jnp.float32)
model_magnitudes_jax, logp_jax = evaluate_likelihood(theta)
model_magnitudes = np.asarray(jax.block_until_ready(model_magnitudes_jax))
logp = float(jax.block_until_ready(logp_jax))
print("theta_physical [age, eep, feh]:", THETA_PHYSICAL.tolist())
print("model absolute magnitudes:")
for name, value in zip(OUTPUT_CHANNELS, model_magnitudes, strict=True):
print(f" {name}: {value:.6f}")
print("observed absolute magnitudes:")
for name, value in zip(OUTPUT_CHANNELS, OBSERVED_MAGNITUDES, strict=True):
print(f" {name}: {value:.6f}")
print("sigma_mag:", OBSERVED_SIGMA_MAG.tolist())
print("log_likelihood:", f"{logp:.6f}")
if __name__ == "__main__":
main()