cyclone-pred-api / src /fuzzy_neural_network.py
clarindasusan's picture
Update src/fuzzy_neural_network.py
7c30ee6 verified
"""
Fuzzy Neural Network (ANFIS-style) Implementation
===================================================
Architecture:
Layer 1 - Fuzzification : Gaussian membership functions per input feature
Layer 2 - Rule Strength : Product T-norm across fuzzy sets per rule
Layer 3 - Normalization : Normalize rule firing strengths
Layer 4 - Consequence : Linear Takagi-Sugeno consequent per rule
Layer 5 - Defuzzification : Weighted sum → crisp output in [0, 1]
Why ANFIS over plain XGBoost for disaster risk:
- Risk is inherently fuzzy (e.g. "moderately high rainfall" has no hard boundary)
- Gaussian MFs produce smooth, interpretable uncertainty bounds per feature
- Fully differentiable → trainable end-to-end with backprop
- Outputs are probabilistic risk scores, not just class labels
"""
import torch
import torch.nn as nn
import numpy as np
from typing import List, Tuple, Optional
import pickle
import os
class GaussianMembershipLayer(nn.Module):
"""
Layer 1: Fuzzification
For each input feature, compute membership degrees across n_terms fuzzy sets.
Each fuzzy set is parameterized by (mean, sigma) — both learnable.
Output shape: (batch, n_features * n_terms)
"""
def __init__(self, n_features: int, n_terms: int = 3):
super().__init__()
self.n_features = n_features
self.n_terms = n_terms
# Initialize means evenly spaced in [0, 1] for each feature
means_init = torch.linspace(0.1, 0.9, n_terms).unsqueeze(0).repeat(n_features, 1)
self.means = nn.Parameter(means_init)
# Initialize sigmas — moderately wide
sigmas_init = torch.full((n_features, n_terms), 0.3)
self.sigmas = nn.Parameter(sigmas_init)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# x: (batch, n_features)
# Expand to (batch, n_features, 1) for broadcasting
x_expanded = x.unsqueeze(2) # (batch, n_features, 1)
means = self.means.unsqueeze(0) # (1, n_features, n_terms)
sigmas = torch.clamp(self.sigmas.unsqueeze(0), min=1e-4)
# Gaussian MF: exp(-0.5 * ((x - c) / sigma)^2)
memberships = torch.exp(
-0.5 * ((x_expanded - means) / sigmas) ** 2
) # (batch, n_features, n_terms)
return memberships
class RuleStrengthLayer(nn.Module):
"""
Layer 2: Rule Firing Strength
Each rule is a combination of one fuzzy term per feature.
Firing strength = product of membership degrees for that rule's terms.
Total rules = n_terms ^ n_features (exponential — we use a subset via random projection
for large feature sets to keep it tractable).
For n_features <= 6, full rule base is used.
For larger inputs, we use n_rules random combinations.
"""
def __init__(self, n_features: int, n_terms: int = 3, n_rules: Optional[int] = None):
super().__init__()
self.n_features = n_features
self.n_terms = n_terms
full_rules = n_terms ** n_features
if n_rules is None:
# Cap at 81 rules to keep computation tractable
self.n_rules = min(full_rules, 81)
else:
self.n_rules = n_rules
# Generate rule indices: each row is a rule, each column is a term index per feature
torch.manual_seed(42)
if full_rules <= self.n_rules:
# Full rule base via cartesian product
import itertools
combos = list(itertools.product(range(n_terms), repeat=n_features))
rule_indices = torch.tensor(combos, dtype=torch.long)
else:
# Random subset of rules
rule_indices = torch.randint(0, n_terms, (self.n_rules, n_features))
self.register_buffer('rule_indices', rule_indices) # (n_rules, n_features)
def forward(self, memberships: torch.Tensor) -> torch.Tensor:
# memberships: (batch, n_features, n_terms)
batch_size = memberships.shape[0]
# Gather membership value for each rule's term assignment
# rule_indices: (n_rules, n_features)
# For each rule r and feature f, get memberships[:, f, rule_indices[r, f]]
rule_indices_expanded = self.rule_indices.unsqueeze(0).expand(batch_size, -1, -1)
# (batch, n_rules, n_features)
memberships_expanded = memberships.unsqueeze(1).expand(-1, self.n_rules, -1, -1)
# (batch, n_rules, n_features, n_terms)
# Gather along term dimension
selected = memberships_expanded.gather(
3,
rule_indices_expanded.unsqueeze(3)
).squeeze(3)
# (batch, n_rules, n_features)
# Product T-norm across features → rule firing strength
rule_strengths = selected.prod(dim=2) # (batch, n_rules)
return rule_strengths
class ConsequentLayer(nn.Module):
"""
Layer 3 + 4: Normalized Consequence
Takagi-Sugeno style: each rule has a linear consequent over inputs.
f_r(x) = w_r0 + w_r1*x1 + ... + w_rn*xn
Final output = sum(normalized_strength_r * f_r(x))
"""
def __init__(self, n_features: int, n_rules: int):
super().__init__()
self.n_features = n_features
self.n_rules = n_rules
# Consequent parameters: (n_rules, n_features + 1) — includes bias
self.consequent_weights = nn.Parameter(
torch.randn(n_rules, n_features + 1) * 0.1
)
def forward(self, rule_strengths: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
# rule_strengths: (batch, n_rules)
# x: (batch, n_features)
# Normalize rule strengths
strength_sum = rule_strengths.sum(dim=1, keepdim=True).clamp(min=1e-8)
normalized_strengths = rule_strengths / strength_sum # (batch, n_rules)
# Compute linear consequence for each rule
x_aug = torch.cat([x, torch.ones(x.shape[0], 1, device=x.device)], dim=1)
# x_aug: (batch, n_features + 1)
# f_r = x_aug @ consequent_weights[r]
# consequent_weights: (n_rules, n_features + 1) → transpose → (n_features+1, n_rules)
consequences = x_aug @ self.consequent_weights.T # (batch, n_rules)
# Weighted sum (defuzzification)
output = (normalized_strengths * consequences).sum(dim=1) # (batch,)
return output
class FuzzyNeuralNetwork(nn.Module):
"""
Full ANFIS model with optional deep extension layers.
For complex disaster risk prediction, a shallow ANFIS may underfit.
We add optional dense layers after defuzzification for capacity.
Parameters
----------
n_features : number of input features
n_terms : fuzzy terms per feature (default 3: LOW, MEDIUM, HIGH)
n_rules : number of fuzzy rules (None = auto)
hidden_dims : list of hidden layer sizes for post-defuzz processing
dropout : dropout rate for regularization
"""
def __init__(
self,
n_features: int,
n_terms: int = 3,
n_rules: Optional[int] = None,
hidden_dims: List[int] = [64, 32],
dropout: float = 0.2
):
super().__init__()
self.n_features = n_features
self.n_terms = n_terms
# ANFIS layers
self.fuzzify = GaussianMembershipLayer(n_features, n_terms)
self.rule_layer = RuleStrengthLayer(n_features, n_terms, n_rules)
actual_n_rules = self.rule_layer.n_rules
self.consequent = ConsequentLayer(n_features, actual_n_rules)
# Optional deep layers
layers = []
in_dim = 1 # output of defuzzification is scalar
for h in hidden_dims:
layers += [nn.Linear(in_dim, h), nn.GELU(), nn.Dropout(dropout)]
in_dim = h
layers += [nn.Linear(in_dim, 1), nn.Sigmoid()]
self.deep_head = nn.Sequential(*layers) if hidden_dims else nn.Sigmoid()
self._init_weights()
def _init_weights(self):
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
memberships = self.fuzzify(x) # (batch, n_features, n_terms)
rule_strengths = self.rule_layer(memberships) # (batch, n_rules)
defuzz = self.consequent(rule_strengths, x) # (batch,)
out = self.deep_head(defuzz.unsqueeze(1)) # (batch, 1)
return out.squeeze(1) # (batch,)
def predict_with_uncertainty(
self,
x: torch.Tensor,
n_samples: int = 50
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Monte Carlo dropout for uncertainty estimation.
Returns (mean_prediction, std_deviation) — both in [0, 1].
Enables the API to return confidence intervals with each prediction.
"""
self.train() # Enable dropout
preds = torch.stack([self(x) for _ in range(n_samples)], dim=0)
self.eval()
return preds.mean(dim=0), preds.std(dim=0)
def get_membership_degrees(self, x: torch.Tensor) -> dict:
"""
Returns interpretable fuzzy membership degrees for a single input.
Useful for the API's /explain endpoint.
"""
with torch.no_grad():
memberships = self.fuzzify(x) # (1, n_features, n_terms)
return memberships.squeeze(0).cpu().numpy() # (n_features, n_terms)
class FNNTrainer:
"""
Training harness for FuzzyNeuralNetwork.
Supports early stopping, LR scheduling, and loss logging.
"""
def __init__(
self,
model: FuzzyNeuralNetwork,
lr: float = 1e-3,
weight_decay: float = 1e-4
):
self.model = model
self.optimizer = torch.optim.AdamW(
model.parameters(), lr=lr, weight_decay=weight_decay
)
self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
self.optimizer, patience=10, factor=0.5
)
self.criterion = nn.BCELoss()
self.history = {"train_loss": [], "val_loss": []}
def train_epoch(
self,
X_train: torch.Tensor,
y_train: torch.Tensor,
batch_size: int = 64
) -> float:
self.model.train()
total_loss = 0.0
n_batches = 0
# Shuffle
perm = torch.randperm(X_train.shape[0])
X_train, y_train = X_train[perm], y_train[perm]
for i in range(0, X_train.shape[0], batch_size):
X_batch = X_train[i:i+batch_size]
y_batch = y_train[i:i+batch_size]
self.optimizer.zero_grad()
preds = self.model(X_batch)
loss = self.criterion(preds, y_batch)
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
self.optimizer.step()
total_loss += loss.item()
n_batches += 1
return total_loss / n_batches
def validate(self, X_val: torch.Tensor, y_val: torch.Tensor) -> float:
self.model.eval()
with torch.no_grad():
preds = self.model(X_val)
loss = self.criterion(preds, y_val)
return loss.item()
def fit(
self,
X_train: torch.Tensor,
y_train: torch.Tensor,
X_val: torch.Tensor,
y_val: torch.Tensor,
epochs: int = 200,
batch_size: int = 64,
patience: int = 20
) -> dict:
best_val_loss = float('inf')
best_state = None
patience_counter = 0
for epoch in range(epochs):
train_loss = self.train_epoch(X_train, y_train, batch_size)
val_loss = self.validate(X_val, y_val)
self.history["train_loss"].append(train_loss)
self.history["val_loss"].append(val_loss)
self.scheduler.step(val_loss)
if val_loss < best_val_loss:
best_val_loss = val_loss
best_state = {k: v.clone() for k, v in self.model.state_dict().items()}
patience_counter = 0
else:
patience_counter += 1
if epoch % 20 == 0:
print(f"Epoch {epoch:4d} | Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
if patience_counter >= patience:
print(f"Early stopping at epoch {epoch}")
break
# Restore best weights
if best_state:
self.model.load_state_dict(best_state)
print(f"Training complete. Best Val Loss: {best_val_loss:.4f}")
return self.history
def save_model(model: FuzzyNeuralNetwork, path: str, feature_names: List[str]):
"""Save model weights + metadata"""
os.makedirs(os.path.dirname(path), exist_ok=True)
torch.save({
"state_dict": model.state_dict(),
"n_features": model.n_features,
"n_terms": model.n_terms,
"feature_names": feature_names,
}, path)
def load_model(path: str) -> Tuple[FuzzyNeuralNetwork, List[str]]:
"""Load model from disk"""
checkpoint = torch.load(path, map_location="cpu")
model = FuzzyNeuralNetwork(
n_features=checkpoint["n_features"],
n_terms=checkpoint["n_terms"]
)
model.load_state_dict(checkpoint["state_dict"])
model.eval()
return model, checkpoint["feature_names"]