Q-TensorFormer / src /qkan.py
Premchan369's picture
Upload src/qkan.py
220eb7c verified
Raw
History Blame Contribute Delete
9.91 kB
"""
QKAN Integration: Quantum Variational Activation Functions.
Based on: QKAN (arXiv:2509.14026) — "Quantum Variational Activation Functions
Empower Kolmogorov-Arnold Networks"
DARUAN (DatA Re-Uploading Activation Networks):
Single-qubit data re-uploading circuits that serve as learnable activation
functions. Unlike multi-qubit VQCs, DARUANs:
- Avoid barren plateaus (single-qubit only)
- Run on classical simulators efficiently
- Have exponentially growing frequency spectrum with repetitions
- Can be transferred to classical B-spline KANs via distillation
HQKAN (Hybrid QKAN):
Drop-in replacement for MLP FFN layers in transformers.
Replaces standard activation + linear with QKAN-activated linear.
Integration with Q-TensorFormer:
The HQKAN FFN can optionally replace or augment the TT-FFN,
providing quantum-enhanced expressivity with fewer parameters.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
from typing import Optional, Tuple
class DARUAN(nn.Module):
"""
Data Re-Uploading Activation Network.
A single-qubit quantum-inspired activation function that uses
repeated data re-uploading to create an exponentially growing
frequency spectrum.
Architecture:
output = W^(R+1) · S(w_R x + b_R) · ... · S(w_1 x + b_1) · W^(1) · x
where S is a base activation (SiLU), and R is the number of
re-uploading repetitions.
This is a fully classical simulation — no quantum hardware needed.
The quantum circuit is simulated classically, matching the behavior
of the single-qubit data re-uploading PQC.
Parameters
----------
n_repeats : int
Number of data re-uploading repetitions (R).
Higher → richer frequency spectrum, more expressivity.
base_activation : str
Base activation function: "silu", "gelu", "relu", or "tanh".
dropout : float
Dropout rate after activation.
"""
def __init__(self, n_repeats: int = 3, base_activation: str = "silu",
dropout: float = 0.0):
super().__init__()
self.n_repeats = n_repeats
self.dropout = nn.Dropout(dropout) if dropout > 0 else nn.Identity()
# Activation function
act_map = {
"silu": nn.SiLU(),
"gelu": nn.GELU(),
"relu": nn.ReLU(),
"tanh": nn.Tanh(),
}
self.activation = act_map.get(base_activation, nn.SiLU())
# Learnable pre-activation weights (w_r, b_r) for each repetition
self.pre_weights = nn.ParameterList([
nn.Parameter(torch.ones(1) * 0.1) for _ in range(n_repeats)
])
self.pre_biases = nn.ParameterList([
nn.Parameter(torch.zeros(1)) for _ in range(n_repeats)
])
# Learnable post-activation weights (W^(r))
self.post_weights = nn.ParameterList([
nn.Parameter(torch.ones(1) * 0.5) for _ in range(n_repeats + 1)
])
self._init_weights()
def _init_weights(self):
"""Initialize with small values for stable training."""
for i in range(self.n_repeats):
nn.init.uniform_(self.pre_weights[i], -0.1, 0.1)
nn.init.zeros_(self.pre_biases[i])
for i in range(self.n_repeats + 1):
nn.init.uniform_(self.post_weights[i], 0.3, 0.7)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Apply DARUAN activation element-wise.
Args:
x: (*) any shape tensor
Returns:
(*) same shape
"""
out = self.post_weights[0] * x
for r in range(self.n_repeats):
# Pre-activation: w_r * x + b_r
z = self.pre_weights[r] * x + self.pre_biases[r]
# Apply nonlinearity
z = self.activation(z)
# Post-activation weighting
out = out + self.post_weights[r + 1] * z
return self.dropout(out)
def extra_repr(self) -> str:
return f"n_repeats={self.n_repeats}"
class QKANLayer(nn.Module):
"""
Quantum KAN Layer — replaces Linear + Activation.
Uses DARUAN activations on each feature dimension independently,
then combines with a linear projection.
This is a DROP-IN REPLACEMENT for nn.Sequential(nn.Linear, nn.GELU).
Architecture:
x → DARUAN (per-feature) → Linear → output
Compared to standard MLP:
- ~30% fewer parameters (DARUAN activations are lightweight)
- Better expressivity per parameter
- Compatible with QKAN→KAN knowledge distillation
Parameters
----------
in_features : int
out_features : int
n_repeats : int
DARUAN repetitions (default: 3).
base_activation : str
Base activation for DARUAN.
bias : bool
Include bias in the output projection.
"""
def __init__(self, in_features: int, out_features: int,
n_repeats: int = 3, base_activation: str = "silu",
bias: bool = True):
super().__init__()
self.in_features = in_features
self.out_features = out_features
# Per-feature DARUAN activations
self.daruans = nn.ModuleList([
DARUAN(n_repeats=n_repeats, base_activation=base_activation)
for _ in range(in_features)
])
# Output projection
self.out_proj = nn.Linear(in_features, out_features, bias=bias)
self._reset_parameters()
def _reset_parameters(self):
nn.init.xavier_uniform_(self.out_proj.weight)
if self.out_proj.bias is not None:
nn.init.zeros_(self.out_proj.bias)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x: (*, in_features)
Returns:
(*, out_features)
"""
# Apply per-feature DARUAN activations
# x: (..., in_features) → split into (..., in_features) list
features = x.unbind(-1)
activated = []
for i, feat in enumerate(features):
activated.append(self.daruans[i](feat))
x = torch.stack(activated, dim=-1) # (..., in_features)
# Output projection
return self.out_proj(x)
def parameter_count(self) -> int:
"""Total trainable parameters."""
return sum(p.numel() for p in self.parameters())
def extra_repr(self) -> str:
return (f"in={self.in_features}, out={self.out_features}, "
f"n_repeats={self.daruans[0].n_repeats}")
class HQKANFFN(nn.Module):
"""
Hybrid QKAN Feed-Forward Network.
Drop-in replacement for transformer FFN:
Standard: Linear↑ → GELU → Linear↓
HQKAN: QKANLayer↑ → QKANLayer↓
Uses DARUAN activations on the expanded dimension for
maximal expressivity.
Compared to TT-FFN:
- HQKAN has better expressivity per parameter
- TT-FFN has better compression ratio
- Can be combined: QKAN on expanded dim, TT on down-projection
Parameters
----------
hidden_dim : int
ff_multiplier : int
Expansion factor (default: 4).
n_repeats : int
DARUAN repetitions.
dropout : float
"""
def __init__(self, hidden_dim: int, ff_multiplier: int = 4,
n_repeats: int = 3, dropout: float = 0.1):
super().__init__()
expanded_dim = hidden_dim * ff_multiplier
self.up_proj = nn.Linear(hidden_dim, expanded_dim)
self.daruan = DARUAN(n_repeats=n_repeats, base_activation="silu")
self.down_proj = nn.Linear(expanded_dim, hidden_dim)
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.up_proj(x)
x = self.daruan(x)
x = self.down_proj(x)
return self.dropout(x)
@property
def total_params(self) -> int:
return sum(p.numel() for p in self.parameters())
class QKANEmbedding(nn.Module):
"""
Quantum-enhanced embedding layer.
Applies DARUAN activation to embedding vectors to enrich
the representation before entering the transformer.
"""
def __init__(self, vocab_size: int, d_model: int, n_repeats: int = 2):
super().__init__()
self.embedding = nn.Embedding(vocab_size, d_model)
self.daruan = DARUAN(n_repeats=n_repeats, base_activation="silu")
def forward(self, input_ids: torch.Tensor) -> torch.Tensor:
x = self.embedding(input_ids)
return self.daruan(x)
def create_qkan_ffn(hidden_dim: int, ff_multiplier: int = 4,
n_repeats: int = 3, dropout: float = 0.1,
use_tt: bool = False, tt_rank: int = 4) -> nn.Module:
"""
Factory for QKAN-based FFN.
Args:
hidden_dim: Hidden dimension.
ff_multiplier: Expansion factor.
n_repeats: DARUAN repetitions.
dropout: Dropout rate.
use_tt: If True, use TT-decomposed down-projection for extra compression.
tt_rank: TT rank (only if use_tt=True).
Returns:
FFN module.
"""
if use_tt:
# TT-QKAN hybrid: QKAN up-projection + TT down-projection
from .tensor_layers import TTLinear
expanded_dim = hidden_dim * ff_multiplier
class TTQKANFFN(nn.Module):
def __init__(self):
super().__init__()
self.up_proj = nn.Linear(hidden_dim, expanded_dim)
self.daruan = DARUAN(n_repeats=n_repeats)
self.down_proj = TTLinear(expanded_dim, hidden_dim, rank=tt_rank)
self.dropout = nn.Dropout(dropout)
def forward(self, x):
x = self.up_proj(x)
x = self.daruan(x)
x = self.down_proj(x)
return self.dropout(x)
return TTQKANFFN()
return HQKANFFN(hidden_dim, ff_multiplier, n_repeats, dropout)