|
|
--- |
|
|
license: mit |
|
|
--- |
|
|
``` |
|
|
from pathlib import Path |
|
|
|
|
|
import numpy as np |
|
|
import matplotlib.pyplot as plt |
|
|
import jax |
|
|
import jax.numpy as jnp |
|
|
import time |
|
|
|
|
|
from astro_emulators_toolkit import Emulator |
|
|
|
|
|
script_dir = Path(__file__).parent.resolve() |
|
|
|
|
|
# ------------------------------------------------------------------------------ |
|
|
# Model description and data scaling info for physical prediction |
|
|
# ------------------------------------------------------------------------------ |
|
|
|
|
|
DEFAULT_INPUTS = ("age", "eep", "feh") |
|
|
DEFAULT_TARGETS = ("G_mag", "BP_mag", "RP_mag") |
|
|
|
|
|
MIN_VAL = np.array([5.8619833, 202.0, -0.87977487, -2.3778718, -2.4398916, -2.2926207], dtype=np.float32) |
|
|
MAX_VAL = np.array([1.02993574e01, 4.54000000e02, 5.95229030e-01, 1.50175705e01, 1.84394169e01, 1.36201954e01], dtype=np.float32) |
|
|
|
|
|
|
|
|
# ------------------------------------------------------------------------------ |
|
|
# Load pretrained emulator bundle from Hugging Face and build a physical predictor |
|
|
# ------------------------------------------------------------------------------ |
|
|
|
|
|
print("Attempting to load pretrained emulator bundle from Hugging Face...") |
|
|
repo_id = "RozanskiT/isochrones-mlp" |
|
|
try: |
|
|
emu = Emulator.from_pretrained( |
|
|
repo_id, |
|
|
cache_dir=script_dir / ".emuspec_cache", |
|
|
) |
|
|
print(f"Loaded pretrained emulator from Hugging Face: {repo_id}") |
|
|
except Exception as exc: |
|
|
print(f"Hugging Face load failed ({exc}).") |
|
|
|
|
|
|
|
|
# ------------------------------------------------------------------------------ |
|
|
# Build a physical predictor that scales inputs and applies the frozen model |
|
|
# ------------------------------------------------------------------------------ |
|
|
|
|
|
def build_physical_predictor(emu: Emulator): |
|
|
"""Return a jitted predictor that scales physical inputs then applies frozen model.""" |
|
|
|
|
|
frozen_apply = emu.make_frozen_apply_fn(postprocess=True, jit=False) |
|
|
x_min = jax.device_put(MIN_VAL[:3]) |
|
|
x_scale = jax.device_put(MAX_VAL[:3] - MIN_VAL[:3]) |
|
|
y_min = jax.device_put(MIN_VAL[3:]) |
|
|
y_scale = jax.device_put(MAX_VAL[3:] - MIN_VAL[3:]) |
|
|
|
|
|
@jax.jit |
|
|
def predict_physical(x_physical): |
|
|
x_norm = (x_physical - x_min) / x_scale |
|
|
y_norm = frozen_apply(x_norm) |
|
|
return y_norm * y_scale + y_min |
|
|
|
|
|
return predict_physical |
|
|
|
|
|
predict_physical = build_physical_predictor(emu) |
|
|
|
|
|
|
|
|
# ------------------------------------------------------------------------------ |
|
|
# Make some physical inputs |
|
|
# ------------------------------------------------------------------------------ |
|
|
|
|
|
no_points = 1000 |
|
|
batch_of_predictions = np.zeros((no_points, 3)) # dummy batch of 10 input points with 3 features (age, eep, feh) |
|
|
batch_of_predictions[:,0] = 9.4 # age |
|
|
batch_of_predictions[:,1] = np.linspace(202, 454, no_points) # eep |
|
|
batch_of_predictions[:,2] = 0.0 # feh |
|
|
|
|
|
# simplified check of domain: |
|
|
assert np.all(batch_of_predictions[:, 0] >= MIN_VAL[0]) and np.all(batch_of_predictions[:, 0] <= MAX_VAL[0]), "Age out of domain" |
|
|
assert np.all(batch_of_predictions[:, 1] >= MIN_VAL[1]) and np.all(batch_of_predictions[:, 1] <= MAX_VAL[1]), "EEP out of domain" |
|
|
assert np.all(batch_of_predictions[:, 2] >= MIN_VAL[2]) and np.all(batch_of_predictions[:, 2] <= MAX_VAL[2]), "FeH out of domain" |
|
|
|
|
|
# move to jax (eg. GPU when availible) |
|
|
batch_of_predictions = jnp.array(batch_of_predictions) |
|
|
|
|
|
# ------------------------------------------------------------------------------ |
|
|
# Create a predictor that uses the frozen model but scales physical inputs, then predict on the batch and time it |
|
|
# This could be extended to include a distance, extinction, by analitical model |
|
|
# ------------------------------------------------------------------------------ |
|
|
|
|
|
# A bit of timing info to see how fast the predictions are after the initial compilation. |
|
|
t0 = time.perf_counter() |
|
|
y_pred_first = predict_physical(batch_of_predictions) |
|
|
y_pred_first = np.asarray(jax.block_until_ready(y_pred_first)) |
|
|
t1 = time.perf_counter() |
|
|
|
|
|
y_pred_second = predict_physical(batch_of_predictions) |
|
|
y_pred_second = np.asarray(jax.block_until_ready(y_pred_second)) |
|
|
t2 = time.perf_counter() |
|
|
|
|
|
# Summarize timings and prediction shape |
|
|
print(f"First call (compile + run): {t1 - t0:.6f} s") |
|
|
print(f"Second call (run only): {t2 - t1:.6f} s") |
|
|
print(f"Predictions size: {y_pred_second.shape}") |
|
|
|
|
|
|
|
|
# color-magnitude at the left and magntude vs step at the right |
|
|
fig, axs = plt.subplots(1, 2, figsize=(12, 4)) |
|
|
ax_cmd = axs[0] |
|
|
ax_cmd.scatter(y_pred_second[:, 1] - y_pred_second[:, 2], y_pred_second[:, 0], s=18, alpha=0.8, color="tab:orange") |
|
|
ax_cmd.set_xlabel("BP - RP") |
|
|
ax_cmd.set_ylabel("G") |
|
|
ax_cmd.set_title(f"CMD (batch of {no_points} predictions)") |
|
|
ax_cmd.grid(alpha=0.25) |
|
|
ax_cmd.invert_yaxis() # Magnitudes are brighter when smaller, so invert y-axis for CMD |
|
|
|
|
|
ax_step = axs[1] |
|
|
for i in range(y_pred_second.shape[1]): |
|
|
ax_step.plot(y_pred_second[:, i], "-", color="tab:orange", alpha=0.9, label=f"Pred {DEFAULT_TARGETS[i]}") |
|
|
ax_step.set_xlabel("Batch Index") |
|
|
ax_step.set_ylabel("Magnitude") |
|
|
ax_step.set_title(f"Predicted {DEFAULT_TARGETS[i]} vs Batch Index") |
|
|
ax_step.legend() |
|
|
ax_step.grid(alpha=0.25) |
|
|
plt.tight_layout() |
|
|
plt.show() |
|
|
``` |