from dataclasses import dataclass from pathlib import Path from typing import Any, Dict, Iterable, List, Optional, Tuple import gradio as gr import numpy as np import pandas as pd import plotly.graph_objects as go import torch import yaml from safetensors.torch import load_file from sbi.neural_nets.factory import posterior_nn MODELS_DIR = Path(__file__).resolve().parent / "models" DISPERSION_CURVES_DIR = Path(__file__).resolve().parent / "disp_curves" DEFAULT_CURVE_NONE_LABEL = "Upload custom curve" @dataclass class LoadedModel: name: str sampler: "PosteriorSampler" class PosteriorSampler: """Thin wrapper around the trained neural posterior for sampling.""" def __init__(self, weights_path: Path, config_path: Path, device: Optional[str] = None) -> None: self.weights_path = weights_path self.config = yaml.safe_load(config_path.read_text()) dataset_cfg = self.config.get("dataset", {}) model_cfg = self.config.get("model", {}) params_cfg = model_cfg.get("parameters", {}) self.context_dim = int(dataset_cfg["input_shape"]) self.theta_dim = int(dataset_cfg["output_shape"]) build_kwargs: Dict[str, int] = {} for key in ("hidden_features", "num_transforms", "num_bins", "num_components"): if key in params_cfg and params_cfg[key] is not None: build_kwargs[key] = int(params_cfg[key]) density_estimator_builder = posterior_nn( model=params_cfg.get("density_estimator", "nsf"), z_score_theta=params_cfg.get("z_score_theta", "independent"), z_score_x=params_cfg.get("z_score_x", "independent"), **build_kwargs, ) # Create a dummy network to load the trained parameters. The actual statistics # (e.g. z-score buffers) are restored from the safetensors file. theta_prototype = torch.zeros(2, self.theta_dim) context_prototype = torch.zeros(2, self.context_dim) net = density_estimator_builder(theta_prototype, context_prototype) state_dict = load_file(str(weights_path)) net.load_state_dict(state_dict) net.eval() runtime_device = torch.device(device) if device else torch.device("cpu") self.net = net.to(runtime_device) self.device = runtime_device def sample(self, context: np.ndarray, num_samples: int) -> np.ndarray: with torch.no_grad(): context_tensor = torch.as_tensor(context, dtype=torch.float32, device=self.device).reshape(-1) if context_tensor.numel() != self.context_dim: raise ValueError( f"Expected context with {self.context_dim} elements, received {context_tensor.numel()}." ) samples = self.net.sample((num_samples,), context=context_tensor) samples_np = samples.cpu().numpy() if samples_np.ndim == 3: samples_np = samples_np[:, 0, :] elif samples_np.ndim != 2: raise ValueError(f"Unexpected sample shape {samples_np.shape}.") return samples_np def discover_dispersion_curves(curves_dir: Path) -> Dict[str, Tuple[Path, Path]]: if not curves_dir.exists(): return {} discovered: Dict[str, Tuple[Path, Path]] = {} for curve_path in sorted(curves_dir.glob("disp_curve_*.csv")): suffix = curve_path.stem.split("disp_curve_")[-1] theta_path = curves_dir / f"theta_{suffix}.csv" display_name = f"Curve {suffix}" discovered[display_name] = (curve_path, theta_path) return discovered PREDEFINED_DISPERSION_CURVES = discover_dispersion_curves(DISPERSION_CURVES_DIR) def discover_models(models_dir: Path) -> List[LoadedModel]: if not models_dir.exists(): raise FileNotFoundError(f"Expected models directory at {models_dir}") discovered: List[LoadedModel] = [] for weights_path in sorted(models_dir.glob("*.safetensors")): config_candidates = [ weights_path.with_suffix(".yaml"), weights_path.with_suffix(".yml"), models_dir / "config.yaml", ] config_path = next((path for path in config_candidates if path.exists()), None) if not config_path: raise FileNotFoundError(f"No configuration file found for {weights_path.name}") sampler = PosteriorSampler(weights_path, config_path) display_name = weights_path.stem.replace("_", " ").title() discovered.append(LoadedModel(name=display_name, sampler=sampler)) if not discovered: raise FileNotFoundError(f"No .safetensors models found in {models_dir}") return discovered class ModelRegistry: def __init__(self, models: Iterable[LoadedModel]): self._registry: Dict[str, PosteriorSampler] = {} for item in models: if item.name in self._registry: raise ValueError(f"Duplicate model name detected: {item.name}") self._registry[item.name] = item.sampler @property def names(self) -> List[str]: return list(self._registry.keys()) def get(self, name: str) -> PosteriorSampler: if name not in self._registry: raise KeyError(f"Unknown model '{name}'") return self._registry[name] REGISTRY = ModelRegistry(discover_models(MODELS_DIR)) DEFAULT_MODEL_NAME = REGISTRY.names[0] def load_predefined_dispersion_curve(name: str) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: if name not in PREDEFINED_DISPERSION_CURVES: raise gr.Error("Unknown dispersion curve selection.") curve_path, theta_path = PREDEFINED_DISPERSION_CURVES[name] if not curve_path.exists(): raise gr.Error(f"Unable to find dispersion curve file at {curve_path}.") if not theta_path.exists(): raise gr.Error(f"Unable to find theta file at {theta_path}.") curve_df = pd.read_csv(curve_path) if curve_df.shape[1] < 2: raise gr.Error(f"Dispersion curve file {curve_path.name} must contain period and vg columns.") periods = pd.to_numeric(curve_df.iloc[:, 0], errors="coerce").to_numpy(dtype=np.float32) vg_values = pd.to_numeric(curve_df.iloc[:, 1], errors="coerce").to_numpy(dtype=np.float32) theta_df = pd.read_csv(theta_path) theta_values = pd.to_numeric(theta_df.to_numpy().reshape(-1), errors="coerce").astype(np.float32) theta_values = theta_values[~np.isnan(theta_values)] if periods.size != vg_values.size: raise gr.Error( f"Dispersion curve file {curve_path.name} contains mismatched period and vg counts." ) if np.isnan(periods).any() or np.isnan(vg_values).any(): raise gr.Error(f"Dispersion curve file {curve_path.name} contains non-numeric entries.") return periods, vg_values, theta_values def read_dispersion_curve(upload: Optional[Any], expected_length: int) -> np.ndarray: if upload is None: raise gr.Error("Please upload a CSV file containing the dispersion curve.") try: df = pd.read_csv(upload.name, header=None) except Exception as exc: # pylint: disable=broad-except raise gr.Error(f"Unable to read CSV file: {exc}") from exc numeric_values = pd.to_numeric(df.to_numpy().reshape(-1), errors="coerce").astype(np.float32) numeric_values = numeric_values[~np.isnan(numeric_values)] if numeric_values.size != expected_length: raise gr.Error( f"Expected {expected_length} values in the dispersion curve, but found {numeric_values.size}. " "Please provide a CSV with exactly one value per frequency sample." ) return numeric_values def build_plot(samples: np.ndarray) -> go.Figure: depth_axis = np.arange(1, samples.shape[1] + 1) fig = go.Figure() for idx, sample in enumerate(samples, start=1): fig.add_trace( go.Scatter( x=depth_axis, y=sample, mode="lines", name=f"Sample {idx}", ) ) fig.update_layout( xaxis_title="Layer index", yaxis_title="Velocity", legend_title="Generated samples", template="plotly_white", margin=dict(l=40, r=10, t=40, b=40), ) return fig def build_dispersion_plot(periods: np.ndarray, group_velocities: np.ndarray) -> go.Figure: fig = go.Figure() fig.add_trace( go.Scatter( x=periods, y=group_velocities, mode="lines+markers", name="Dispersion curve", ) ) fig.update_layout( xaxis_title="Period", yaxis_title="Group velocity", template="plotly_white", margin=dict(l=40, r=10, t=40, b=40), showlegend=False, ) return fig def handle_predefined_curve_selection(selection: Optional[str]) -> Tuple[Any, Optional[np.ndarray], Optional[np.ndarray]]: if not selection or selection == DEFAULT_CURVE_NONE_LABEL: return gr.update(value=None), None, None periods, vg_values, theta_values = load_predefined_dispersion_curve(selection) figure = build_dispersion_plot(periods, vg_values) return figure, vg_values, theta_values def format_samples(samples: np.ndarray) -> pd.DataFrame: index = [f"Layer {i}" for i in range(1, samples.shape[1] + 1)] columns = [f"Sample {idx}" for idx in range(1, samples.shape[0] + 1)] return pd.DataFrame(samples.T, index=index, columns=columns) def generate_velocity_models( upload: Optional[Any], model_name: str, num_samples: int, predefined_curve_name: Optional[str], predefined_vg: Optional[np.ndarray], _preloaded_theta: Optional[np.ndarray], ) -> Tuple[go.Figure, pd.DataFrame]: sampler = REGISTRY.get(model_name) dispersion_curve: Optional[np.ndarray] = None if predefined_curve_name and predefined_curve_name != DEFAULT_CURVE_NONE_LABEL: if predefined_vg is None: # Reload from disk if the state is empty for any reason. _, vg_values, _ = load_predefined_dispersion_curve(predefined_curve_name) predefined_vg = vg_values dispersion_curve = np.asarray(predefined_vg, dtype=np.float32) else: dispersion_curve = read_dispersion_curve(upload, sampler.context_dim) if dispersion_curve.size != sampler.context_dim: raise gr.Error( f"The selected dispersion curve contains {dispersion_curve.size} samples, " f"but the posterior expects {sampler.context_dim}." ) samples = sampler.sample(dispersion_curve, int(num_samples)) return build_plot(samples), format_samples(samples) with gr.Blocks(title="Surface Wave Inversion with NPE") as demo: default_curve_choices = [DEFAULT_CURVE_NONE_LABEL] + list(PREDEFINED_DISPERSION_CURVES.keys()) selected_vg_state = gr.State(value=None) selected_theta_state = gr.State(value=None) gr.Markdown( "## Neural Posterior Estimation for Surface Wave Inversion\n" "Select a built-in dispersion curve or upload your own, then choose a pretrained posterior model " "to draw samples of the subsurface velocity structure." ) with gr.Row(): with gr.Column(scale=1): default_curve_choice = gr.Dropdown( label="Default dispersion curve", choices=default_curve_choices, value=DEFAULT_CURVE_NONE_LABEL, interactive=len(default_curve_choices) > 1, info="Pick a built-in curve or stay on Upload custom curve to provide your own file.", ) curve_input = gr.File( label="Dispersion curve (.csv)", file_types=[".csv"], ) model_choice = gr.Dropdown( label="Posterior model", choices=REGISTRY.names, value=DEFAULT_MODEL_NAME, ) sample_count = gr.Slider( label="Number of samples", minimum=1, maximum=200, value=20, step=1, ) generate_btn = gr.Button("Generate velocity models", variant="primary") with gr.Column(scale=1): dispersion_plot = gr.Plot(label="Selected dispersion curve") plot_output = gr.Plot(label="Sampled velocity profiles") table_output = gr.Dataframe( headers=[f"Sample {idx}" for idx in range(1, 6)], datatype="number", interactive=False, label="Sample values", ) default_curve_choice.change( handle_predefined_curve_selection, inputs=default_curve_choice, outputs=[dispersion_plot, selected_vg_state, selected_theta_state], ) generate_btn.click( generate_velocity_models, inputs=[ curve_input, model_choice, sample_count, default_curve_choice, selected_vg_state, selected_theta_state, ], outputs=[plot_output, table_output], ) if __name__ == "__main__": demo.launch()