PULSE-code / experiments /nets /models.py
velvet-pine-22's picture
Upload folder using huggingface_hub
b4b2877 verified
"""
Model definitions for Experiment 1: Scene Recognition.
Backbones: CNN1D, BiLSTM, Transformer
Fusion: Early (default), Late, Attention, WeightedLate, GatedLate, Stacking, Product, MoE
Supports optional per-modality projection via proj_dim parameter:
proj_dim > 0: project each modality to proj_dim before backbone
proj_dim = 0: no projection, use raw features (original behavior)
"""
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
# ============================================================
# Per-modality projection
# ============================================================
class ModalityProjector(nn.Module):
"""Project each modality from its raw dimension to proj_dim."""
def __init__(self, modality_dims, proj_dim):
super().__init__()
self.mod_names = list(modality_dims.keys())
self.mod_dims = list(modality_dims.values())
self.proj_dim = proj_dim
self.projectors = nn.ModuleList()
for dim in self.mod_dims:
self.projectors.append(nn.Sequential(
nn.Linear(dim, proj_dim),
nn.LayerNorm(proj_dim),
nn.ReLU(),
))
@property
def output_dim(self):
return self.proj_dim * len(self.mod_dims)
def forward(self, x):
"""x: (B, T, total_raw_dim) -> (B, T, proj_dim * M)"""
parts = []
offset = 0
for i, dim in enumerate(self.mod_dims):
x_mod = x[:, :, offset:offset + dim]
offset += dim
parts.append(self.projectors[i](x_mod))
return torch.cat(parts, dim=-1)
# ============================================================
# Per-modality hidden dim scaling (used when proj_dim=0)
# ============================================================
def _compute_per_modality_hidden(mod_dim, base_hidden_dim):
if mod_dim >= 128:
return max(base_hidden_dim, 48)
elif mod_dim >= 32:
return base_hidden_dim
else:
return max(16, base_hidden_dim // 2)
# ============================================================
# Backbones
# ============================================================
class CNN1DBackbone(nn.Module):
def __init__(self, input_dim, hidden_dim=128):
super().__init__()
self.conv1 = nn.Sequential(
nn.Conv1d(input_dim, 64, kernel_size=7, padding=3),
nn.BatchNorm1d(64), nn.ReLU(), nn.Dropout(0.1),
)
self.conv2 = nn.Sequential(
nn.Conv1d(64, 128, kernel_size=5, padding=2),
nn.BatchNorm1d(128), nn.ReLU(), nn.Dropout(0.1),
)
self.conv3 = nn.Sequential(
nn.Conv1d(128, hidden_dim, kernel_size=3, padding=1),
nn.BatchNorm1d(hidden_dim), nn.ReLU(),
)
self.output_dim = hidden_dim
def forward(self, x, mask=None):
x = x.permute(0, 2, 1)
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
if mask is not None:
x = (x * mask.unsqueeze(1).float()).sum(2) / mask.sum(1, keepdim=True).float().clamp(min=1)
else:
x = x.mean(2)
return x
class LSTMBackbone(nn.Module):
def __init__(self, input_dim, hidden_dim=128, num_layers=2, dropout=0.2):
super().__init__()
self.lstm = nn.LSTM(
input_dim, hidden_dim, num_layers=num_layers,
batch_first=True, bidirectional=True,
dropout=dropout if num_layers > 1 else 0,
)
self.attn = nn.Linear(hidden_dim * 2, 1)
self.output_dim = hidden_dim * 2
def forward(self, x, mask=None):
out, _ = self.lstm(x)
scores = self.attn(out).squeeze(-1)
if mask is not None:
scores = scores.masked_fill(~mask, float('-inf'))
weights = torch.softmax(scores, dim=1)
out = (out * weights.unsqueeze(-1)).sum(dim=1)
return out
class TinyHARBackbone(nn.Module):
"""TinyHAR backbone (Zhou et al., ISWC 2022 Best Paper).
Lightweight model for human activity recognition from wearable sensors.
Uses multi-scale temporal convolutions + cross-channel interaction + temporal pooling.
Input: (B, T, C) with optional mask
Output: (B, hidden_dim)
"""
def __init__(self, input_dim, hidden_dim=128, num_scales=4):
super().__init__()
scale_dim = max(4, hidden_dim // num_scales)
actual_hidden = scale_dim * num_scales
# Multi-scale temporal convolution feature extraction
self.convs = nn.ModuleList()
for i in range(num_scales):
ks = 2 * (i + 1) + 1 # kernel sizes: 3, 5, 7, 9
self.convs.append(nn.Sequential(
nn.Conv1d(input_dim, scale_dim, kernel_size=ks, padding=ks // 2),
nn.BatchNorm1d(scale_dim),
nn.ReLU(),
))
# Cross-channel interaction via multi-head self-attention
nhead = max(1, min(4, actual_hidden // 8))
# Ensure actual_hidden is divisible by nhead
while actual_hidden % nhead != 0 and nhead > 1:
nhead -= 1
self.channel_attn = nn.MultiheadAttention(
actual_hidden, num_heads=nhead, batch_first=True, dropout=0.1,
)
self.channel_norm = nn.LayerNorm(actual_hidden)
self.channel_ff = nn.Sequential(
nn.Linear(actual_hidden, actual_hidden),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(actual_hidden, actual_hidden),
)
self.ff_norm = nn.LayerNorm(actual_hidden)
# Temporal attention pooling
self.temporal_query = nn.Parameter(torch.randn(1, 1, actual_hidden) * 0.02)
self.temporal_attn = nn.MultiheadAttention(
actual_hidden, num_heads=1, batch_first=True, dropout=0.1,
)
self.output_dim = actual_hidden
def forward(self, x, mask=None):
# x: (B, T, C)
B, T, C = x.shape
x_t = x.permute(0, 2, 1) # (B, C, T)
# Multi-scale feature extraction
scale_features = [conv(x_t) for conv in self.convs]
x = torch.cat(scale_features, dim=1) # (B, actual_hidden, T)
x = x.permute(0, 2, 1) # (B, T, actual_hidden)
# Cross-channel interaction
key_padding_mask = ~mask if mask is not None else None
attn_out, _ = self.channel_attn(x, x, x, key_padding_mask=key_padding_mask)
x = self.channel_norm(x + attn_out)
x = self.ff_norm(x + self.channel_ff(x))
# Temporal attention pooling
query = self.temporal_query.expand(B, -1, -1) # (B, 1, actual_hidden)
pooled, _ = self.temporal_attn(query, x, x, key_padding_mask=key_padding_mask)
return pooled.squeeze(1) # (B, actual_hidden)
class PositionalEncoding(nn.Module):
def __init__(self, d_model, dropout=0.1, max_len=5000):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
pe = torch.zeros(max_len, d_model)
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)
pe = pe.unsqueeze(0)
self.register_buffer('pe', pe)
def forward(self, x):
x = x + self.pe[:, :x.size(1)]
return self.dropout(x)
class TransformerBackbone(nn.Module):
def __init__(self, input_dim, d_model=128, nhead=4, num_layers=2, dropout=0.1):
super().__init__()
self.input_proj = nn.Linear(input_dim, d_model)
self.pos_enc = PositionalEncoding(d_model, dropout=dropout)
encoder_layer = nn.TransformerEncoderLayer(
d_model=d_model, nhead=nhead, dim_feedforward=d_model * 4,
dropout=dropout, batch_first=True,
)
self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
self.output_dim = d_model
def forward(self, x, mask=None):
x = self.input_proj(x)
x = self.pos_enc(x)
src_key_padding_mask = ~mask if mask is not None else None
x = self.encoder(x, src_key_padding_mask=src_key_padding_mask)
if mask is not None:
x = (x * mask.unsqueeze(-1).float()).sum(1) / mask.sum(1, keepdim=True).float().clamp(min=1)
else:
x = x.mean(1)
return x
# ============================================================
# Full models
# ============================================================
def get_backbone(name, input_dim, hidden_dim=128):
if name == 'cnn':
return CNN1DBackbone(input_dim, hidden_dim)
elif name == 'lstm':
return LSTMBackbone(input_dim, hidden_dim)
elif name == 'transformer':
return TransformerBackbone(input_dim, hidden_dim)
elif name == 'tinyhar':
return TinyHARBackbone(input_dim, hidden_dim)
elif name == 'deepconvlstm':
from experiments.published_models import DeepConvLSTMBackbone
return DeepConvLSTMBackbone(input_dim, hidden_dim)
elif name == 'inceptiontime':
from experiments.published_models import InceptionTimeBackbone
return InceptionTimeBackbone(input_dim, hidden_dim)
else:
raise ValueError(f"Unknown backbone: {name}")
def _make_branch(backbone_name, raw_dim, hidden_dim, proj_dim):
"""Create optional projector + backbone for one modality branch."""
if proj_dim > 0:
proj = nn.Sequential(
nn.Linear(raw_dim, proj_dim),
nn.LayerNorm(proj_dim),
nn.ReLU(),
)
bb_input = proj_dim
bb_hidden = hidden_dim
else:
proj = None
bb_input = raw_dim
bb_hidden = _compute_per_modality_hidden(raw_dim, hidden_dim)
bb = get_backbone(backbone_name, bb_input, bb_hidden)
return proj, bb
class SingleModel(nn.Module):
"""Single backbone + classifier (early fusion or single-modality)."""
def __init__(self, backbone_name, input_dim, num_classes, hidden_dim=128,
modality_dims=None, proj_dim=0):
super().__init__()
self.projector = None
if proj_dim > 0 and modality_dims:
self.projector = ModalityProjector(modality_dims, proj_dim)
actual_input_dim = self.projector.output_dim
else:
actual_input_dim = input_dim
self.backbone = get_backbone(backbone_name, actual_input_dim, hidden_dim)
self.classifier = nn.Sequential(
nn.Dropout(0.5),
nn.Linear(self.backbone.output_dim, num_classes),
)
def forward(self, x, mask=None):
if self.projector is not None:
x = self.projector(x)
feat = self.backbone(x, mask)
return self.classifier(feat)
class LateFusionModel(nn.Module):
"""Late fusion: separate backbone per modality, configurable logit aggregation.
late_agg='mean': simple average (original)
late_agg='confidence': entropy-based confidence weighting (0 extra params)
late_agg='learned': temperature-scaled learned weights (M+1 extra params)
"""
def __init__(self, backbone_name, modality_dims, num_classes, hidden_dim=64,
proj_dim=0, late_agg='mean'):
super().__init__()
self.mod_names = list(modality_dims.keys())
self.mod_dims = list(modality_dims.values())
self.late_agg = late_agg
self.projectors = nn.ModuleList()
self.backbones = nn.ModuleList()
self.classifiers = nn.ModuleList()
for dim in self.mod_dims:
proj, bb = _make_branch(backbone_name, dim, hidden_dim, proj_dim)
self.projectors.append(proj if proj else nn.Identity())
self.backbones.append(bb)
self.classifiers.append(nn.Sequential(
nn.Dropout(0.5), nn.Linear(bb.output_dim, num_classes),
))
self._has_proj = proj_dim > 0
M = len(self.mod_dims)
if late_agg == 'learned':
self.modality_logits = nn.Parameter(torch.zeros(M))
self.temperature = nn.Parameter(torch.ones(1))
def forward(self, x, mask=None):
offset = 0
all_logits = []
for i, dim in enumerate(self.mod_dims):
x_mod = x[:, :, offset:offset + dim]
offset += dim
if self._has_proj:
x_mod = self.projectors[i](x_mod)
feat = self.backbones[i](x_mod, mask)
all_logits.append(self.classifiers[i](feat))
stacked = torch.stack(all_logits, dim=0) # (M, B, C)
if self.late_agg == 'confidence':
# Weight by confidence: low entropy → high weight
probs = F.softmax(stacked, dim=-1) # (M, B, C)
entropy = -(probs * (probs + 1e-8).log()).sum(dim=-1) # (M, B)
weights = F.softmax(-entropy, dim=0).unsqueeze(-1) # (M, B, 1)
return (stacked * weights).sum(dim=0)
elif self.late_agg == 'learned':
weights = F.softmax(self.modality_logits / self.temperature, dim=0)
return (stacked * weights.view(-1, 1, 1)).sum(dim=0)
else: # 'mean'
return stacked.mean(dim=0)
class AttentionFusionModel(nn.Module):
"""Attention fusion: separate encoder per modality -> cross-modal attention -> classifier."""
def __init__(self, backbone_name, modality_dims, num_classes, hidden_dim=64, proj_dim=0):
super().__init__()
self.mod_names = list(modality_dims.keys())
self.mod_dims = list(modality_dims.values())
unified_dim = hidden_dim
self.projectors = nn.ModuleList()
self.backbones = nn.ModuleList()
self.feat_projections = nn.ModuleList()
for dim in self.mod_dims:
proj, bb = _make_branch(backbone_name, dim, hidden_dim, proj_dim)
self.projectors.append(proj if proj else nn.Identity())
self.backbones.append(bb)
if bb.output_dim != unified_dim:
self.feat_projections.append(nn.Linear(bb.output_dim, unified_dim))
else:
self.feat_projections.append(nn.Identity())
self._has_proj = proj_dim > 0
nhead = 4 if unified_dim % 4 == 0 else (2 if unified_dim % 2 == 0 else 1)
self.cross_attn = nn.TransformerEncoderLayer(
d_model=unified_dim, nhead=nhead, dim_feedforward=unified_dim * 2,
dropout=0.1, batch_first=True,
)
self.classifier = nn.Sequential(
nn.Dropout(0.5), nn.Linear(unified_dim, num_classes),
)
def forward(self, x, mask=None):
offset = 0
mod_features = []
for i, dim in enumerate(self.mod_dims):
x_mod = x[:, :, offset:offset + dim]
offset += dim
if self._has_proj:
x_mod = self.projectors[i](x_mod)
feat = self.backbones[i](x_mod, mask)
feat = self.feat_projections[i](feat)
mod_features.append(feat)
tokens = torch.stack(mod_features, dim=1)
tokens = self.cross_attn(tokens)
pooled = tokens.mean(dim=1)
return self.classifier(pooled)
class WeightedLateFusionModel(nn.Module):
def __init__(self, backbone_name, modality_dims, num_classes, hidden_dim=64, proj_dim=0):
super().__init__()
self.mod_names = list(modality_dims.keys())
self.mod_dims = list(modality_dims.values())
self.projectors = nn.ModuleList()
self.backbones = nn.ModuleList()
self.classifiers = nn.ModuleList()
for dim in self.mod_dims:
proj, bb = _make_branch(backbone_name, dim, hidden_dim, proj_dim)
self.projectors.append(proj if proj else nn.Identity())
self.backbones.append(bb)
self.classifiers.append(nn.Sequential(
nn.Dropout(0.5), nn.Linear(bb.output_dim, num_classes),
))
self._has_proj = proj_dim > 0
self.modality_weights = nn.Parameter(torch.ones(len(self.mod_dims)))
def forward(self, x, mask=None):
offset = 0
all_logits = []
for i, dim in enumerate(self.mod_dims):
x_mod = x[:, :, offset:offset + dim]
offset += dim
if self._has_proj:
x_mod = self.projectors[i](x_mod)
feat = self.backbones[i](x_mod, mask)
all_logits.append(self.classifiers[i](feat))
weights = F.softmax(self.modality_weights, dim=0)
stacked = torch.stack(all_logits, dim=0)
return (stacked * weights.view(-1, 1, 1)).sum(dim=0)
class GatedLateFusionModel(nn.Module):
def __init__(self, backbone_name, modality_dims, num_classes, hidden_dim=64, proj_dim=0):
super().__init__()
self.mod_names = list(modality_dims.keys())
self.mod_dims = list(modality_dims.values())
M = len(self.mod_dims)
self.projectors = nn.ModuleList()
self.backbones = nn.ModuleList()
self.classifiers = nn.ModuleList()
total_feat_dim = 0
for dim in self.mod_dims:
proj, bb = _make_branch(backbone_name, dim, hidden_dim, proj_dim)
self.projectors.append(proj if proj else nn.Identity())
self.backbones.append(bb)
total_feat_dim += bb.output_dim
self.classifiers.append(nn.Sequential(
nn.Dropout(0.5), nn.Linear(bb.output_dim, num_classes),
))
self._has_proj = proj_dim > 0
self.gate = nn.Sequential(
nn.Linear(total_feat_dim, 32), nn.ReLU(), nn.Linear(32, M),
)
def forward(self, x, mask=None):
offset = 0
all_feats, all_logits = [], []
for i, dim in enumerate(self.mod_dims):
x_mod = x[:, :, offset:offset + dim]
offset += dim
if self._has_proj:
x_mod = self.projectors[i](x_mod)
feat = self.backbones[i](x_mod, mask)
all_feats.append(feat)
all_logits.append(self.classifiers[i](feat))
cat_feats = torch.cat(all_feats, dim=1)
gate_weights = F.softmax(self.gate(cat_feats), dim=1)
stacked = torch.stack(all_logits, dim=1)
return (stacked * gate_weights.unsqueeze(-1)).sum(dim=1)
class StackingFusionModel(nn.Module):
def __init__(self, backbone_name, modality_dims, num_classes, hidden_dim=64, proj_dim=0):
super().__init__()
self.mod_names = list(modality_dims.keys())
self.mod_dims = list(modality_dims.values())
M = len(self.mod_dims)
self.projectors = nn.ModuleList()
self.backbones = nn.ModuleList()
self.classifiers = nn.ModuleList()
for dim in self.mod_dims:
proj, bb = _make_branch(backbone_name, dim, hidden_dim, proj_dim)
self.projectors.append(proj if proj else nn.Identity())
self.backbones.append(bb)
self.classifiers.append(nn.Sequential(
nn.Dropout(0.5), nn.Linear(bb.output_dim, num_classes),
))
self._has_proj = proj_dim > 0
self.meta_learner = nn.Sequential(
nn.Linear(M * num_classes, 32), nn.ReLU(),
nn.Dropout(0.5), nn.Linear(32, num_classes),
)
def forward(self, x, mask=None):
offset = 0
all_logits = []
for i, dim in enumerate(self.mod_dims):
x_mod = x[:, :, offset:offset + dim]
offset += dim
if self._has_proj:
x_mod = self.projectors[i](x_mod)
feat = self.backbones[i](x_mod, mask)
all_logits.append(self.classifiers[i](feat))
cat_logits = torch.cat(all_logits, dim=1)
return self.meta_learner(cat_logits)
class ProductOfExpertsModel(nn.Module):
def __init__(self, backbone_name, modality_dims, num_classes, hidden_dim=64, proj_dim=0):
super().__init__()
self.mod_names = list(modality_dims.keys())
self.mod_dims = list(modality_dims.values())
self.projectors = nn.ModuleList()
self.backbones = nn.ModuleList()
self.classifiers = nn.ModuleList()
for dim in self.mod_dims:
proj, bb = _make_branch(backbone_name, dim, hidden_dim, proj_dim)
self.projectors.append(proj if proj else nn.Identity())
self.backbones.append(bb)
self.classifiers.append(nn.Sequential(
nn.Dropout(0.5), nn.Linear(bb.output_dim, num_classes),
))
self._has_proj = proj_dim > 0
def forward(self, x, mask=None):
offset = 0
log_probs_sum = None
for i, dim in enumerate(self.mod_dims):
x_mod = x[:, :, offset:offset + dim]
offset += dim
if self._has_proj:
x_mod = self.projectors[i](x_mod)
feat = self.backbones[i](x_mod, mask)
logits = self.classifiers[i](feat)
log_p = F.log_softmax(logits, dim=1)
log_probs_sum = log_p if log_probs_sum is None else log_probs_sum + log_p
return log_probs_sum
class MoEFusionModel(nn.Module):
def __init__(self, backbone_name, modality_dims, num_classes, hidden_dim=64, proj_dim=0):
super().__init__()
self.mod_names = list(modality_dims.keys())
self.mod_dims = list(modality_dims.values())
M = len(self.mod_dims)
self.top_k = min(2, M)
self.projectors = nn.ModuleList()
self.backbones = nn.ModuleList()
self.classifiers = nn.ModuleList()
total_feat_dim = 0
for dim in self.mod_dims:
proj, bb = _make_branch(backbone_name, dim, hidden_dim, proj_dim)
self.projectors.append(proj if proj else nn.Identity())
self.backbones.append(bb)
total_feat_dim += bb.output_dim
self.classifiers.append(nn.Sequential(
nn.Dropout(0.5), nn.Linear(bb.output_dim, num_classes),
))
self._has_proj = proj_dim > 0
self.router = nn.Linear(total_feat_dim, M)
def forward(self, x, mask=None):
offset = 0
all_feats, all_logits = [], []
for i, dim in enumerate(self.mod_dims):
x_mod = x[:, :, offset:offset + dim]
offset += dim
if self._has_proj:
x_mod = self.projectors[i](x_mod)
feat = self.backbones[i](x_mod, mask)
all_feats.append(feat)
all_logits.append(self.classifiers[i](feat))
cat_feats = torch.cat(all_feats, dim=1)
router_logits = self.router(cat_feats)
top_vals, top_idx = router_logits.topk(self.top_k, dim=1)
top_weights = F.softmax(top_vals, dim=1)
stacked = torch.stack(all_logits, dim=1)
top_idx_exp = top_idx.unsqueeze(-1).expand(-1, -1, stacked.size(-1))
selected = stacked.gather(1, top_idx_exp)
return (selected * top_weights.unsqueeze(-1)).sum(dim=1)
class FeatureConcatFusionModel(nn.Module):
"""Feature-level late fusion: separate backbones, concatenate features, joint classifier."""
def __init__(self, backbone_name, modality_dims, num_classes, hidden_dim=64, proj_dim=0):
super().__init__()
self.mod_names = list(modality_dims.keys())
self.mod_dims = list(modality_dims.values())
self.projectors = nn.ModuleList()
self.backbones = nn.ModuleList()
total_feat_dim = 0
for dim in self.mod_dims:
proj, bb = _make_branch(backbone_name, dim, hidden_dim, proj_dim)
self.projectors.append(proj if proj else nn.Identity())
self.backbones.append(bb)
total_feat_dim += bb.output_dim
self._has_proj = proj_dim > 0
self.classifier = nn.Sequential(
nn.LayerNorm(total_feat_dim),
nn.Dropout(0.5),
nn.Linear(total_feat_dim, hidden_dim),
nn.ReLU(),
nn.Dropout(0.3),
nn.Linear(hidden_dim, num_classes),
)
def forward(self, x, mask=None):
offset = 0
all_feats = []
for i, dim in enumerate(self.mod_dims):
x_mod = x[:, :, offset:offset + dim]
offset += dim
if self._has_proj:
x_mod = self.projectors[i](x_mod)
feat = self.backbones[i](x_mod, mask)
all_feats.append(feat)
cat_feats = torch.cat(all_feats, dim=1)
return self.classifier(cat_feats)
def build_model(backbone_name, fusion, input_dim, modality_dims, num_classes,
hidden_dim=128, proj_dim=0, late_agg='mean'):
"""Factory function. proj_dim=0 means no projection (raw features)."""
if fusion == 'early':
return SingleModel(backbone_name, input_dim, num_classes, hidden_dim,
modality_dims=modality_dims, proj_dim=proj_dim)
elif fusion == 'late':
return LateFusionModel(backbone_name, modality_dims, num_classes, hidden_dim,
proj_dim, late_agg=late_agg)
elif fusion == 'attention':
return AttentionFusionModel(backbone_name, modality_dims, num_classes, hidden_dim, proj_dim)
elif fusion == 'weighted_late':
return WeightedLateFusionModel(backbone_name, modality_dims, num_classes, hidden_dim, proj_dim)
elif fusion == 'gated_late':
return GatedLateFusionModel(backbone_name, modality_dims, num_classes, hidden_dim, proj_dim)
elif fusion == 'stacking':
return StackingFusionModel(backbone_name, modality_dims, num_classes, hidden_dim, proj_dim)
elif fusion == 'product':
return ProductOfExpertsModel(backbone_name, modality_dims, num_classes, hidden_dim, proj_dim)
elif fusion == 'moe':
return MoEFusionModel(backbone_name, modality_dims, num_classes, hidden_dim, proj_dim)
elif fusion == 'feat_concat':
return FeatureConcatFusionModel(backbone_name, modality_dims, num_classes, hidden_dim, proj_dim)
else:
raise ValueError(f"Unknown fusion: {fusion}")