Neon / neon_slab.py
Herrprofessor's picture
Upload folder using huggingface_hub
aed74e2 verified
from __future__ import annotations
import json
import warnings
from pathlib import Path
from typing import Any, Mapping
import numpy as np
import torch
from torch import nn
_INPUT_COLUMNS = ("thickness", "epsilon_real", "wavelength")
_OUTPUT_COLUMNS = ("transmission", "reflection", "intensity")
def _activation_layer(name: str) -> type[nn.Module]:
normalized = name.lower()
activations: dict[str, type[nn.Module]] = {
"relu": nn.ReLU,
"tanh": nn.Tanh,
"gelu": nn.GELU,
}
if normalized not in activations:
raise ValueError(f"Unsupported activation '{name}'.")
return activations[normalized]
def _make_mlp(
*,
input_dim: int,
output_dim: int,
hidden_sizes: list[int],
activation: str,
dropout: float,
) -> nn.Sequential:
if not hidden_sizes:
raise ValueError("Hidden sizes must not be empty.")
activation_layer = _activation_layer(activation)
layer_sizes = [input_dim, *hidden_sizes, output_dim]
layers: list[nn.Module] = []
for index in range(len(layer_sizes) - 2):
layers.append(nn.Linear(layer_sizes[index], layer_sizes[index + 1]))
layers.append(activation_layer())
if dropout > 0.0:
layers.append(nn.Dropout(dropout))
layers.append(nn.Linear(layer_sizes[-2], layer_sizes[-1]))
return nn.Sequential(*layers)
class _ScalarNeonNet(nn.Module):
def __init__(
self,
*,
input_dim: int,
output_dim: int,
latent_dim: int,
encoder_hidden_sizes: list[int],
scalar_hidden_sizes: list[int],
activation: str,
dropout: float,
) -> None:
super().__init__()
self.encoder = _make_mlp(
input_dim=input_dim,
output_dim=latent_dim,
hidden_sizes=encoder_hidden_sizes,
activation=activation,
dropout=dropout,
)
self.scalar_head = _make_mlp(
input_dim=latent_dim,
output_dim=output_dim,
hidden_sizes=scalar_hidden_sizes,
activation=activation,
dropout=dropout,
)
def forward(self, features: torch.Tensor) -> torch.Tensor:
return self.scalar_head(self.encoder(features))
class Neon:
def __init__(
self,
*,
model: _ScalarNeonNet,
config: dict[str, Any],
device: str,
) -> None:
self.model = model.to(device)
self.model.eval()
self.config = config
self.device = device
self.input_mean = np.asarray(config["normalization"]["inputs"]["mean"], dtype=np.float64)
self.input_std = np.asarray(config["normalization"]["inputs"]["std"], dtype=np.float64)
self.output_mean = np.asarray(config["normalization"]["outputs"]["mean"], dtype=np.float64)
self.output_std = np.asarray(config["normalization"]["outputs"]["std"], dtype=np.float64)
training_range = config["training_data_range"]
self.training_min = np.asarray([training_range[name]["min"] for name in _INPUT_COLUMNS], dtype=np.float64)
self.training_max = np.asarray([training_range[name]["max"] for name in _INPUT_COLUMNS], dtype=np.float64)
@classmethod
def from_pretrained(
cls,
model_dir: str | Path | None = None,
*,
device: str | None = None,
) -> "Neon":
base_dir = Path(model_dir).expanduser().resolve() if model_dir else Path(__file__).resolve().parent
config_path = base_dir / "config.json"
model_path = base_dir / "model.pt"
if not config_path.exists():
raise FileNotFoundError(f"Missing config.json at {config_path}.")
if not model_path.exists():
raise FileNotFoundError(f"Missing model.pt at {model_path}.")
config = json.loads(config_path.read_text())
resolved_device = _resolve_device(device)
checkpoint = torch.load(model_path, map_location=resolved_device)
state_dict = checkpoint["state_dict"] if isinstance(checkpoint, dict) and "state_dict" in checkpoint else checkpoint
architecture = config["architecture"]
model = _ScalarNeonNet(
input_dim=int(architecture["input_dim"]),
output_dim=int(architecture["output_dim"]),
latent_dim=int(architecture["latent_dim"]),
encoder_hidden_sizes=[int(value) for value in architecture["encoder_hidden_sizes"]],
scalar_hidden_sizes=[int(value) for value in architecture["scalar_hidden_sizes"]],
activation=str(architecture["activation"]),
dropout=float(architecture["dropout"]),
)
scalar_state = {
key: value
for key, value in state_dict.items()
if key.startswith("encoder.") or key.startswith("scalar_head.")
}
missing, unexpected = model.load_state_dict(scalar_state, strict=False)
if missing:
raise RuntimeError(f"Checkpoint is missing scalar inference weights: {sorted(missing)}")
if unexpected:
raise RuntimeError(f"Checkpoint contains unexpected scalar inference weights: {sorted(unexpected)}")
return cls(model=model, config=config, device=resolved_device)
def predict(
self,
inputs: Any = None,
*,
thickness: float | None = None,
epsilon_real: float | None = None,
epsilon: float | None = None,
wavelength: float | None = None,
warn_only: bool = False,
) -> dict[str, float] | list[dict[str, float]]:
values, single_input = self._coerce_inputs(
inputs,
thickness=thickness,
epsilon_real=epsilon_real,
epsilon=epsilon,
wavelength=wavelength,
)
self._validate_inputs(values, warn_only=warn_only)
normalized = (values - self.input_mean) / self.input_std
with torch.inference_mode():
prediction_norm = self.model(
torch.as_tensor(normalized, dtype=torch.float32, device=self.device)
).cpu().numpy()
prediction = prediction_norm * self.output_std + self.output_mean
records = [
{
"transmission": float(row[0]),
"reflection": float(row[1]),
"intensity": float(row[2]),
}
for row in prediction
]
return records[0] if single_input else records
def _coerce_inputs(
self,
inputs: Any,
*,
thickness: float | None,
epsilon_real: float | None,
epsilon: float | None,
wavelength: float | None,
) -> tuple[np.ndarray, bool]:
has_keyword_inputs = any(value is not None for value in (thickness, epsilon_real, epsilon, wavelength))
if inputs is not None and has_keyword_inputs:
raise ValueError("Pass either `inputs` or keyword arguments, not both.")
if inputs is None:
if epsilon_real is not None and epsilon is not None:
raise ValueError("Use either `epsilon_real` or `epsilon`, not both.")
epsilon_value = epsilon_real if epsilon_real is not None else epsilon
if thickness is None or epsilon_value is None or wavelength is None:
raise ValueError(
"Expected thickness, epsilon_real (or epsilon), and wavelength when `inputs` is not provided."
)
return self._mapping_to_array(
{
"thickness": thickness,
"epsilon_real": epsilon_value,
"wavelength": wavelength,
}
)
if isinstance(inputs, Mapping):
return self._mapping_to_array(inputs)
values = np.asarray(inputs, dtype=np.float64)
if values.ndim == 1:
if values.shape[0] != len(_INPUT_COLUMNS):
raise ValueError(
f"Expected a 3-element input array ordered as {list(_INPUT_COLUMNS)}, received shape {values.shape}."
)
return values.reshape(1, -1), True
if values.ndim == 2 and values.shape[1] == len(_INPUT_COLUMNS):
return values, False
raise ValueError(
f"Expected an input array with shape (3,) or (N, 3) ordered as {list(_INPUT_COLUMNS)}, "
f"received shape {values.shape}."
)
def _mapping_to_array(self, mapping: Mapping[str, Any]) -> tuple[np.ndarray, bool]:
if "epsilon_real" in mapping and "epsilon" in mapping:
raise ValueError("Use either `epsilon_real` or `epsilon`, not both.")
epsilon_value = mapping["epsilon_real"] if "epsilon_real" in mapping else mapping.get("epsilon")
missing = [name for name in ("thickness", "wavelength") if name not in mapping]
if epsilon_value is None:
missing.append("epsilon_real")
if missing:
raise ValueError(f"Missing required input keys: {', '.join(missing)}.")
thickness = np.asarray(mapping["thickness"], dtype=np.float64)
epsilon_real = np.asarray(epsilon_value, dtype=np.float64)
wavelength = np.asarray(mapping["wavelength"], dtype=np.float64)
broadcasted = np.broadcast_arrays(thickness, epsilon_real, wavelength)
values = np.stack([item.reshape(-1) for item in broadcasted], axis=1)
single_input = values.shape[0] == 1 and all(item.ndim == 0 for item in (thickness, epsilon_real, wavelength))
return values, single_input
def _validate_inputs(self, values: np.ndarray, *, warn_only: bool) -> None:
messages: list[str] = []
for index, name in enumerate(_INPUT_COLUMNS):
lower = float(self.training_min[index])
upper = float(self.training_max[index])
out_of_range = (values[:, index] < lower) | (values[:, index] > upper)
if not np.any(out_of_range):
continue
units = self.config["training_data_range"][name]["units"]
bad_values = values[out_of_range, index]
preview = ", ".join(f"{value:.6g}" for value in bad_values[:3])
if bad_values.shape[0] > 3:
preview = f"{preview}, ..."
messages.append(
f"{name} values [{preview}] are outside the training range [{lower:.6g}, {upper:.6g}] {units}."
)
if not messages:
return
message = " ".join(messages)
if warn_only:
warnings.warn(message, RuntimeWarning, stacklevel=2)
return
raise ValueError(message)
def _resolve_device(device: str | None) -> str:
if device is None:
return "cuda" if torch.cuda.is_available() else "cpu"
if device == "cuda" and not torch.cuda.is_available():
return "cpu"
return device
__all__ = ["Neon"]