Spaces:
Running
Running
| """ | |
| src.models.deep.vae_lstm | |
| ======================== | |
| Variational Autoencoder with LSTM encoder/decoder for battery health | |
| state embedding and anomaly detection. | |
| Architecture: | |
| - Encoder: 2-layer bi-LSTM → μ and log-σ (latent dim) | |
| - Reparameterization: z = μ + ε·σ | |
| - Decoder: 2-layer LSTM → reconstruct input sequence | |
| - Health head: latent μ → MLP → SOH/RUL prediction | |
| - Anomaly: cycles with reconstruction error > 3σ flagged | |
| Loss: L = L_recon + β·L_KL (β annealed from 0→1 during training) | |
| """ | |
| from __future__ import annotations | |
| import math | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class VAE_LSTM(nn.Module): | |
| """Variational Autoencoder with LSTM backbone for battery sequences.""" | |
| def __init__( | |
| self, | |
| input_dim: int, | |
| seq_len: int, | |
| hidden_dim: int = 128, | |
| latent_dim: int = 16, | |
| n_layers: int = 2, | |
| dropout: float = 0.2, | |
| ): | |
| super().__init__() | |
| self.input_dim = input_dim | |
| self.seq_len = seq_len | |
| self.hidden_dim = hidden_dim | |
| self.latent_dim = latent_dim | |
| # ── Encoder ── | |
| self.encoder_lstm = nn.LSTM( | |
| input_dim, hidden_dim, num_layers=n_layers, | |
| batch_first=True, bidirectional=True, | |
| dropout=dropout if n_layers > 1 else 0, | |
| ) | |
| self.fc_mu = nn.Linear(hidden_dim * 2, latent_dim) | |
| self.fc_logvar = nn.Linear(hidden_dim * 2, latent_dim) | |
| # ── Decoder ── | |
| self.decoder_input = nn.Linear(latent_dim, hidden_dim) | |
| self.decoder_lstm = nn.LSTM( | |
| hidden_dim, hidden_dim, num_layers=n_layers, | |
| batch_first=True, dropout=dropout if n_layers > 1 else 0, | |
| ) | |
| self.decoder_output = nn.Linear(hidden_dim, input_dim) | |
| # ── Health prediction head ── | |
| self.health_head = nn.Sequential( | |
| nn.Linear(latent_dim, 64), | |
| nn.ReLU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(64, 32), | |
| nn.ReLU(), | |
| nn.Linear(32, 1), | |
| ) | |
| def encode(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: | |
| """Encode input sequence to latent distribution parameters.""" | |
| out, (h_n, _) = self.encoder_lstm(x) | |
| # Concatenate last forward and backward hidden states | |
| h_fwd = h_n[-2] | |
| h_bwd = h_n[-1] | |
| h_cat = torch.cat([h_fwd, h_bwd], dim=-1) # (B, 2*H) | |
| mu = self.fc_mu(h_cat) # (B, latent_dim) | |
| logvar = self.fc_logvar(h_cat) # (B, latent_dim) | |
| return mu, logvar | |
| def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor: | |
| """Reparameterization trick: z = μ + ε·σ.""" | |
| std = torch.exp(0.5 * logvar) | |
| eps = torch.randn_like(std) | |
| return mu + eps * std | |
| def decode(self, z: torch.Tensor) -> torch.Tensor: | |
| """Decode latent vector to reconstructed sequence.""" | |
| # Repeat latent vector across sequence length | |
| z_proj = self.decoder_input(z) # (B, H) | |
| z_seq = z_proj.unsqueeze(1).repeat(1, self.seq_len, 1) # (B, T, H) | |
| out, _ = self.decoder_lstm(z_seq) # (B, T, H) | |
| recon = self.decoder_output(out) # (B, T, input_dim) | |
| return recon | |
| def forward(self, x: torch.Tensor) -> dict[str, torch.Tensor]: | |
| """Full forward pass. | |
| Returns dict with keys: recon, mu, logvar, z, health_pred | |
| """ | |
| mu, logvar = self.encode(x) | |
| z = self.reparameterize(mu, logvar) | |
| recon = self.decode(z) | |
| health_pred = self.health_head(mu).squeeze(-1) # Use μ (not z) for deterministic health estimate | |
| return { | |
| "recon": recon, | |
| "mu": mu, | |
| "logvar": logvar, | |
| "z": z, | |
| "health_pred": health_pred, | |
| } | |
| def vae_loss( | |
| x: torch.Tensor, | |
| recon: torch.Tensor, | |
| mu: torch.Tensor, | |
| logvar: torch.Tensor, | |
| beta: float = 1.0, | |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
| """VAE loss = reconstruction loss + β × KL divergence. | |
| Returns (total_loss, recon_loss, kl_loss). | |
| """ | |
| recon_loss = F.mse_loss(recon, x, reduction="mean") | |
| # KL divergence: -0.5 * Σ(1 + log(σ²) - μ² - σ²) | |
| kl_loss = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp()) | |
| total = recon_loss + beta * kl_loss | |
| return total, recon_loss, kl_loss | |
| class BetaScheduler: | |
| """KL annealing: β increases linearly from 0 to 1 over warmup epochs.""" | |
| def __init__(self, warmup_epochs: int = 30, max_beta: float = 1.0): | |
| self.warmup_epochs = warmup_epochs | |
| self.max_beta = max_beta | |
| def get_beta(self, epoch: int) -> float: | |
| if epoch >= self.warmup_epochs: | |
| return self.max_beta | |
| return self.max_beta * (epoch / self.warmup_epochs) | |
| def detect_anomalies( | |
| model: VAE_LSTM, | |
| dataloader: torch.utils.data.DataLoader, | |
| device: str = "cpu", | |
| threshold_sigma: float = 3.0, | |
| ) -> tuple[list[bool], list[float]]: | |
| """Flag cycles with reconstruction error > threshold_sigma × σ. | |
| Returns: | |
| - anomaly_flags: list of bool per sample | |
| - recon_errors: list of float (MSE per sample) | |
| """ | |
| model.eval() | |
| all_errors = [] | |
| with torch.no_grad(): | |
| for xb, *_ in dataloader: | |
| xb = xb.to(device) | |
| out = model(xb) | |
| mse = F.mse_loss(out["recon"], xb, reduction="none").mean(dim=(1, 2)) | |
| all_errors.extend(mse.cpu().tolist()) | |
| errors = torch.tensor(all_errors) | |
| mu_err = errors.mean() | |
| std_err = errors.std() | |
| threshold = mu_err + threshold_sigma * std_err | |
| flags = (errors > threshold).tolist() | |
| return flags, all_errors | |
| def train_vae( | |
| model: VAE_LSTM, | |
| train_loader: torch.utils.data.DataLoader, | |
| val_loader: torch.utils.data.DataLoader, | |
| y_train_health: torch.Tensor | None = None, | |
| *, | |
| max_epochs: int = 150, | |
| lr: float = 1e-3, | |
| patience: int = 20, | |
| device: str = "cpu", | |
| warmup_epochs: int = 30, | |
| health_weight: float = 1.0, | |
| ) -> dict: | |
| """Train VAE-LSTM with KL annealing and optional health prediction loss.""" | |
| from src.models.deep.lstm import EarlyStopping | |
| model = model.to(device) | |
| optimizer = torch.optim.Adam(model.parameters(), lr=lr) | |
| scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( | |
| optimizer, mode="min", factor=0.5, patience=10, | |
| ) | |
| beta_scheduler = BetaScheduler(warmup_epochs) | |
| early_stop = EarlyStopping(patience=patience) | |
| train_losses, val_losses = [], [] | |
| for epoch in range(1, max_epochs + 1): | |
| beta = beta_scheduler.get_beta(epoch) | |
| model.train() | |
| epoch_loss = 0.0 | |
| n_batches = 0 | |
| for batch in train_loader: | |
| if len(batch) == 2: | |
| xb, yb = batch | |
| xb, yb = xb.to(device), yb.to(device) | |
| else: | |
| xb = batch[0].to(device) | |
| yb = None | |
| optimizer.zero_grad() | |
| out = model(xb) | |
| total, recon_l, kl_l = vae_loss(xb, out["recon"], out["mu"], out["logvar"], beta) | |
| if yb is not None: | |
| health_loss = F.l1_loss(out["health_pred"], yb) | |
| total = total + health_weight * health_loss | |
| total.backward() | |
| nn.utils.clip_grad_norm_(model.parameters(), 1.0) | |
| optimizer.step() | |
| epoch_loss += total.item() | |
| n_batches += 1 | |
| train_losses.append(epoch_loss / max(n_batches, 1)) | |
| # Validation | |
| model.eval() | |
| val_loss = 0.0 | |
| n_val = 0 | |
| with torch.no_grad(): | |
| for batch in val_loader: | |
| if len(batch) == 2: | |
| xb, yb = batch | |
| xb, yb = xb.to(device), yb.to(device) | |
| else: | |
| xb = batch[0].to(device) | |
| yb = None | |
| out = model(xb) | |
| total, _, _ = vae_loss(xb, out["recon"], out["mu"], out["logvar"], beta) | |
| if yb is not None: | |
| total = total + health_weight * F.l1_loss(out["health_pred"], yb) | |
| val_loss += total.item() | |
| n_val += 1 | |
| val_losses.append(val_loss / max(n_val, 1)) | |
| scheduler.step(val_losses[-1]) | |
| if early_stop.step(val_losses[-1], model): | |
| break | |
| early_stop.load_best(model) | |
| return {"train_losses": train_losses, "val_losses": val_losses} | |