Spaces:
Sleeping
Sleeping
| """ | |
| Kolmogorov-Arnold Network (KAN) Classifier | |
| Stage 4b: Interpretable Decision Engine | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from typing import Tuple, Optional, Dict | |
| import math | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| class BSplineBasis(nn.Module): | |
| """ | |
| B-Spline Basis Functions for KAN | |
| B-splines provide smooth, localized basis functions | |
| that enable interpretable learned transformations. | |
| """ | |
| def __init__( | |
| self, | |
| grid_size: int = 5, | |
| spline_order: int = 3, | |
| grid_range: Tuple[float, float] = (-1.0, 1.0), | |
| ): | |
| super().__init__() | |
| self.grid_size = grid_size | |
| self.spline_order = spline_order | |
| self.grid_range = grid_range | |
| # Number of basis functions | |
| self.num_basis = grid_size + spline_order | |
| # Create uniform grid (extended for boundary conditions) | |
| h = (grid_range[1] - grid_range[0]) / grid_size | |
| grid = torch.linspace( | |
| grid_range[0] - spline_order * h, | |
| grid_range[1] + spline_order * h, | |
| grid_size + 2 * spline_order + 1, | |
| ) | |
| self.register_buffer('grid', grid) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Evaluate B-spline basis functions at points x. | |
| Args: | |
| x: Input points (batch, features) | |
| Returns: | |
| Basis values (batch, features, num_basis) | |
| """ | |
| x = x.unsqueeze(-1) # (batch, features, 1) | |
| grid = self.grid # (num_knots,) | |
| # Recursive B-spline computation | |
| # B_i,0(x) = 1 if grid[i] <= x < grid[i+1] else 0 | |
| bases = ((x >= grid[:-1]) & (x < grid[1:])).float() | |
| for k in range(1, self.spline_order + 1): | |
| # B_i,k(x) = (x - grid[i]) / (grid[i+k] - grid[i]) * B_i,k-1(x) | |
| # + (grid[i+k+1] - x) / (grid[i+k+1] - grid[i+1]) * B_i+1,k-1(x) | |
| left_num = x - grid[:-k-1].unsqueeze(0).unsqueeze(0) | |
| left_den = grid[k:-1] - grid[:-k-1] + 1e-8 | |
| left_term = left_num / left_den.unsqueeze(0).unsqueeze(0) * bases[..., :-1] | |
| right_num = grid[k+1:].unsqueeze(0).unsqueeze(0) - x | |
| right_den = grid[k+1:] - grid[1:-k] + 1e-8 | |
| right_term = right_num / right_den.unsqueeze(0).unsqueeze(0) * bases[..., 1:] | |
| bases = left_term + right_term | |
| return bases | |
| class KANLayer(nn.Module): | |
| """ | |
| Kolmogorov-Arnold Network Layer | |
| Unlike MLPs which use fixed activations, KAN learns | |
| univariate functions on the edges using B-splines. | |
| Reference: "KAN: Kolmogorov-Arnold Networks" (Liu et al., 2024) | |
| """ | |
| def __init__( | |
| self, | |
| in_features: int, | |
| out_features: int, | |
| grid_size: int = 5, | |
| spline_order: int = 3, | |
| ): | |
| super().__init__() | |
| self.in_features = in_features | |
| self.out_features = out_features | |
| self.grid_size = grid_size | |
| self.spline_order = spline_order | |
| # B-spline basis | |
| self.basis = BSplineBasis( | |
| grid_size=grid_size, | |
| spline_order=spline_order, | |
| ) | |
| # Number of basis functions | |
| num_basis = grid_size + spline_order | |
| # Learnable spline coefficients for each edge | |
| # Shape: (out_features, in_features, num_basis) | |
| self.coef = nn.Parameter( | |
| torch.randn(out_features, in_features, num_basis) * 0.1 | |
| ) | |
| # Residual (silu) connection for stability | |
| self.residual_weight = nn.Parameter(torch.ones(out_features, in_features) * 0.1) | |
| # Scale factor | |
| self.scale = nn.Parameter(torch.ones(out_features)) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Forward pass through KAN layer. | |
| Args: | |
| x: Input (batch, in_features) | |
| Returns: | |
| Output (batch, out_features) | |
| """ | |
| batch_size = x.size(0) | |
| # Normalize input to grid range | |
| x_norm = torch.tanh(x) # Map to [-1, 1] | |
| # Evaluate B-spline basis | |
| basis = self.basis(x_norm) # (batch, in_features, num_basis) | |
| # Compute spline output | |
| # For each output neuron, sum over input neurons | |
| # output[j] = sum_i sum_k coef[j,i,k] * basis[i,k] | |
| spline_out = torch.einsum('bik,jik->bj', basis, self.coef) | |
| # Add residual (silu) connection | |
| residual = torch.einsum('bi,ji->bj', F.silu(x), self.residual_weight) | |
| # Combine with scale | |
| output = self.scale * (spline_out + residual) | |
| return output | |
| def get_feature_importance(self, x: torch.Tensor) -> torch.Tensor: | |
| """ | |
| Get importance of each input feature. | |
| Returns: | |
| Feature importance (in_features,) | |
| """ | |
| # Use coefficient magnitudes as proxy for importance | |
| importance = self.coef.abs().mean(dim=(0, 2)) # (in_features,) | |
| importance = importance / importance.sum() | |
| return importance | |
| class KANClassifier(nn.Module): | |
| """ | |
| Complete KAN Classifier for Deepfake Detection | |
| Key Features: | |
| - Learnable B-spline functions for interpretability | |
| - Feature importance extraction for explanations | |
| - Dropout for regularization (Layer 1) | |
| """ | |
| def __init__(self, config: dict): | |
| super().__init__() | |
| kan_config = config["model"]["kan"] | |
| self.input_dim = kan_config["input_dim"] | |
| self.hidden_dim = kan_config["hidden_dim"] | |
| self.output_dim = kan_config["output_dim"] | |
| self.grid_size = kan_config["grid_size"] | |
| self.spline_order = kan_config["spline_order"] | |
| self.dropout_rate = kan_config["dropout"] | |
| # KAN layers | |
| self.kan1 = KANLayer( | |
| in_features=self.input_dim, | |
| out_features=self.hidden_dim, | |
| grid_size=self.grid_size, | |
| spline_order=self.spline_order, | |
| ) | |
| self.kan2 = KANLayer( | |
| in_features=self.hidden_dim, | |
| out_features=self.output_dim, | |
| grid_size=self.grid_size, | |
| spline_order=self.spline_order, | |
| ) | |
| # Dropout (Layer 1) | |
| self.dropout = nn.Dropout(self.dropout_rate) | |
| # Layer normalization | |
| self.ln1 = nn.LayerNorm(self.hidden_dim) | |
| logger.info(f"KAN Classifier initialized:") | |
| logger.info(f" Input: {self.input_dim} -> Hidden: {self.hidden_dim} -> Output: {self.output_dim}") | |
| logger.info(f" Grid Size: {self.grid_size}, Spline Order: {self.spline_order}") | |
| def forward( | |
| self, | |
| x: torch.Tensor, | |
| return_logits: bool = True, | |
| ) -> torch.Tensor: | |
| """ | |
| Classify input. | |
| Args: | |
| x: Input features (batch, input_dim) | |
| return_logits: If True, return raw logits; else probabilities | |
| Returns: | |
| Logits or probabilities (batch, output_dim) | |
| """ | |
| # First KAN layer | |
| x = self.kan1(x) | |
| x = self.ln1(x) | |
| x = self.dropout(x) | |
| # Second KAN layer | |
| logits = self.kan2(x) | |
| if return_logits: | |
| return logits | |
| else: | |
| return F.softmax(logits, dim=-1) | |
| def get_explanation( | |
| self, | |
| x: torch.Tensor, | |
| class_idx: Optional[int] = None, | |
| ) -> Dict[str, torch.Tensor]: | |
| """ | |
| Generate explanation for classification. | |
| Returns: | |
| Dictionary with: | |
| - 'feature_importance': Input feature importance | |
| - 'hidden_importance': Hidden layer importance | |
| - 'prediction': Predicted class | |
| - 'confidence': Confidence score | |
| """ | |
| # Get predictions | |
| with torch.no_grad(): | |
| logits = self.forward(x, return_logits=True) | |
| probs = F.softmax(logits, dim=-1) | |
| if class_idx is None: | |
| class_idx = logits.argmax(dim=-1) | |
| confidence = probs.gather(1, class_idx.unsqueeze(-1)).squeeze(-1) | |
| # Get feature importance | |
| input_importance = self.kan1.get_feature_importance(x) | |
| hidden_importance = self.kan2.get_feature_importance( | |
| self.ln1(self.kan1(x)) | |
| ) | |
| return { | |
| 'feature_importance': input_importance, | |
| 'hidden_importance': hidden_importance, | |
| 'prediction': class_idx, | |
| 'confidence': confidence, | |
| } |