Lilith-Weather / models /components /ensemble_head.py
consigcody94's picture
Upload source code and documentation
8bcb60f verified
"""
Ensemble Head for LILITH.
Generates ensemble forecasts for uncertainty quantification using
diffusion-based sampling or dropout-based approaches.
"""
import math
from typing import Optional, Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
class GaussianHead(nn.Module):
"""
Predicts mean and variance for Gaussian output distribution.
Simple but effective for uncertainty estimation.
"""
def __init__(
self,
input_dim: int,
output_dim: int,
hidden_dim: Optional[int] = None,
min_std: float = 0.01,
max_std: float = 10.0,
):
"""
Initialize Gaussian prediction head.
Args:
input_dim: Input feature dimension
output_dim: Output dimension (number of predicted variables)
hidden_dim: Hidden layer dimension
min_std: Minimum standard deviation
max_std: Maximum standard deviation
"""
super().__init__()
self.output_dim = output_dim
self.min_std = min_std
self.max_std = max_std
hidden_dim = hidden_dim or input_dim
# Shared feature extraction
self.shared = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.GELU(),
)
# Mean prediction
self.mean_head = nn.Linear(hidden_dim, output_dim)
# Log variance prediction (for numerical stability)
self.logvar_head = nn.Linear(hidden_dim, output_dim)
def forward(
self,
x: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Predict mean and standard deviation.
Args:
x: Input features of shape (batch, ..., input_dim)
Returns:
Tuple of (mean, std) each of shape (batch, ..., output_dim)
"""
h = self.shared(x)
mean = self.mean_head(h)
logvar = self.logvar_head(h)
# Convert to std with bounds
std = torch.exp(0.5 * logvar)
std = torch.clamp(std, self.min_std, self.max_std)
return mean, std
def sample(
self,
x: torch.Tensor,
n_samples: int = 1,
) -> torch.Tensor:
"""
Generate samples from the predicted distribution.
Args:
x: Input features
n_samples: Number of samples to generate
Returns:
Samples of shape (n_samples, batch, ..., output_dim)
"""
mean, std = self.forward(x)
# Expand for multiple samples
mean = mean.unsqueeze(0).expand(n_samples, *mean.shape)
std = std.unsqueeze(0).expand(n_samples, *std.shape)
# Sample
eps = torch.randn_like(mean)
samples = mean + std * eps
return samples
class QuantileHead(nn.Module):
"""
Predicts multiple quantiles for non-Gaussian distributions.
Useful for skewed variables like precipitation.
"""
def __init__(
self,
input_dim: int,
output_dim: int,
quantiles: Tuple[float, ...] = (0.05, 0.25, 0.5, 0.75, 0.95),
hidden_dim: Optional[int] = None,
):
"""
Initialize quantile prediction head.
Args:
input_dim: Input feature dimension
output_dim: Output dimension (number of predicted variables)
quantiles: Quantile levels to predict
hidden_dim: Hidden layer dimension
"""
super().__init__()
self.output_dim = output_dim
self.quantiles = quantiles
self.n_quantiles = len(quantiles)
hidden_dim = hidden_dim or input_dim
self.net = nn.Sequential(
nn.Linear(input_dim, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, output_dim * self.n_quantiles),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Predict quantiles.
Args:
x: Input features of shape (batch, ..., input_dim)
Returns:
Quantiles of shape (batch, ..., output_dim, n_quantiles)
"""
shape = x.shape[:-1]
out = self.net(x)
# Reshape to separate quantiles
out = out.view(*shape, self.output_dim, self.n_quantiles)
# Ensure quantiles are monotonically increasing
# Using softmax to get positive increments
increments = F.softmax(out, dim=-1)
out = torch.cumsum(increments, dim=-1)
return out
class MCDropoutHead(nn.Module):
"""
Monte Carlo Dropout for uncertainty estimation.
Uses dropout at inference time to generate ensemble samples.
Simple and computationally efficient.
"""
def __init__(
self,
input_dim: int,
output_dim: int,
hidden_dim: Optional[int] = None,
dropout: float = 0.1,
n_layers: int = 2,
):
"""
Initialize MC Dropout head.
Args:
input_dim: Input feature dimension
output_dim: Output dimension
hidden_dim: Hidden layer dimension
dropout: Dropout probability
n_layers: Number of hidden layers
"""
super().__init__()
self.output_dim = output_dim
self.dropout = dropout
hidden_dim = hidden_dim or input_dim
layers = []
in_dim = input_dim
for i in range(n_layers):
layers.extend([
nn.Linear(in_dim, hidden_dim),
nn.GELU(),
nn.Dropout(dropout),
])
in_dim = hidden_dim
layers.append(nn.Linear(hidden_dim, output_dim))
self.net = nn.Sequential(*layers)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Standard forward pass."""
return self.net(x)
def sample(
self,
x: torch.Tensor,
n_samples: int = 10,
) -> torch.Tensor:
"""
Generate samples using MC Dropout.
Args:
x: Input features
n_samples: Number of samples to generate
Returns:
Samples of shape (n_samples, batch, ..., output_dim)
"""
# Ensure dropout is active
self.train()
samples = []
for _ in range(n_samples):
samples.append(self.forward(x))
return torch.stack(samples, dim=0)
class DiffusionEnsembleHead(nn.Module):
"""
Diffusion-based ensemble generation for high-quality uncertainty.
Uses a lightweight denoising diffusion model to generate diverse
ensemble members conditioned on the deterministic forecast.
"""
def __init__(
self,
input_dim: int,
output_dim: int,
hidden_dim: int = 128,
n_steps: int = 50,
beta_start: float = 1e-4,
beta_end: float = 0.02,
):
"""
Initialize diffusion ensemble head.
Args:
input_dim: Input (conditioning) feature dimension
output_dim: Output dimension
hidden_dim: Hidden dimension for denoising network
n_steps: Number of diffusion steps
beta_start: Starting noise schedule
beta_end: Ending noise schedule
"""
super().__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.n_steps = n_steps
# Noise schedule
betas = torch.linspace(beta_start, beta_end, n_steps)
alphas = 1 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
self.register_buffer("betas", betas)
self.register_buffer("alphas", alphas)
self.register_buffer("alphas_cumprod", alphas_cumprod)
self.register_buffer("sqrt_alphas_cumprod", torch.sqrt(alphas_cumprod))
self.register_buffer("sqrt_one_minus_alphas_cumprod", torch.sqrt(1 - alphas_cumprod))
# Time embedding
self.time_embed = nn.Sequential(
nn.Linear(1, hidden_dim),
nn.GELU(),
nn.Linear(hidden_dim, hidden_dim),
)
# Denoising network (simple MLP)
self.denoise_net = nn.Sequential(
nn.Linear(output_dim + input_dim + hidden_dim, hidden_dim * 2),
nn.GELU(),
nn.Linear(hidden_dim * 2, hidden_dim * 2),
nn.GELU(),
nn.Linear(hidden_dim * 2, output_dim),
)
# Mean prediction (deterministic baseline)
self.mean_head = nn.Linear(input_dim, output_dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""Return deterministic mean prediction."""
return self.mean_head(x)
def denoise_step(
self,
x_t: torch.Tensor,
t: torch.Tensor,
condition: torch.Tensor,
) -> torch.Tensor:
"""
Single denoising step.
Args:
x_t: Noisy sample at time t
t: Time step (normalized to [0, 1])
condition: Conditioning information
Returns:
Predicted noise
"""
# Time embedding
t_emb = self.time_embed(t.unsqueeze(-1))
# Concatenate inputs
h = torch.cat([x_t, condition, t_emb], dim=-1)
# Predict noise
return self.denoise_net(h)
def sample(
self,
x: torch.Tensor,
n_samples: int = 10,
) -> torch.Tensor:
"""
Generate ensemble samples via reverse diffusion.
Args:
x: Conditioning features of shape (batch, ..., input_dim)
n_samples: Number of ensemble members
Returns:
Samples of shape (n_samples, batch, ..., output_dim)
"""
shape = x.shape[:-1]
device = x.device
samples = []
for _ in range(n_samples):
# Start from noise
x_t = torch.randn(*shape, self.output_dim, device=device)
# Reverse diffusion
for i in reversed(range(self.n_steps)):
t = torch.full(shape, i / self.n_steps, device=device)
# Predict noise
noise_pred = self.denoise_step(x_t, t, x)
# Denoise step
alpha = self.alphas[i]
alpha_cumprod = self.alphas_cumprod[i]
beta = self.betas[i]
if i > 0:
noise = torch.randn_like(x_t)
else:
noise = 0
x_t = (
1 / torch.sqrt(alpha) *
(x_t - beta / self.sqrt_one_minus_alphas_cumprod[i] * noise_pred)
+ torch.sqrt(beta) * noise
)
# Add deterministic mean
mean = self.mean_head(x)
samples.append(x_t + mean)
return torch.stack(samples, dim=0)
class EnsembleHead(nn.Module):
"""
Unified ensemble head that combines multiple uncertainty methods.
Supports:
- Gaussian parametric uncertainty
- Quantile regression
- MC Dropout
- Diffusion ensemble (optional)
"""
def __init__(
self,
input_dim: int,
output_dim: int,
hidden_dim: int = 128,
method: str = "gaussian", # "gaussian", "quantile", "mc_dropout", "diffusion"
n_quantiles: int = 5,
dropout: float = 0.1,
diffusion_steps: int = 50,
):
"""
Initialize ensemble head.
Args:
input_dim: Input feature dimension
output_dim: Output dimension
hidden_dim: Hidden dimension
method: Uncertainty method to use
n_quantiles: Number of quantiles (for quantile method)
dropout: Dropout rate (for MC dropout)
diffusion_steps: Diffusion steps (for diffusion method)
"""
super().__init__()
self.method = method
self.output_dim = output_dim
if method == "gaussian":
self.head = GaussianHead(input_dim, output_dim, hidden_dim)
elif method == "quantile":
quantiles = tuple([i / (n_quantiles + 1) for i in range(1, n_quantiles + 1)])
self.head = QuantileHead(input_dim, output_dim, quantiles, hidden_dim)
elif method == "mc_dropout":
self.head = MCDropoutHead(input_dim, output_dim, hidden_dim, dropout)
elif method == "diffusion":
self.head = DiffusionEnsembleHead(
input_dim, output_dim, hidden_dim, diffusion_steps
)
else:
raise ValueError(f"Unknown method: {method}")
def forward(
self,
x: torch.Tensor,
) -> torch.Tensor:
"""
Get deterministic prediction.
Args:
x: Input features
Returns:
Prediction (mean for Gaussian, median for quantile, etc.)
"""
if self.method == "gaussian":
mean, _ = self.head(x)
return mean
elif self.method == "quantile":
quantiles = self.head(x)
# Return median (middle quantile)
return quantiles[..., quantiles.size(-1) // 2]
else:
return self.head(x)
def predict_with_uncertainty(
self,
x: torch.Tensor,
n_samples: int = 10,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Get prediction with uncertainty estimates.
Args:
x: Input features
n_samples: Number of samples for MC methods
Returns:
Tuple of (mean, lower_bound, upper_bound)
"""
if self.method == "gaussian":
mean, std = self.head(x)
lower = mean - 1.96 * std # 95% CI
upper = mean + 1.96 * std
return mean, lower, upper
elif self.method == "quantile":
quantiles = self.head(x)
mean = quantiles[..., quantiles.size(-1) // 2]
lower = quantiles[..., 0] # Lowest quantile
upper = quantiles[..., -1] # Highest quantile
return mean, lower, upper
else:
# MC methods
samples = self.head.sample(x, n_samples)
mean = samples.mean(dim=0)
lower = samples.quantile(0.025, dim=0)
upper = samples.quantile(0.975, dim=0)
return mean, lower, upper
def sample(
self,
x: torch.Tensor,
n_samples: int = 10,
) -> torch.Tensor:
"""
Generate ensemble samples.
Args:
x: Input features
n_samples: Number of samples
Returns:
Ensemble samples
"""
if hasattr(self.head, "sample"):
return self.head.sample(x, n_samples)
elif self.method == "gaussian":
return self.head.sample(x, n_samples)
else:
# For quantile, sample uniformly between quantiles
quantiles = self.head(x)
samples = []
for _ in range(n_samples):
# Random interpolation between adjacent quantiles
idx = torch.randint(0, quantiles.size(-1) - 1, (1,)).item()
alpha = torch.rand(1, device=x.device)
sample = (1 - alpha) * quantiles[..., idx] + alpha * quantiles[..., idx + 1]
samples.append(sample)
return torch.stack(samples, dim=0)