schattin commited on
Commit
771a91d
·
1 Parent(s): 0213ef9

feature(model): model configuration

Browse files
app.py ADDED
@@ -0,0 +1,357 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from pathlib import Path
3
+ from typing import Any, Dict, Iterable, List, Optional, Tuple
4
+
5
+ import gradio as gr
6
+ import numpy as np
7
+ import pandas as pd
8
+ import plotly.graph_objects as go
9
+ import torch
10
+ import yaml
11
+ from safetensors.torch import load_file
12
+ from sbi.neural_nets.factory import posterior_nn
13
+
14
+
15
+ MODELS_DIR = Path(__file__).resolve().parent / "models"
16
+ DISPERSION_CURVES_DIR = Path(__file__).resolve().parent / "disp_curves"
17
+ DEFAULT_CURVE_NONE_LABEL = "Upload custom curve"
18
+
19
+
20
+ @dataclass
21
+ class LoadedModel:
22
+ name: str
23
+ sampler: "PosteriorSampler"
24
+
25
+
26
+ class PosteriorSampler:
27
+ """Thin wrapper around the trained neural posterior for sampling."""
28
+
29
+ def __init__(self, weights_path: Path, config_path: Path, device: Optional[str] = None) -> None:
30
+ self.weights_path = weights_path
31
+ self.config = yaml.safe_load(config_path.read_text())
32
+
33
+ dataset_cfg = self.config.get("dataset", {})
34
+ model_cfg = self.config.get("model", {})
35
+ params_cfg = model_cfg.get("parameters", {})
36
+
37
+ self.context_dim = int(dataset_cfg["input_shape"])
38
+ self.theta_dim = int(dataset_cfg["output_shape"])
39
+
40
+ build_kwargs: Dict[str, int] = {}
41
+ for key in ("hidden_features", "num_transforms", "num_bins", "num_components"):
42
+ if key in params_cfg and params_cfg[key] is not None:
43
+ build_kwargs[key] = int(params_cfg[key])
44
+
45
+ density_estimator_builder = posterior_nn(
46
+ model=params_cfg.get("density_estimator", "nsf"),
47
+ z_score_theta=params_cfg.get("z_score_theta", "independent"),
48
+ z_score_x=params_cfg.get("z_score_x", "independent"),
49
+ **build_kwargs,
50
+ )
51
+
52
+ # Create a dummy network to load the trained parameters. The actual statistics
53
+ # (e.g. z-score buffers) are restored from the safetensors file.
54
+ theta_prototype = torch.zeros(2, self.theta_dim)
55
+ context_prototype = torch.zeros(2, self.context_dim)
56
+ net = density_estimator_builder(theta_prototype, context_prototype)
57
+
58
+ state_dict = load_file(str(weights_path))
59
+ net.load_state_dict(state_dict)
60
+ net.eval()
61
+
62
+ runtime_device = torch.device(device) if device else torch.device("cpu")
63
+ self.net = net.to(runtime_device)
64
+ self.device = runtime_device
65
+
66
+ def sample(self, context: np.ndarray, num_samples: int) -> np.ndarray:
67
+ with torch.no_grad():
68
+ context_tensor = torch.as_tensor(context, dtype=torch.float32, device=self.device).reshape(-1)
69
+ if context_tensor.numel() != self.context_dim:
70
+ raise ValueError(
71
+ f"Expected context with {self.context_dim} elements, received {context_tensor.numel()}."
72
+ )
73
+ samples = self.net.sample((num_samples,), context=context_tensor)
74
+ samples_np = samples.cpu().numpy()
75
+ if samples_np.ndim == 3:
76
+ samples_np = samples_np[:, 0, :]
77
+ elif samples_np.ndim != 2:
78
+ raise ValueError(f"Unexpected sample shape {samples_np.shape}.")
79
+ return samples_np
80
+
81
+
82
+ def discover_dispersion_curves(curves_dir: Path) -> Dict[str, Tuple[Path, Path]]:
83
+ if not curves_dir.exists():
84
+ return {}
85
+
86
+ discovered: Dict[str, Tuple[Path, Path]] = {}
87
+ for curve_path in sorted(curves_dir.glob("disp_curve_*.csv")):
88
+ suffix = curve_path.stem.split("disp_curve_")[-1]
89
+ theta_path = curves_dir / f"theta_{suffix}.csv"
90
+ display_name = f"Curve {suffix}"
91
+ discovered[display_name] = (curve_path, theta_path)
92
+
93
+ return discovered
94
+
95
+
96
+ PREDEFINED_DISPERSION_CURVES = discover_dispersion_curves(DISPERSION_CURVES_DIR)
97
+
98
+
99
+ def discover_models(models_dir: Path) -> List[LoadedModel]:
100
+ if not models_dir.exists():
101
+ raise FileNotFoundError(f"Expected models directory at {models_dir}")
102
+
103
+ discovered: List[LoadedModel] = []
104
+ for weights_path in sorted(models_dir.glob("*.safetensors")):
105
+ config_candidates = [
106
+ weights_path.with_suffix(".yaml"),
107
+ weights_path.with_suffix(".yml"),
108
+ models_dir / "config.yaml",
109
+ ]
110
+ config_path = next((path for path in config_candidates if path.exists()), None)
111
+ if not config_path:
112
+ raise FileNotFoundError(f"No configuration file found for {weights_path.name}")
113
+
114
+ sampler = PosteriorSampler(weights_path, config_path)
115
+ display_name = weights_path.stem.replace("_", " ").title()
116
+ discovered.append(LoadedModel(name=display_name, sampler=sampler))
117
+
118
+ if not discovered:
119
+ raise FileNotFoundError(f"No .safetensors models found in {models_dir}")
120
+
121
+ return discovered
122
+
123
+
124
+ class ModelRegistry:
125
+ def __init__(self, models: Iterable[LoadedModel]):
126
+ self._registry: Dict[str, PosteriorSampler] = {}
127
+ for item in models:
128
+ if item.name in self._registry:
129
+ raise ValueError(f"Duplicate model name detected: {item.name}")
130
+ self._registry[item.name] = item.sampler
131
+
132
+ @property
133
+ def names(self) -> List[str]:
134
+ return list(self._registry.keys())
135
+
136
+ def get(self, name: str) -> PosteriorSampler:
137
+ if name not in self._registry:
138
+ raise KeyError(f"Unknown model '{name}'")
139
+ return self._registry[name]
140
+
141
+
142
+ REGISTRY = ModelRegistry(discover_models(MODELS_DIR))
143
+ DEFAULT_MODEL_NAME = REGISTRY.names[0]
144
+
145
+
146
+ def load_predefined_dispersion_curve(name: str) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
147
+ if name not in PREDEFINED_DISPERSION_CURVES:
148
+ raise gr.Error("Unknown dispersion curve selection.")
149
+
150
+ curve_path, theta_path = PREDEFINED_DISPERSION_CURVES[name]
151
+ if not curve_path.exists():
152
+ raise gr.Error(f"Unable to find dispersion curve file at {curve_path}.")
153
+ if not theta_path.exists():
154
+ raise gr.Error(f"Unable to find theta file at {theta_path}.")
155
+
156
+ curve_df = pd.read_csv(curve_path)
157
+ if curve_df.shape[1] < 2:
158
+ raise gr.Error(f"Dispersion curve file {curve_path.name} must contain period and vg columns.")
159
+
160
+ periods = pd.to_numeric(curve_df.iloc[:, 0], errors="coerce").to_numpy(dtype=np.float32)
161
+ vg_values = pd.to_numeric(curve_df.iloc[:, 1], errors="coerce").to_numpy(dtype=np.float32)
162
+
163
+ theta_df = pd.read_csv(theta_path)
164
+ theta_values = pd.to_numeric(theta_df.to_numpy().reshape(-1), errors="coerce").astype(np.float32)
165
+ theta_values = theta_values[~np.isnan(theta_values)]
166
+
167
+ if periods.size != vg_values.size:
168
+ raise gr.Error(
169
+ f"Dispersion curve file {curve_path.name} contains mismatched period and vg counts."
170
+ )
171
+
172
+ if np.isnan(periods).any() or np.isnan(vg_values).any():
173
+ raise gr.Error(f"Dispersion curve file {curve_path.name} contains non-numeric entries.")
174
+
175
+ return periods, vg_values, theta_values
176
+
177
+
178
+ def read_dispersion_curve(upload: Optional[Any], expected_length: int) -> np.ndarray:
179
+ if upload is None:
180
+ raise gr.Error("Please upload a CSV file containing the dispersion curve.")
181
+
182
+ try:
183
+ df = pd.read_csv(upload.name, header=None)
184
+ except Exception as exc: # pylint: disable=broad-except
185
+ raise gr.Error(f"Unable to read CSV file: {exc}") from exc
186
+
187
+ numeric_values = pd.to_numeric(df.to_numpy().reshape(-1), errors="coerce").astype(np.float32)
188
+ numeric_values = numeric_values[~np.isnan(numeric_values)]
189
+
190
+ if numeric_values.size != expected_length:
191
+ raise gr.Error(
192
+ f"Expected {expected_length} values in the dispersion curve, but found {numeric_values.size}. "
193
+ "Please provide a CSV with exactly one value per frequency sample."
194
+ )
195
+
196
+ return numeric_values
197
+
198
+
199
+ def build_plot(samples: np.ndarray) -> go.Figure:
200
+ depth_axis = np.arange(1, samples.shape[1] + 1)
201
+ fig = go.Figure()
202
+ for idx, sample in enumerate(samples, start=1):
203
+ fig.add_trace(
204
+ go.Scatter(
205
+ x=depth_axis,
206
+ y=sample,
207
+ mode="lines",
208
+ name=f"Sample {idx}",
209
+ )
210
+ )
211
+
212
+ fig.update_layout(
213
+ xaxis_title="Layer index",
214
+ yaxis_title="Velocity",
215
+ legend_title="Generated samples",
216
+ template="plotly_white",
217
+ margin=dict(l=40, r=10, t=40, b=40),
218
+ )
219
+ return fig
220
+
221
+
222
+ def build_dispersion_plot(periods: np.ndarray, group_velocities: np.ndarray) -> go.Figure:
223
+ fig = go.Figure()
224
+ fig.add_trace(
225
+ go.Scatter(
226
+ x=periods,
227
+ y=group_velocities,
228
+ mode="lines+markers",
229
+ name="Dispersion curve",
230
+ )
231
+ )
232
+ fig.update_layout(
233
+ xaxis_title="Period",
234
+ yaxis_title="Group velocity",
235
+ template="plotly_white",
236
+ margin=dict(l=40, r=10, t=40, b=40),
237
+ showlegend=False,
238
+ )
239
+ return fig
240
+
241
+
242
+ def handle_predefined_curve_selection(selection: Optional[str]) -> Tuple[Any, Optional[np.ndarray], Optional[np.ndarray]]:
243
+ if not selection or selection == DEFAULT_CURVE_NONE_LABEL:
244
+ return gr.update(value=None), None, None
245
+
246
+ periods, vg_values, theta_values = load_predefined_dispersion_curve(selection)
247
+ figure = build_dispersion_plot(periods, vg_values)
248
+ return figure, vg_values, theta_values
249
+
250
+
251
+ def format_samples(samples: np.ndarray) -> pd.DataFrame:
252
+ index = [f"Layer {i}" for i in range(1, samples.shape[1] + 1)]
253
+ columns = [f"Sample {idx}" for idx in range(1, samples.shape[0] + 1)]
254
+ return pd.DataFrame(samples.T, index=index, columns=columns)
255
+
256
+
257
+ def generate_velocity_models(
258
+ upload: Optional[Any],
259
+ model_name: str,
260
+ num_samples: int,
261
+ predefined_curve_name: Optional[str],
262
+ predefined_vg: Optional[np.ndarray],
263
+ _preloaded_theta: Optional[np.ndarray],
264
+ ) -> Tuple[go.Figure, pd.DataFrame]:
265
+ sampler = REGISTRY.get(model_name)
266
+ dispersion_curve: Optional[np.ndarray] = None
267
+
268
+ if predefined_curve_name and predefined_curve_name != DEFAULT_CURVE_NONE_LABEL:
269
+ if predefined_vg is None:
270
+ # Reload from disk if the state is empty for any reason.
271
+ _, vg_values, _ = load_predefined_dispersion_curve(predefined_curve_name)
272
+ predefined_vg = vg_values
273
+ dispersion_curve = np.asarray(predefined_vg, dtype=np.float32)
274
+ else:
275
+ dispersion_curve = read_dispersion_curve(upload, sampler.context_dim)
276
+
277
+ if dispersion_curve.size != sampler.context_dim:
278
+ raise gr.Error(
279
+ f"The selected dispersion curve contains {dispersion_curve.size} samples, "
280
+ f"but the posterior expects {sampler.context_dim}."
281
+ )
282
+
283
+ samples = sampler.sample(dispersion_curve, int(num_samples))
284
+
285
+ return build_plot(samples), format_samples(samples)
286
+
287
+
288
+ with gr.Blocks(title="Surface Wave Inversion with NPE") as demo:
289
+ default_curve_choices = [DEFAULT_CURVE_NONE_LABEL] + list(PREDEFINED_DISPERSION_CURVES.keys())
290
+ selected_vg_state = gr.State(value=None)
291
+ selected_theta_state = gr.State(value=None)
292
+
293
+ gr.Markdown(
294
+ "## Neural Posterior Estimation for Surface Wave Inversion\n"
295
+ "Select a built-in dispersion curve or upload your own, then choose a pretrained posterior model "
296
+ "to draw samples of the subsurface velocity structure."
297
+ )
298
+
299
+ with gr.Row():
300
+ with gr.Column(scale=1):
301
+ default_curve_choice = gr.Dropdown(
302
+ label="Default dispersion curve",
303
+ choices=default_curve_choices,
304
+ value=DEFAULT_CURVE_NONE_LABEL,
305
+ interactive=len(default_curve_choices) > 1,
306
+ info="Pick a built-in curve or stay on Upload custom curve to provide your own file.",
307
+ )
308
+ curve_input = gr.File(
309
+ label="Dispersion curve (.csv)",
310
+ file_types=[".csv"],
311
+ )
312
+ model_choice = gr.Dropdown(
313
+ label="Posterior model",
314
+ choices=REGISTRY.names,
315
+ value=DEFAULT_MODEL_NAME,
316
+ )
317
+ sample_count = gr.Slider(
318
+ label="Number of samples",
319
+ minimum=1,
320
+ maximum=200,
321
+ value=20,
322
+ step=1,
323
+ )
324
+ generate_btn = gr.Button("Generate velocity models", variant="primary")
325
+
326
+ with gr.Column(scale=1):
327
+ dispersion_plot = gr.Plot(label="Selected dispersion curve")
328
+ plot_output = gr.Plot(label="Sampled velocity profiles")
329
+ table_output = gr.Dataframe(
330
+ headers=[f"Sample {idx}" for idx in range(1, 6)],
331
+ datatype="number",
332
+ interactive=False,
333
+ label="Sample values",
334
+ )
335
+
336
+ default_curve_choice.change(
337
+ handle_predefined_curve_selection,
338
+ inputs=default_curve_choice,
339
+ outputs=[dispersion_plot, selected_vg_state, selected_theta_state],
340
+ )
341
+
342
+ generate_btn.click(
343
+ generate_velocity_models,
344
+ inputs=[
345
+ curve_input,
346
+ model_choice,
347
+ sample_count,
348
+ default_curve_choice,
349
+ selected_vg_state,
350
+ selected_theta_state,
351
+ ],
352
+ outputs=[plot_output, table_output],
353
+ )
354
+
355
+
356
+ if __name__ == "__main__":
357
+ demo.launch()
disp_curves/disp_curve_01.csv ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ period,vg
2
+ 1.0000,1.4450
3
+ 1.0374,1.4460
4
+ 1.0748,1.4469
5
+ 1.1121,1.4476
6
+ 1.1495,1.4482
7
+ 1.1869,1.4487
8
+ 1.2243,1.4490
9
+ 1.2617,1.4493
10
+ 1.2991,1.4494
11
+ 1.3364,1.4495
12
+ 1.3738,1.4494
13
+ 1.4112,1.4494
14
+ 1.4486,1.4497
15
+ 1.4860,1.4501
16
+ 1.5234,1.4504
17
+ 1.5607,1.4509
18
+ 1.5981,1.4516
19
+ 1.6355,1.4525
20
+ 1.6729,1.4536
21
+ 1.7103,1.4553
22
+ 1.7477,1.4571
23
+ 1.7850,1.4596
24
+ 1.8224,1.4624
25
+ 1.8598,1.4663
26
+ 1.8972,1.4709
27
+ 1.9346,1.4767
28
+ 1.9720,1.4842
29
+ 2.0093,1.4929
30
+ 2.0467,1.5059
31
+ 2.0841,1.5201
32
+ 2.1215,1.5466
33
+ 2.1589,1.5816
34
+ 2.1963,1.6574
35
+ 2.2336,1.6195
36
+ 2.2710,1.3761
37
+ 2.3084,1.2702
38
+ 2.3458,1.2804
39
+ 2.3832,1.2859
40
+ 2.4206,1.2897
41
+ 2.4579,1.2911
42
+ 2.4953,1.2922
43
+ 2.5327,1.2921
44
+ 2.5701,1.2918
45
+ 2.6075,1.2909
46
+ 2.6449,1.2899
47
+ 2.6822,1.2886
48
+ 2.7196,1.2871
49
+ 2.7570,1.2855
50
+ 2.7944,1.2837
51
+ 2.8318,1.2819
52
+ 2.8692,1.2799
53
+ 2.9065,1.2780
54
+ 2.9439,1.2759
55
+ 2.9813,1.2739
56
+ 3.0187,1.2718
57
+ 3.0561,1.2696
58
+ 3.0935,1.2675
59
+ 3.1308,1.2654
60
+ 3.1682,1.2633
61
+ 3.2056,1.2611
62
+ 3.2430,1.2590
63
+ 3.2804,1.2569
64
+ 3.3178,1.2548
65
+ 3.3551,1.2527
66
+ 3.3925,1.2506
67
+ 3.4299,1.2485
68
+ 3.4673,1.2465
69
+ 3.5047,1.2444
70
+ 3.5421,1.2424
71
+ 3.5794,1.2404
72
+ 3.6168,1.2385
73
+ 3.6542,1.2365
74
+ 3.6916,1.2346
75
+ 3.7290,1.2327
76
+ 3.7664,1.2308
77
+ 3.8037,1.2290
78
+ 3.8411,1.2272
79
+ 3.8785,1.2254
80
+ 3.9159,1.2237
81
+ 3.9533,1.2219
82
+ 3.9907,1.2202
83
+ 4.0280,1.2185
84
+ 4.0654,1.2168
85
+ 4.1028,1.2152
86
+ 4.1402,1.2135
87
+ 4.1776,1.2119
88
+ 4.2149,1.2103
89
+ 4.2523,1.2088
90
+ 4.2897,1.2073
91
+ 4.3271,1.2058
92
+ 4.3645,1.2043
93
+ 4.4019,1.2029
94
+ 4.4393,1.2014
95
+ 4.4766,1.2000
96
+ 4.5140,1.1987
97
+ 4.5514,1.1973
98
+ 4.5888,1.1960
99
+ 4.6262,1.1947
100
+ 4.6636,1.1934
101
+ 4.7009,1.1922
102
+ 4.7383,1.1909
103
+ 4.7757,1.1897
104
+ 4.8131,1.1886
105
+ 4.8505,1.1874
106
+ 4.8879,1.1862
107
+ 4.9252,1.1851
108
+ 4.9626,1.1840
109
+ 5.0000,1.1830
disp_curves/theta_01.csv ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ theta
2
+ 2.0000
3
+ 1.7500
4
+ 0.9000
5
+ 1.2000
6
+ 0.3400
7
+ 0.0000
8
+ 1.5000
9
+ 1.2000
10
+ 2.5000
11
+ 3.5000
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ gradio==5.49.1
2
+ pandas
3
+ plotly
4
+ pyyaml
5
+ safetensors
6
+ sbi
7
+ torch