from __future__ import annotations import json import warnings from pathlib import Path from typing import Any, Mapping import numpy as np import torch from torch import nn _INPUT_COLUMNS = ("thickness", "epsilon_real", "wavelength") _OUTPUT_COLUMNS = ("transmission", "reflection", "intensity") def _activation_layer(name: str) -> type[nn.Module]: normalized = name.lower() activations: dict[str, type[nn.Module]] = { "relu": nn.ReLU, "tanh": nn.Tanh, "gelu": nn.GELU, } if normalized not in activations: raise ValueError(f"Unsupported activation '{name}'.") return activations[normalized] def _make_mlp( *, input_dim: int, output_dim: int, hidden_sizes: list[int], activation: str, dropout: float, ) -> nn.Sequential: if not hidden_sizes: raise ValueError("Hidden sizes must not be empty.") activation_layer = _activation_layer(activation) layer_sizes = [input_dim, *hidden_sizes, output_dim] layers: list[nn.Module] = [] for index in range(len(layer_sizes) - 2): layers.append(nn.Linear(layer_sizes[index], layer_sizes[index + 1])) layers.append(activation_layer()) if dropout > 0.0: layers.append(nn.Dropout(dropout)) layers.append(nn.Linear(layer_sizes[-2], layer_sizes[-1])) return nn.Sequential(*layers) class _ScalarNeonNet(nn.Module): def __init__( self, *, input_dim: int, output_dim: int, latent_dim: int, encoder_hidden_sizes: list[int], scalar_hidden_sizes: list[int], activation: str, dropout: float, ) -> None: super().__init__() self.encoder = _make_mlp( input_dim=input_dim, output_dim=latent_dim, hidden_sizes=encoder_hidden_sizes, activation=activation, dropout=dropout, ) self.scalar_head = _make_mlp( input_dim=latent_dim, output_dim=output_dim, hidden_sizes=scalar_hidden_sizes, activation=activation, dropout=dropout, ) def forward(self, features: torch.Tensor) -> torch.Tensor: return self.scalar_head(self.encoder(features)) class Neon: def __init__( self, *, model: _ScalarNeonNet, config: dict[str, Any], device: str, ) -> None: self.model = model.to(device) self.model.eval() self.config = config self.device = device self.input_mean = np.asarray(config["normalization"]["inputs"]["mean"], dtype=np.float64) self.input_std = np.asarray(config["normalization"]["inputs"]["std"], dtype=np.float64) self.output_mean = np.asarray(config["normalization"]["outputs"]["mean"], dtype=np.float64) self.output_std = np.asarray(config["normalization"]["outputs"]["std"], dtype=np.float64) training_range = config["training_data_range"] self.training_min = np.asarray([training_range[name]["min"] for name in _INPUT_COLUMNS], dtype=np.float64) self.training_max = np.asarray([training_range[name]["max"] for name in _INPUT_COLUMNS], dtype=np.float64) @classmethod def from_pretrained( cls, model_dir: str | Path | None = None, *, device: str | None = None, ) -> "Neon": base_dir = Path(model_dir).expanduser().resolve() if model_dir else Path(__file__).resolve().parent config_path = base_dir / "config.json" model_path = base_dir / "model.pt" if not config_path.exists(): raise FileNotFoundError(f"Missing config.json at {config_path}.") if not model_path.exists(): raise FileNotFoundError(f"Missing model.pt at {model_path}.") config = json.loads(config_path.read_text()) resolved_device = _resolve_device(device) checkpoint = torch.load(model_path, map_location=resolved_device) state_dict = checkpoint["state_dict"] if isinstance(checkpoint, dict) and "state_dict" in checkpoint else checkpoint architecture = config["architecture"] model = _ScalarNeonNet( input_dim=int(architecture["input_dim"]), output_dim=int(architecture["output_dim"]), latent_dim=int(architecture["latent_dim"]), encoder_hidden_sizes=[int(value) for value in architecture["encoder_hidden_sizes"]], scalar_hidden_sizes=[int(value) for value in architecture["scalar_hidden_sizes"]], activation=str(architecture["activation"]), dropout=float(architecture["dropout"]), ) scalar_state = { key: value for key, value in state_dict.items() if key.startswith("encoder.") or key.startswith("scalar_head.") } missing, unexpected = model.load_state_dict(scalar_state, strict=False) if missing: raise RuntimeError(f"Checkpoint is missing scalar inference weights: {sorted(missing)}") if unexpected: raise RuntimeError(f"Checkpoint contains unexpected scalar inference weights: {sorted(unexpected)}") return cls(model=model, config=config, device=resolved_device) def predict( self, inputs: Any = None, *, thickness: float | None = None, epsilon_real: float | None = None, epsilon: float | None = None, wavelength: float | None = None, warn_only: bool = False, ) -> dict[str, float] | list[dict[str, float]]: values, single_input = self._coerce_inputs( inputs, thickness=thickness, epsilon_real=epsilon_real, epsilon=epsilon, wavelength=wavelength, ) self._validate_inputs(values, warn_only=warn_only) normalized = (values - self.input_mean) / self.input_std with torch.inference_mode(): prediction_norm = self.model( torch.as_tensor(normalized, dtype=torch.float32, device=self.device) ).cpu().numpy() prediction = prediction_norm * self.output_std + self.output_mean records = [ { "transmission": float(row[0]), "reflection": float(row[1]), "intensity": float(row[2]), } for row in prediction ] return records[0] if single_input else records def _coerce_inputs( self, inputs: Any, *, thickness: float | None, epsilon_real: float | None, epsilon: float | None, wavelength: float | None, ) -> tuple[np.ndarray, bool]: has_keyword_inputs = any(value is not None for value in (thickness, epsilon_real, epsilon, wavelength)) if inputs is not None and has_keyword_inputs: raise ValueError("Pass either `inputs` or keyword arguments, not both.") if inputs is None: if epsilon_real is not None and epsilon is not None: raise ValueError("Use either `epsilon_real` or `epsilon`, not both.") epsilon_value = epsilon_real if epsilon_real is not None else epsilon if thickness is None or epsilon_value is None or wavelength is None: raise ValueError( "Expected thickness, epsilon_real (or epsilon), and wavelength when `inputs` is not provided." ) return self._mapping_to_array( { "thickness": thickness, "epsilon_real": epsilon_value, "wavelength": wavelength, } ) if isinstance(inputs, Mapping): return self._mapping_to_array(inputs) values = np.asarray(inputs, dtype=np.float64) if values.ndim == 1: if values.shape[0] != len(_INPUT_COLUMNS): raise ValueError( f"Expected a 3-element input array ordered as {list(_INPUT_COLUMNS)}, received shape {values.shape}." ) return values.reshape(1, -1), True if values.ndim == 2 and values.shape[1] == len(_INPUT_COLUMNS): return values, False raise ValueError( f"Expected an input array with shape (3,) or (N, 3) ordered as {list(_INPUT_COLUMNS)}, " f"received shape {values.shape}." ) def _mapping_to_array(self, mapping: Mapping[str, Any]) -> tuple[np.ndarray, bool]: if "epsilon_real" in mapping and "epsilon" in mapping: raise ValueError("Use either `epsilon_real` or `epsilon`, not both.") epsilon_value = mapping["epsilon_real"] if "epsilon_real" in mapping else mapping.get("epsilon") missing = [name for name in ("thickness", "wavelength") if name not in mapping] if epsilon_value is None: missing.append("epsilon_real") if missing: raise ValueError(f"Missing required input keys: {', '.join(missing)}.") thickness = np.asarray(mapping["thickness"], dtype=np.float64) epsilon_real = np.asarray(epsilon_value, dtype=np.float64) wavelength = np.asarray(mapping["wavelength"], dtype=np.float64) broadcasted = np.broadcast_arrays(thickness, epsilon_real, wavelength) values = np.stack([item.reshape(-1) for item in broadcasted], axis=1) single_input = values.shape[0] == 1 and all(item.ndim == 0 for item in (thickness, epsilon_real, wavelength)) return values, single_input def _validate_inputs(self, values: np.ndarray, *, warn_only: bool) -> None: messages: list[str] = [] for index, name in enumerate(_INPUT_COLUMNS): lower = float(self.training_min[index]) upper = float(self.training_max[index]) out_of_range = (values[:, index] < lower) | (values[:, index] > upper) if not np.any(out_of_range): continue units = self.config["training_data_range"][name]["units"] bad_values = values[out_of_range, index] preview = ", ".join(f"{value:.6g}" for value in bad_values[:3]) if bad_values.shape[0] > 3: preview = f"{preview}, ..." messages.append( f"{name} values [{preview}] are outside the training range [{lower:.6g}, {upper:.6g}] {units}." ) if not messages: return message = " ".join(messages) if warn_only: warnings.warn(message, RuntimeWarning, stacklevel=2) return raise ValueError(message) def _resolve_device(device: str | None) -> str: if device is None: return "cuda" if torch.cuda.is_available() else "cpu" if device == "cuda" and not torch.cuda.is_available(): return "cpu" return device __all__ = ["Neon"]