indicguard / src /models /kan_classifier.py
realruneet's picture
Upload 4 files
dadfa66 verified
"""
Kolmogorov-Arnold Network (KAN) Classifier
Stage 4b: Interpretable Decision Engine
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple, Optional, Dict
import math
import logging
logger = logging.getLogger(__name__)
class BSplineBasis(nn.Module):
"""
B-Spline Basis Functions for KAN
B-splines provide smooth, localized basis functions
that enable interpretable learned transformations.
"""
def __init__(
self,
grid_size: int = 5,
spline_order: int = 3,
grid_range: Tuple[float, float] = (-1.0, 1.0),
):
super().__init__()
self.grid_size = grid_size
self.spline_order = spline_order
self.grid_range = grid_range
# Number of basis functions
self.num_basis = grid_size + spline_order
# Create uniform grid (extended for boundary conditions)
h = (grid_range[1] - grid_range[0]) / grid_size
grid = torch.linspace(
grid_range[0] - spline_order * h,
grid_range[1] + spline_order * h,
grid_size + 2 * spline_order + 1,
)
self.register_buffer('grid', grid)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Evaluate B-spline basis functions at points x.
Args:
x: Input points (batch, features)
Returns:
Basis values (batch, features, num_basis)
"""
x = x.unsqueeze(-1) # (batch, features, 1)
grid = self.grid # (num_knots,)
# Recursive B-spline computation
# B_i,0(x) = 1 if grid[i] <= x < grid[i+1] else 0
bases = ((x >= grid[:-1]) & (x < grid[1:])).float()
for k in range(1, self.spline_order + 1):
# B_i,k(x) = (x - grid[i]) / (grid[i+k] - grid[i]) * B_i,k-1(x)
# + (grid[i+k+1] - x) / (grid[i+k+1] - grid[i+1]) * B_i+1,k-1(x)
left_num = x - grid[:-k-1].unsqueeze(0).unsqueeze(0)
left_den = grid[k:-1] - grid[:-k-1] + 1e-8
left_term = left_num / left_den.unsqueeze(0).unsqueeze(0) * bases[..., :-1]
right_num = grid[k+1:].unsqueeze(0).unsqueeze(0) - x
right_den = grid[k+1:] - grid[1:-k] + 1e-8
right_term = right_num / right_den.unsqueeze(0).unsqueeze(0) * bases[..., 1:]
bases = left_term + right_term
return bases
class KANLayer(nn.Module):
"""
Kolmogorov-Arnold Network Layer
Unlike MLPs which use fixed activations, KAN learns
univariate functions on the edges using B-splines.
Reference: "KAN: Kolmogorov-Arnold Networks" (Liu et al., 2024)
"""
def __init__(
self,
in_features: int,
out_features: int,
grid_size: int = 5,
spline_order: int = 3,
):
super().__init__()
self.in_features = in_features
self.out_features = out_features
self.grid_size = grid_size
self.spline_order = spline_order
# B-spline basis
self.basis = BSplineBasis(
grid_size=grid_size,
spline_order=spline_order,
)
# Number of basis functions
num_basis = grid_size + spline_order
# Learnable spline coefficients for each edge
# Shape: (out_features, in_features, num_basis)
self.coef = nn.Parameter(
torch.randn(out_features, in_features, num_basis) * 0.1
)
# Residual (silu) connection for stability
self.residual_weight = nn.Parameter(torch.ones(out_features, in_features) * 0.1)
# Scale factor
self.scale = nn.Parameter(torch.ones(out_features))
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass through KAN layer.
Args:
x: Input (batch, in_features)
Returns:
Output (batch, out_features)
"""
batch_size = x.size(0)
# Normalize input to grid range
x_norm = torch.tanh(x) # Map to [-1, 1]
# Evaluate B-spline basis
basis = self.basis(x_norm) # (batch, in_features, num_basis)
# Compute spline output
# For each output neuron, sum over input neurons
# output[j] = sum_i sum_k coef[j,i,k] * basis[i,k]
spline_out = torch.einsum('bik,jik->bj', basis, self.coef)
# Add residual (silu) connection
residual = torch.einsum('bi,ji->bj', F.silu(x), self.residual_weight)
# Combine with scale
output = self.scale * (spline_out + residual)
return output
def get_feature_importance(self, x: torch.Tensor) -> torch.Tensor:
"""
Get importance of each input feature.
Returns:
Feature importance (in_features,)
"""
# Use coefficient magnitudes as proxy for importance
importance = self.coef.abs().mean(dim=(0, 2)) # (in_features,)
importance = importance / importance.sum()
return importance
class KANClassifier(nn.Module):
"""
Complete KAN Classifier for Deepfake Detection
Key Features:
- Learnable B-spline functions for interpretability
- Feature importance extraction for explanations
- Dropout for regularization (Layer 1)
"""
def __init__(self, config: dict):
super().__init__()
kan_config = config["model"]["kan"]
self.input_dim = kan_config["input_dim"]
self.hidden_dim = kan_config["hidden_dim"]
self.output_dim = kan_config["output_dim"]
self.grid_size = kan_config["grid_size"]
self.spline_order = kan_config["spline_order"]
self.dropout_rate = kan_config["dropout"]
# KAN layers
self.kan1 = KANLayer(
in_features=self.input_dim,
out_features=self.hidden_dim,
grid_size=self.grid_size,
spline_order=self.spline_order,
)
self.kan2 = KANLayer(
in_features=self.hidden_dim,
out_features=self.output_dim,
grid_size=self.grid_size,
spline_order=self.spline_order,
)
# Dropout (Layer 1)
self.dropout = nn.Dropout(self.dropout_rate)
# Layer normalization
self.ln1 = nn.LayerNorm(self.hidden_dim)
logger.info(f"KAN Classifier initialized:")
logger.info(f" Input: {self.input_dim} -> Hidden: {self.hidden_dim} -> Output: {self.output_dim}")
logger.info(f" Grid Size: {self.grid_size}, Spline Order: {self.spline_order}")
def forward(
self,
x: torch.Tensor,
return_logits: bool = True,
) -> torch.Tensor:
"""
Classify input.
Args:
x: Input features (batch, input_dim)
return_logits: If True, return raw logits; else probabilities
Returns:
Logits or probabilities (batch, output_dim)
"""
# First KAN layer
x = self.kan1(x)
x = self.ln1(x)
x = self.dropout(x)
# Second KAN layer
logits = self.kan2(x)
if return_logits:
return logits
else:
return F.softmax(logits, dim=-1)
def get_explanation(
self,
x: torch.Tensor,
class_idx: Optional[int] = None,
) -> Dict[str, torch.Tensor]:
"""
Generate explanation for classification.
Returns:
Dictionary with:
- 'feature_importance': Input feature importance
- 'hidden_importance': Hidden layer importance
- 'prediction': Predicted class
- 'confidence': Confidence score
"""
# Get predictions
with torch.no_grad():
logits = self.forward(x, return_logits=True)
probs = F.softmax(logits, dim=-1)
if class_idx is None:
class_idx = logits.argmax(dim=-1)
confidence = probs.gather(1, class_idx.unsqueeze(-1)).squeeze(-1)
# Get feature importance
input_importance = self.kan1.get_feature_importance(x)
hidden_importance = self.kan2.get_feature_importance(
self.ln1(self.kan1(x))
)
return {
'feature_importance': input_importance,
'hidden_importance': hidden_importance,
'prediction': class_idx,
'confidence': confidence,
}