Replace repo contents
Browse files- README.md +0 -127
- README.txt +129 -0
- bundle_integrity.json +43 -0
- config.json +127 -32
- weights.safetensors → fingerprint_evaluation/inputs.safetensors +2 -2
- fingerprint_evaluation/outputs.safetensors +3 -0
- input_domain.safetensors +3 -0
- metadata.json +231 -19
- reference_likelihood.py +109 -0
- reference_scaling_inputs.safetensors +3 -0
- reference_scaling_outputs.safetensors +3 -0
- weights/weights.safetensors +3 -0
README.md
DELETED
|
@@ -1,127 +0,0 @@
|
|
| 1 |
-
---
|
| 2 |
-
license: mit
|
| 3 |
-
---
|
| 4 |
-
```
|
| 5 |
-
from pathlib import Path
|
| 6 |
-
|
| 7 |
-
import numpy as np
|
| 8 |
-
import matplotlib.pyplot as plt
|
| 9 |
-
import jax
|
| 10 |
-
import jax.numpy as jnp
|
| 11 |
-
import time
|
| 12 |
-
|
| 13 |
-
from astro_emulators_toolkit import Emulator
|
| 14 |
-
|
| 15 |
-
script_dir = Path(__file__).parent.resolve()
|
| 16 |
-
|
| 17 |
-
# ------------------------------------------------------------------------------
|
| 18 |
-
# Model description and data scaling info for physical prediction
|
| 19 |
-
# ------------------------------------------------------------------------------
|
| 20 |
-
|
| 21 |
-
DEFAULT_INPUTS = ("age", "eep", "feh")
|
| 22 |
-
DEFAULT_TARGETS = ("G_mag", "BP_mag", "RP_mag")
|
| 23 |
-
|
| 24 |
-
MIN_VAL = np.array([5.8619833, 202.0, -0.87977487, -2.3778718, -2.4398916, -2.2926207], dtype=np.float32)
|
| 25 |
-
MAX_VAL = np.array([1.02993574e01, 4.54000000e02, 5.95229030e-01, 1.50175705e01, 1.84394169e01, 1.36201954e01], dtype=np.float32)
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
# ------------------------------------------------------------------------------
|
| 29 |
-
# Load pretrained emulator bundle from Hugging Face and build a physical predictor
|
| 30 |
-
# ------------------------------------------------------------------------------
|
| 31 |
-
|
| 32 |
-
print("Attempting to load pretrained emulator bundle from Hugging Face...")
|
| 33 |
-
repo_id = "RozanskiT/isochrones-mlp"
|
| 34 |
-
try:
|
| 35 |
-
emu = Emulator.from_pretrained(
|
| 36 |
-
repo_id,
|
| 37 |
-
cache_dir=script_dir / ".emuspec_cache",
|
| 38 |
-
)
|
| 39 |
-
print(f"Loaded pretrained emulator from Hugging Face: {repo_id}")
|
| 40 |
-
except Exception as exc:
|
| 41 |
-
print(f"Hugging Face load failed ({exc}).")
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
# ------------------------------------------------------------------------------
|
| 45 |
-
# Build a physical predictor that scales inputs and applies the frozen model
|
| 46 |
-
# ------------------------------------------------------------------------------
|
| 47 |
-
|
| 48 |
-
def build_physical_predictor(emu: Emulator):
|
| 49 |
-
"""Return a jitted predictor that scales physical inputs then applies frozen model."""
|
| 50 |
-
|
| 51 |
-
frozen_apply = emu.make_frozen_apply_fn(postprocess=True, jit=False)
|
| 52 |
-
x_min = jax.device_put(MIN_VAL[:3])
|
| 53 |
-
x_scale = jax.device_put(MAX_VAL[:3] - MIN_VAL[:3])
|
| 54 |
-
y_min = jax.device_put(MIN_VAL[3:])
|
| 55 |
-
y_scale = jax.device_put(MAX_VAL[3:] - MIN_VAL[3:])
|
| 56 |
-
|
| 57 |
-
@jax.jit
|
| 58 |
-
def predict_physical(x_physical):
|
| 59 |
-
x_norm = (x_physical - x_min) / x_scale
|
| 60 |
-
y_norm = frozen_apply(x_norm)
|
| 61 |
-
return y_norm * y_scale + y_min
|
| 62 |
-
|
| 63 |
-
return predict_physical
|
| 64 |
-
|
| 65 |
-
predict_physical = build_physical_predictor(emu)
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
# ------------------------------------------------------------------------------
|
| 69 |
-
# Make some physical inputs
|
| 70 |
-
# ------------------------------------------------------------------------------
|
| 71 |
-
|
| 72 |
-
no_points = 1000
|
| 73 |
-
batch_of_predictions = np.zeros((no_points, 3)) # dummy batch of 10 input points with 3 features (age, eep, feh)
|
| 74 |
-
batch_of_predictions[:,0] = 9.4 # age
|
| 75 |
-
batch_of_predictions[:,1] = np.linspace(202, 454, no_points) # eep
|
| 76 |
-
batch_of_predictions[:,2] = 0.0 # feh
|
| 77 |
-
|
| 78 |
-
# simplified check of domain:
|
| 79 |
-
assert np.all(batch_of_predictions[:, 0] >= MIN_VAL[0]) and np.all(batch_of_predictions[:, 0] <= MAX_VAL[0]), "Age out of domain"
|
| 80 |
-
assert np.all(batch_of_predictions[:, 1] >= MIN_VAL[1]) and np.all(batch_of_predictions[:, 1] <= MAX_VAL[1]), "EEP out of domain"
|
| 81 |
-
assert np.all(batch_of_predictions[:, 2] >= MIN_VAL[2]) and np.all(batch_of_predictions[:, 2] <= MAX_VAL[2]), "FeH out of domain"
|
| 82 |
-
|
| 83 |
-
# move to jax (eg. GPU when availible)
|
| 84 |
-
batch_of_predictions = jnp.array(batch_of_predictions)
|
| 85 |
-
|
| 86 |
-
# ------------------------------------------------------------------------------
|
| 87 |
-
# Create a predictor that uses the frozen model but scales physical inputs, then predict on the batch and time it
|
| 88 |
-
# This could be extended to include a distance, extinction, by analitical model
|
| 89 |
-
# ------------------------------------------------------------------------------
|
| 90 |
-
|
| 91 |
-
# A bit of timing info to see how fast the predictions are after the initial compilation.
|
| 92 |
-
t0 = time.perf_counter()
|
| 93 |
-
y_pred_first = predict_physical(batch_of_predictions)
|
| 94 |
-
y_pred_first = np.asarray(jax.block_until_ready(y_pred_first))
|
| 95 |
-
t1 = time.perf_counter()
|
| 96 |
-
|
| 97 |
-
y_pred_second = predict_physical(batch_of_predictions)
|
| 98 |
-
y_pred_second = np.asarray(jax.block_until_ready(y_pred_second))
|
| 99 |
-
t2 = time.perf_counter()
|
| 100 |
-
|
| 101 |
-
# Summarize timings and prediction shape
|
| 102 |
-
print(f"First call (compile + run): {t1 - t0:.6f} s")
|
| 103 |
-
print(f"Second call (run only): {t2 - t1:.6f} s")
|
| 104 |
-
print(f"Predictions size: {y_pred_second.shape}")
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
# color-magnitude at the left and magntude vs step at the right
|
| 108 |
-
fig, axs = plt.subplots(1, 2, figsize=(12, 4))
|
| 109 |
-
ax_cmd = axs[0]
|
| 110 |
-
ax_cmd.scatter(y_pred_second[:, 1] - y_pred_second[:, 2], y_pred_second[:, 0], s=18, alpha=0.8, color="tab:orange")
|
| 111 |
-
ax_cmd.set_xlabel("BP - RP")
|
| 112 |
-
ax_cmd.set_ylabel("G")
|
| 113 |
-
ax_cmd.set_title(f"CMD (batch of {no_points} predictions)")
|
| 114 |
-
ax_cmd.grid(alpha=0.25)
|
| 115 |
-
ax_cmd.invert_yaxis() # Magnitudes are brighter when smaller, so invert y-axis for CMD
|
| 116 |
-
|
| 117 |
-
ax_step = axs[1]
|
| 118 |
-
for i in range(y_pred_second.shape[1]):
|
| 119 |
-
ax_step.plot(y_pred_second[:, i], "-", color="tab:orange", alpha=0.9, label=f"Pred {DEFAULT_TARGETS[i]}")
|
| 120 |
-
ax_step.set_xlabel("Batch Index")
|
| 121 |
-
ax_step.set_ylabel("Magnitude")
|
| 122 |
-
ax_step.set_title(f"Predicted {DEFAULT_TARGETS[i]} vs Batch Index")
|
| 123 |
-
ax_step.legend()
|
| 124 |
-
ax_step.grid(alpha=0.25)
|
| 125 |
-
plt.tight_layout()
|
| 126 |
-
plt.show()
|
| 127 |
-
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
README.txt
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Astro Emulators Toolkit Bundle
|
| 2 |
+
|
| 3 |
+
Summary:
|
| 4 |
+
model: mlp
|
| 5 |
+
release: mist-isochrone-modern-mlp-10m-128x3@0.1.0-collab.1 (released)
|
| 6 |
+
bundle_format_version: 1
|
| 7 |
+
config_schema_version: 1
|
| 8 |
+
spec_version: 1
|
| 9 |
+
weights_layout: params_plus_model_state_v1
|
| 10 |
+
model_family_id: mlp_v1
|
| 11 |
+
fingerprint_evaluation: present
|
| 12 |
+
task: regression
|
| 13 |
+
fit_method: gradient
|
| 14 |
+
solver_params: not provided
|
| 15 |
+
solver_diagnostics: not provided
|
| 16 |
+
solver_design_matrix: not provided
|
| 17 |
+
role_paths: {'input_leaf': 'inputs/parameters', 'output_leaf': 'outputs/magnitudes'}
|
| 18 |
+
|
| 19 |
+
Domain:
|
| 20 |
+
input_domain: {'kind': 'box_v1', 'max_tree': {'parameters': [10.299357414245605, 454.0, 0.5952290296554565]}, 'min_tree': {'parameters': [5.861983299255371, 202.0, -0.8797748684883118]}, 'storage': {'filename': 'input_domain.safetensors', 'format': 'safetensors_v1', 'layout': 'split_minmax_tree_v1'}, 'value_space': 'physical_input_dict_tree_v1'}
|
| 21 |
+
reference_scaling_inputs: {'applies_to': 'inputs', 'kind': 'affine_minmax_v1', 'max_tree': {'parameters': [10.299357414245605, 454.0, 0.5952290296554565]}, 'min_tree': {'parameters': [5.861983299255371, 202.0, -0.8797748684883118]}, 'source_space': 'physical_input_dict_tree_v1', 'storage': {'filename': 'reference_scaling_inputs.safetensors', 'format': 'safetensors_v1', 'layout': 'split_minmax_tree_v1'}, 'target_space': 'canonical_input_dict_tree_v1'}
|
| 22 |
+
reference_scaling_outputs: {'applies_to': 'outputs', 'kind': 'affine_minmax_v1', 'max_tree': {'magnitudes': [15.017570495605469, 18.439416885375977, 13.620195388793945]}, 'min_tree': {'magnitudes': [-2.3778717517852783, -2.4398915767669678, -2.2926206588745117]}, 'source_space': 'canonical_output_dict_tree_v1', 'storage': {'filename': 'reference_scaling_outputs.safetensors', 'format': 'safetensors_v1', 'layout': 'split_minmax_tree_v1'}, 'target_space': 'physical_output_dict_tree_v1'}
|
| 23 |
+
extras: not provided
|
| 24 |
+
|
| 25 |
+
Provenance:
|
| 26 |
+
toolkit_version: 0.1.0
|
| 27 |
+
created_at: 2026-04-21T18:10:12.535436+00:00
|
| 28 |
+
python_version: 3.12.13
|
| 29 |
+
git_commit: b3415cfd04a48359232624dba9a1a746cf91313f
|
| 30 |
+
|
| 31 |
+
spec:
|
| 32 |
+
input_domain:
|
| 33 |
+
kind: box_v1
|
| 34 |
+
max_tree:
|
| 35 |
+
parameters:
|
| 36 |
+
- 10.299357414245605
|
| 37 |
+
- 454.0
|
| 38 |
+
- 0.5952290296554565
|
| 39 |
+
min_tree:
|
| 40 |
+
parameters:
|
| 41 |
+
- 5.861983299255371
|
| 42 |
+
- 202.0
|
| 43 |
+
- -0.8797748684883118
|
| 44 |
+
storage:
|
| 45 |
+
filename: input_domain.safetensors
|
| 46 |
+
format: safetensors_v1
|
| 47 |
+
layout: split_minmax_tree_v1
|
| 48 |
+
value_space: physical_input_dict_tree_v1
|
| 49 |
+
inputs:
|
| 50 |
+
channel_meanings_tree:
|
| 51 |
+
parameters:
|
| 52 |
+
- log10 stellar age in years
|
| 53 |
+
- equivalent evolutionary phase
|
| 54 |
+
- metallicity relative to solar
|
| 55 |
+
channel_names_tree:
|
| 56 |
+
parameters:
|
| 57 |
+
- age
|
| 58 |
+
- eep
|
| 59 |
+
- feh
|
| 60 |
+
channel_units_tree:
|
| 61 |
+
parameters:
|
| 62 |
+
- log10(age [yr])
|
| 63 |
+
-
|
| 64 |
+
- [Fe/H]
|
| 65 |
+
leaf_meanings_tree: None
|
| 66 |
+
leaf_units_tree: None
|
| 67 |
+
structure_tree:
|
| 68 |
+
parameters: None
|
| 69 |
+
outputs:
|
| 70 |
+
channel_meanings_tree:
|
| 71 |
+
magnitudes:
|
| 72 |
+
- absolute Gaia G magnitude
|
| 73 |
+
- absolute Gaia BP magnitude
|
| 74 |
+
- absolute Gaia RP magnitude
|
| 75 |
+
channel_names_tree:
|
| 76 |
+
magnitudes:
|
| 77 |
+
- G_mag
|
| 78 |
+
- BP_mag
|
| 79 |
+
- RP_mag
|
| 80 |
+
channel_units_tree:
|
| 81 |
+
magnitudes:
|
| 82 |
+
- abs_mag
|
| 83 |
+
- abs_mag
|
| 84 |
+
- abs_mag
|
| 85 |
+
leaf_meanings_tree: None
|
| 86 |
+
leaf_units_tree: None
|
| 87 |
+
structure_tree:
|
| 88 |
+
magnitudes: None
|
| 89 |
+
reference_scaling_inputs:
|
| 90 |
+
applies_to: inputs
|
| 91 |
+
kind: affine_minmax_v1
|
| 92 |
+
max_tree:
|
| 93 |
+
parameters:
|
| 94 |
+
- 10.299357414245605
|
| 95 |
+
- 454.0
|
| 96 |
+
- 0.5952290296554565
|
| 97 |
+
min_tree:
|
| 98 |
+
parameters:
|
| 99 |
+
- 5.861983299255371
|
| 100 |
+
- 202.0
|
| 101 |
+
- -0.8797748684883118
|
| 102 |
+
source_space: physical_input_dict_tree_v1
|
| 103 |
+
storage:
|
| 104 |
+
filename: reference_scaling_inputs.safetensors
|
| 105 |
+
format: safetensors_v1
|
| 106 |
+
layout: split_minmax_tree_v1
|
| 107 |
+
target_space: canonical_input_dict_tree_v1
|
| 108 |
+
reference_scaling_outputs:
|
| 109 |
+
applies_to: outputs
|
| 110 |
+
kind: affine_minmax_v1
|
| 111 |
+
max_tree:
|
| 112 |
+
magnitudes:
|
| 113 |
+
- 15.017570495605469
|
| 114 |
+
- 18.439416885375977
|
| 115 |
+
- 13.620195388793945
|
| 116 |
+
min_tree:
|
| 117 |
+
magnitudes:
|
| 118 |
+
- -2.3778717517852783
|
| 119 |
+
- -2.4398915767669678
|
| 120 |
+
- -2.2926206588745117
|
| 121 |
+
source_space: canonical_output_dict_tree_v1
|
| 122 |
+
storage:
|
| 123 |
+
filename: reference_scaling_outputs.safetensors
|
| 124 |
+
format: safetensors_v1
|
| 125 |
+
layout: split_minmax_tree_v1
|
| 126 |
+
target_space: physical_output_dict_tree_v1
|
| 127 |
+
spec_version: 1
|
| 128 |
+
|
| 129 |
+
Note: this bundle is the canonical emulator artifact. Physical-space composition is external.
|
bundle_integrity.json
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"algorithm": "sha256",
|
| 3 |
+
"bundle_id": "sha256:f1d5c3f5792c57ac841a07dfa844f1637a357d402fa7e01a2139989d559462ab",
|
| 4 |
+
"integrity_format_version": 1,
|
| 5 |
+
"tree": [
|
| 6 |
+
{
|
| 7 |
+
"path": "README.txt",
|
| 8 |
+
"sha256": "5ecf592e929e06c29b11da5d2286047e58107465971c928276562130a5829cc7"
|
| 9 |
+
},
|
| 10 |
+
{
|
| 11 |
+
"path": "config.json",
|
| 12 |
+
"sha256": "8beec961e9bac4a9c13f087cb8a70af1fd91caf7e09733f2a681a133a02e1b7d"
|
| 13 |
+
},
|
| 14 |
+
{
|
| 15 |
+
"path": "fingerprint_evaluation/inputs.safetensors",
|
| 16 |
+
"sha256": "e411d571388dd13d2c5c9f020386c14378f2a63e63a8a28182b3ac2a45f60e60"
|
| 17 |
+
},
|
| 18 |
+
{
|
| 19 |
+
"path": "fingerprint_evaluation/outputs.safetensors",
|
| 20 |
+
"sha256": "5b168542ec3041851b3ac0bf8dd278c0ccb3f53623dc7c12f6172305e6ed25a2"
|
| 21 |
+
},
|
| 22 |
+
{
|
| 23 |
+
"path": "input_domain.safetensors",
|
| 24 |
+
"sha256": "5d8acd2ed09e2a417ff6f83a4482d0aa88be08ad949144969dc5f2e3b9329c66"
|
| 25 |
+
},
|
| 26 |
+
{
|
| 27 |
+
"path": "metadata.json",
|
| 28 |
+
"sha256": "3d6032c9918be3c88751164653632090eb22d1f578e1fb93c23faf60f83ec03a"
|
| 29 |
+
},
|
| 30 |
+
{
|
| 31 |
+
"path": "reference_scaling_inputs.safetensors",
|
| 32 |
+
"sha256": "5d8acd2ed09e2a417ff6f83a4482d0aa88be08ad949144969dc5f2e3b9329c66"
|
| 33 |
+
},
|
| 34 |
+
{
|
| 35 |
+
"path": "reference_scaling_outputs.safetensors",
|
| 36 |
+
"sha256": "d3420c76743566a2093846f876d0472c0da282e67ff09b8dab72b7b16c7bd5d1"
|
| 37 |
+
},
|
| 38 |
+
{
|
| 39 |
+
"path": "weights/weights.safetensors",
|
| 40 |
+
"sha256": "d407b20c90e930cd335ee6b3ef287114a1ae635c12323c126d3e51e043abda7b"
|
| 41 |
+
}
|
| 42 |
+
]
|
| 43 |
+
}
|
config.json
CHANGED
|
@@ -2,41 +2,129 @@
|
|
| 2 |
"bundle": {
|
| 3 |
"bundle_subdir": "bundle"
|
| 4 |
},
|
| 5 |
-
"data": {
|
| 6 |
-
"columns": null,
|
| 7 |
-
"dtype": "float32",
|
| 8 |
-
"inputs": [],
|
| 9 |
-
"memmap": true,
|
| 10 |
-
"path": "",
|
| 11 |
-
"targets": []
|
| 12 |
-
},
|
| 13 |
"hub": {
|
| 14 |
"repo_id": null,
|
| 15 |
"revision": null
|
| 16 |
},
|
| 17 |
"io": {
|
| 18 |
-
"
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
"
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
},
|
| 31 |
"model": {
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
"name": "mlp",
|
| 33 |
"params": {
|
| 34 |
"activation": "gelu",
|
| 35 |
"dtype": "float32",
|
| 36 |
"hidden_sizes": [
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
]
|
| 41 |
}
|
| 42 |
},
|
|
@@ -45,16 +133,20 @@
|
|
| 45 |
"b2": 0.999,
|
| 46 |
"decay_steps": 0,
|
| 47 |
"eps": 1e-08,
|
| 48 |
-
"lr": 0.
|
| 49 |
"name": "soap",
|
| 50 |
"precondition_1d": false,
|
| 51 |
-
"precondition_frequency":
|
| 52 |
"schedule": "cosine",
|
| 53 |
-
"warmup_steps":
|
| 54 |
"weight_decay": 1e-05
|
| 55 |
},
|
| 56 |
"schema_version": 1,
|
| 57 |
"seed": 0,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
"task": {
|
| 59 |
"name": "regression",
|
| 60 |
"params": {
|
|
@@ -71,15 +163,18 @@
|
|
| 71 |
},
|
| 72 |
"training": {
|
| 73 |
"batch_size": 2048,
|
| 74 |
-
"
|
| 75 |
-
"
|
| 76 |
-
"
|
| 77 |
-
"
|
| 78 |
-
"
|
|
|
|
|
|
|
|
|
|
| 79 |
"shuffle": true,
|
| 80 |
"shuffle_seed": 0,
|
| 81 |
"steps_per_epoch": null,
|
| 82 |
"val_fraction": 0.1,
|
| 83 |
-
"workdir": "/
|
| 84 |
}
|
| 85 |
}
|
|
|
|
| 2 |
"bundle": {
|
| 3 |
"bundle_subdir": "bundle"
|
| 4 |
},
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
"hub": {
|
| 6 |
"repo_id": null,
|
| 7 |
"revision": null
|
| 8 |
},
|
| 9 |
"io": {
|
| 10 |
+
"input_domain": {
|
| 11 |
+
"max_tree": {
|
| 12 |
+
"parameters": [
|
| 13 |
+
10.299357414245605,
|
| 14 |
+
454.0,
|
| 15 |
+
0.5952290296554565
|
| 16 |
+
]
|
| 17 |
+
},
|
| 18 |
+
"min_tree": {
|
| 19 |
+
"parameters": [
|
| 20 |
+
5.861983299255371,
|
| 21 |
+
202.0,
|
| 22 |
+
-0.8797748684883118
|
| 23 |
+
]
|
| 24 |
+
}
|
| 25 |
+
},
|
| 26 |
+
"inputs": {
|
| 27 |
+
"channel_meanings_tree": {
|
| 28 |
+
"parameters": [
|
| 29 |
+
"log10 stellar age in years",
|
| 30 |
+
"equivalent evolutionary phase",
|
| 31 |
+
"metallicity relative to solar"
|
| 32 |
+
]
|
| 33 |
+
},
|
| 34 |
+
"channel_names_tree": {
|
| 35 |
+
"parameters": [
|
| 36 |
+
"age",
|
| 37 |
+
"eep",
|
| 38 |
+
"feh"
|
| 39 |
+
]
|
| 40 |
+
},
|
| 41 |
+
"channel_units_tree": {
|
| 42 |
+
"parameters": [
|
| 43 |
+
"log10(age [yr])",
|
| 44 |
+
"",
|
| 45 |
+
"[Fe/H]"
|
| 46 |
+
]
|
| 47 |
+
},
|
| 48 |
+
"leaf_meanings_tree": null,
|
| 49 |
+
"leaf_units_tree": null,
|
| 50 |
+
"structure_tree": {
|
| 51 |
+
"parameters": null
|
| 52 |
+
}
|
| 53 |
+
},
|
| 54 |
+
"outputs": {
|
| 55 |
+
"channel_meanings_tree": {
|
| 56 |
+
"magnitudes": [
|
| 57 |
+
"absolute Gaia G magnitude",
|
| 58 |
+
"absolute Gaia BP magnitude",
|
| 59 |
+
"absolute Gaia RP magnitude"
|
| 60 |
+
]
|
| 61 |
+
},
|
| 62 |
+
"channel_names_tree": {
|
| 63 |
+
"magnitudes": [
|
| 64 |
+
"G_mag",
|
| 65 |
+
"BP_mag",
|
| 66 |
+
"RP_mag"
|
| 67 |
+
]
|
| 68 |
+
},
|
| 69 |
+
"channel_units_tree": {
|
| 70 |
+
"magnitudes": [
|
| 71 |
+
"abs_mag",
|
| 72 |
+
"abs_mag",
|
| 73 |
+
"abs_mag"
|
| 74 |
+
]
|
| 75 |
+
},
|
| 76 |
+
"leaf_meanings_tree": null,
|
| 77 |
+
"leaf_units_tree": null,
|
| 78 |
+
"structure_tree": {
|
| 79 |
+
"magnitudes": null
|
| 80 |
+
}
|
| 81 |
+
},
|
| 82 |
+
"reference_scaling_inputs": {
|
| 83 |
+
"max_tree": {
|
| 84 |
+
"parameters": [
|
| 85 |
+
10.299357414245605,
|
| 86 |
+
454.0,
|
| 87 |
+
0.5952290296554565
|
| 88 |
+
]
|
| 89 |
+
},
|
| 90 |
+
"min_tree": {
|
| 91 |
+
"parameters": [
|
| 92 |
+
5.861983299255371,
|
| 93 |
+
202.0,
|
| 94 |
+
-0.8797748684883118
|
| 95 |
+
]
|
| 96 |
+
}
|
| 97 |
+
},
|
| 98 |
+
"reference_scaling_outputs": {
|
| 99 |
+
"max_tree": {
|
| 100 |
+
"magnitudes": [
|
| 101 |
+
15.017570495605469,
|
| 102 |
+
18.439416885375977,
|
| 103 |
+
13.620195388793945
|
| 104 |
+
]
|
| 105 |
+
},
|
| 106 |
+
"min_tree": {
|
| 107 |
+
"magnitudes": [
|
| 108 |
+
-2.3778717517852783,
|
| 109 |
+
-2.4398915767669678,
|
| 110 |
+
-2.2926206588745117
|
| 111 |
+
]
|
| 112 |
+
}
|
| 113 |
+
}
|
| 114 |
},
|
| 115 |
"model": {
|
| 116 |
+
"init_hints": {
|
| 117 |
+
"input_last_axis": 3,
|
| 118 |
+
"output_last_axis": 3
|
| 119 |
+
},
|
| 120 |
"name": "mlp",
|
| 121 |
"params": {
|
| 122 |
"activation": "gelu",
|
| 123 |
"dtype": "float32",
|
| 124 |
"hidden_sizes": [
|
| 125 |
+
128,
|
| 126 |
+
128,
|
| 127 |
+
128
|
| 128 |
]
|
| 129 |
}
|
| 130 |
},
|
|
|
|
| 133 |
"b2": 0.999,
|
| 134 |
"decay_steps": 0,
|
| 135 |
"eps": 1e-08,
|
| 136 |
+
"lr": 0.001,
|
| 137 |
"name": "soap",
|
| 138 |
"precondition_1d": false,
|
| 139 |
+
"precondition_frequency": 20,
|
| 140 |
"schedule": "cosine",
|
| 141 |
+
"warmup_steps": 1000000,
|
| 142 |
"weight_decay": 1e-05
|
| 143 |
},
|
| 144 |
"schema_version": 1,
|
| 145 |
"seed": 0,
|
| 146 |
+
"solver": {
|
| 147 |
+
"name": "auto",
|
| 148 |
+
"params": {}
|
| 149 |
+
},
|
| 150 |
"task": {
|
| 151 |
"name": "regression",
|
| 152 |
"params": {
|
|
|
|
| 163 |
},
|
| 164 |
"training": {
|
| 165 |
"batch_size": 2048,
|
| 166 |
+
"checkpoint_interval_steps": null,
|
| 167 |
+
"checkpoint_steps": null,
|
| 168 |
+
"evaluation_interval_steps": 50000,
|
| 169 |
+
"evaluation_steps": null,
|
| 170 |
+
"logging_interval_steps": 10000,
|
| 171 |
+
"logging_steps": null,
|
| 172 |
+
"max_saved_checkpoints": 0,
|
| 173 |
+
"num_steps": 10000000,
|
| 174 |
"shuffle": true,
|
| 175 |
"shuffle_seed": 0,
|
| 176 |
"steps_per_epoch": null,
|
| 177 |
"val_fraction": 0.1,
|
| 178 |
+
"workdir": "./runs/from_bundle"
|
| 179 |
}
|
| 180 |
}
|
weights.safetensors → fingerprint_evaluation/inputs.safetensors
RENAMED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e411d571388dd13d2c5c9f020386c14378f2a63e63a8a28182b3ac2a45f60e60
|
| 3 |
+
size 92
|
fingerprint_evaluation/outputs.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5b168542ec3041851b3ac0bf8dd278c0ccb3f53623dc7c12f6172305e6ed25a2
|
| 3 |
+
size 92
|
input_domain.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5d8acd2ed09e2a417ff6f83a4482d0aa88be08ad949144969dc5f2e3b9329c66
|
| 3 |
+
size 192
|
metadata.json
CHANGED
|
@@ -1,23 +1,235 @@
|
|
| 1 |
{
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
"fit_method": "gradient",
|
| 3 |
-
"
|
| 4 |
-
"
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 13 |
},
|
| 14 |
-
"
|
| 15 |
-
"dim": 3,
|
| 16 |
-
"names": [
|
| 17 |
-
"G_mag",
|
| 18 |
-
"BP_mag",
|
| 19 |
-
"RP_mag"
|
| 20 |
-
],
|
| 21 |
-
"representation": "model-space targets emitted directly by the dataset"
|
| 22 |
-
}
|
| 23 |
}
|
|
|
|
| 1 |
{
|
| 2 |
+
"bundle_format_version": 1,
|
| 3 |
+
"config_schema_version": 1,
|
| 4 |
+
"extras": {},
|
| 5 |
+
"fingerprint_evaluation": {
|
| 6 |
+
"atol": 1e-07,
|
| 7 |
+
"inputs": {
|
| 8 |
+
"filename": "fingerprint_evaluation/inputs.safetensors",
|
| 9 |
+
"format": "safetensors_v1",
|
| 10 |
+
"layout": "numeric_dict_tree_v1",
|
| 11 |
+
"space": "canonical_input_dict_trees_v1"
|
| 12 |
+
},
|
| 13 |
+
"kind": "canonical_inputs_outputs_v1",
|
| 14 |
+
"outputs": {
|
| 15 |
+
"filename": "fingerprint_evaluation/outputs.safetensors",
|
| 16 |
+
"format": "safetensors_v1",
|
| 17 |
+
"layout": "numeric_dict_tree_v1",
|
| 18 |
+
"space": "canonical_output_dict_trees_v1"
|
| 19 |
+
},
|
| 20 |
+
"rtol": 1e-05,
|
| 21 |
+
"selection_strategy": "midpoint_from_input_domain_then_reference_scaling_inputs_v1"
|
| 22 |
+
},
|
| 23 |
"fit_method": "gradient",
|
| 24 |
+
"model_family_id": "mlp_v1",
|
| 25 |
+
"model_init": {
|
| 26 |
+
"hints": {
|
| 27 |
+
"input_last_axis": 3,
|
| 28 |
+
"output_last_axis": 3
|
| 29 |
+
},
|
| 30 |
+
"representation": "model-local init hints only"
|
| 31 |
+
},
|
| 32 |
+
"provenance": {
|
| 33 |
+
"created_at": "2026-04-21T18:10:12.535436+00:00",
|
| 34 |
+
"dependencies": {
|
| 35 |
+
"flax": "0.12.6",
|
| 36 |
+
"jax": "0.9.2",
|
| 37 |
+
"numpy": "2.4.4",
|
| 38 |
+
"optax": "0.2.8"
|
| 39 |
+
},
|
| 40 |
+
"git_commit": "b3415cfd04a48359232624dba9a1a746cf91313f",
|
| 41 |
+
"platform": "macOS-26.4.1-arm64-arm-64bit",
|
| 42 |
+
"python_version": "3.12.13",
|
| 43 |
+
"toolkit": "astro_emulators_toolkit",
|
| 44 |
+
"toolkit_version": "0.1.0"
|
| 45 |
+
},
|
| 46 |
+
"release": {
|
| 47 |
+
"name": "mist-isochrone-modern-mlp-10m-128x3",
|
| 48 |
+
"status": "released",
|
| 49 |
+
"version": "0.1.0-collab.1"
|
| 50 |
+
},
|
| 51 |
+
"resolved": {
|
| 52 |
+
"model": {
|
| 53 |
+
"name": "mlp",
|
| 54 |
+
"params": {
|
| 55 |
+
"activation": "gelu",
|
| 56 |
+
"dtype": "float32",
|
| 57 |
+
"hidden_sizes": [
|
| 58 |
+
128,
|
| 59 |
+
128,
|
| 60 |
+
128
|
| 61 |
+
],
|
| 62 |
+
"use_bias": true
|
| 63 |
+
}
|
| 64 |
+
},
|
| 65 |
+
"solver": {
|
| 66 |
+
"name": "gradient",
|
| 67 |
+
"params": {}
|
| 68 |
+
},
|
| 69 |
+
"task": {
|
| 70 |
+
"name": "regression",
|
| 71 |
+
"params": {
|
| 72 |
+
"loss": "mse",
|
| 73 |
+
"loss_weights": null,
|
| 74 |
+
"metric_axes": {
|
| 75 |
+
"global": "all",
|
| 76 |
+
"per_dim": []
|
| 77 |
+
},
|
| 78 |
+
"metrics": [
|
| 79 |
+
"mse",
|
| 80 |
+
"mae"
|
| 81 |
+
]
|
| 82 |
+
}
|
| 83 |
+
}
|
| 84 |
+
},
|
| 85 |
+
"runtime_contract": {
|
| 86 |
+
"affine_leaf_specs": {
|
| 87 |
+
"inputs/parameters": {
|
| 88 |
+
"last_axis": 3,
|
| 89 |
+
"mode": "scalar_or_last_axis"
|
| 90 |
+
},
|
| 91 |
+
"outputs/magnitudes": {
|
| 92 |
+
"last_axis": 3,
|
| 93 |
+
"mode": "scalar_or_last_axis"
|
| 94 |
+
}
|
| 95 |
+
},
|
| 96 |
+
"role_paths": {
|
| 97 |
+
"input_leaf": "inputs/parameters",
|
| 98 |
+
"output_leaf": "outputs/magnitudes"
|
| 99 |
+
},
|
| 100 |
+
"surface": "canonical_dict_trees_v1"
|
| 101 |
+
},
|
| 102 |
+
"spec": {
|
| 103 |
+
"input_domain": {
|
| 104 |
+
"kind": "box_v1",
|
| 105 |
+
"max_tree": {
|
| 106 |
+
"parameters": [
|
| 107 |
+
10.299357414245605,
|
| 108 |
+
454.0,
|
| 109 |
+
0.5952290296554565
|
| 110 |
+
]
|
| 111 |
+
},
|
| 112 |
+
"min_tree": {
|
| 113 |
+
"parameters": [
|
| 114 |
+
5.861983299255371,
|
| 115 |
+
202.0,
|
| 116 |
+
-0.8797748684883118
|
| 117 |
+
]
|
| 118 |
+
},
|
| 119 |
+
"storage": {
|
| 120 |
+
"filename": "input_domain.safetensors",
|
| 121 |
+
"format": "safetensors_v1",
|
| 122 |
+
"layout": "split_minmax_tree_v1"
|
| 123 |
+
},
|
| 124 |
+
"value_space": "physical_input_dict_tree_v1"
|
| 125 |
+
},
|
| 126 |
+
"inputs": {
|
| 127 |
+
"channel_meanings_tree": {
|
| 128 |
+
"parameters": [
|
| 129 |
+
"log10 stellar age in years",
|
| 130 |
+
"equivalent evolutionary phase",
|
| 131 |
+
"metallicity relative to solar"
|
| 132 |
+
]
|
| 133 |
+
},
|
| 134 |
+
"channel_names_tree": {
|
| 135 |
+
"parameters": [
|
| 136 |
+
"age",
|
| 137 |
+
"eep",
|
| 138 |
+
"feh"
|
| 139 |
+
]
|
| 140 |
+
},
|
| 141 |
+
"channel_units_tree": {
|
| 142 |
+
"parameters": [
|
| 143 |
+
"log10(age [yr])",
|
| 144 |
+
"",
|
| 145 |
+
"[Fe/H]"
|
| 146 |
+
]
|
| 147 |
+
},
|
| 148 |
+
"leaf_meanings_tree": null,
|
| 149 |
+
"leaf_units_tree": null,
|
| 150 |
+
"structure_tree": {
|
| 151 |
+
"parameters": null
|
| 152 |
+
}
|
| 153 |
+
},
|
| 154 |
+
"outputs": {
|
| 155 |
+
"channel_meanings_tree": {
|
| 156 |
+
"magnitudes": [
|
| 157 |
+
"absolute Gaia G magnitude",
|
| 158 |
+
"absolute Gaia BP magnitude",
|
| 159 |
+
"absolute Gaia RP magnitude"
|
| 160 |
+
]
|
| 161 |
+
},
|
| 162 |
+
"channel_names_tree": {
|
| 163 |
+
"magnitudes": [
|
| 164 |
+
"G_mag",
|
| 165 |
+
"BP_mag",
|
| 166 |
+
"RP_mag"
|
| 167 |
+
]
|
| 168 |
+
},
|
| 169 |
+
"channel_units_tree": {
|
| 170 |
+
"magnitudes": [
|
| 171 |
+
"abs_mag",
|
| 172 |
+
"abs_mag",
|
| 173 |
+
"abs_mag"
|
| 174 |
+
]
|
| 175 |
+
},
|
| 176 |
+
"leaf_meanings_tree": null,
|
| 177 |
+
"leaf_units_tree": null,
|
| 178 |
+
"structure_tree": {
|
| 179 |
+
"magnitudes": null
|
| 180 |
+
}
|
| 181 |
+
},
|
| 182 |
+
"reference_scaling_inputs": {
|
| 183 |
+
"applies_to": "inputs",
|
| 184 |
+
"kind": "affine_minmax_v1",
|
| 185 |
+
"max_tree": {
|
| 186 |
+
"parameters": [
|
| 187 |
+
10.299357414245605,
|
| 188 |
+
454.0,
|
| 189 |
+
0.5952290296554565
|
| 190 |
+
]
|
| 191 |
+
},
|
| 192 |
+
"min_tree": {
|
| 193 |
+
"parameters": [
|
| 194 |
+
5.861983299255371,
|
| 195 |
+
202.0,
|
| 196 |
+
-0.8797748684883118
|
| 197 |
+
]
|
| 198 |
+
},
|
| 199 |
+
"source_space": "physical_input_dict_tree_v1",
|
| 200 |
+
"storage": {
|
| 201 |
+
"filename": "reference_scaling_inputs.safetensors",
|
| 202 |
+
"format": "safetensors_v1",
|
| 203 |
+
"layout": "split_minmax_tree_v1"
|
| 204 |
+
},
|
| 205 |
+
"target_space": "canonical_input_dict_tree_v1"
|
| 206 |
+
},
|
| 207 |
+
"reference_scaling_outputs": {
|
| 208 |
+
"applies_to": "outputs",
|
| 209 |
+
"kind": "affine_minmax_v1",
|
| 210 |
+
"max_tree": {
|
| 211 |
+
"magnitudes": [
|
| 212 |
+
15.017570495605469,
|
| 213 |
+
18.439416885375977,
|
| 214 |
+
13.620195388793945
|
| 215 |
+
]
|
| 216 |
+
},
|
| 217 |
+
"min_tree": {
|
| 218 |
+
"magnitudes": [
|
| 219 |
+
-2.3778717517852783,
|
| 220 |
+
-2.4398915767669678,
|
| 221 |
+
-2.2926206588745117
|
| 222 |
+
]
|
| 223 |
+
},
|
| 224 |
+
"source_space": "canonical_output_dict_tree_v1",
|
| 225 |
+
"storage": {
|
| 226 |
+
"filename": "reference_scaling_outputs.safetensors",
|
| 227 |
+
"format": "safetensors_v1",
|
| 228 |
+
"layout": "split_minmax_tree_v1"
|
| 229 |
+
},
|
| 230 |
+
"target_space": "physical_output_dict_tree_v1"
|
| 231 |
+
},
|
| 232 |
+
"spec_version": 1
|
| 233 |
},
|
| 234 |
+
"weights_layout": "params_plus_model_state_v1"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 235 |
}
|
reference_likelihood.py
ADDED
|
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Quickstart: Gaussian likelihood for one Gaia photometric measurement.
|
| 2 |
+
|
| 3 |
+
This is a compact reference script for collaborators loading the released
|
| 4 |
+
isochrone emulator from Hugging Face. It follows the same pattern as the main
|
| 5 |
+
README examples:
|
| 6 |
+
|
| 7 |
+
1. download a bundle with ``Emulator.from_pretrained(...)``;
|
| 8 |
+
2. freeze a JAX callable with ``make_frozen_apply(jit=False)``;
|
| 9 |
+
3. explicitly normalize physical inputs into the bundle's canonical space;
|
| 10 |
+
4. explicitly denormalize canonical outputs back to physical magnitudes;
|
| 11 |
+
5. evaluate a simple diagonal Gaussian log likelihood in one outer jit.
|
| 12 |
+
|
| 13 |
+
The input vector is ``[log10_age_yr, eep, feh]``. The measurement vector is
|
| 14 |
+
absolute ``[G_mag, BP_mag, RP_mag]``. If your data are apparent magnitudes,
|
| 15 |
+
include distance modulus, extinction, calibration offsets, or other nuisance
|
| 16 |
+
terms in your own likelihood around the emulator prediction.
|
| 17 |
+
|
| 18 |
+
Related examples in the source repository:
|
| 19 |
+
|
| 20 |
+
- examples/basic/02_load_bundle_predict.py
|
| 21 |
+
- examples/basic/04_use_bundle_in_map_fit.py
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
from __future__ import annotations
|
| 25 |
+
|
| 26 |
+
import jax
|
| 27 |
+
import jax.numpy as jnp
|
| 28 |
+
import numpy as np
|
| 29 |
+
|
| 30 |
+
from astro_emulators_toolkit import Emulator, denormalize_tree, normalize_tree
|
| 31 |
+
|
| 32 |
+
REPO_ID = "RozanskiT/isochrones-mlp"
|
| 33 |
+
REVISION = None
|
| 34 |
+
CACHE_DIR = ".emuspec_cache"
|
| 35 |
+
|
| 36 |
+
OUTPUT_LEAF = "magnitudes"
|
| 37 |
+
OUTPUT_CHANNELS = ("G_mag", "BP_mag", "RP_mag")
|
| 38 |
+
|
| 39 |
+
# One trial isochrone point: log10(age [yr]), EEP, [Fe/H].
|
| 40 |
+
THETA_PHYSICAL = np.asarray([9.4, 300.0, 0.0], dtype=np.float32)
|
| 41 |
+
|
| 42 |
+
# One example absolute photometric measurement in G, BP, RP.
|
| 43 |
+
# Replace these with your own absolute magnitudes and uncertainties.
|
| 44 |
+
OBSERVED_MAGNITUDES = np.asarray([6.94, 7.52, 6.21], dtype=np.float32)
|
| 45 |
+
OBSERVED_SIGMA_MAG = np.asarray([0.03, 0.03, 0.03], dtype=np.float32)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def main() -> None:
|
| 49 |
+
emu = Emulator.from_pretrained(
|
| 50 |
+
REPO_ID,
|
| 51 |
+
revision=REVISION,
|
| 52 |
+
cache_dir=CACHE_DIR,
|
| 53 |
+
verbose=True,
|
| 54 |
+
)
|
| 55 |
+
apply_magnitudes = emu.make_frozen_apply(jit=False)
|
| 56 |
+
|
| 57 |
+
ref_inputs = emu.reference_scaling_inputs
|
| 58 |
+
ref_outputs = emu.reference_scaling_outputs
|
| 59 |
+
if ref_inputs is None or ref_outputs is None:
|
| 60 |
+
raise ValueError(
|
| 61 |
+
"This likelihood example requires reference_scaling_inputs and "
|
| 62 |
+
"reference_scaling_outputs in the bundle metadata."
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
y_obs = jnp.asarray(OBSERVED_MAGNITUDES, dtype=jnp.float32)
|
| 66 |
+
y_err = jnp.asarray(OBSERVED_SIGMA_MAG, dtype=jnp.float32)
|
| 67 |
+
|
| 68 |
+
def predict_magnitudes(theta):
|
| 69 |
+
"""Predict physical magnitudes; jit the outer objective, not this helper."""
|
| 70 |
+
x_physical = {"parameters": theta[None, :]}
|
| 71 |
+
x_scaled = normalize_tree(
|
| 72 |
+
x_physical,
|
| 73 |
+
ref_inputs["min_tree"],
|
| 74 |
+
ref_inputs["max_tree"],
|
| 75 |
+
)
|
| 76 |
+
y_scaled = apply_magnitudes(x_scaled)
|
| 77 |
+
y_physical = denormalize_tree(
|
| 78 |
+
y_scaled,
|
| 79 |
+
ref_outputs["min_tree"],
|
| 80 |
+
ref_outputs["max_tree"],
|
| 81 |
+
)
|
| 82 |
+
return y_physical[OUTPUT_LEAF][0]
|
| 83 |
+
|
| 84 |
+
@jax.jit
|
| 85 |
+
def evaluate_likelihood(theta):
|
| 86 |
+
y_model = predict_magnitudes(theta)
|
| 87 |
+
resid = (y_obs - y_model) / y_err
|
| 88 |
+
log_norm = jnp.sum(jnp.log(2.0 * jnp.pi * y_err**2))
|
| 89 |
+
log_likelihood = -0.5 * (jnp.sum(resid**2) + log_norm)
|
| 90 |
+
return y_model, log_likelihood
|
| 91 |
+
|
| 92 |
+
theta = jnp.asarray(THETA_PHYSICAL, dtype=jnp.float32)
|
| 93 |
+
model_magnitudes_jax, logp_jax = evaluate_likelihood(theta)
|
| 94 |
+
model_magnitudes = np.asarray(jax.block_until_ready(model_magnitudes_jax))
|
| 95 |
+
logp = float(jax.block_until_ready(logp_jax))
|
| 96 |
+
|
| 97 |
+
print("theta_physical [age, eep, feh]:", THETA_PHYSICAL.tolist())
|
| 98 |
+
print("model absolute magnitudes:")
|
| 99 |
+
for name, value in zip(OUTPUT_CHANNELS, model_magnitudes, strict=True):
|
| 100 |
+
print(f" {name}: {value:.6f}")
|
| 101 |
+
print("observed absolute magnitudes:")
|
| 102 |
+
for name, value in zip(OUTPUT_CHANNELS, OBSERVED_MAGNITUDES, strict=True):
|
| 103 |
+
print(f" {name}: {value:.6f}")
|
| 104 |
+
print("sigma_mag:", OBSERVED_SIGMA_MAG.tolist())
|
| 105 |
+
print("log_likelihood:", f"{logp:.6f}")
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
if __name__ == "__main__":
|
| 109 |
+
main()
|
reference_scaling_inputs.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5d8acd2ed09e2a417ff6f83a4482d0aa88be08ad949144969dc5f2e3b9329c66
|
| 3 |
+
size 192
|
reference_scaling_outputs.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d3420c76743566a2093846f876d0472c0da282e67ff09b8dab72b7b16c7bd5d1
|
| 3 |
+
size 192
|
weights/weights.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d407b20c90e930cd335ee6b3ef287114a1ae635c12323c126d3e51e043abda7b
|
| 3 |
+
size 136588
|