schattin's picture
feature(model): model configuration
771a91d
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, Iterable, List, Optional, Tuple
import gradio as gr
import numpy as np
import pandas as pd
import plotly.graph_objects as go
import torch
import yaml
from safetensors.torch import load_file
from sbi.neural_nets.factory import posterior_nn
MODELS_DIR = Path(__file__).resolve().parent / "models"
DISPERSION_CURVES_DIR = Path(__file__).resolve().parent / "disp_curves"
DEFAULT_CURVE_NONE_LABEL = "Upload custom curve"
@dataclass
class LoadedModel:
name: str
sampler: "PosteriorSampler"
class PosteriorSampler:
"""Thin wrapper around the trained neural posterior for sampling."""
def __init__(self, weights_path: Path, config_path: Path, device: Optional[str] = None) -> None:
self.weights_path = weights_path
self.config = yaml.safe_load(config_path.read_text())
dataset_cfg = self.config.get("dataset", {})
model_cfg = self.config.get("model", {})
params_cfg = model_cfg.get("parameters", {})
self.context_dim = int(dataset_cfg["input_shape"])
self.theta_dim = int(dataset_cfg["output_shape"])
build_kwargs: Dict[str, int] = {}
for key in ("hidden_features", "num_transforms", "num_bins", "num_components"):
if key in params_cfg and params_cfg[key] is not None:
build_kwargs[key] = int(params_cfg[key])
density_estimator_builder = posterior_nn(
model=params_cfg.get("density_estimator", "nsf"),
z_score_theta=params_cfg.get("z_score_theta", "independent"),
z_score_x=params_cfg.get("z_score_x", "independent"),
**build_kwargs,
)
# Create a dummy network to load the trained parameters. The actual statistics
# (e.g. z-score buffers) are restored from the safetensors file.
theta_prototype = torch.zeros(2, self.theta_dim)
context_prototype = torch.zeros(2, self.context_dim)
net = density_estimator_builder(theta_prototype, context_prototype)
state_dict = load_file(str(weights_path))
net.load_state_dict(state_dict)
net.eval()
runtime_device = torch.device(device) if device else torch.device("cpu")
self.net = net.to(runtime_device)
self.device = runtime_device
def sample(self, context: np.ndarray, num_samples: int) -> np.ndarray:
with torch.no_grad():
context_tensor = torch.as_tensor(context, dtype=torch.float32, device=self.device).reshape(-1)
if context_tensor.numel() != self.context_dim:
raise ValueError(
f"Expected context with {self.context_dim} elements, received {context_tensor.numel()}."
)
samples = self.net.sample((num_samples,), context=context_tensor)
samples_np = samples.cpu().numpy()
if samples_np.ndim == 3:
samples_np = samples_np[:, 0, :]
elif samples_np.ndim != 2:
raise ValueError(f"Unexpected sample shape {samples_np.shape}.")
return samples_np
def discover_dispersion_curves(curves_dir: Path) -> Dict[str, Tuple[Path, Path]]:
if not curves_dir.exists():
return {}
discovered: Dict[str, Tuple[Path, Path]] = {}
for curve_path in sorted(curves_dir.glob("disp_curve_*.csv")):
suffix = curve_path.stem.split("disp_curve_")[-1]
theta_path = curves_dir / f"theta_{suffix}.csv"
display_name = f"Curve {suffix}"
discovered[display_name] = (curve_path, theta_path)
return discovered
PREDEFINED_DISPERSION_CURVES = discover_dispersion_curves(DISPERSION_CURVES_DIR)
def discover_models(models_dir: Path) -> List[LoadedModel]:
if not models_dir.exists():
raise FileNotFoundError(f"Expected models directory at {models_dir}")
discovered: List[LoadedModel] = []
for weights_path in sorted(models_dir.glob("*.safetensors")):
config_candidates = [
weights_path.with_suffix(".yaml"),
weights_path.with_suffix(".yml"),
models_dir / "config.yaml",
]
config_path = next((path for path in config_candidates if path.exists()), None)
if not config_path:
raise FileNotFoundError(f"No configuration file found for {weights_path.name}")
sampler = PosteriorSampler(weights_path, config_path)
display_name = weights_path.stem.replace("_", " ").title()
discovered.append(LoadedModel(name=display_name, sampler=sampler))
if not discovered:
raise FileNotFoundError(f"No .safetensors models found in {models_dir}")
return discovered
class ModelRegistry:
def __init__(self, models: Iterable[LoadedModel]):
self._registry: Dict[str, PosteriorSampler] = {}
for item in models:
if item.name in self._registry:
raise ValueError(f"Duplicate model name detected: {item.name}")
self._registry[item.name] = item.sampler
@property
def names(self) -> List[str]:
return list(self._registry.keys())
def get(self, name: str) -> PosteriorSampler:
if name not in self._registry:
raise KeyError(f"Unknown model '{name}'")
return self._registry[name]
REGISTRY = ModelRegistry(discover_models(MODELS_DIR))
DEFAULT_MODEL_NAME = REGISTRY.names[0]
def load_predefined_dispersion_curve(name: str) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
if name not in PREDEFINED_DISPERSION_CURVES:
raise gr.Error("Unknown dispersion curve selection.")
curve_path, theta_path = PREDEFINED_DISPERSION_CURVES[name]
if not curve_path.exists():
raise gr.Error(f"Unable to find dispersion curve file at {curve_path}.")
if not theta_path.exists():
raise gr.Error(f"Unable to find theta file at {theta_path}.")
curve_df = pd.read_csv(curve_path)
if curve_df.shape[1] < 2:
raise gr.Error(f"Dispersion curve file {curve_path.name} must contain period and vg columns.")
periods = pd.to_numeric(curve_df.iloc[:, 0], errors="coerce").to_numpy(dtype=np.float32)
vg_values = pd.to_numeric(curve_df.iloc[:, 1], errors="coerce").to_numpy(dtype=np.float32)
theta_df = pd.read_csv(theta_path)
theta_values = pd.to_numeric(theta_df.to_numpy().reshape(-1), errors="coerce").astype(np.float32)
theta_values = theta_values[~np.isnan(theta_values)]
if periods.size != vg_values.size:
raise gr.Error(
f"Dispersion curve file {curve_path.name} contains mismatched period and vg counts."
)
if np.isnan(periods).any() or np.isnan(vg_values).any():
raise gr.Error(f"Dispersion curve file {curve_path.name} contains non-numeric entries.")
return periods, vg_values, theta_values
def read_dispersion_curve(upload: Optional[Any], expected_length: int) -> np.ndarray:
if upload is None:
raise gr.Error("Please upload a CSV file containing the dispersion curve.")
try:
df = pd.read_csv(upload.name, header=None)
except Exception as exc: # pylint: disable=broad-except
raise gr.Error(f"Unable to read CSV file: {exc}") from exc
numeric_values = pd.to_numeric(df.to_numpy().reshape(-1), errors="coerce").astype(np.float32)
numeric_values = numeric_values[~np.isnan(numeric_values)]
if numeric_values.size != expected_length:
raise gr.Error(
f"Expected {expected_length} values in the dispersion curve, but found {numeric_values.size}. "
"Please provide a CSV with exactly one value per frequency sample."
)
return numeric_values
def build_plot(samples: np.ndarray) -> go.Figure:
depth_axis = np.arange(1, samples.shape[1] + 1)
fig = go.Figure()
for idx, sample in enumerate(samples, start=1):
fig.add_trace(
go.Scatter(
x=depth_axis,
y=sample,
mode="lines",
name=f"Sample {idx}",
)
)
fig.update_layout(
xaxis_title="Layer index",
yaxis_title="Velocity",
legend_title="Generated samples",
template="plotly_white",
margin=dict(l=40, r=10, t=40, b=40),
)
return fig
def build_dispersion_plot(periods: np.ndarray, group_velocities: np.ndarray) -> go.Figure:
fig = go.Figure()
fig.add_trace(
go.Scatter(
x=periods,
y=group_velocities,
mode="lines+markers",
name="Dispersion curve",
)
)
fig.update_layout(
xaxis_title="Period",
yaxis_title="Group velocity",
template="plotly_white",
margin=dict(l=40, r=10, t=40, b=40),
showlegend=False,
)
return fig
def handle_predefined_curve_selection(selection: Optional[str]) -> Tuple[Any, Optional[np.ndarray], Optional[np.ndarray]]:
if not selection or selection == DEFAULT_CURVE_NONE_LABEL:
return gr.update(value=None), None, None
periods, vg_values, theta_values = load_predefined_dispersion_curve(selection)
figure = build_dispersion_plot(periods, vg_values)
return figure, vg_values, theta_values
def format_samples(samples: np.ndarray) -> pd.DataFrame:
index = [f"Layer {i}" for i in range(1, samples.shape[1] + 1)]
columns = [f"Sample {idx}" for idx in range(1, samples.shape[0] + 1)]
return pd.DataFrame(samples.T, index=index, columns=columns)
def generate_velocity_models(
upload: Optional[Any],
model_name: str,
num_samples: int,
predefined_curve_name: Optional[str],
predefined_vg: Optional[np.ndarray],
_preloaded_theta: Optional[np.ndarray],
) -> Tuple[go.Figure, pd.DataFrame]:
sampler = REGISTRY.get(model_name)
dispersion_curve: Optional[np.ndarray] = None
if predefined_curve_name and predefined_curve_name != DEFAULT_CURVE_NONE_LABEL:
if predefined_vg is None:
# Reload from disk if the state is empty for any reason.
_, vg_values, _ = load_predefined_dispersion_curve(predefined_curve_name)
predefined_vg = vg_values
dispersion_curve = np.asarray(predefined_vg, dtype=np.float32)
else:
dispersion_curve = read_dispersion_curve(upload, sampler.context_dim)
if dispersion_curve.size != sampler.context_dim:
raise gr.Error(
f"The selected dispersion curve contains {dispersion_curve.size} samples, "
f"but the posterior expects {sampler.context_dim}."
)
samples = sampler.sample(dispersion_curve, int(num_samples))
return build_plot(samples), format_samples(samples)
with gr.Blocks(title="Surface Wave Inversion with NPE") as demo:
default_curve_choices = [DEFAULT_CURVE_NONE_LABEL] + list(PREDEFINED_DISPERSION_CURVES.keys())
selected_vg_state = gr.State(value=None)
selected_theta_state = gr.State(value=None)
gr.Markdown(
"## Neural Posterior Estimation for Surface Wave Inversion\n"
"Select a built-in dispersion curve or upload your own, then choose a pretrained posterior model "
"to draw samples of the subsurface velocity structure."
)
with gr.Row():
with gr.Column(scale=1):
default_curve_choice = gr.Dropdown(
label="Default dispersion curve",
choices=default_curve_choices,
value=DEFAULT_CURVE_NONE_LABEL,
interactive=len(default_curve_choices) > 1,
info="Pick a built-in curve or stay on Upload custom curve to provide your own file.",
)
curve_input = gr.File(
label="Dispersion curve (.csv)",
file_types=[".csv"],
)
model_choice = gr.Dropdown(
label="Posterior model",
choices=REGISTRY.names,
value=DEFAULT_MODEL_NAME,
)
sample_count = gr.Slider(
label="Number of samples",
minimum=1,
maximum=200,
value=20,
step=1,
)
generate_btn = gr.Button("Generate velocity models", variant="primary")
with gr.Column(scale=1):
dispersion_plot = gr.Plot(label="Selected dispersion curve")
plot_output = gr.Plot(label="Sampled velocity profiles")
table_output = gr.Dataframe(
headers=[f"Sample {idx}" for idx in range(1, 6)],
datatype="number",
interactive=False,
label="Sample values",
)
default_curve_choice.change(
handle_predefined_curve_selection,
inputs=default_curve_choice,
outputs=[dispersion_plot, selected_vg_state, selected_theta_state],
)
generate_btn.click(
generate_velocity_models,
inputs=[
curve_input,
model_choice,
sample_count,
default_curve_choice,
selected_vg_state,
selected_theta_state,
],
outputs=[plot_output, table_output],
)
if __name__ == "__main__":
demo.launch()