RozanskiT commited on
Commit
bf7cd89
·
verified ·
1 Parent(s): df51abf

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +127 -3
README.md CHANGED
@@ -1,3 +1,127 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ ---
4
+ ```
5
+ from pathlib import Path
6
+
7
+ import numpy as np
8
+ import matplotlib.pyplot as plt
9
+ import jax
10
+ import jax.numpy as jnp
11
+ import time
12
+
13
+ from astro_emulators_toolkit import Emulator
14
+
15
+ script_dir = Path(__file__).parent.resolve()
16
+
17
+ # ------------------------------------------------------------------------------
18
+ # Model description and data scaling info for physical prediction
19
+ # ------------------------------------------------------------------------------
20
+
21
+ DEFAULT_INPUTS = ("age", "eep", "feh")
22
+ DEFAULT_TARGETS = ("G_mag", "BP_mag", "RP_mag")
23
+
24
+ MIN_VAL = np.array([5.8619833, 202.0, -0.87977487, -2.3778718, -2.4398916, -2.2926207], dtype=np.float32)
25
+ MAX_VAL = np.array([1.02993574e01, 4.54000000e02, 5.95229030e-01, 1.50175705e01, 1.84394169e01, 1.36201954e01], dtype=np.float32)
26
+
27
+
28
+ # ------------------------------------------------------------------------------
29
+ # Load pretrained emulator bundle from Hugging Face and build a physical predictor
30
+ # ------------------------------------------------------------------------------
31
+
32
+ print("Attempting to load pretrained emulator bundle from Hugging Face...")
33
+ repo_id = "RozanskiT/isochrones-mlp"
34
+ try:
35
+ emu = Emulator.from_pretrained(
36
+ repo_id,
37
+ cache_dir=script_dir / ".emuspec_cache",
38
+ )
39
+ print(f"Loaded pretrained emulator from Hugging Face: {repo_id}")
40
+ except Exception as exc:
41
+ print(f"Hugging Face load failed ({exc}).")
42
+
43
+
44
+ # ------------------------------------------------------------------------------
45
+ # Build a physical predictor that scales inputs and applies the frozen model
46
+ # ------------------------------------------------------------------------------
47
+
48
+ def build_physical_predictor(emu: Emulator):
49
+ """Return a jitted predictor that scales physical inputs then applies frozen model."""
50
+
51
+ frozen_apply = emu.make_frozen_apply_fn(postprocess=True, jit=False)
52
+ x_min = jax.device_put(MIN_VAL[:3])
53
+ x_scale = jax.device_put(MAX_VAL[:3] - MIN_VAL[:3])
54
+ y_min = jax.device_put(MIN_VAL[3:])
55
+ y_scale = jax.device_put(MAX_VAL[3:] - MIN_VAL[3:])
56
+
57
+ @jax.jit
58
+ def predict_physical(x_physical):
59
+ x_norm = (x_physical - x_min) / x_scale
60
+ y_norm = frozen_apply(x_norm)
61
+ return y_norm * y_scale + y_min
62
+
63
+ return predict_physical
64
+
65
+ predict_physical = build_physical_predictor(emu)
66
+
67
+
68
+ # ------------------------------------------------------------------------------
69
+ # Make some physical inputs
70
+ # ------------------------------------------------------------------------------
71
+
72
+ no_points = 1000
73
+ batch_of_predictions = np.zeros((no_points, 3)) # dummy batch of 10 input points with 3 features (age, eep, feh)
74
+ batch_of_predictions[:,0] = 9.4 # age
75
+ batch_of_predictions[:,1] = np.linspace(202, 454, no_points) # eep
76
+ batch_of_predictions[:,2] = 0.0 # feh
77
+
78
+ # simplified check of domain:
79
+ assert np.all(batch_of_predictions[:, 0] >= MIN_VAL[0]) and np.all(batch_of_predictions[:, 0] <= MAX_VAL[0]), "Age out of domain"
80
+ assert np.all(batch_of_predictions[:, 1] >= MIN_VAL[1]) and np.all(batch_of_predictions[:, 1] <= MAX_VAL[1]), "EEP out of domain"
81
+ assert np.all(batch_of_predictions[:, 2] >= MIN_VAL[2]) and np.all(batch_of_predictions[:, 2] <= MAX_VAL[2]), "FeH out of domain"
82
+
83
+ # move to jax (eg. GPU when availible)
84
+ batch_of_predictions = jnp.array(batch_of_predictions)
85
+
86
+ # ------------------------------------------------------------------------------
87
+ # Create a predictor that uses the frozen model but scales physical inputs, then predict on the batch and time it
88
+ # This could be extended to include a distance, extinction, by analitical model
89
+ # ------------------------------------------------------------------------------
90
+
91
+ # A bit of timing info to see how fast the predictions are after the initial compilation.
92
+ t0 = time.perf_counter()
93
+ y_pred_first = predict_physical(batch_of_predictions)
94
+ y_pred_first = np.asarray(jax.block_until_ready(y_pred_first))
95
+ t1 = time.perf_counter()
96
+
97
+ y_pred_second = predict_physical(batch_of_predictions)
98
+ y_pred_second = np.asarray(jax.block_until_ready(y_pred_second))
99
+ t2 = time.perf_counter()
100
+
101
+ # Summarize timings and prediction shape
102
+ print(f"First call (compile + run): {t1 - t0:.6f} s")
103
+ print(f"Second call (run only): {t2 - t1:.6f} s")
104
+ print(f"Predictions size: {y_pred_second.shape}")
105
+
106
+
107
+ # color-magnitude at the left and magntude vs step at the right
108
+ fig, axs = plt.subplots(1, 2, figsize=(12, 4))
109
+ ax_cmd = axs[0]
110
+ ax_cmd.scatter(y_pred_second[:, 1] - y_pred_second[:, 2], y_pred_second[:, 0], s=18, alpha=0.8, color="tab:orange")
111
+ ax_cmd.set_xlabel("BP - RP")
112
+ ax_cmd.set_ylabel("G")
113
+ ax_cmd.set_title(f"CMD (batch of {no_points} predictions)")
114
+ ax_cmd.grid(alpha=0.25)
115
+ ax_cmd.invert_yaxis() # Magnitudes are brighter when smaller, so invert y-axis for CMD
116
+
117
+ ax_step = axs[1]
118
+ for i in range(y_pred_second.shape[1]):
119
+ ax_step.plot(y_pred_second[:, i], "-", color="tab:orange", alpha=0.9, label=f"Pred {DEFAULT_TARGETS[i]}")
120
+ ax_step.set_xlabel("Batch Index")
121
+ ax_step.set_ylabel("Magnitude")
122
+ ax_step.set_title(f"Predicted {DEFAULT_TARGETS[i]} vs Batch Index")
123
+ ax_step.legend()
124
+ ax_step.grid(alpha=0.25)
125
+ plt.tight_layout()
126
+ plt.show()
127
+ ```