isochrones-mlp / README.md
RozanskiT's picture
Update README.md
bf7cd89 verified
---
license: mit
---
```
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp
import time
from astro_emulators_toolkit import Emulator
script_dir = Path(__file__).parent.resolve()
# ------------------------------------------------------------------------------
# Model description and data scaling info for physical prediction
# ------------------------------------------------------------------------------
DEFAULT_INPUTS = ("age", "eep", "feh")
DEFAULT_TARGETS = ("G_mag", "BP_mag", "RP_mag")
MIN_VAL = np.array([5.8619833, 202.0, -0.87977487, -2.3778718, -2.4398916, -2.2926207], dtype=np.float32)
MAX_VAL = np.array([1.02993574e01, 4.54000000e02, 5.95229030e-01, 1.50175705e01, 1.84394169e01, 1.36201954e01], dtype=np.float32)
# ------------------------------------------------------------------------------
# Load pretrained emulator bundle from Hugging Face and build a physical predictor
# ------------------------------------------------------------------------------
print("Attempting to load pretrained emulator bundle from Hugging Face...")
repo_id = "RozanskiT/isochrones-mlp"
try:
emu = Emulator.from_pretrained(
repo_id,
cache_dir=script_dir / ".emuspec_cache",
)
print(f"Loaded pretrained emulator from Hugging Face: {repo_id}")
except Exception as exc:
print(f"Hugging Face load failed ({exc}).")
# ------------------------------------------------------------------------------
# Build a physical predictor that scales inputs and applies the frozen model
# ------------------------------------------------------------------------------
def build_physical_predictor(emu: Emulator):
"""Return a jitted predictor that scales physical inputs then applies frozen model."""
frozen_apply = emu.make_frozen_apply_fn(postprocess=True, jit=False)
x_min = jax.device_put(MIN_VAL[:3])
x_scale = jax.device_put(MAX_VAL[:3] - MIN_VAL[:3])
y_min = jax.device_put(MIN_VAL[3:])
y_scale = jax.device_put(MAX_VAL[3:] - MIN_VAL[3:])
@jax.jit
def predict_physical(x_physical):
x_norm = (x_physical - x_min) / x_scale
y_norm = frozen_apply(x_norm)
return y_norm * y_scale + y_min
return predict_physical
predict_physical = build_physical_predictor(emu)
# ------------------------------------------------------------------------------
# Make some physical inputs
# ------------------------------------------------------------------------------
no_points = 1000
batch_of_predictions = np.zeros((no_points, 3)) # dummy batch of 10 input points with 3 features (age, eep, feh)
batch_of_predictions[:,0] = 9.4 # age
batch_of_predictions[:,1] = np.linspace(202, 454, no_points) # eep
batch_of_predictions[:,2] = 0.0 # feh
# simplified check of domain:
assert np.all(batch_of_predictions[:, 0] >= MIN_VAL[0]) and np.all(batch_of_predictions[:, 0] <= MAX_VAL[0]), "Age out of domain"
assert np.all(batch_of_predictions[:, 1] >= MIN_VAL[1]) and np.all(batch_of_predictions[:, 1] <= MAX_VAL[1]), "EEP out of domain"
assert np.all(batch_of_predictions[:, 2] >= MIN_VAL[2]) and np.all(batch_of_predictions[:, 2] <= MAX_VAL[2]), "FeH out of domain"
# move to jax (eg. GPU when availible)
batch_of_predictions = jnp.array(batch_of_predictions)
# ------------------------------------------------------------------------------
# Create a predictor that uses the frozen model but scales physical inputs, then predict on the batch and time it
# This could be extended to include a distance, extinction, by analitical model
# ------------------------------------------------------------------------------
# A bit of timing info to see how fast the predictions are after the initial compilation.
t0 = time.perf_counter()
y_pred_first = predict_physical(batch_of_predictions)
y_pred_first = np.asarray(jax.block_until_ready(y_pred_first))
t1 = time.perf_counter()
y_pred_second = predict_physical(batch_of_predictions)
y_pred_second = np.asarray(jax.block_until_ready(y_pred_second))
t2 = time.perf_counter()
# Summarize timings and prediction shape
print(f"First call (compile + run): {t1 - t0:.6f} s")
print(f"Second call (run only): {t2 - t1:.6f} s")
print(f"Predictions size: {y_pred_second.shape}")
# color-magnitude at the left and magntude vs step at the right
fig, axs = plt.subplots(1, 2, figsize=(12, 4))
ax_cmd = axs[0]
ax_cmd.scatter(y_pred_second[:, 1] - y_pred_second[:, 2], y_pred_second[:, 0], s=18, alpha=0.8, color="tab:orange")
ax_cmd.set_xlabel("BP - RP")
ax_cmd.set_ylabel("G")
ax_cmd.set_title(f"CMD (batch of {no_points} predictions)")
ax_cmd.grid(alpha=0.25)
ax_cmd.invert_yaxis() # Magnitudes are brighter when smaller, so invert y-axis for CMD
ax_step = axs[1]
for i in range(y_pred_second.shape[1]):
ax_step.plot(y_pred_second[:, i], "-", color="tab:orange", alpha=0.9, label=f"Pred {DEFAULT_TARGETS[i]}")
ax_step.set_xlabel("Batch Index")
ax_step.set_ylabel("Magnitude")
ax_step.set_title(f"Predicted {DEFAULT_TARGETS[i]} vs Batch Index")
ax_step.legend()
ax_step.grid(alpha=0.25)
plt.tight_layout()
plt.show()
```