Spaces:
Sleeping
Sleeping
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import Any, Dict, Iterable, List, Optional, Tuple | |
| import gradio as gr | |
| import numpy as np | |
| import pandas as pd | |
| import plotly.graph_objects as go | |
| import torch | |
| import yaml | |
| from safetensors.torch import load_file | |
| from sbi.neural_nets.factory import posterior_nn | |
| MODELS_DIR = Path(__file__).resolve().parent / "models" | |
| DISPERSION_CURVES_DIR = Path(__file__).resolve().parent / "disp_curves" | |
| DEFAULT_CURVE_NONE_LABEL = "Upload custom curve" | |
| class LoadedModel: | |
| name: str | |
| sampler: "PosteriorSampler" | |
| class PosteriorSampler: | |
| """Thin wrapper around the trained neural posterior for sampling.""" | |
| def __init__(self, weights_path: Path, config_path: Path, device: Optional[str] = None) -> None: | |
| self.weights_path = weights_path | |
| self.config = yaml.safe_load(config_path.read_text()) | |
| dataset_cfg = self.config.get("dataset", {}) | |
| model_cfg = self.config.get("model", {}) | |
| params_cfg = model_cfg.get("parameters", {}) | |
| self.context_dim = int(dataset_cfg["input_shape"]) | |
| self.theta_dim = int(dataset_cfg["output_shape"]) | |
| build_kwargs: Dict[str, int] = {} | |
| for key in ("hidden_features", "num_transforms", "num_bins", "num_components"): | |
| if key in params_cfg and params_cfg[key] is not None: | |
| build_kwargs[key] = int(params_cfg[key]) | |
| density_estimator_builder = posterior_nn( | |
| model=params_cfg.get("density_estimator", "nsf"), | |
| z_score_theta=params_cfg.get("z_score_theta", "independent"), | |
| z_score_x=params_cfg.get("z_score_x", "independent"), | |
| **build_kwargs, | |
| ) | |
| # Create a dummy network to load the trained parameters. The actual statistics | |
| # (e.g. z-score buffers) are restored from the safetensors file. | |
| theta_prototype = torch.zeros(2, self.theta_dim) | |
| context_prototype = torch.zeros(2, self.context_dim) | |
| net = density_estimator_builder(theta_prototype, context_prototype) | |
| state_dict = load_file(str(weights_path)) | |
| net.load_state_dict(state_dict) | |
| net.eval() | |
| runtime_device = torch.device(device) if device else torch.device("cpu") | |
| self.net = net.to(runtime_device) | |
| self.device = runtime_device | |
| def sample(self, context: np.ndarray, num_samples: int) -> np.ndarray: | |
| with torch.no_grad(): | |
| context_tensor = torch.as_tensor(context, dtype=torch.float32, device=self.device).reshape(-1) | |
| if context_tensor.numel() != self.context_dim: | |
| raise ValueError( | |
| f"Expected context with {self.context_dim} elements, received {context_tensor.numel()}." | |
| ) | |
| samples = self.net.sample((num_samples,), context=context_tensor) | |
| samples_np = samples.cpu().numpy() | |
| if samples_np.ndim == 3: | |
| samples_np = samples_np[:, 0, :] | |
| elif samples_np.ndim != 2: | |
| raise ValueError(f"Unexpected sample shape {samples_np.shape}.") | |
| return samples_np | |
| def discover_dispersion_curves(curves_dir: Path) -> Dict[str, Tuple[Path, Path]]: | |
| if not curves_dir.exists(): | |
| return {} | |
| discovered: Dict[str, Tuple[Path, Path]] = {} | |
| for curve_path in sorted(curves_dir.glob("disp_curve_*.csv")): | |
| suffix = curve_path.stem.split("disp_curve_")[-1] | |
| theta_path = curves_dir / f"theta_{suffix}.csv" | |
| display_name = f"Curve {suffix}" | |
| discovered[display_name] = (curve_path, theta_path) | |
| return discovered | |
| PREDEFINED_DISPERSION_CURVES = discover_dispersion_curves(DISPERSION_CURVES_DIR) | |
| def discover_models(models_dir: Path) -> List[LoadedModel]: | |
| if not models_dir.exists(): | |
| raise FileNotFoundError(f"Expected models directory at {models_dir}") | |
| discovered: List[LoadedModel] = [] | |
| for weights_path in sorted(models_dir.glob("*.safetensors")): | |
| config_candidates = [ | |
| weights_path.with_suffix(".yaml"), | |
| weights_path.with_suffix(".yml"), | |
| models_dir / "config.yaml", | |
| ] | |
| config_path = next((path for path in config_candidates if path.exists()), None) | |
| if not config_path: | |
| raise FileNotFoundError(f"No configuration file found for {weights_path.name}") | |
| sampler = PosteriorSampler(weights_path, config_path) | |
| display_name = weights_path.stem.replace("_", " ").title() | |
| discovered.append(LoadedModel(name=display_name, sampler=sampler)) | |
| if not discovered: | |
| raise FileNotFoundError(f"No .safetensors models found in {models_dir}") | |
| return discovered | |
| class ModelRegistry: | |
| def __init__(self, models: Iterable[LoadedModel]): | |
| self._registry: Dict[str, PosteriorSampler] = {} | |
| for item in models: | |
| if item.name in self._registry: | |
| raise ValueError(f"Duplicate model name detected: {item.name}") | |
| self._registry[item.name] = item.sampler | |
| def names(self) -> List[str]: | |
| return list(self._registry.keys()) | |
| def get(self, name: str) -> PosteriorSampler: | |
| if name not in self._registry: | |
| raise KeyError(f"Unknown model '{name}'") | |
| return self._registry[name] | |
| REGISTRY = ModelRegistry(discover_models(MODELS_DIR)) | |
| DEFAULT_MODEL_NAME = REGISTRY.names[0] | |
| def load_predefined_dispersion_curve(name: str) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: | |
| if name not in PREDEFINED_DISPERSION_CURVES: | |
| raise gr.Error("Unknown dispersion curve selection.") | |
| curve_path, theta_path = PREDEFINED_DISPERSION_CURVES[name] | |
| if not curve_path.exists(): | |
| raise gr.Error(f"Unable to find dispersion curve file at {curve_path}.") | |
| if not theta_path.exists(): | |
| raise gr.Error(f"Unable to find theta file at {theta_path}.") | |
| curve_df = pd.read_csv(curve_path) | |
| if curve_df.shape[1] < 2: | |
| raise gr.Error(f"Dispersion curve file {curve_path.name} must contain period and vg columns.") | |
| periods = pd.to_numeric(curve_df.iloc[:, 0], errors="coerce").to_numpy(dtype=np.float32) | |
| vg_values = pd.to_numeric(curve_df.iloc[:, 1], errors="coerce").to_numpy(dtype=np.float32) | |
| theta_df = pd.read_csv(theta_path) | |
| theta_values = pd.to_numeric(theta_df.to_numpy().reshape(-1), errors="coerce").astype(np.float32) | |
| theta_values = theta_values[~np.isnan(theta_values)] | |
| if periods.size != vg_values.size: | |
| raise gr.Error( | |
| f"Dispersion curve file {curve_path.name} contains mismatched period and vg counts." | |
| ) | |
| if np.isnan(periods).any() or np.isnan(vg_values).any(): | |
| raise gr.Error(f"Dispersion curve file {curve_path.name} contains non-numeric entries.") | |
| return periods, vg_values, theta_values | |
| def read_dispersion_curve(upload: Optional[Any], expected_length: int) -> np.ndarray: | |
| if upload is None: | |
| raise gr.Error("Please upload a CSV file containing the dispersion curve.") | |
| try: | |
| df = pd.read_csv(upload.name, header=None) | |
| except Exception as exc: # pylint: disable=broad-except | |
| raise gr.Error(f"Unable to read CSV file: {exc}") from exc | |
| numeric_values = pd.to_numeric(df.to_numpy().reshape(-1), errors="coerce").astype(np.float32) | |
| numeric_values = numeric_values[~np.isnan(numeric_values)] | |
| if numeric_values.size != expected_length: | |
| raise gr.Error( | |
| f"Expected {expected_length} values in the dispersion curve, but found {numeric_values.size}. " | |
| "Please provide a CSV with exactly one value per frequency sample." | |
| ) | |
| return numeric_values | |
| def build_plot(samples: np.ndarray) -> go.Figure: | |
| depth_axis = np.arange(1, samples.shape[1] + 1) | |
| fig = go.Figure() | |
| for idx, sample in enumerate(samples, start=1): | |
| fig.add_trace( | |
| go.Scatter( | |
| x=depth_axis, | |
| y=sample, | |
| mode="lines", | |
| name=f"Sample {idx}", | |
| ) | |
| ) | |
| fig.update_layout( | |
| xaxis_title="Layer index", | |
| yaxis_title="Velocity", | |
| legend_title="Generated samples", | |
| template="plotly_white", | |
| margin=dict(l=40, r=10, t=40, b=40), | |
| ) | |
| return fig | |
| def build_dispersion_plot(periods: np.ndarray, group_velocities: np.ndarray) -> go.Figure: | |
| fig = go.Figure() | |
| fig.add_trace( | |
| go.Scatter( | |
| x=periods, | |
| y=group_velocities, | |
| mode="lines+markers", | |
| name="Dispersion curve", | |
| ) | |
| ) | |
| fig.update_layout( | |
| xaxis_title="Period", | |
| yaxis_title="Group velocity", | |
| template="plotly_white", | |
| margin=dict(l=40, r=10, t=40, b=40), | |
| showlegend=False, | |
| ) | |
| return fig | |
| def handle_predefined_curve_selection(selection: Optional[str]) -> Tuple[Any, Optional[np.ndarray], Optional[np.ndarray]]: | |
| if not selection or selection == DEFAULT_CURVE_NONE_LABEL: | |
| return gr.update(value=None), None, None | |
| periods, vg_values, theta_values = load_predefined_dispersion_curve(selection) | |
| figure = build_dispersion_plot(periods, vg_values) | |
| return figure, vg_values, theta_values | |
| def format_samples(samples: np.ndarray) -> pd.DataFrame: | |
| index = [f"Layer {i}" for i in range(1, samples.shape[1] + 1)] | |
| columns = [f"Sample {idx}" for idx in range(1, samples.shape[0] + 1)] | |
| return pd.DataFrame(samples.T, index=index, columns=columns) | |
| def generate_velocity_models( | |
| upload: Optional[Any], | |
| model_name: str, | |
| num_samples: int, | |
| predefined_curve_name: Optional[str], | |
| predefined_vg: Optional[np.ndarray], | |
| _preloaded_theta: Optional[np.ndarray], | |
| ) -> Tuple[go.Figure, pd.DataFrame]: | |
| sampler = REGISTRY.get(model_name) | |
| dispersion_curve: Optional[np.ndarray] = None | |
| if predefined_curve_name and predefined_curve_name != DEFAULT_CURVE_NONE_LABEL: | |
| if predefined_vg is None: | |
| # Reload from disk if the state is empty for any reason. | |
| _, vg_values, _ = load_predefined_dispersion_curve(predefined_curve_name) | |
| predefined_vg = vg_values | |
| dispersion_curve = np.asarray(predefined_vg, dtype=np.float32) | |
| else: | |
| dispersion_curve = read_dispersion_curve(upload, sampler.context_dim) | |
| if dispersion_curve.size != sampler.context_dim: | |
| raise gr.Error( | |
| f"The selected dispersion curve contains {dispersion_curve.size} samples, " | |
| f"but the posterior expects {sampler.context_dim}." | |
| ) | |
| samples = sampler.sample(dispersion_curve, int(num_samples)) | |
| return build_plot(samples), format_samples(samples) | |
| with gr.Blocks(title="Surface Wave Inversion with NPE") as demo: | |
| default_curve_choices = [DEFAULT_CURVE_NONE_LABEL] + list(PREDEFINED_DISPERSION_CURVES.keys()) | |
| selected_vg_state = gr.State(value=None) | |
| selected_theta_state = gr.State(value=None) | |
| gr.Markdown( | |
| "## Neural Posterior Estimation for Surface Wave Inversion\n" | |
| "Select a built-in dispersion curve or upload your own, then choose a pretrained posterior model " | |
| "to draw samples of the subsurface velocity structure." | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| default_curve_choice = gr.Dropdown( | |
| label="Default dispersion curve", | |
| choices=default_curve_choices, | |
| value=DEFAULT_CURVE_NONE_LABEL, | |
| interactive=len(default_curve_choices) > 1, | |
| info="Pick a built-in curve or stay on Upload custom curve to provide your own file.", | |
| ) | |
| curve_input = gr.File( | |
| label="Dispersion curve (.csv)", | |
| file_types=[".csv"], | |
| ) | |
| model_choice = gr.Dropdown( | |
| label="Posterior model", | |
| choices=REGISTRY.names, | |
| value=DEFAULT_MODEL_NAME, | |
| ) | |
| sample_count = gr.Slider( | |
| label="Number of samples", | |
| minimum=1, | |
| maximum=200, | |
| value=20, | |
| step=1, | |
| ) | |
| generate_btn = gr.Button("Generate velocity models", variant="primary") | |
| with gr.Column(scale=1): | |
| dispersion_plot = gr.Plot(label="Selected dispersion curve") | |
| plot_output = gr.Plot(label="Sampled velocity profiles") | |
| table_output = gr.Dataframe( | |
| headers=[f"Sample {idx}" for idx in range(1, 6)], | |
| datatype="number", | |
| interactive=False, | |
| label="Sample values", | |
| ) | |
| default_curve_choice.change( | |
| handle_predefined_curve_selection, | |
| inputs=default_curve_choice, | |
| outputs=[dispersion_plot, selected_vg_state, selected_theta_state], | |
| ) | |
| generate_btn.click( | |
| generate_velocity_models, | |
| inputs=[ | |
| curve_input, | |
| model_choice, | |
| sample_count, | |
| default_curve_choice, | |
| selected_vg_state, | |
| selected_theta_state, | |
| ], | |
| outputs=[plot_output, table_output], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |