| from __future__ import annotations |
|
|
| import json |
| import warnings |
| from pathlib import Path |
| from typing import Any, Mapping |
|
|
| import numpy as np |
| import torch |
| from torch import nn |
|
|
| _INPUT_COLUMNS = ("thickness", "epsilon_real", "wavelength") |
| _OUTPUT_COLUMNS = ("transmission", "reflection", "intensity") |
|
|
|
|
| def _activation_layer(name: str) -> type[nn.Module]: |
| normalized = name.lower() |
| activations: dict[str, type[nn.Module]] = { |
| "relu": nn.ReLU, |
| "tanh": nn.Tanh, |
| "gelu": nn.GELU, |
| } |
| if normalized not in activations: |
| raise ValueError(f"Unsupported activation '{name}'.") |
| return activations[normalized] |
|
|
|
|
| def _make_mlp( |
| *, |
| input_dim: int, |
| output_dim: int, |
| hidden_sizes: list[int], |
| activation: str, |
| dropout: float, |
| ) -> nn.Sequential: |
| if not hidden_sizes: |
| raise ValueError("Hidden sizes must not be empty.") |
|
|
| activation_layer = _activation_layer(activation) |
| layer_sizes = [input_dim, *hidden_sizes, output_dim] |
| layers: list[nn.Module] = [] |
| for index in range(len(layer_sizes) - 2): |
| layers.append(nn.Linear(layer_sizes[index], layer_sizes[index + 1])) |
| layers.append(activation_layer()) |
| if dropout > 0.0: |
| layers.append(nn.Dropout(dropout)) |
| layers.append(nn.Linear(layer_sizes[-2], layer_sizes[-1])) |
| return nn.Sequential(*layers) |
|
|
|
|
| class _ScalarNeonNet(nn.Module): |
| def __init__( |
| self, |
| *, |
| input_dim: int, |
| output_dim: int, |
| latent_dim: int, |
| encoder_hidden_sizes: list[int], |
| scalar_hidden_sizes: list[int], |
| activation: str, |
| dropout: float, |
| ) -> None: |
| super().__init__() |
| self.encoder = _make_mlp( |
| input_dim=input_dim, |
| output_dim=latent_dim, |
| hidden_sizes=encoder_hidden_sizes, |
| activation=activation, |
| dropout=dropout, |
| ) |
| self.scalar_head = _make_mlp( |
| input_dim=latent_dim, |
| output_dim=output_dim, |
| hidden_sizes=scalar_hidden_sizes, |
| activation=activation, |
| dropout=dropout, |
| ) |
|
|
| def forward(self, features: torch.Tensor) -> torch.Tensor: |
| return self.scalar_head(self.encoder(features)) |
|
|
|
|
| class Neon: |
| def __init__( |
| self, |
| *, |
| model: _ScalarNeonNet, |
| config: dict[str, Any], |
| device: str, |
| ) -> None: |
| self.model = model.to(device) |
| self.model.eval() |
| self.config = config |
| self.device = device |
| self.input_mean = np.asarray(config["normalization"]["inputs"]["mean"], dtype=np.float64) |
| self.input_std = np.asarray(config["normalization"]["inputs"]["std"], dtype=np.float64) |
| self.output_mean = np.asarray(config["normalization"]["outputs"]["mean"], dtype=np.float64) |
| self.output_std = np.asarray(config["normalization"]["outputs"]["std"], dtype=np.float64) |
| training_range = config["training_data_range"] |
| self.training_min = np.asarray([training_range[name]["min"] for name in _INPUT_COLUMNS], dtype=np.float64) |
| self.training_max = np.asarray([training_range[name]["max"] for name in _INPUT_COLUMNS], dtype=np.float64) |
|
|
| @classmethod |
| def from_pretrained( |
| cls, |
| model_dir: str | Path | None = None, |
| *, |
| device: str | None = None, |
| ) -> "Neon": |
| base_dir = Path(model_dir).expanduser().resolve() if model_dir else Path(__file__).resolve().parent |
| config_path = base_dir / "config.json" |
| model_path = base_dir / "model.pt" |
|
|
| if not config_path.exists(): |
| raise FileNotFoundError(f"Missing config.json at {config_path}.") |
| if not model_path.exists(): |
| raise FileNotFoundError(f"Missing model.pt at {model_path}.") |
|
|
| config = json.loads(config_path.read_text()) |
| resolved_device = _resolve_device(device) |
| checkpoint = torch.load(model_path, map_location=resolved_device) |
| state_dict = checkpoint["state_dict"] if isinstance(checkpoint, dict) and "state_dict" in checkpoint else checkpoint |
|
|
| architecture = config["architecture"] |
| model = _ScalarNeonNet( |
| input_dim=int(architecture["input_dim"]), |
| output_dim=int(architecture["output_dim"]), |
| latent_dim=int(architecture["latent_dim"]), |
| encoder_hidden_sizes=[int(value) for value in architecture["encoder_hidden_sizes"]], |
| scalar_hidden_sizes=[int(value) for value in architecture["scalar_hidden_sizes"]], |
| activation=str(architecture["activation"]), |
| dropout=float(architecture["dropout"]), |
| ) |
|
|
| scalar_state = { |
| key: value |
| for key, value in state_dict.items() |
| if key.startswith("encoder.") or key.startswith("scalar_head.") |
| } |
| missing, unexpected = model.load_state_dict(scalar_state, strict=False) |
| if missing: |
| raise RuntimeError(f"Checkpoint is missing scalar inference weights: {sorted(missing)}") |
| if unexpected: |
| raise RuntimeError(f"Checkpoint contains unexpected scalar inference weights: {sorted(unexpected)}") |
|
|
| return cls(model=model, config=config, device=resolved_device) |
|
|
| def predict( |
| self, |
| inputs: Any = None, |
| *, |
| thickness: float | None = None, |
| epsilon_real: float | None = None, |
| epsilon: float | None = None, |
| wavelength: float | None = None, |
| warn_only: bool = False, |
| ) -> dict[str, float] | list[dict[str, float]]: |
| values, single_input = self._coerce_inputs( |
| inputs, |
| thickness=thickness, |
| epsilon_real=epsilon_real, |
| epsilon=epsilon, |
| wavelength=wavelength, |
| ) |
| self._validate_inputs(values, warn_only=warn_only) |
| normalized = (values - self.input_mean) / self.input_std |
|
|
| with torch.inference_mode(): |
| prediction_norm = self.model( |
| torch.as_tensor(normalized, dtype=torch.float32, device=self.device) |
| ).cpu().numpy() |
|
|
| prediction = prediction_norm * self.output_std + self.output_mean |
| records = [ |
| { |
| "transmission": float(row[0]), |
| "reflection": float(row[1]), |
| "intensity": float(row[2]), |
| } |
| for row in prediction |
| ] |
| return records[0] if single_input else records |
|
|
| def _coerce_inputs( |
| self, |
| inputs: Any, |
| *, |
| thickness: float | None, |
| epsilon_real: float | None, |
| epsilon: float | None, |
| wavelength: float | None, |
| ) -> tuple[np.ndarray, bool]: |
| has_keyword_inputs = any(value is not None for value in (thickness, epsilon_real, epsilon, wavelength)) |
| if inputs is not None and has_keyword_inputs: |
| raise ValueError("Pass either `inputs` or keyword arguments, not both.") |
|
|
| if inputs is None: |
| if epsilon_real is not None and epsilon is not None: |
| raise ValueError("Use either `epsilon_real` or `epsilon`, not both.") |
| epsilon_value = epsilon_real if epsilon_real is not None else epsilon |
| if thickness is None or epsilon_value is None or wavelength is None: |
| raise ValueError( |
| "Expected thickness, epsilon_real (or epsilon), and wavelength when `inputs` is not provided." |
| ) |
| return self._mapping_to_array( |
| { |
| "thickness": thickness, |
| "epsilon_real": epsilon_value, |
| "wavelength": wavelength, |
| } |
| ) |
|
|
| if isinstance(inputs, Mapping): |
| return self._mapping_to_array(inputs) |
|
|
| values = np.asarray(inputs, dtype=np.float64) |
| if values.ndim == 1: |
| if values.shape[0] != len(_INPUT_COLUMNS): |
| raise ValueError( |
| f"Expected a 3-element input array ordered as {list(_INPUT_COLUMNS)}, received shape {values.shape}." |
| ) |
| return values.reshape(1, -1), True |
| if values.ndim == 2 and values.shape[1] == len(_INPUT_COLUMNS): |
| return values, False |
| raise ValueError( |
| f"Expected an input array with shape (3,) or (N, 3) ordered as {list(_INPUT_COLUMNS)}, " |
| f"received shape {values.shape}." |
| ) |
|
|
| def _mapping_to_array(self, mapping: Mapping[str, Any]) -> tuple[np.ndarray, bool]: |
| if "epsilon_real" in mapping and "epsilon" in mapping: |
| raise ValueError("Use either `epsilon_real` or `epsilon`, not both.") |
| epsilon_value = mapping["epsilon_real"] if "epsilon_real" in mapping else mapping.get("epsilon") |
| missing = [name for name in ("thickness", "wavelength") if name not in mapping] |
| if epsilon_value is None: |
| missing.append("epsilon_real") |
| if missing: |
| raise ValueError(f"Missing required input keys: {', '.join(missing)}.") |
|
|
| thickness = np.asarray(mapping["thickness"], dtype=np.float64) |
| epsilon_real = np.asarray(epsilon_value, dtype=np.float64) |
| wavelength = np.asarray(mapping["wavelength"], dtype=np.float64) |
| broadcasted = np.broadcast_arrays(thickness, epsilon_real, wavelength) |
| values = np.stack([item.reshape(-1) for item in broadcasted], axis=1) |
| single_input = values.shape[0] == 1 and all(item.ndim == 0 for item in (thickness, epsilon_real, wavelength)) |
| return values, single_input |
|
|
| def _validate_inputs(self, values: np.ndarray, *, warn_only: bool) -> None: |
| messages: list[str] = [] |
| for index, name in enumerate(_INPUT_COLUMNS): |
| lower = float(self.training_min[index]) |
| upper = float(self.training_max[index]) |
| out_of_range = (values[:, index] < lower) | (values[:, index] > upper) |
| if not np.any(out_of_range): |
| continue |
|
|
| units = self.config["training_data_range"][name]["units"] |
| bad_values = values[out_of_range, index] |
| preview = ", ".join(f"{value:.6g}" for value in bad_values[:3]) |
| if bad_values.shape[0] > 3: |
| preview = f"{preview}, ..." |
| messages.append( |
| f"{name} values [{preview}] are outside the training range [{lower:.6g}, {upper:.6g}] {units}." |
| ) |
|
|
| if not messages: |
| return |
|
|
| message = " ".join(messages) |
| if warn_only: |
| warnings.warn(message, RuntimeWarning, stacklevel=2) |
| return |
| raise ValueError(message) |
|
|
|
|
| def _resolve_device(device: str | None) -> str: |
| if device is None: |
| return "cuda" if torch.cuda.is_available() else "cpu" |
| if device == "cuda" and not torch.cuda.is_available(): |
| return "cpu" |
| return device |
|
|
|
|
| __all__ = ["Neon"] |
|
|