TWLab's picture
Add DCDE: Depth-Conditioned Dynamic Ensemble architecture
0930e10 verified
"""
=============================================================================
DCDE: Depth-Conditioned Dynamic Ensemble with Evidential Uncertainty
for Femtosecond Laser Internal Hydrogel Etching Prediction
A novel hybrid architecture combining:
1. FiLM-conditioned Neural Network (depth-adaptive feature modulation)
2. XGBoost gradient-boosted trees (capturing tabular feature interactions)
3. Learned dynamic gating network (input-conditioned fusion)
4. Evidential Deep Learning (Normal-Inverse-Gamma uncertainty)
5. Physics-informed regularization (monotonicity + energy constraints)
References:
- FiLM: Perez et al., AAAI 2018 (arxiv:1709.07871)
- Deep Evidential Regression: Amini et al., NeurIPS 2020 (arxiv:1910.02600)
- DELE gating: AAAI 2023 (arxiv:2302.00932)
- Physics-informed ML: Zhang et al. 2022 (arxiv:2211.08064)
=============================================================================
"""
from __future__ import annotations
import math
from typing import Dict, List, Optional, Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
# =============================================================================
# 1. PHYSICS-INFORMED FEATURE ENGINEERING (Depth-Dependent)
# =============================================================================
class DepthPhysicsFeatures:
"""
Compute analytically-derived physics features that encode how
femtosecond laser behavior changes with focusing depth in hydrogels.
These features capture three primary depth-dependent effects:
1. Spherical aberration (Strehl ratio degradation)
2. Group velocity dispersion (pulse temporal broadening)
3. Self-focusing proximity (Kerr nonlinearity regime)
Scientific basis:
- Vogel et al., Applied Physics B (2005) - fs-laser tissue interaction
- Schaffer et al., Optics Letters (2001) - bulk modification thresholds
- Boyd, Nonlinear Optics (2020) - self-focusing, GVD theory
"""
def __init__(
self,
n_medium: float = 1.34, # Refractive index of hydrogel
beta2_fs2_mm: float = 55.0, # GVD parameter (fs²/mm) for water-like medium
n2_m2_W: float = 2.0e-20, # Nonlinear refractive index (m²/W)
):
self.n_medium = n_medium
self.beta2 = beta2_fs2_mm * 1e-30 / 1e-3 # Convert to s²/m
self.n2 = n2_m2_W
def compute(
self,
focusing_depth_um: np.ndarray,
pulse_duration_fs: np.ndarray,
wavelength_nm: np.ndarray,
NA: np.ndarray,
power_mW: np.ndarray,
rep_rate_kHz: np.ndarray,
) -> np.ndarray:
"""
Compute physics features from raw parameters.
Returns array of shape (N, 5) with columns:
[strehl_ratio, intensity_factor, z_normalized, self_focus_ratio, depth_aberration]
"""
z = np.asarray(focusing_depth_um) * 1e-6 # µm → m
tau0 = np.asarray(pulse_duration_fs) * 1e-15 # fs → s
lam = np.asarray(wavelength_nm) * 1e-9 # nm → m
na = np.asarray(NA)
P_avg = np.asarray(power_mW) * 1e-3 # mW → W
f_rep = np.asarray(rep_rate_kHz) * 1e3 # kHz → Hz
# 1. Strehl ratio: S(z) = exp(-(2π·Δn·z·NA²/λ)²)
# Quantifies how much aberration degrades the focal spot
delta_n = self.n_medium - 1.0 # Air-hydrogel RI mismatch
strehl = np.exp(-((2 * np.pi * delta_n * z * na**2) / lam)**2)
strehl = np.clip(strehl, 1e-6, 1.0)
# 2. GVD pulse broadening: τ(z) = τ₀·√(1 + (z/L_D)²)
# Reduced peak intensity at depth
L_D = tau0**2 / np.abs(self.beta2) # Dispersion length
tau_z = tau0 * np.sqrt(1 + (z / np.maximum(L_D, 1e-10))**2)
intensity_factor = tau0 / np.maximum(tau_z, tau0) # ∈ (0, 1]
# 3. Normalized depth (relative to Rayleigh range)
# Indicates when geometric vs. wave-optical effects dominate
w0 = lam / (np.pi * np.maximum(na, 0.01)) # Beam waist
z_rayleigh = np.pi * w0**2 / lam
z_normalized = z / np.maximum(z_rayleigh, 1e-10)
# 4. Self-focusing proximity: P_peak / P_critical
# When > 1: catastrophic self-focusing regime
P_peak = P_avg / (f_rep * tau0) # Peak power per pulse
P_cr = 3.77 * lam**2 / (8 * np.pi * self.n_medium * self.n2)
sf_ratio = P_peak / np.maximum(P_cr, 1e-10)
sf_ratio = np.clip(sf_ratio, 0, 50) # Cap at 50× critical
# 5. Depth-dependent aberration parameter
# Combined effect: how much the focal volume degrades with depth
depth_aberration = delta_n * z * na**2 / lam
return np.column_stack([
strehl,
intensity_factor,
z_normalized,
sf_ratio,
depth_aberration,
]).astype(np.float32)
@property
def feature_names(self) -> List[str]:
return [
"strehl_ratio",
"intensity_factor_gvd",
"z_normalized_rayleigh",
"self_focusing_ratio",
"depth_aberration_param",
]
# =============================================================================
# 2. FiLM-CONDITIONED NEURAL NETWORK (Depth-Adaptive)
# =============================================================================
class FiLMGenerator(nn.Module):
"""
Feature-wise Linear Modulation (FiLM) generator.
Maps conditioning input (depth features) to per-layer (γ, β) pairs
that modulate hidden representations: h' = γ ⊙ h + β
Uses the Δγ initialization trick: γ = 1 + Δγ for stable training
(identity modulation at initialization).
Reference: Perez et al., "FiLM: Visual Reasoning with a General
Conditioning Layer", AAAI 2018.
"""
def __init__(self, conditioning_dim: int, hidden_dims: List[int]):
super().__init__()
self.generators = nn.ModuleList()
for h_dim in hidden_dims:
self.generators.append(
nn.Sequential(
nn.Linear(conditioning_dim, 64),
nn.SiLU(),
nn.Linear(64, h_dim * 2), # γ and β
)
)
# Initialize near identity (Δγ ≈ 0, β ≈ 0)
for gen in self.generators:
nn.init.zeros_(gen[-1].weight)
nn.init.zeros_(gen[-1].bias)
def forward(self, conditioning: torch.Tensor) -> List[Tuple[torch.Tensor, torch.Tensor]]:
"""
Parameters
----------
conditioning : Tensor, shape (B, conditioning_dim)
Depth-related features for conditioning
Returns
-------
list of (gamma, beta) tuples for each layer
"""
film_params = []
for gen in self.generators:
params = gen(conditioning)
h_dim = params.shape[-1] // 2
delta_gamma = params[:, :h_dim]
beta = params[:, h_dim:]
gamma = 1.0 + delta_gamma # Δγ trick
film_params.append((gamma, beta))
return film_params
class FiLMConditionedMLP(nn.Module):
"""
Multi-layer perceptron with FiLM conditioning at each hidden layer.
Architecture:
Input → [Linear → BatchNorm → FiLM(γ,β) → SiLU → Dropout] × L → Output
The FiLM conditioning allows depth information to modulate the network's
intermediate representations multiplicatively, enabling fundamentally
different processing depending on focusing depth — not just adding depth
as another input feature.
"""
def __init__(
self,
input_dim: int,
hidden_dims: List[int],
output_dim: int,
conditioning_dim: int,
dropout: float = 0.15,
):
super().__init__()
self.hidden_dims = hidden_dims
# Build layers
dims = [input_dim] + hidden_dims
self.layers = nn.ModuleList([
nn.Linear(d_in, d_out) for d_in, d_out in zip(dims[:-1], dims[1:])
])
self.batch_norms = nn.ModuleList([
nn.BatchNorm1d(d) for d in hidden_dims
])
self.dropouts = nn.ModuleList([
nn.Dropout(dropout * (1 - i / len(hidden_dims)))
for i in range(len(hidden_dims))
])
# FiLM generator (depth → modulation parameters)
self.film_generator = FiLMGenerator(conditioning_dim, hidden_dims)
# Output projection
self.output_layer = nn.Linear(hidden_dims[-1], output_dim)
def forward(
self,
x: torch.Tensor,
conditioning: torch.Tensor,
) -> torch.Tensor:
"""
Parameters
----------
x : Tensor (B, input_dim) - laser + material features
conditioning : Tensor (B, conditioning_dim) - depth physics features
Returns
-------
Tensor (B, output_dim) - latent representation
"""
# Get FiLM parameters for all layers
film_params = self.film_generator(conditioning)
h = x
for i, (layer, bn, dropout) in enumerate(
zip(self.layers, self.batch_norms, self.dropouts)
):
h = layer(h)
h = bn(h)
# Apply FiLM modulation
gamma, beta = film_params[i]
h = gamma * h + beta
h = F.silu(h)
h = dropout(h)
return self.output_layer(h)
# =============================================================================
# 3. EVIDENTIAL REGRESSION HEAD (Normal-Inverse-Gamma)
# =============================================================================
class EvidentialHead(nn.Module):
"""
Normal-Inverse-Gamma (NIG) evidential regression head.
Outputs four parameters per target that parameterize a NIG distribution,
providing both aleatoric and epistemic uncertainty estimates in a single
forward pass (no ensemble or MC dropout required).
For each output dimension:
μ ~ N(γ, σ²/ν) [predictive mean with epistemic noise]
σ² ~ InvGamma(α, β) [aleatoric variance]
Uncertainty decomposition:
Aleatoric: E[σ²] = β / (α - 1)
Epistemic: Var[μ] = β / (ν(α - 1))
Reference: Amini et al., "Deep Evidential Regression", NeurIPS 2020.
"""
def __init__(self, input_dim: int, n_outputs: int):
super().__init__()
self.n_outputs = n_outputs
# Output: 4 parameters per target (γ, ν, α, β)
self.fc = nn.Linear(input_dim, n_outputs * 4)
# Initialize carefully for stable NIG parameters
nn.init.xavier_normal_(self.fc.weight, gain=0.1)
nn.init.zeros_(self.fc.bias)
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, ...]:
"""
Returns
-------
gamma : Tensor (B, n_outputs) - predictive mean
nu : Tensor (B, n_outputs) - evidence for mean (>0)
alpha : Tensor (B, n_outputs) - evidence for variance (>1)
beta : Tensor (B, n_outputs) - scale for variance (>0)
"""
out = self.fc(x).reshape(-1, self.n_outputs, 4)
gamma = out[..., 0]
nu = F.softplus(out[..., 1]) + 1e-6 # ν > 0
alpha = F.softplus(out[..., 2]) + 1.0 + 1e-6 # α > 1
beta = F.softplus(out[..., 3]) + 1e-6 # β > 0
return gamma, nu, alpha, beta
@staticmethod
def aleatoric_uncertainty(alpha: torch.Tensor, beta: torch.Tensor) -> torch.Tensor:
"""E[σ²] = β / (α - 1)"""
return beta / (alpha - 1.0).clamp(min=1e-6)
@staticmethod
def epistemic_uncertainty(nu: torch.Tensor, alpha: torch.Tensor, beta: torch.Tensor) -> torch.Tensor:
"""Var[μ] = β / (ν(α - 1))"""
return beta / (nu * (alpha - 1.0).clamp(min=1e-6))
# =============================================================================
# 4. DEPTH-CONDITIONED GATING NETWORK (Learned Dynamic Fusion)
# =============================================================================
class DepthConditionedGatingNetwork(nn.Module):
"""
Input-conditioned gating network that dynamically determines how to
fuse XGBoost and Neural Network predictions.
Unlike a fixed 60/40 weighting, this network learns WHEN each expert
is more reliable — conditioned on both input features and focusing depth.
Key insight from DELE (arxiv:2302.00932): the gating network benefits
from seeing the same features as the experts, plus the experts' own
predictions as additional input.
Architecture:
[input_features ⊕ depth_physics ⊕ expert_predictions] → MLP → softmax(2)
"""
def __init__(
self,
input_dim: int,
depth_dim: int,
n_expert_outputs: int,
n_experts: int = 2,
hidden_dim: int = 64,
):
super().__init__()
total_input = input_dim + depth_dim + n_expert_outputs * n_experts
self.gate = nn.Sequential(
nn.Linear(total_input, hidden_dim),
nn.SiLU(),
nn.Dropout(0.1),
nn.Linear(hidden_dim, hidden_dim // 2),
nn.SiLU(),
nn.Linear(hidden_dim // 2, n_experts),
)
# Temperature parameter (learnable) for softmax sharpness
self.temperature = nn.Parameter(torch.ones(1))
def forward(
self,
features: torch.Tensor,
depth_physics: torch.Tensor,
expert_preds: List[torch.Tensor],
) -> torch.Tensor:
"""
Parameters
----------
features : Tensor (B, input_dim)
depth_physics : Tensor (B, depth_dim)
expert_preds : list of Tensor (B, n_outputs) per expert
Returns
-------
weights : Tensor (B, n_experts) - softmax weights summing to 1
"""
gate_input = torch.cat(
[features, depth_physics] + expert_preds, dim=-1
)
logits = self.gate(gate_input) / self.temperature.clamp(min=0.1)
return F.softmax(logits, dim=-1)
# =============================================================================
# 5. COMPLETE DCDE MODEL
# =============================================================================
class DCDE(nn.Module):
"""
Depth-Conditioned Dynamic Ensemble (DCDE)
A hybrid architecture for predicting femtosecond laser internal etching
geometry in hydrogels. Combines:
1. XGBoost branch: Pre-trained gradient-boosted trees capturing
complex tabular feature interactions (frozen during DCDE training)
2. FiLM-NN branch: Depth-conditioned neural network where focusing
depth modulates intermediate representations via FiLM layers
3. Dynamic gating: Input-conditioned fusion network that learns
optimal weighting between branches depending on input regime
4. Evidential head: NIG distribution output providing calibrated
aleatoric + epistemic uncertainty
5. Physics-informed loss: Soft monotonicity constraints and energy
conservation regularization
Training protocol (3-phase, following DELE):
Phase 1: Train XGBoost independently on tabular features
Phase 2: Train FiLM-NN with evidential head (XGBoost frozen)
Phase 3: Train gating network jointly (optionally fine-tune FiLM-NN)
"""
def __init__(
self,
input_dim: int,
depth_physics_dim: int = 5,
hidden_dims: List[int] = [128, 96, 64],
n_outputs: int = 5,
n_experts: int = 2,
gating_hidden: int = 64,
):
super().__init__()
self.input_dim = input_dim
self.n_outputs = n_outputs
# FiLM-conditioned NN branch
self.film_nn = FiLMConditionedMLP(
input_dim=input_dim,
hidden_dims=hidden_dims,
output_dim=hidden_dims[-1],
conditioning_dim=depth_physics_dim,
)
# XGBoost prediction embedding (projects XGB outputs to latent space)
self.xgb_embed = nn.Sequential(
nn.Linear(n_outputs, hidden_dims[-1]),
nn.SiLU(),
nn.Linear(hidden_dims[-1], hidden_dims[-1]),
)
# Gating network
self.gating = DepthConditionedGatingNetwork(
input_dim=input_dim,
depth_dim=depth_physics_dim,
n_expert_outputs=n_outputs,
n_experts=n_experts,
hidden_dim=gating_hidden,
)
# Evidential head (NIG parameters)
self.evidential_head = EvidentialHead(hidden_dims[-1], n_outputs)
# Direct output head for XGBoost branch (for gating comparison)
self.xgb_output = nn.Linear(hidden_dims[-1], n_outputs)
def forward(
self,
features: torch.Tensor,
depth_physics: torch.Tensor,
xgb_predictions: torch.Tensor,
) -> Dict[str, torch.Tensor]:
"""
Parameters
----------
features : Tensor (B, input_dim) - all input features
depth_physics : Tensor (B, depth_physics_dim) - computed physics features
xgb_predictions : Tensor (B, n_outputs) - pre-computed XGBoost predictions
Returns
-------
dict with keys:
'gamma' : predictive mean (B, n_outputs)
'nu', 'alpha', 'beta' : NIG parameters
'aleatoric_unc' : aleatoric uncertainty
'epistemic_unc' : epistemic uncertainty
'gate_weights' : expert weights (B, 2)
'nn_pred' : raw NN branch prediction
'xgb_pred' : embedded XGBoost prediction
"""
# NN branch: depth-conditioned via FiLM
nn_latent = self.film_nn(features, depth_physics)
# XGBoost branch: embed predictions into latent space
xgb_latent = self.xgb_embed(xgb_predictions)
# Compute intermediate predictions for gating input
nn_pred_raw = self.evidential_head(nn_latent)[0] # Just gamma
# Dynamic gating: determine expert weights
gate_weights = self.gating(
features, depth_physics,
[xgb_predictions, nn_pred_raw.detach()] # Detach to avoid circular gradients
)
# Fused latent representation
w_xgb = gate_weights[:, 0:1] # (B, 1)
w_nn = gate_weights[:, 1:2] # (B, 1)
fused_latent = w_xgb * xgb_latent + w_nn * nn_latent
# Evidential output
gamma, nu, alpha, beta = self.evidential_head(fused_latent)
# Uncertainty decomposition
aleatoric = EvidentialHead.aleatoric_uncertainty(alpha, beta)
epistemic = EvidentialHead.epistemic_uncertainty(nu, alpha, beta)
return {
"gamma": gamma,
"nu": nu,
"alpha": alpha,
"beta": beta,
"aleatoric_unc": aleatoric,
"epistemic_unc": epistemic,
"gate_weights": gate_weights,
"nn_pred": nn_pred_raw,
"xgb_pred": xgb_predictions,
}
# =============================================================================
# 6. LOSS FUNCTIONS (NIG + Physics-Informed)
# =============================================================================
class DCDELoss(nn.Module):
"""
Composite loss for DCDE training:
L_total = L_NIG + λ_mono·L_monotonicity + λ_energy·L_energy + λ_gate·L_gate_entropy
Components:
1. NIG Loss (evidential regression) - primary data fitting
2. Monotonicity loss - enforces physical depth-etch relationships
3. Energy conservation - volume scales with deposited energy
4. Gate entropy regularization - prevents degenerate gating
"""
def __init__(
self,
lambda_nig_reg: float = 0.01,
lambda_mono: float = 0.05,
lambda_energy: float = 0.02,
lambda_gate: float = 0.01,
depth_feature_idx: int = -1,
power_feature_idx: int = 0,
):
super().__init__()
self.lambda_nig_reg = lambda_nig_reg
self.lambda_mono = lambda_mono
self.lambda_energy = lambda_energy
self.lambda_gate = lambda_gate
self.depth_idx = depth_feature_idx
self.power_idx = power_feature_idx
def nig_loss(
self,
y: torch.Tensor,
gamma: torch.Tensor,
nu: torch.Tensor,
alpha: torch.Tensor,
beta: torch.Tensor,
) -> torch.Tensor:
"""
Normal-Inverse-Gamma negative log-likelihood with evidence regularization.
L = L_NLL + λ·L_evidence_regularization
The regularization penalizes high evidence (ν, α) when the prediction
is wrong, encouraging the model to be uncertain when inaccurate.
"""
# NLL term
omega = 2 * beta * (1 + nu)
nll = (
0.5 * torch.log(torch.pi / nu.clamp(min=1e-6))
- alpha * torch.log(omega.clamp(min=1e-10))
+ (alpha + 0.5) * torch.log(
((y - gamma) ** 2 * nu + omega).clamp(min=1e-10)
)
+ torch.lgamma(alpha) - torch.lgamma(alpha + 0.5)
)
# Evidence regularization (penalize evidence when wrong)
error = torch.abs(y - gamma)
evidence = 2 * nu + alpha
reg = error * evidence
return (nll + self.lambda_nig_reg * reg).mean()
def monotonicity_loss(
self,
features: torch.Tensor,
gamma: torch.Tensor,
model: nn.Module,
depth_physics: torch.Tensor,
xgb_pred: torch.Tensor,
) -> torch.Tensor:
"""
Soft monotonicity constraint: for most targets, increasing laser
parameters (power, passes) at fixed depth should not decrease output.
Specifically for depth etching:
- More passes → deeper etch (target 0: etch_depth)
- Higher fluence → wider etch (target 1: etch_width)
Implemented as finite-difference gradient penalty.
"""
# Perturb power upward by small amount
features_perturbed = features.clone()
features_perturbed[:, self.power_idx] = features[:, self.power_idx] * 1.05
# Get predictions for perturbed input
with torch.no_grad():
output_perturbed = model(features_perturbed, depth_physics, xgb_pred)
# Depth and width should increase with power (soft constraint)
# Only penalize violations (relu of negative gradient)
violation_depth = F.relu(gamma[:, 0] - output_perturbed["gamma"][:, 0])
violation_width = F.relu(gamma[:, 1] - output_perturbed["gamma"][:, 1])
return (violation_depth.mean() + violation_width.mean()) / 2
def energy_conservation_loss(
self,
features: torch.Tensor,
gamma: torch.Tensor,
) -> torch.Tensor:
"""
Soft energy constraint: predicted ablated volume should correlate
positively with deposited energy.
Volume proxy ∝ depth × width²
Energy proxy ∝ power × (num_passes / scan_speed)
We penalize anti-correlation (negative cosine similarity).
"""
# Volume proxy from predictions
depth_pred = gamma[:, 0].clamp(min=0)
width_pred = gamma[:, 1].clamp(min=0)
volume_proxy = depth_pred * width_pred ** 2
# Energy proxy from inputs
power = features[:, self.power_idx].clamp(min=1e-6)
energy_proxy = power # Simplified; could include scan speed, passes
# Penalize negative correlation
# Cosine similarity should be positive
cos_sim = F.cosine_similarity(
volume_proxy.unsqueeze(-1),
energy_proxy.unsqueeze(-1),
dim=0,
)
return F.relu(-cos_sim).mean()
def gate_entropy_loss(self, gate_weights: torch.Tensor) -> torch.Tensor:
"""
Encourage non-degenerate gating (not always choosing one expert).
Maximize entropy of gate weights (encourage exploration).
Penalize when one weight is always 0 or 1.
"""
# Per-sample entropy
entropy = -(gate_weights * torch.log(gate_weights + 1e-8)).sum(dim=-1)
# Maximize entropy → minimize negative entropy
max_entropy = math.log(gate_weights.shape[-1])
return (max_entropy - entropy.mean())
def forward(
self,
y: torch.Tensor,
model_output: Dict[str, torch.Tensor],
features: torch.Tensor,
depth_physics: torch.Tensor,
model: Optional[nn.Module] = None,
) -> Dict[str, torch.Tensor]:
"""
Compute total loss with all components.
Returns dict with individual loss components for logging.
"""
gamma = model_output["gamma"]
nu = model_output["nu"]
alpha = model_output["alpha"]
beta = model_output["beta"]
gate_weights = model_output["gate_weights"]
xgb_pred = model_output["xgb_pred"]
# Primary loss: NIG
l_nig = self.nig_loss(y, gamma, nu, alpha, beta)
# Physics losses
l_mono = torch.tensor(0.0, device=y.device)
if model is not None and self.lambda_mono > 0:
l_mono = self.monotonicity_loss(features, gamma, model, depth_physics, xgb_pred)
l_energy = torch.tensor(0.0, device=y.device)
if self.lambda_energy > 0:
l_energy = self.energy_conservation_loss(features, gamma)
# Gating regularization
l_gate = self.gate_entropy_loss(gate_weights)
# Total
total = (
l_nig
+ self.lambda_mono * l_mono
+ self.lambda_energy * l_energy
+ self.lambda_gate * l_gate
)
return {
"total": total,
"nig": l_nig,
"monotonicity": l_mono,
"energy": l_energy,
"gate_entropy": l_gate,
}
# =============================================================================
# 7. TRAINING UTILITIES
# =============================================================================
class DCDETrainer:
"""
Three-phase training protocol for DCDE.
Phase 1: Train XGBoost on tabular features (external, uses sklearn/xgboost)
Phase 2: Train FiLM-NN with evidential head (XGBoost predictions as input)
Phase 3: Train gating network + fine-tune FiLM-NN end-to-end
"""
def __init__(
self,
model: DCDE,
loss_fn: DCDELoss,
lr_phase2: float = 1e-3,
lr_phase3: float = 3e-4,
weight_decay: float = 1e-4,
device: str = "cpu",
):
self.model = model.to(device)
self.loss_fn = loss_fn
self.lr_phase2 = lr_phase2
self.lr_phase3 = lr_phase3
self.weight_decay = weight_decay
self.device = device
def phase2_train_step(
self,
features: torch.Tensor,
depth_physics: torch.Tensor,
xgb_predictions: torch.Tensor,
targets: torch.Tensor,
optimizer: torch.optim.Optimizer,
) -> Dict[str, float]:
"""Single training step for Phase 2 (FiLM-NN + evidential head)."""
self.model.train()
optimizer.zero_grad()
output = self.model(features, depth_physics, xgb_predictions)
losses = self.loss_fn(targets, output, features, depth_physics, self.model)
losses["total"].backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
optimizer.step()
return {k: v.item() for k, v in losses.items()}
def phase3_train_step(
self,
features: torch.Tensor,
depth_physics: torch.Tensor,
xgb_predictions: torch.Tensor,
targets: torch.Tensor,
optimizer: torch.optim.Optimizer,
) -> Dict[str, float]:
"""Single training step for Phase 3 (end-to-end with gating)."""
# Same as phase 2 but with different learning rate and all params unfrozen
return self.phase2_train_step(features, depth_physics, xgb_predictions, targets, optimizer)
@torch.no_grad()
def predict(
self,
features: torch.Tensor,
depth_physics: torch.Tensor,
xgb_predictions: torch.Tensor,
) -> Dict[str, np.ndarray]:
"""
Inference with uncertainty quantification.
Returns
-------
dict with:
'mean': predicted values (B, n_outputs)
'aleatoric_unc': aleatoric uncertainty per target
'epistemic_unc': epistemic uncertainty per target
'total_unc': total predictive uncertainty
'gate_weights': expert weights showing XGB vs NN dominance
"""
self.model.eval()
output = self.model(features, depth_physics, xgb_predictions)
return {
"mean": output["gamma"].cpu().numpy(),
"aleatoric_unc": output["aleatoric_unc"].cpu().numpy(),
"epistemic_unc": output["epistemic_unc"].cpu().numpy(),
"total_unc": (output["aleatoric_unc"] + output["epistemic_unc"]).cpu().numpy(),
"gate_weights": output["gate_weights"].cpu().numpy(),
}