File size: 4,034 Bytes
9f104bf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
"""Quickstart: Gaussian likelihood for one Gaia photometric measurement.

This is a compact reference script for collaborators loading the released
isochrone emulator from Hugging Face. It follows the same pattern as the main
README examples:

1. download a bundle with ``Emulator.from_pretrained(...)``;
2. freeze a JAX callable with ``make_frozen_apply(jit=False)``;
3. explicitly normalize physical inputs into the bundle's canonical space;
4. explicitly denormalize canonical outputs back to physical magnitudes;
5. evaluate a simple diagonal Gaussian log likelihood in one outer jit.

The input vector is ``[log10_age_yr, eep, feh]``. The measurement vector is
absolute ``[G_mag, BP_mag, RP_mag]``. If your data are apparent magnitudes,
include distance modulus, extinction, calibration offsets, or other nuisance
terms in your own likelihood around the emulator prediction.

Related examples in the source repository:

- examples/basic/02_load_bundle_predict.py
- examples/basic/04_use_bundle_in_map_fit.py
"""

from __future__ import annotations

import jax
import jax.numpy as jnp
import numpy as np

from astro_emulators_toolkit import Emulator, denormalize_tree, normalize_tree

REPO_ID = "RozanskiT/isochrones-mlp"
REVISION = None
CACHE_DIR = ".emuspec_cache"

OUTPUT_LEAF = "magnitudes"
OUTPUT_CHANNELS = ("G_mag", "BP_mag", "RP_mag")

# One trial isochrone point: log10(age [yr]), EEP, [Fe/H].
THETA_PHYSICAL = np.asarray([9.4, 300.0, 0.0], dtype=np.float32)

# One example absolute photometric measurement in G, BP, RP.
# Replace these with your own absolute magnitudes and uncertainties.
OBSERVED_MAGNITUDES = np.asarray([6.94, 7.52, 6.21], dtype=np.float32)
OBSERVED_SIGMA_MAG = np.asarray([0.03, 0.03, 0.03], dtype=np.float32)


def main() -> None:
    emu = Emulator.from_pretrained(
        REPO_ID,
        revision=REVISION,
        cache_dir=CACHE_DIR,
        verbose=True,
    )
    apply_magnitudes = emu.make_frozen_apply(jit=False)

    ref_inputs = emu.reference_scaling_inputs
    ref_outputs = emu.reference_scaling_outputs
    if ref_inputs is None or ref_outputs is None:
        raise ValueError(
            "This likelihood example requires reference_scaling_inputs and "
            "reference_scaling_outputs in the bundle metadata."
        )

    y_obs = jnp.asarray(OBSERVED_MAGNITUDES, dtype=jnp.float32)
    y_err = jnp.asarray(OBSERVED_SIGMA_MAG, dtype=jnp.float32)

    def predict_magnitudes(theta):
        """Predict physical magnitudes; jit the outer objective, not this helper."""
        x_physical = {"parameters": theta[None, :]}
        x_scaled = normalize_tree(
            x_physical,
            ref_inputs["min_tree"],
            ref_inputs["max_tree"],
        )
        y_scaled = apply_magnitudes(x_scaled)
        y_physical = denormalize_tree(
            y_scaled,
            ref_outputs["min_tree"],
            ref_outputs["max_tree"],
        )
        return y_physical[OUTPUT_LEAF][0]

    @jax.jit
    def evaluate_likelihood(theta):
        y_model = predict_magnitudes(theta)
        resid = (y_obs - y_model) / y_err
        log_norm = jnp.sum(jnp.log(2.0 * jnp.pi * y_err**2))
        log_likelihood = -0.5 * (jnp.sum(resid**2) + log_norm)
        return y_model, log_likelihood

    theta = jnp.asarray(THETA_PHYSICAL, dtype=jnp.float32)
    model_magnitudes_jax, logp_jax = evaluate_likelihood(theta)
    model_magnitudes = np.asarray(jax.block_until_ready(model_magnitudes_jax))
    logp = float(jax.block_until_ready(logp_jax))

    print("theta_physical [age, eep, feh]:", THETA_PHYSICAL.tolist())
    print("model absolute magnitudes:")
    for name, value in zip(OUTPUT_CHANNELS, model_magnitudes, strict=True):
        print(f"  {name}: {value:.6f}")
    print("observed absolute magnitudes:")
    for name, value in zip(OUTPUT_CHANNELS, OBSERVED_MAGNITUDES, strict=True):
        print(f"  {name}: {value:.6f}")
    print("sigma_mag:", OBSERVED_SIGMA_MAG.tolist())
    print("log_likelihood:", f"{logp:.6f}")


if __name__ == "__main__":
    main()