isochrones-mlp / README.md
RozanskiT's picture
Update README.md
bf7cd89 verified
metadata
license: mit
from pathlib import Path

import numpy as np
import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
import time 

from astro_emulators_toolkit import Emulator

script_dir = Path(__file__).parent.resolve()

# ------------------------------------------------------------------------------
# Model description and data scaling info for physical prediction
# ------------------------------------------------------------------------------

DEFAULT_INPUTS = ("age", "eep", "feh")
DEFAULT_TARGETS = ("G_mag", "BP_mag", "RP_mag")

MIN_VAL = np.array([5.8619833, 202.0, -0.87977487, -2.3778718, -2.4398916, -2.2926207], dtype=np.float32)
MAX_VAL = np.array([1.02993574e01, 4.54000000e02, 5.95229030e-01, 1.50175705e01, 1.84394169e01, 1.36201954e01], dtype=np.float32)


# ------------------------------------------------------------------------------
# Load pretrained emulator bundle from Hugging Face and build a physical predictor
# ------------------------------------------------------------------------------

print("Attempting to load pretrained emulator bundle from Hugging Face...")
repo_id = "RozanskiT/isochrones-mlp"
try:
    emu = Emulator.from_pretrained(
        repo_id,
        cache_dir=script_dir / ".emuspec_cache",
    )
    print(f"Loaded pretrained emulator from Hugging Face: {repo_id}")
except Exception as exc:
    print(f"Hugging Face load failed ({exc}).")


# ------------------------------------------------------------------------------
# Build a physical predictor that scales inputs and applies the frozen model
# ------------------------------------------------------------------------------

def build_physical_predictor(emu: Emulator):
    """Return a jitted predictor that scales physical inputs then applies frozen model."""

    frozen_apply = emu.make_frozen_apply_fn(postprocess=True, jit=False)
    x_min = jax.device_put(MIN_VAL[:3])
    x_scale = jax.device_put(MAX_VAL[:3] - MIN_VAL[:3])
    y_min = jax.device_put(MIN_VAL[3:])
    y_scale = jax.device_put(MAX_VAL[3:] - MIN_VAL[3:])

    @jax.jit
    def predict_physical(x_physical):
        x_norm = (x_physical - x_min) / x_scale
        y_norm = frozen_apply(x_norm)
        return y_norm * y_scale + y_min

    return predict_physical

predict_physical = build_physical_predictor(emu)


# ------------------------------------------------------------------------------
# Make some physical inputs
# ------------------------------------------------------------------------------

no_points = 1000
batch_of_predictions = np.zeros((no_points, 3))  # dummy batch of 10 input points with 3 features (age, eep, feh)
batch_of_predictions[:,0] = 9.4  # age
batch_of_predictions[:,1] = np.linspace(202, 454, no_points)  # eep
batch_of_predictions[:,2] = 0.0  # feh

# simplified check of domain:
assert np.all(batch_of_predictions[:, 0] >= MIN_VAL[0]) and np.all(batch_of_predictions[:, 0] <= MAX_VAL[0]), "Age out of domain"
assert np.all(batch_of_predictions[:, 1] >= MIN_VAL[1]) and np.all(batch_of_predictions[:, 1] <= MAX_VAL[1]), "EEP out of domain"
assert np.all(batch_of_predictions[:, 2] >= MIN_VAL[2]) and np.all(batch_of_predictions[:, 2] <= MAX_VAL[2]), "FeH out of domain"

# move to jax (eg. GPU when availible)
batch_of_predictions = jnp.array(batch_of_predictions)

# ------------------------------------------------------------------------------
# Create a predictor that uses the frozen model but scales physical inputs, then predict on the batch and time it
# This could be extended to include a distance, extinction, by analitical model
# ------------------------------------------------------------------------------

# A bit of timing info to see how fast the predictions are after the initial compilation.
t0 = time.perf_counter()
y_pred_first = predict_physical(batch_of_predictions)
y_pred_first = np.asarray(jax.block_until_ready(y_pred_first))
t1 = time.perf_counter()

y_pred_second = predict_physical(batch_of_predictions)
y_pred_second = np.asarray(jax.block_until_ready(y_pred_second))
t2 = time.perf_counter()

# Summarize timings and prediction shape
print(f"First call (compile + run): {t1 - t0:.6f} s")
print(f"Second call (run only): {t2 - t1:.6f} s")
print(f"Predictions size: {y_pred_second.shape}")


# color-magnitude at the left and magntude vs step at the right
fig, axs = plt.subplots(1, 2, figsize=(12, 4))
ax_cmd = axs[0]
ax_cmd.scatter(y_pred_second[:, 1] - y_pred_second[:, 2], y_pred_second[:, 0], s=18, alpha=0.8, color="tab:orange")
ax_cmd.set_xlabel("BP - RP")
ax_cmd.set_ylabel("G")
ax_cmd.set_title(f"CMD (batch of {no_points} predictions)")
ax_cmd.grid(alpha=0.25)
ax_cmd.invert_yaxis()  # Magnitudes are brighter when smaller, so invert y-axis for CMD

ax_step = axs[1]
for i in range(y_pred_second.shape[1]):
    ax_step.plot(y_pred_second[:, i], "-", color="tab:orange", alpha=0.9, label=f"Pred {DEFAULT_TARGETS[i]}")
ax_step.set_xlabel("Batch Index")
ax_step.set_ylabel("Magnitude")
ax_step.set_title(f"Predicted {DEFAULT_TARGETS[i]} vs Batch Index")
ax_step.legend()
ax_step.grid(alpha=0.25)
plt.tight_layout()
plt.show()