RozanskiT commited on
Commit
9f104bf
·
verified ·
1 Parent(s): bf7cd89

Replace repo contents

Browse files
README.md DELETED
@@ -1,127 +0,0 @@
1
- ---
2
- license: mit
3
- ---
4
- ```
5
- from pathlib import Path
6
-
7
- import numpy as np
8
- import matplotlib.pyplot as plt
9
- import jax
10
- import jax.numpy as jnp
11
- import time
12
-
13
- from astro_emulators_toolkit import Emulator
14
-
15
- script_dir = Path(__file__).parent.resolve()
16
-
17
- # ------------------------------------------------------------------------------
18
- # Model description and data scaling info for physical prediction
19
- # ------------------------------------------------------------------------------
20
-
21
- DEFAULT_INPUTS = ("age", "eep", "feh")
22
- DEFAULT_TARGETS = ("G_mag", "BP_mag", "RP_mag")
23
-
24
- MIN_VAL = np.array([5.8619833, 202.0, -0.87977487, -2.3778718, -2.4398916, -2.2926207], dtype=np.float32)
25
- MAX_VAL = np.array([1.02993574e01, 4.54000000e02, 5.95229030e-01, 1.50175705e01, 1.84394169e01, 1.36201954e01], dtype=np.float32)
26
-
27
-
28
- # ------------------------------------------------------------------------------
29
- # Load pretrained emulator bundle from Hugging Face and build a physical predictor
30
- # ------------------------------------------------------------------------------
31
-
32
- print("Attempting to load pretrained emulator bundle from Hugging Face...")
33
- repo_id = "RozanskiT/isochrones-mlp"
34
- try:
35
- emu = Emulator.from_pretrained(
36
- repo_id,
37
- cache_dir=script_dir / ".emuspec_cache",
38
- )
39
- print(f"Loaded pretrained emulator from Hugging Face: {repo_id}")
40
- except Exception as exc:
41
- print(f"Hugging Face load failed ({exc}).")
42
-
43
-
44
- # ------------------------------------------------------------------------------
45
- # Build a physical predictor that scales inputs and applies the frozen model
46
- # ------------------------------------------------------------------------------
47
-
48
- def build_physical_predictor(emu: Emulator):
49
- """Return a jitted predictor that scales physical inputs then applies frozen model."""
50
-
51
- frozen_apply = emu.make_frozen_apply_fn(postprocess=True, jit=False)
52
- x_min = jax.device_put(MIN_VAL[:3])
53
- x_scale = jax.device_put(MAX_VAL[:3] - MIN_VAL[:3])
54
- y_min = jax.device_put(MIN_VAL[3:])
55
- y_scale = jax.device_put(MAX_VAL[3:] - MIN_VAL[3:])
56
-
57
- @jax.jit
58
- def predict_physical(x_physical):
59
- x_norm = (x_physical - x_min) / x_scale
60
- y_norm = frozen_apply(x_norm)
61
- return y_norm * y_scale + y_min
62
-
63
- return predict_physical
64
-
65
- predict_physical = build_physical_predictor(emu)
66
-
67
-
68
- # ------------------------------------------------------------------------------
69
- # Make some physical inputs
70
- # ------------------------------------------------------------------------------
71
-
72
- no_points = 1000
73
- batch_of_predictions = np.zeros((no_points, 3)) # dummy batch of 10 input points with 3 features (age, eep, feh)
74
- batch_of_predictions[:,0] = 9.4 # age
75
- batch_of_predictions[:,1] = np.linspace(202, 454, no_points) # eep
76
- batch_of_predictions[:,2] = 0.0 # feh
77
-
78
- # simplified check of domain:
79
- assert np.all(batch_of_predictions[:, 0] >= MIN_VAL[0]) and np.all(batch_of_predictions[:, 0] <= MAX_VAL[0]), "Age out of domain"
80
- assert np.all(batch_of_predictions[:, 1] >= MIN_VAL[1]) and np.all(batch_of_predictions[:, 1] <= MAX_VAL[1]), "EEP out of domain"
81
- assert np.all(batch_of_predictions[:, 2] >= MIN_VAL[2]) and np.all(batch_of_predictions[:, 2] <= MAX_VAL[2]), "FeH out of domain"
82
-
83
- # move to jax (eg. GPU when availible)
84
- batch_of_predictions = jnp.array(batch_of_predictions)
85
-
86
- # ------------------------------------------------------------------------------
87
- # Create a predictor that uses the frozen model but scales physical inputs, then predict on the batch and time it
88
- # This could be extended to include a distance, extinction, by analitical model
89
- # ------------------------------------------------------------------------------
90
-
91
- # A bit of timing info to see how fast the predictions are after the initial compilation.
92
- t0 = time.perf_counter()
93
- y_pred_first = predict_physical(batch_of_predictions)
94
- y_pred_first = np.asarray(jax.block_until_ready(y_pred_first))
95
- t1 = time.perf_counter()
96
-
97
- y_pred_second = predict_physical(batch_of_predictions)
98
- y_pred_second = np.asarray(jax.block_until_ready(y_pred_second))
99
- t2 = time.perf_counter()
100
-
101
- # Summarize timings and prediction shape
102
- print(f"First call (compile + run): {t1 - t0:.6f} s")
103
- print(f"Second call (run only): {t2 - t1:.6f} s")
104
- print(f"Predictions size: {y_pred_second.shape}")
105
-
106
-
107
- # color-magnitude at the left and magntude vs step at the right
108
- fig, axs = plt.subplots(1, 2, figsize=(12, 4))
109
- ax_cmd = axs[0]
110
- ax_cmd.scatter(y_pred_second[:, 1] - y_pred_second[:, 2], y_pred_second[:, 0], s=18, alpha=0.8, color="tab:orange")
111
- ax_cmd.set_xlabel("BP - RP")
112
- ax_cmd.set_ylabel("G")
113
- ax_cmd.set_title(f"CMD (batch of {no_points} predictions)")
114
- ax_cmd.grid(alpha=0.25)
115
- ax_cmd.invert_yaxis() # Magnitudes are brighter when smaller, so invert y-axis for CMD
116
-
117
- ax_step = axs[1]
118
- for i in range(y_pred_second.shape[1]):
119
- ax_step.plot(y_pred_second[:, i], "-", color="tab:orange", alpha=0.9, label=f"Pred {DEFAULT_TARGETS[i]}")
120
- ax_step.set_xlabel("Batch Index")
121
- ax_step.set_ylabel("Magnitude")
122
- ax_step.set_title(f"Predicted {DEFAULT_TARGETS[i]} vs Batch Index")
123
- ax_step.legend()
124
- ax_step.grid(alpha=0.25)
125
- plt.tight_layout()
126
- plt.show()
127
- ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
README.txt ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Astro Emulators Toolkit Bundle
2
+
3
+ Summary:
4
+ model: mlp
5
+ release: mist-isochrone-modern-mlp-10m-128x3@0.1.0-collab.1 (released)
6
+ bundle_format_version: 1
7
+ config_schema_version: 1
8
+ spec_version: 1
9
+ weights_layout: params_plus_model_state_v1
10
+ model_family_id: mlp_v1
11
+ fingerprint_evaluation: present
12
+ task: regression
13
+ fit_method: gradient
14
+ solver_params: not provided
15
+ solver_diagnostics: not provided
16
+ solver_design_matrix: not provided
17
+ role_paths: {'input_leaf': 'inputs/parameters', 'output_leaf': 'outputs/magnitudes'}
18
+
19
+ Domain:
20
+ input_domain: {'kind': 'box_v1', 'max_tree': {'parameters': [10.299357414245605, 454.0, 0.5952290296554565]}, 'min_tree': {'parameters': [5.861983299255371, 202.0, -0.8797748684883118]}, 'storage': {'filename': 'input_domain.safetensors', 'format': 'safetensors_v1', 'layout': 'split_minmax_tree_v1'}, 'value_space': 'physical_input_dict_tree_v1'}
21
+ reference_scaling_inputs: {'applies_to': 'inputs', 'kind': 'affine_minmax_v1', 'max_tree': {'parameters': [10.299357414245605, 454.0, 0.5952290296554565]}, 'min_tree': {'parameters': [5.861983299255371, 202.0, -0.8797748684883118]}, 'source_space': 'physical_input_dict_tree_v1', 'storage': {'filename': 'reference_scaling_inputs.safetensors', 'format': 'safetensors_v1', 'layout': 'split_minmax_tree_v1'}, 'target_space': 'canonical_input_dict_tree_v1'}
22
+ reference_scaling_outputs: {'applies_to': 'outputs', 'kind': 'affine_minmax_v1', 'max_tree': {'magnitudes': [15.017570495605469, 18.439416885375977, 13.620195388793945]}, 'min_tree': {'magnitudes': [-2.3778717517852783, -2.4398915767669678, -2.2926206588745117]}, 'source_space': 'canonical_output_dict_tree_v1', 'storage': {'filename': 'reference_scaling_outputs.safetensors', 'format': 'safetensors_v1', 'layout': 'split_minmax_tree_v1'}, 'target_space': 'physical_output_dict_tree_v1'}
23
+ extras: not provided
24
+
25
+ Provenance:
26
+ toolkit_version: 0.1.0
27
+ created_at: 2026-04-21T18:10:12.535436+00:00
28
+ python_version: 3.12.13
29
+ git_commit: b3415cfd04a48359232624dba9a1a746cf91313f
30
+
31
+ spec:
32
+ input_domain:
33
+ kind: box_v1
34
+ max_tree:
35
+ parameters:
36
+ - 10.299357414245605
37
+ - 454.0
38
+ - 0.5952290296554565
39
+ min_tree:
40
+ parameters:
41
+ - 5.861983299255371
42
+ - 202.0
43
+ - -0.8797748684883118
44
+ storage:
45
+ filename: input_domain.safetensors
46
+ format: safetensors_v1
47
+ layout: split_minmax_tree_v1
48
+ value_space: physical_input_dict_tree_v1
49
+ inputs:
50
+ channel_meanings_tree:
51
+ parameters:
52
+ - log10 stellar age in years
53
+ - equivalent evolutionary phase
54
+ - metallicity relative to solar
55
+ channel_names_tree:
56
+ parameters:
57
+ - age
58
+ - eep
59
+ - feh
60
+ channel_units_tree:
61
+ parameters:
62
+ - log10(age [yr])
63
+ -
64
+ - [Fe/H]
65
+ leaf_meanings_tree: None
66
+ leaf_units_tree: None
67
+ structure_tree:
68
+ parameters: None
69
+ outputs:
70
+ channel_meanings_tree:
71
+ magnitudes:
72
+ - absolute Gaia G magnitude
73
+ - absolute Gaia BP magnitude
74
+ - absolute Gaia RP magnitude
75
+ channel_names_tree:
76
+ magnitudes:
77
+ - G_mag
78
+ - BP_mag
79
+ - RP_mag
80
+ channel_units_tree:
81
+ magnitudes:
82
+ - abs_mag
83
+ - abs_mag
84
+ - abs_mag
85
+ leaf_meanings_tree: None
86
+ leaf_units_tree: None
87
+ structure_tree:
88
+ magnitudes: None
89
+ reference_scaling_inputs:
90
+ applies_to: inputs
91
+ kind: affine_minmax_v1
92
+ max_tree:
93
+ parameters:
94
+ - 10.299357414245605
95
+ - 454.0
96
+ - 0.5952290296554565
97
+ min_tree:
98
+ parameters:
99
+ - 5.861983299255371
100
+ - 202.0
101
+ - -0.8797748684883118
102
+ source_space: physical_input_dict_tree_v1
103
+ storage:
104
+ filename: reference_scaling_inputs.safetensors
105
+ format: safetensors_v1
106
+ layout: split_minmax_tree_v1
107
+ target_space: canonical_input_dict_tree_v1
108
+ reference_scaling_outputs:
109
+ applies_to: outputs
110
+ kind: affine_minmax_v1
111
+ max_tree:
112
+ magnitudes:
113
+ - 15.017570495605469
114
+ - 18.439416885375977
115
+ - 13.620195388793945
116
+ min_tree:
117
+ magnitudes:
118
+ - -2.3778717517852783
119
+ - -2.4398915767669678
120
+ - -2.2926206588745117
121
+ source_space: canonical_output_dict_tree_v1
122
+ storage:
123
+ filename: reference_scaling_outputs.safetensors
124
+ format: safetensors_v1
125
+ layout: split_minmax_tree_v1
126
+ target_space: physical_output_dict_tree_v1
127
+ spec_version: 1
128
+
129
+ Note: this bundle is the canonical emulator artifact. Physical-space composition is external.
bundle_integrity.json ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "algorithm": "sha256",
3
+ "bundle_id": "sha256:f1d5c3f5792c57ac841a07dfa844f1637a357d402fa7e01a2139989d559462ab",
4
+ "integrity_format_version": 1,
5
+ "tree": [
6
+ {
7
+ "path": "README.txt",
8
+ "sha256": "5ecf592e929e06c29b11da5d2286047e58107465971c928276562130a5829cc7"
9
+ },
10
+ {
11
+ "path": "config.json",
12
+ "sha256": "8beec961e9bac4a9c13f087cb8a70af1fd91caf7e09733f2a681a133a02e1b7d"
13
+ },
14
+ {
15
+ "path": "fingerprint_evaluation/inputs.safetensors",
16
+ "sha256": "e411d571388dd13d2c5c9f020386c14378f2a63e63a8a28182b3ac2a45f60e60"
17
+ },
18
+ {
19
+ "path": "fingerprint_evaluation/outputs.safetensors",
20
+ "sha256": "5b168542ec3041851b3ac0bf8dd278c0ccb3f53623dc7c12f6172305e6ed25a2"
21
+ },
22
+ {
23
+ "path": "input_domain.safetensors",
24
+ "sha256": "5d8acd2ed09e2a417ff6f83a4482d0aa88be08ad949144969dc5f2e3b9329c66"
25
+ },
26
+ {
27
+ "path": "metadata.json",
28
+ "sha256": "3d6032c9918be3c88751164653632090eb22d1f578e1fb93c23faf60f83ec03a"
29
+ },
30
+ {
31
+ "path": "reference_scaling_inputs.safetensors",
32
+ "sha256": "5d8acd2ed09e2a417ff6f83a4482d0aa88be08ad949144969dc5f2e3b9329c66"
33
+ },
34
+ {
35
+ "path": "reference_scaling_outputs.safetensors",
36
+ "sha256": "d3420c76743566a2093846f876d0472c0da282e67ff09b8dab72b7b16c7bd5d1"
37
+ },
38
+ {
39
+ "path": "weights/weights.safetensors",
40
+ "sha256": "d407b20c90e930cd335ee6b3ef287114a1ae635c12323c126d3e51e043abda7b"
41
+ }
42
+ ]
43
+ }
config.json CHANGED
@@ -2,41 +2,129 @@
2
  "bundle": {
3
  "bundle_subdir": "bundle"
4
  },
5
- "data": {
6
- "columns": null,
7
- "dtype": "float32",
8
- "inputs": [],
9
- "memmap": true,
10
- "path": "",
11
- "targets": []
12
- },
13
  "hub": {
14
  "repo_id": null,
15
  "revision": null
16
  },
17
  "io": {
18
- "x_dim": 3,
19
- "x_names": [
20
- "age",
21
- "eep",
22
- "feh"
23
- ],
24
- "y_dim": 3,
25
- "y_names": [
26
- "G_mag",
27
- "BP_mag",
28
- "RP_mag"
29
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
  },
31
  "model": {
 
 
 
 
32
  "name": "mlp",
33
  "params": {
34
  "activation": "gelu",
35
  "dtype": "float32",
36
  "hidden_sizes": [
37
- 64,
38
- 64,
39
- 64
40
  ]
41
  }
42
  },
@@ -45,16 +133,20 @@
45
  "b2": 0.999,
46
  "decay_steps": 0,
47
  "eps": 1e-08,
48
- "lr": 0.003,
49
  "name": "soap",
50
  "precondition_1d": false,
51
- "precondition_frequency": 10,
52
  "schedule": "cosine",
53
- "warmup_steps": 10000,
54
  "weight_decay": 1e-05
55
  },
56
  "schema_version": 1,
57
  "seed": 0,
 
 
 
 
58
  "task": {
59
  "name": "regression",
60
  "params": {
@@ -71,15 +163,18 @@
71
  },
72
  "training": {
73
  "batch_size": 2048,
74
- "checkpoint_every_steps": 150000,
75
- "eval_every_steps": 5000,
76
- "log_every_steps": 1000,
77
- "max_checkpoints": 3,
78
- "num_steps": 100000,
 
 
 
79
  "shuffle": true,
80
  "shuffle_seed": 0,
81
  "steps_per_epoch": null,
82
  "val_fraction": 0.1,
83
- "workdir": "/Users/tr/repos/astro-emulators-toolkit/large_local_examples/isochrones/run2"
84
  }
85
  }
 
2
  "bundle": {
3
  "bundle_subdir": "bundle"
4
  },
 
 
 
 
 
 
 
 
5
  "hub": {
6
  "repo_id": null,
7
  "revision": null
8
  },
9
  "io": {
10
+ "input_domain": {
11
+ "max_tree": {
12
+ "parameters": [
13
+ 10.299357414245605,
14
+ 454.0,
15
+ 0.5952290296554565
16
+ ]
17
+ },
18
+ "min_tree": {
19
+ "parameters": [
20
+ 5.861983299255371,
21
+ 202.0,
22
+ -0.8797748684883118
23
+ ]
24
+ }
25
+ },
26
+ "inputs": {
27
+ "channel_meanings_tree": {
28
+ "parameters": [
29
+ "log10 stellar age in years",
30
+ "equivalent evolutionary phase",
31
+ "metallicity relative to solar"
32
+ ]
33
+ },
34
+ "channel_names_tree": {
35
+ "parameters": [
36
+ "age",
37
+ "eep",
38
+ "feh"
39
+ ]
40
+ },
41
+ "channel_units_tree": {
42
+ "parameters": [
43
+ "log10(age [yr])",
44
+ "",
45
+ "[Fe/H]"
46
+ ]
47
+ },
48
+ "leaf_meanings_tree": null,
49
+ "leaf_units_tree": null,
50
+ "structure_tree": {
51
+ "parameters": null
52
+ }
53
+ },
54
+ "outputs": {
55
+ "channel_meanings_tree": {
56
+ "magnitudes": [
57
+ "absolute Gaia G magnitude",
58
+ "absolute Gaia BP magnitude",
59
+ "absolute Gaia RP magnitude"
60
+ ]
61
+ },
62
+ "channel_names_tree": {
63
+ "magnitudes": [
64
+ "G_mag",
65
+ "BP_mag",
66
+ "RP_mag"
67
+ ]
68
+ },
69
+ "channel_units_tree": {
70
+ "magnitudes": [
71
+ "abs_mag",
72
+ "abs_mag",
73
+ "abs_mag"
74
+ ]
75
+ },
76
+ "leaf_meanings_tree": null,
77
+ "leaf_units_tree": null,
78
+ "structure_tree": {
79
+ "magnitudes": null
80
+ }
81
+ },
82
+ "reference_scaling_inputs": {
83
+ "max_tree": {
84
+ "parameters": [
85
+ 10.299357414245605,
86
+ 454.0,
87
+ 0.5952290296554565
88
+ ]
89
+ },
90
+ "min_tree": {
91
+ "parameters": [
92
+ 5.861983299255371,
93
+ 202.0,
94
+ -0.8797748684883118
95
+ ]
96
+ }
97
+ },
98
+ "reference_scaling_outputs": {
99
+ "max_tree": {
100
+ "magnitudes": [
101
+ 15.017570495605469,
102
+ 18.439416885375977,
103
+ 13.620195388793945
104
+ ]
105
+ },
106
+ "min_tree": {
107
+ "magnitudes": [
108
+ -2.3778717517852783,
109
+ -2.4398915767669678,
110
+ -2.2926206588745117
111
+ ]
112
+ }
113
+ }
114
  },
115
  "model": {
116
+ "init_hints": {
117
+ "input_last_axis": 3,
118
+ "output_last_axis": 3
119
+ },
120
  "name": "mlp",
121
  "params": {
122
  "activation": "gelu",
123
  "dtype": "float32",
124
  "hidden_sizes": [
125
+ 128,
126
+ 128,
127
+ 128
128
  ]
129
  }
130
  },
 
133
  "b2": 0.999,
134
  "decay_steps": 0,
135
  "eps": 1e-08,
136
+ "lr": 0.001,
137
  "name": "soap",
138
  "precondition_1d": false,
139
+ "precondition_frequency": 20,
140
  "schedule": "cosine",
141
+ "warmup_steps": 1000000,
142
  "weight_decay": 1e-05
143
  },
144
  "schema_version": 1,
145
  "seed": 0,
146
+ "solver": {
147
+ "name": "auto",
148
+ "params": {}
149
+ },
150
  "task": {
151
  "name": "regression",
152
  "params": {
 
163
  },
164
  "training": {
165
  "batch_size": 2048,
166
+ "checkpoint_interval_steps": null,
167
+ "checkpoint_steps": null,
168
+ "evaluation_interval_steps": 50000,
169
+ "evaluation_steps": null,
170
+ "logging_interval_steps": 10000,
171
+ "logging_steps": null,
172
+ "max_saved_checkpoints": 0,
173
+ "num_steps": 10000000,
174
  "shuffle": true,
175
  "shuffle_seed": 0,
176
  "steps_per_epoch": null,
177
  "val_fraction": 0.1,
178
+ "workdir": "./runs/from_bundle"
179
  }
180
  }
weights.safetensors → fingerprint_evaluation/inputs.safetensors RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:7916af20b55930166ba2f40866892ecf13cec5c9354d827ac44055d6797bf8e8
3
- size 35876
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e411d571388dd13d2c5c9f020386c14378f2a63e63a8a28182b3ac2a45f60e60
3
+ size 92
fingerprint_evaluation/outputs.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5b168542ec3041851b3ac0bf8dd278c0ccb3f53623dc7c12f6172305e6ed25a2
3
+ size 92
input_domain.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5d8acd2ed09e2a417ff6f83a4482d0aa88be08ad949144969dc5f2e3b9329c66
3
+ size 192
metadata.json CHANGED
@@ -1,23 +1,235 @@
1
  {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  "fit_method": "gradient",
3
- "schema_version": 1,
4
- "tool": "astro_emulators_toolkit",
5
- "x_schema": {
6
- "dim": 3,
7
- "names": [
8
- "age",
9
- "eep",
10
- "feh"
11
- ],
12
- "representation": "model-space inputs emitted directly by the dataset"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
  },
14
- "y_schema": {
15
- "dim": 3,
16
- "names": [
17
- "G_mag",
18
- "BP_mag",
19
- "RP_mag"
20
- ],
21
- "representation": "model-space targets emitted directly by the dataset"
22
- }
23
  }
 
1
  {
2
+ "bundle_format_version": 1,
3
+ "config_schema_version": 1,
4
+ "extras": {},
5
+ "fingerprint_evaluation": {
6
+ "atol": 1e-07,
7
+ "inputs": {
8
+ "filename": "fingerprint_evaluation/inputs.safetensors",
9
+ "format": "safetensors_v1",
10
+ "layout": "numeric_dict_tree_v1",
11
+ "space": "canonical_input_dict_trees_v1"
12
+ },
13
+ "kind": "canonical_inputs_outputs_v1",
14
+ "outputs": {
15
+ "filename": "fingerprint_evaluation/outputs.safetensors",
16
+ "format": "safetensors_v1",
17
+ "layout": "numeric_dict_tree_v1",
18
+ "space": "canonical_output_dict_trees_v1"
19
+ },
20
+ "rtol": 1e-05,
21
+ "selection_strategy": "midpoint_from_input_domain_then_reference_scaling_inputs_v1"
22
+ },
23
  "fit_method": "gradient",
24
+ "model_family_id": "mlp_v1",
25
+ "model_init": {
26
+ "hints": {
27
+ "input_last_axis": 3,
28
+ "output_last_axis": 3
29
+ },
30
+ "representation": "model-local init hints only"
31
+ },
32
+ "provenance": {
33
+ "created_at": "2026-04-21T18:10:12.535436+00:00",
34
+ "dependencies": {
35
+ "flax": "0.12.6",
36
+ "jax": "0.9.2",
37
+ "numpy": "2.4.4",
38
+ "optax": "0.2.8"
39
+ },
40
+ "git_commit": "b3415cfd04a48359232624dba9a1a746cf91313f",
41
+ "platform": "macOS-26.4.1-arm64-arm-64bit",
42
+ "python_version": "3.12.13",
43
+ "toolkit": "astro_emulators_toolkit",
44
+ "toolkit_version": "0.1.0"
45
+ },
46
+ "release": {
47
+ "name": "mist-isochrone-modern-mlp-10m-128x3",
48
+ "status": "released",
49
+ "version": "0.1.0-collab.1"
50
+ },
51
+ "resolved": {
52
+ "model": {
53
+ "name": "mlp",
54
+ "params": {
55
+ "activation": "gelu",
56
+ "dtype": "float32",
57
+ "hidden_sizes": [
58
+ 128,
59
+ 128,
60
+ 128
61
+ ],
62
+ "use_bias": true
63
+ }
64
+ },
65
+ "solver": {
66
+ "name": "gradient",
67
+ "params": {}
68
+ },
69
+ "task": {
70
+ "name": "regression",
71
+ "params": {
72
+ "loss": "mse",
73
+ "loss_weights": null,
74
+ "metric_axes": {
75
+ "global": "all",
76
+ "per_dim": []
77
+ },
78
+ "metrics": [
79
+ "mse",
80
+ "mae"
81
+ ]
82
+ }
83
+ }
84
+ },
85
+ "runtime_contract": {
86
+ "affine_leaf_specs": {
87
+ "inputs/parameters": {
88
+ "last_axis": 3,
89
+ "mode": "scalar_or_last_axis"
90
+ },
91
+ "outputs/magnitudes": {
92
+ "last_axis": 3,
93
+ "mode": "scalar_or_last_axis"
94
+ }
95
+ },
96
+ "role_paths": {
97
+ "input_leaf": "inputs/parameters",
98
+ "output_leaf": "outputs/magnitudes"
99
+ },
100
+ "surface": "canonical_dict_trees_v1"
101
+ },
102
+ "spec": {
103
+ "input_domain": {
104
+ "kind": "box_v1",
105
+ "max_tree": {
106
+ "parameters": [
107
+ 10.299357414245605,
108
+ 454.0,
109
+ 0.5952290296554565
110
+ ]
111
+ },
112
+ "min_tree": {
113
+ "parameters": [
114
+ 5.861983299255371,
115
+ 202.0,
116
+ -0.8797748684883118
117
+ ]
118
+ },
119
+ "storage": {
120
+ "filename": "input_domain.safetensors",
121
+ "format": "safetensors_v1",
122
+ "layout": "split_minmax_tree_v1"
123
+ },
124
+ "value_space": "physical_input_dict_tree_v1"
125
+ },
126
+ "inputs": {
127
+ "channel_meanings_tree": {
128
+ "parameters": [
129
+ "log10 stellar age in years",
130
+ "equivalent evolutionary phase",
131
+ "metallicity relative to solar"
132
+ ]
133
+ },
134
+ "channel_names_tree": {
135
+ "parameters": [
136
+ "age",
137
+ "eep",
138
+ "feh"
139
+ ]
140
+ },
141
+ "channel_units_tree": {
142
+ "parameters": [
143
+ "log10(age [yr])",
144
+ "",
145
+ "[Fe/H]"
146
+ ]
147
+ },
148
+ "leaf_meanings_tree": null,
149
+ "leaf_units_tree": null,
150
+ "structure_tree": {
151
+ "parameters": null
152
+ }
153
+ },
154
+ "outputs": {
155
+ "channel_meanings_tree": {
156
+ "magnitudes": [
157
+ "absolute Gaia G magnitude",
158
+ "absolute Gaia BP magnitude",
159
+ "absolute Gaia RP magnitude"
160
+ ]
161
+ },
162
+ "channel_names_tree": {
163
+ "magnitudes": [
164
+ "G_mag",
165
+ "BP_mag",
166
+ "RP_mag"
167
+ ]
168
+ },
169
+ "channel_units_tree": {
170
+ "magnitudes": [
171
+ "abs_mag",
172
+ "abs_mag",
173
+ "abs_mag"
174
+ ]
175
+ },
176
+ "leaf_meanings_tree": null,
177
+ "leaf_units_tree": null,
178
+ "structure_tree": {
179
+ "magnitudes": null
180
+ }
181
+ },
182
+ "reference_scaling_inputs": {
183
+ "applies_to": "inputs",
184
+ "kind": "affine_minmax_v1",
185
+ "max_tree": {
186
+ "parameters": [
187
+ 10.299357414245605,
188
+ 454.0,
189
+ 0.5952290296554565
190
+ ]
191
+ },
192
+ "min_tree": {
193
+ "parameters": [
194
+ 5.861983299255371,
195
+ 202.0,
196
+ -0.8797748684883118
197
+ ]
198
+ },
199
+ "source_space": "physical_input_dict_tree_v1",
200
+ "storage": {
201
+ "filename": "reference_scaling_inputs.safetensors",
202
+ "format": "safetensors_v1",
203
+ "layout": "split_minmax_tree_v1"
204
+ },
205
+ "target_space": "canonical_input_dict_tree_v1"
206
+ },
207
+ "reference_scaling_outputs": {
208
+ "applies_to": "outputs",
209
+ "kind": "affine_minmax_v1",
210
+ "max_tree": {
211
+ "magnitudes": [
212
+ 15.017570495605469,
213
+ 18.439416885375977,
214
+ 13.620195388793945
215
+ ]
216
+ },
217
+ "min_tree": {
218
+ "magnitudes": [
219
+ -2.3778717517852783,
220
+ -2.4398915767669678,
221
+ -2.2926206588745117
222
+ ]
223
+ },
224
+ "source_space": "canonical_output_dict_tree_v1",
225
+ "storage": {
226
+ "filename": "reference_scaling_outputs.safetensors",
227
+ "format": "safetensors_v1",
228
+ "layout": "split_minmax_tree_v1"
229
+ },
230
+ "target_space": "physical_output_dict_tree_v1"
231
+ },
232
+ "spec_version": 1
233
  },
234
+ "weights_layout": "params_plus_model_state_v1"
 
 
 
 
 
 
 
 
235
  }
reference_likelihood.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Quickstart: Gaussian likelihood for one Gaia photometric measurement.
2
+
3
+ This is a compact reference script for collaborators loading the released
4
+ isochrone emulator from Hugging Face. It follows the same pattern as the main
5
+ README examples:
6
+
7
+ 1. download a bundle with ``Emulator.from_pretrained(...)``;
8
+ 2. freeze a JAX callable with ``make_frozen_apply(jit=False)``;
9
+ 3. explicitly normalize physical inputs into the bundle's canonical space;
10
+ 4. explicitly denormalize canonical outputs back to physical magnitudes;
11
+ 5. evaluate a simple diagonal Gaussian log likelihood in one outer jit.
12
+
13
+ The input vector is ``[log10_age_yr, eep, feh]``. The measurement vector is
14
+ absolute ``[G_mag, BP_mag, RP_mag]``. If your data are apparent magnitudes,
15
+ include distance modulus, extinction, calibration offsets, or other nuisance
16
+ terms in your own likelihood around the emulator prediction.
17
+
18
+ Related examples in the source repository:
19
+
20
+ - examples/basic/02_load_bundle_predict.py
21
+ - examples/basic/04_use_bundle_in_map_fit.py
22
+ """
23
+
24
+ from __future__ import annotations
25
+
26
+ import jax
27
+ import jax.numpy as jnp
28
+ import numpy as np
29
+
30
+ from astro_emulators_toolkit import Emulator, denormalize_tree, normalize_tree
31
+
32
+ REPO_ID = "RozanskiT/isochrones-mlp"
33
+ REVISION = None
34
+ CACHE_DIR = ".emuspec_cache"
35
+
36
+ OUTPUT_LEAF = "magnitudes"
37
+ OUTPUT_CHANNELS = ("G_mag", "BP_mag", "RP_mag")
38
+
39
+ # One trial isochrone point: log10(age [yr]), EEP, [Fe/H].
40
+ THETA_PHYSICAL = np.asarray([9.4, 300.0, 0.0], dtype=np.float32)
41
+
42
+ # One example absolute photometric measurement in G, BP, RP.
43
+ # Replace these with your own absolute magnitudes and uncertainties.
44
+ OBSERVED_MAGNITUDES = np.asarray([6.94, 7.52, 6.21], dtype=np.float32)
45
+ OBSERVED_SIGMA_MAG = np.asarray([0.03, 0.03, 0.03], dtype=np.float32)
46
+
47
+
48
+ def main() -> None:
49
+ emu = Emulator.from_pretrained(
50
+ REPO_ID,
51
+ revision=REVISION,
52
+ cache_dir=CACHE_DIR,
53
+ verbose=True,
54
+ )
55
+ apply_magnitudes = emu.make_frozen_apply(jit=False)
56
+
57
+ ref_inputs = emu.reference_scaling_inputs
58
+ ref_outputs = emu.reference_scaling_outputs
59
+ if ref_inputs is None or ref_outputs is None:
60
+ raise ValueError(
61
+ "This likelihood example requires reference_scaling_inputs and "
62
+ "reference_scaling_outputs in the bundle metadata."
63
+ )
64
+
65
+ y_obs = jnp.asarray(OBSERVED_MAGNITUDES, dtype=jnp.float32)
66
+ y_err = jnp.asarray(OBSERVED_SIGMA_MAG, dtype=jnp.float32)
67
+
68
+ def predict_magnitudes(theta):
69
+ """Predict physical magnitudes; jit the outer objective, not this helper."""
70
+ x_physical = {"parameters": theta[None, :]}
71
+ x_scaled = normalize_tree(
72
+ x_physical,
73
+ ref_inputs["min_tree"],
74
+ ref_inputs["max_tree"],
75
+ )
76
+ y_scaled = apply_magnitudes(x_scaled)
77
+ y_physical = denormalize_tree(
78
+ y_scaled,
79
+ ref_outputs["min_tree"],
80
+ ref_outputs["max_tree"],
81
+ )
82
+ return y_physical[OUTPUT_LEAF][0]
83
+
84
+ @jax.jit
85
+ def evaluate_likelihood(theta):
86
+ y_model = predict_magnitudes(theta)
87
+ resid = (y_obs - y_model) / y_err
88
+ log_norm = jnp.sum(jnp.log(2.0 * jnp.pi * y_err**2))
89
+ log_likelihood = -0.5 * (jnp.sum(resid**2) + log_norm)
90
+ return y_model, log_likelihood
91
+
92
+ theta = jnp.asarray(THETA_PHYSICAL, dtype=jnp.float32)
93
+ model_magnitudes_jax, logp_jax = evaluate_likelihood(theta)
94
+ model_magnitudes = np.asarray(jax.block_until_ready(model_magnitudes_jax))
95
+ logp = float(jax.block_until_ready(logp_jax))
96
+
97
+ print("theta_physical [age, eep, feh]:", THETA_PHYSICAL.tolist())
98
+ print("model absolute magnitudes:")
99
+ for name, value in zip(OUTPUT_CHANNELS, model_magnitudes, strict=True):
100
+ print(f" {name}: {value:.6f}")
101
+ print("observed absolute magnitudes:")
102
+ for name, value in zip(OUTPUT_CHANNELS, OBSERVED_MAGNITUDES, strict=True):
103
+ print(f" {name}: {value:.6f}")
104
+ print("sigma_mag:", OBSERVED_SIGMA_MAG.tolist())
105
+ print("log_likelihood:", f"{logp:.6f}")
106
+
107
+
108
+ if __name__ == "__main__":
109
+ main()
reference_scaling_inputs.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5d8acd2ed09e2a417ff6f83a4482d0aa88be08ad949144969dc5f2e3b9329c66
3
+ size 192
reference_scaling_outputs.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d3420c76743566a2093846f876d0472c0da282e67ff09b8dab72b7b16c7bd5d1
3
+ size 192
weights/weights.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d407b20c90e930cd335ee6b3ef287114a1ae635c12323c126d3e51e043abda7b
3
+ size 136588