--- license: mit --- ``` from pathlib import Path import numpy as np import matplotlib.pyplot as plt import jax import jax.numpy as jnp import time from astro_emulators_toolkit import Emulator script_dir = Path(__file__).parent.resolve() # ------------------------------------------------------------------------------ # Model description and data scaling info for physical prediction # ------------------------------------------------------------------------------ DEFAULT_INPUTS = ("age", "eep", "feh") DEFAULT_TARGETS = ("G_mag", "BP_mag", "RP_mag") MIN_VAL = np.array([5.8619833, 202.0, -0.87977487, -2.3778718, -2.4398916, -2.2926207], dtype=np.float32) MAX_VAL = np.array([1.02993574e01, 4.54000000e02, 5.95229030e-01, 1.50175705e01, 1.84394169e01, 1.36201954e01], dtype=np.float32) # ------------------------------------------------------------------------------ # Load pretrained emulator bundle from Hugging Face and build a physical predictor # ------------------------------------------------------------------------------ print("Attempting to load pretrained emulator bundle from Hugging Face...") repo_id = "RozanskiT/isochrones-mlp" try: emu = Emulator.from_pretrained( repo_id, cache_dir=script_dir / ".emuspec_cache", ) print(f"Loaded pretrained emulator from Hugging Face: {repo_id}") except Exception as exc: print(f"Hugging Face load failed ({exc}).") # ------------------------------------------------------------------------------ # Build a physical predictor that scales inputs and applies the frozen model # ------------------------------------------------------------------------------ def build_physical_predictor(emu: Emulator): """Return a jitted predictor that scales physical inputs then applies frozen model.""" frozen_apply = emu.make_frozen_apply_fn(postprocess=True, jit=False) x_min = jax.device_put(MIN_VAL[:3]) x_scale = jax.device_put(MAX_VAL[:3] - MIN_VAL[:3]) y_min = jax.device_put(MIN_VAL[3:]) y_scale = jax.device_put(MAX_VAL[3:] - MIN_VAL[3:]) @jax.jit def predict_physical(x_physical): x_norm = (x_physical - x_min) / x_scale y_norm = frozen_apply(x_norm) return y_norm * y_scale + y_min return predict_physical predict_physical = build_physical_predictor(emu) # ------------------------------------------------------------------------------ # Make some physical inputs # ------------------------------------------------------------------------------ no_points = 1000 batch_of_predictions = np.zeros((no_points, 3)) # dummy batch of 10 input points with 3 features (age, eep, feh) batch_of_predictions[:,0] = 9.4 # age batch_of_predictions[:,1] = np.linspace(202, 454, no_points) # eep batch_of_predictions[:,2] = 0.0 # feh # simplified check of domain: assert np.all(batch_of_predictions[:, 0] >= MIN_VAL[0]) and np.all(batch_of_predictions[:, 0] <= MAX_VAL[0]), "Age out of domain" assert np.all(batch_of_predictions[:, 1] >= MIN_VAL[1]) and np.all(batch_of_predictions[:, 1] <= MAX_VAL[1]), "EEP out of domain" assert np.all(batch_of_predictions[:, 2] >= MIN_VAL[2]) and np.all(batch_of_predictions[:, 2] <= MAX_VAL[2]), "FeH out of domain" # move to jax (eg. GPU when availible) batch_of_predictions = jnp.array(batch_of_predictions) # ------------------------------------------------------------------------------ # Create a predictor that uses the frozen model but scales physical inputs, then predict on the batch and time it # This could be extended to include a distance, extinction, by analitical model # ------------------------------------------------------------------------------ # A bit of timing info to see how fast the predictions are after the initial compilation. t0 = time.perf_counter() y_pred_first = predict_physical(batch_of_predictions) y_pred_first = np.asarray(jax.block_until_ready(y_pred_first)) t1 = time.perf_counter() y_pred_second = predict_physical(batch_of_predictions) y_pred_second = np.asarray(jax.block_until_ready(y_pred_second)) t2 = time.perf_counter() # Summarize timings and prediction shape print(f"First call (compile + run): {t1 - t0:.6f} s") print(f"Second call (run only): {t2 - t1:.6f} s") print(f"Predictions size: {y_pred_second.shape}") # color-magnitude at the left and magntude vs step at the right fig, axs = plt.subplots(1, 2, figsize=(12, 4)) ax_cmd = axs[0] ax_cmd.scatter(y_pred_second[:, 1] - y_pred_second[:, 2], y_pred_second[:, 0], s=18, alpha=0.8, color="tab:orange") ax_cmd.set_xlabel("BP - RP") ax_cmd.set_ylabel("G") ax_cmd.set_title(f"CMD (batch of {no_points} predictions)") ax_cmd.grid(alpha=0.25) ax_cmd.invert_yaxis() # Magnitudes are brighter when smaller, so invert y-axis for CMD ax_step = axs[1] for i in range(y_pred_second.shape[1]): ax_step.plot(y_pred_second[:, i], "-", color="tab:orange", alpha=0.9, label=f"Pred {DEFAULT_TARGETS[i]}") ax_step.set_xlabel("Batch Index") ax_step.set_ylabel("Magnitude") ax_step.set_title(f"Predicted {DEFAULT_TARGETS[i]} vs Batch Index") ax_step.legend() ax_step.grid(alpha=0.25) plt.tight_layout() plt.show() ```