""" QKAN Integration: Quantum Variational Activation Functions. Based on: QKAN (arXiv:2509.14026) — "Quantum Variational Activation Functions Empower Kolmogorov-Arnold Networks" DARUAN (DatA Re-Uploading Activation Networks): Single-qubit data re-uploading circuits that serve as learnable activation functions. Unlike multi-qubit VQCs, DARUANs: - Avoid barren plateaus (single-qubit only) - Run on classical simulators efficiently - Have exponentially growing frequency spectrum with repetitions - Can be transferred to classical B-spline KANs via distillation HQKAN (Hybrid QKAN): Drop-in replacement for MLP FFN layers in transformers. Replaces standard activation + linear with QKAN-activated linear. Integration with Q-TensorFormer: The HQKAN FFN can optionally replace or augment the TT-FFN, providing quantum-enhanced expressivity with fewer parameters. """ import torch import torch.nn as nn import torch.nn.functional as F import math from typing import Optional, Tuple class DARUAN(nn.Module): """ Data Re-Uploading Activation Network. A single-qubit quantum-inspired activation function that uses repeated data re-uploading to create an exponentially growing frequency spectrum. Architecture: output = W^(R+1) · S(w_R x + b_R) · ... · S(w_1 x + b_1) · W^(1) · x where S is a base activation (SiLU), and R is the number of re-uploading repetitions. This is a fully classical simulation — no quantum hardware needed. The quantum circuit is simulated classically, matching the behavior of the single-qubit data re-uploading PQC. Parameters ---------- n_repeats : int Number of data re-uploading repetitions (R). Higher → richer frequency spectrum, more expressivity. base_activation : str Base activation function: "silu", "gelu", "relu", or "tanh". dropout : float Dropout rate after activation. """ def __init__(self, n_repeats: int = 3, base_activation: str = "silu", dropout: float = 0.0): super().__init__() self.n_repeats = n_repeats self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity() # Activation function act_map = { "silu": nn.SiLU(), "gelu": nn.GELU(), "relu": nn.ReLU(), "tanh": nn.Tanh(), } self.activation = act_map.get(base_activation, nn.SiLU()) # Learnable pre-activation weights (w_r, b_r) for each repetition self.pre_weights = nn.ParameterList([ nn.Parameter(torch.ones(1) * 0.1) for _ in range(n_repeats) ]) self.pre_biases = nn.ParameterList([ nn.Parameter(torch.zeros(1)) for _ in range(n_repeats) ]) # Learnable post-activation weights (W^(r)) self.post_weights = nn.ParameterList([ nn.Parameter(torch.ones(1) * 0.5) for _ in range(n_repeats + 1) ]) self._init_weights() def _init_weights(self): """Initialize with small values for stable training.""" for i in range(self.n_repeats): nn.init.uniform_(self.pre_weights[i], -0.1, 0.1) nn.init.zeros_(self.pre_biases[i]) for i in range(self.n_repeats + 1): nn.init.uniform_(self.post_weights[i], 0.3, 0.7) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Apply DARUAN activation element-wise. Args: x: (*) any shape tensor Returns: (*) same shape """ out = self.post_weights[0] * x for r in range(self.n_repeats): # Pre-activation: w_r * x + b_r z = self.pre_weights[r] * x + self.pre_biases[r] # Apply nonlinearity z = self.activation(z) # Post-activation weighting out = out + self.post_weights[r + 1] * z return self.dropout(out) def extra_repr(self) -> str: return f"n_repeats={self.n_repeats}" class QKANLayer(nn.Module): """ Quantum KAN Layer — replaces Linear + Activation. Uses DARUAN activations on each feature dimension independently, then combines with a linear projection. This is a DROP-IN REPLACEMENT for nn.Sequential(nn.Linear, nn.GELU). Architecture: x → DARUAN (per-feature) → Linear → output Compared to standard MLP: - ~30% fewer parameters (DARUAN activations are lightweight) - Better expressivity per parameter - Compatible with QKAN→KAN knowledge distillation Parameters ---------- in_features : int out_features : int n_repeats : int DARUAN repetitions (default: 3). base_activation : str Base activation for DARUAN. bias : bool Include bias in the output projection. """ def __init__(self, in_features: int, out_features: int, n_repeats: int = 3, base_activation: str = "silu", bias: bool = True): super().__init__() self.in_features = in_features self.out_features = out_features # Per-feature DARUAN activations self.daruans = nn.ModuleList([ DARUAN(n_repeats=n_repeats, base_activation=base_activation) for _ in range(in_features) ]) # Output projection self.out_proj = nn.Linear(in_features, out_features, bias=bias) self._reset_parameters() def _reset_parameters(self): nn.init.xavier_uniform_(self.out_proj.weight) if self.out_proj.bias is not None: nn.init.zeros_(self.out_proj.bias) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Args: x: (*, in_features) Returns: (*, out_features) """ # Apply per-feature DARUAN activations # x: (..., in_features) → split into (..., in_features) list features = x.unbind(-1) activated = [] for i, feat in enumerate(features): activated.append(self.daruans[i](feat)) x = torch.stack(activated, dim=-1) # (..., in_features) # Output projection return self.out_proj(x) def parameter_count(self) -> int: """Total trainable parameters.""" return sum(p.numel() for p in self.parameters()) def extra_repr(self) -> str: return (f"in={self.in_features}, out={self.out_features}, " f"n_repeats={self.daruans[0].n_repeats}") class HQKANFFN(nn.Module): """ Hybrid QKAN Feed-Forward Network. Drop-in replacement for transformer FFN: Standard: Linear↑ → GELU → Linear↓ HQKAN: QKANLayer↑ → QKANLayer↓ Uses DARUAN activations on the expanded dimension for maximal expressivity. Compared to TT-FFN: - HQKAN has better expressivity per parameter - TT-FFN has better compression ratio - Can be combined: QKAN on expanded dim, TT on down-projection Parameters ---------- hidden_dim : int ff_multiplier : int Expansion factor (default: 4). n_repeats : int DARUAN repetitions. dropout : float """ def __init__(self, hidden_dim: int, ff_multiplier: int = 4, n_repeats: int = 3, dropout: float = 0.1): super().__init__() expanded_dim = hidden_dim * ff_multiplier self.up_proj = nn.Linear(hidden_dim, expanded_dim) self.daruan = DARUAN(n_repeats=n_repeats, base_activation="silu") self.down_proj = nn.Linear(expanded_dim, hidden_dim) self.dropout = nn.Dropout(dropout) def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.up_proj(x) x = self.daruan(x) x = self.down_proj(x) return self.dropout(x) @property def total_params(self) -> int: return sum(p.numel() for p in self.parameters()) class QKANEmbedding(nn.Module): """ Quantum-enhanced embedding layer. Applies DARUAN activation to embedding vectors to enrich the representation before entering the transformer. """ def __init__(self, vocab_size: int, d_model: int, n_repeats: int = 2): super().__init__() self.embedding = nn.Embedding(vocab_size, d_model) self.daruan = DARUAN(n_repeats=n_repeats, base_activation="silu") def forward(self, input_ids: torch.Tensor) -> torch.Tensor: x = self.embedding(input_ids) return self.daruan(x) def create_qkan_ffn(hidden_dim: int, ff_multiplier: int = 4, n_repeats: int = 3, dropout: float = 0.1, use_tt: bool = False, tt_rank: int = 4) -> nn.Module: """ Factory for QKAN-based FFN. Args: hidden_dim: Hidden dimension. ff_multiplier: Expansion factor. n_repeats: DARUAN repetitions. dropout: Dropout rate. use_tt: If True, use TT-decomposed down-projection for extra compression. tt_rank: TT rank (only if use_tt=True). Returns: FFN module. """ if use_tt: # TT-QKAN hybrid: QKAN up-projection + TT down-projection from .tensor_layers import TTLinear expanded_dim = hidden_dim * ff_multiplier class TTQKANFFN(nn.Module): def __init__(self): super().__init__() self.up_proj = nn.Linear(hidden_dim, expanded_dim) self.daruan = DARUAN(n_repeats=n_repeats) self.down_proj = TTLinear(expanded_dim, hidden_dim, rank=tt_rank) self.dropout = nn.Dropout(dropout) def forward(self, x): x = self.up_proj(x) x = self.daruan(x) x = self.down_proj(x) return self.dropout(x) return TTQKANFFN() return HQKANFFN(hidden_dim, ff_multiplier, n_repeats, dropout)