| """ |
| Model definitions for Experiment 1: Scene Recognition. |
| Backbones: CNN1D, BiLSTM, Transformer |
| Fusion: Early (default), Late, Attention, WeightedLate, GatedLate, Stacking, Product, MoE |
| |
| Supports optional per-modality projection via proj_dim parameter: |
| proj_dim > 0: project each modality to proj_dim before backbone |
| proj_dim = 0: no projection, use raw features (original behavior) |
| """ |
|
|
| import math |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
|
|
| |
| |
| |
|
|
| class ModalityProjector(nn.Module): |
| """Project each modality from its raw dimension to proj_dim.""" |
|
|
| def __init__(self, modality_dims, proj_dim): |
| super().__init__() |
| self.mod_names = list(modality_dims.keys()) |
| self.mod_dims = list(modality_dims.values()) |
| self.proj_dim = proj_dim |
| self.projectors = nn.ModuleList() |
| for dim in self.mod_dims: |
| self.projectors.append(nn.Sequential( |
| nn.Linear(dim, proj_dim), |
| nn.LayerNorm(proj_dim), |
| nn.ReLU(), |
| )) |
|
|
| @property |
| def output_dim(self): |
| return self.proj_dim * len(self.mod_dims) |
|
|
| def forward(self, x): |
| """x: (B, T, total_raw_dim) -> (B, T, proj_dim * M)""" |
| parts = [] |
| offset = 0 |
| for i, dim in enumerate(self.mod_dims): |
| x_mod = x[:, :, offset:offset + dim] |
| offset += dim |
| parts.append(self.projectors[i](x_mod)) |
| return torch.cat(parts, dim=-1) |
|
|
|
|
| |
| |
| |
|
|
| def _compute_per_modality_hidden(mod_dim, base_hidden_dim): |
| if mod_dim >= 128: |
| return max(base_hidden_dim, 48) |
| elif mod_dim >= 32: |
| return base_hidden_dim |
| else: |
| return max(16, base_hidden_dim // 2) |
|
|
|
|
| |
| |
| |
|
|
| class CNN1DBackbone(nn.Module): |
| def __init__(self, input_dim, hidden_dim=128): |
| super().__init__() |
| self.conv1 = nn.Sequential( |
| nn.Conv1d(input_dim, 64, kernel_size=7, padding=3), |
| nn.BatchNorm1d(64), nn.ReLU(), nn.Dropout(0.1), |
| ) |
| self.conv2 = nn.Sequential( |
| nn.Conv1d(64, 128, kernel_size=5, padding=2), |
| nn.BatchNorm1d(128), nn.ReLU(), nn.Dropout(0.1), |
| ) |
| self.conv3 = nn.Sequential( |
| nn.Conv1d(128, hidden_dim, kernel_size=3, padding=1), |
| nn.BatchNorm1d(hidden_dim), nn.ReLU(), |
| ) |
| self.output_dim = hidden_dim |
|
|
| def forward(self, x, mask=None): |
| x = x.permute(0, 2, 1) |
| x = self.conv1(x) |
| x = self.conv2(x) |
| x = self.conv3(x) |
| if mask is not None: |
| x = (x * mask.unsqueeze(1).float()).sum(2) / mask.sum(1, keepdim=True).float().clamp(min=1) |
| else: |
| x = x.mean(2) |
| return x |
|
|
|
|
| class LSTMBackbone(nn.Module): |
| def __init__(self, input_dim, hidden_dim=128, num_layers=2, dropout=0.2): |
| super().__init__() |
| self.lstm = nn.LSTM( |
| input_dim, hidden_dim, num_layers=num_layers, |
| batch_first=True, bidirectional=True, |
| dropout=dropout if num_layers > 1 else 0, |
| ) |
| self.attn = nn.Linear(hidden_dim * 2, 1) |
| self.output_dim = hidden_dim * 2 |
|
|
| def forward(self, x, mask=None): |
| out, _ = self.lstm(x) |
| scores = self.attn(out).squeeze(-1) |
| if mask is not None: |
| scores = scores.masked_fill(~mask, float('-inf')) |
| weights = torch.softmax(scores, dim=1) |
| out = (out * weights.unsqueeze(-1)).sum(dim=1) |
| return out |
|
|
|
|
| class TinyHARBackbone(nn.Module): |
| """TinyHAR backbone (Zhou et al., ISWC 2022 Best Paper). |
| |
| Lightweight model for human activity recognition from wearable sensors. |
| Uses multi-scale temporal convolutions + cross-channel interaction + temporal pooling. |
| |
| Input: (B, T, C) with optional mask |
| Output: (B, hidden_dim) |
| """ |
|
|
| def __init__(self, input_dim, hidden_dim=128, num_scales=4): |
| super().__init__() |
| scale_dim = max(4, hidden_dim // num_scales) |
| actual_hidden = scale_dim * num_scales |
|
|
| |
| self.convs = nn.ModuleList() |
| for i in range(num_scales): |
| ks = 2 * (i + 1) + 1 |
| self.convs.append(nn.Sequential( |
| nn.Conv1d(input_dim, scale_dim, kernel_size=ks, padding=ks // 2), |
| nn.BatchNorm1d(scale_dim), |
| nn.ReLU(), |
| )) |
|
|
| |
| nhead = max(1, min(4, actual_hidden // 8)) |
| |
| while actual_hidden % nhead != 0 and nhead > 1: |
| nhead -= 1 |
| self.channel_attn = nn.MultiheadAttention( |
| actual_hidden, num_heads=nhead, batch_first=True, dropout=0.1, |
| ) |
| self.channel_norm = nn.LayerNorm(actual_hidden) |
| self.channel_ff = nn.Sequential( |
| nn.Linear(actual_hidden, actual_hidden), |
| nn.ReLU(), |
| nn.Dropout(0.1), |
| nn.Linear(actual_hidden, actual_hidden), |
| ) |
| self.ff_norm = nn.LayerNorm(actual_hidden) |
|
|
| |
| self.temporal_query = nn.Parameter(torch.randn(1, 1, actual_hidden) * 0.02) |
| self.temporal_attn = nn.MultiheadAttention( |
| actual_hidden, num_heads=1, batch_first=True, dropout=0.1, |
| ) |
|
|
| self.output_dim = actual_hidden |
|
|
| def forward(self, x, mask=None): |
| |
| B, T, C = x.shape |
| x_t = x.permute(0, 2, 1) |
|
|
| |
| scale_features = [conv(x_t) for conv in self.convs] |
| x = torch.cat(scale_features, dim=1) |
| x = x.permute(0, 2, 1) |
|
|
| |
| key_padding_mask = ~mask if mask is not None else None |
| attn_out, _ = self.channel_attn(x, x, x, key_padding_mask=key_padding_mask) |
| x = self.channel_norm(x + attn_out) |
| x = self.ff_norm(x + self.channel_ff(x)) |
|
|
| |
| query = self.temporal_query.expand(B, -1, -1) |
| pooled, _ = self.temporal_attn(query, x, x, key_padding_mask=key_padding_mask) |
| return pooled.squeeze(1) |
|
|
|
|
| class PositionalEncoding(nn.Module): |
| def __init__(self, d_model, dropout=0.1, max_len=5000): |
| super().__init__() |
| self.dropout = nn.Dropout(p=dropout) |
| pe = torch.zeros(max_len, d_model) |
| position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) |
| div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) |
| pe[:, 0::2] = torch.sin(position * div_term) |
| pe[:, 1::2] = torch.cos(position * div_term) |
| pe = pe.unsqueeze(0) |
| self.register_buffer('pe', pe) |
|
|
| def forward(self, x): |
| x = x + self.pe[:, :x.size(1)] |
| return self.dropout(x) |
|
|
|
|
| class TransformerBackbone(nn.Module): |
| def __init__(self, input_dim, d_model=128, nhead=4, num_layers=2, dropout=0.1): |
| super().__init__() |
| self.input_proj = nn.Linear(input_dim, d_model) |
| self.pos_enc = PositionalEncoding(d_model, dropout=dropout) |
| encoder_layer = nn.TransformerEncoderLayer( |
| d_model=d_model, nhead=nhead, dim_feedforward=d_model * 4, |
| dropout=dropout, batch_first=True, |
| ) |
| self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers) |
| self.output_dim = d_model |
|
|
| def forward(self, x, mask=None): |
| x = self.input_proj(x) |
| x = self.pos_enc(x) |
| src_key_padding_mask = ~mask if mask is not None else None |
| x = self.encoder(x, src_key_padding_mask=src_key_padding_mask) |
| if mask is not None: |
| x = (x * mask.unsqueeze(-1).float()).sum(1) / mask.sum(1, keepdim=True).float().clamp(min=1) |
| else: |
| x = x.mean(1) |
| return x |
|
|
|
|
| |
| |
| |
|
|
| def get_backbone(name, input_dim, hidden_dim=128): |
| if name == 'cnn': |
| return CNN1DBackbone(input_dim, hidden_dim) |
| elif name == 'lstm': |
| return LSTMBackbone(input_dim, hidden_dim) |
| elif name == 'transformer': |
| return TransformerBackbone(input_dim, hidden_dim) |
| elif name == 'tinyhar': |
| return TinyHARBackbone(input_dim, hidden_dim) |
| elif name == 'deepconvlstm': |
| from experiments.published_models import DeepConvLSTMBackbone |
| return DeepConvLSTMBackbone(input_dim, hidden_dim) |
| elif name == 'inceptiontime': |
| from experiments.published_models import InceptionTimeBackbone |
| return InceptionTimeBackbone(input_dim, hidden_dim) |
| else: |
| raise ValueError(f"Unknown backbone: {name}") |
|
|
|
|
| def _make_branch(backbone_name, raw_dim, hidden_dim, proj_dim): |
| """Create optional projector + backbone for one modality branch.""" |
| if proj_dim > 0: |
| proj = nn.Sequential( |
| nn.Linear(raw_dim, proj_dim), |
| nn.LayerNorm(proj_dim), |
| nn.ReLU(), |
| ) |
| bb_input = proj_dim |
| bb_hidden = hidden_dim |
| else: |
| proj = None |
| bb_input = raw_dim |
| bb_hidden = _compute_per_modality_hidden(raw_dim, hidden_dim) |
| bb = get_backbone(backbone_name, bb_input, bb_hidden) |
| return proj, bb |
|
|
|
|
| class SingleModel(nn.Module): |
| """Single backbone + classifier (early fusion or single-modality).""" |
|
|
| def __init__(self, backbone_name, input_dim, num_classes, hidden_dim=128, |
| modality_dims=None, proj_dim=0): |
| super().__init__() |
| self.projector = None |
| if proj_dim > 0 and modality_dims: |
| self.projector = ModalityProjector(modality_dims, proj_dim) |
| actual_input_dim = self.projector.output_dim |
| else: |
| actual_input_dim = input_dim |
| self.backbone = get_backbone(backbone_name, actual_input_dim, hidden_dim) |
| self.classifier = nn.Sequential( |
| nn.Dropout(0.5), |
| nn.Linear(self.backbone.output_dim, num_classes), |
| ) |
|
|
| def forward(self, x, mask=None): |
| if self.projector is not None: |
| x = self.projector(x) |
| feat = self.backbone(x, mask) |
| return self.classifier(feat) |
|
|
|
|
| class LateFusionModel(nn.Module): |
| """Late fusion: separate backbone per modality, configurable logit aggregation. |
| |
| late_agg='mean': simple average (original) |
| late_agg='confidence': entropy-based confidence weighting (0 extra params) |
| late_agg='learned': temperature-scaled learned weights (M+1 extra params) |
| """ |
|
|
| def __init__(self, backbone_name, modality_dims, num_classes, hidden_dim=64, |
| proj_dim=0, late_agg='mean'): |
| super().__init__() |
| self.mod_names = list(modality_dims.keys()) |
| self.mod_dims = list(modality_dims.values()) |
| self.late_agg = late_agg |
| self.projectors = nn.ModuleList() |
| self.backbones = nn.ModuleList() |
| self.classifiers = nn.ModuleList() |
| for dim in self.mod_dims: |
| proj, bb = _make_branch(backbone_name, dim, hidden_dim, proj_dim) |
| self.projectors.append(proj if proj else nn.Identity()) |
| self.backbones.append(bb) |
| self.classifiers.append(nn.Sequential( |
| nn.Dropout(0.5), nn.Linear(bb.output_dim, num_classes), |
| )) |
| self._has_proj = proj_dim > 0 |
|
|
| M = len(self.mod_dims) |
| if late_agg == 'learned': |
| self.modality_logits = nn.Parameter(torch.zeros(M)) |
| self.temperature = nn.Parameter(torch.ones(1)) |
|
|
| def forward(self, x, mask=None): |
| offset = 0 |
| all_logits = [] |
| for i, dim in enumerate(self.mod_dims): |
| x_mod = x[:, :, offset:offset + dim] |
| offset += dim |
| if self._has_proj: |
| x_mod = self.projectors[i](x_mod) |
| feat = self.backbones[i](x_mod, mask) |
| all_logits.append(self.classifiers[i](feat)) |
|
|
| stacked = torch.stack(all_logits, dim=0) |
|
|
| if self.late_agg == 'confidence': |
| |
| probs = F.softmax(stacked, dim=-1) |
| entropy = -(probs * (probs + 1e-8).log()).sum(dim=-1) |
| weights = F.softmax(-entropy, dim=0).unsqueeze(-1) |
| return (stacked * weights).sum(dim=0) |
| elif self.late_agg == 'learned': |
| weights = F.softmax(self.modality_logits / self.temperature, dim=0) |
| return (stacked * weights.view(-1, 1, 1)).sum(dim=0) |
| else: |
| return stacked.mean(dim=0) |
|
|
|
|
| class AttentionFusionModel(nn.Module): |
| """Attention fusion: separate encoder per modality -> cross-modal attention -> classifier.""" |
|
|
| def __init__(self, backbone_name, modality_dims, num_classes, hidden_dim=64, proj_dim=0): |
| super().__init__() |
| self.mod_names = list(modality_dims.keys()) |
| self.mod_dims = list(modality_dims.values()) |
| unified_dim = hidden_dim |
| self.projectors = nn.ModuleList() |
| self.backbones = nn.ModuleList() |
| self.feat_projections = nn.ModuleList() |
| for dim in self.mod_dims: |
| proj, bb = _make_branch(backbone_name, dim, hidden_dim, proj_dim) |
| self.projectors.append(proj if proj else nn.Identity()) |
| self.backbones.append(bb) |
| if bb.output_dim != unified_dim: |
| self.feat_projections.append(nn.Linear(bb.output_dim, unified_dim)) |
| else: |
| self.feat_projections.append(nn.Identity()) |
| self._has_proj = proj_dim > 0 |
| nhead = 4 if unified_dim % 4 == 0 else (2 if unified_dim % 2 == 0 else 1) |
| self.cross_attn = nn.TransformerEncoderLayer( |
| d_model=unified_dim, nhead=nhead, dim_feedforward=unified_dim * 2, |
| dropout=0.1, batch_first=True, |
| ) |
| self.classifier = nn.Sequential( |
| nn.Dropout(0.5), nn.Linear(unified_dim, num_classes), |
| ) |
|
|
| def forward(self, x, mask=None): |
| offset = 0 |
| mod_features = [] |
| for i, dim in enumerate(self.mod_dims): |
| x_mod = x[:, :, offset:offset + dim] |
| offset += dim |
| if self._has_proj: |
| x_mod = self.projectors[i](x_mod) |
| feat = self.backbones[i](x_mod, mask) |
| feat = self.feat_projections[i](feat) |
| mod_features.append(feat) |
| tokens = torch.stack(mod_features, dim=1) |
| tokens = self.cross_attn(tokens) |
| pooled = tokens.mean(dim=1) |
| return self.classifier(pooled) |
|
|
|
|
| class WeightedLateFusionModel(nn.Module): |
| def __init__(self, backbone_name, modality_dims, num_classes, hidden_dim=64, proj_dim=0): |
| super().__init__() |
| self.mod_names = list(modality_dims.keys()) |
| self.mod_dims = list(modality_dims.values()) |
| self.projectors = nn.ModuleList() |
| self.backbones = nn.ModuleList() |
| self.classifiers = nn.ModuleList() |
| for dim in self.mod_dims: |
| proj, bb = _make_branch(backbone_name, dim, hidden_dim, proj_dim) |
| self.projectors.append(proj if proj else nn.Identity()) |
| self.backbones.append(bb) |
| self.classifiers.append(nn.Sequential( |
| nn.Dropout(0.5), nn.Linear(bb.output_dim, num_classes), |
| )) |
| self._has_proj = proj_dim > 0 |
| self.modality_weights = nn.Parameter(torch.ones(len(self.mod_dims))) |
|
|
| def forward(self, x, mask=None): |
| offset = 0 |
| all_logits = [] |
| for i, dim in enumerate(self.mod_dims): |
| x_mod = x[:, :, offset:offset + dim] |
| offset += dim |
| if self._has_proj: |
| x_mod = self.projectors[i](x_mod) |
| feat = self.backbones[i](x_mod, mask) |
| all_logits.append(self.classifiers[i](feat)) |
| weights = F.softmax(self.modality_weights, dim=0) |
| stacked = torch.stack(all_logits, dim=0) |
| return (stacked * weights.view(-1, 1, 1)).sum(dim=0) |
|
|
|
|
| class GatedLateFusionModel(nn.Module): |
| def __init__(self, backbone_name, modality_dims, num_classes, hidden_dim=64, proj_dim=0): |
| super().__init__() |
| self.mod_names = list(modality_dims.keys()) |
| self.mod_dims = list(modality_dims.values()) |
| M = len(self.mod_dims) |
| self.projectors = nn.ModuleList() |
| self.backbones = nn.ModuleList() |
| self.classifiers = nn.ModuleList() |
| total_feat_dim = 0 |
| for dim in self.mod_dims: |
| proj, bb = _make_branch(backbone_name, dim, hidden_dim, proj_dim) |
| self.projectors.append(proj if proj else nn.Identity()) |
| self.backbones.append(bb) |
| total_feat_dim += bb.output_dim |
| self.classifiers.append(nn.Sequential( |
| nn.Dropout(0.5), nn.Linear(bb.output_dim, num_classes), |
| )) |
| self._has_proj = proj_dim > 0 |
| self.gate = nn.Sequential( |
| nn.Linear(total_feat_dim, 32), nn.ReLU(), nn.Linear(32, M), |
| ) |
|
|
| def forward(self, x, mask=None): |
| offset = 0 |
| all_feats, all_logits = [], [] |
| for i, dim in enumerate(self.mod_dims): |
| x_mod = x[:, :, offset:offset + dim] |
| offset += dim |
| if self._has_proj: |
| x_mod = self.projectors[i](x_mod) |
| feat = self.backbones[i](x_mod, mask) |
| all_feats.append(feat) |
| all_logits.append(self.classifiers[i](feat)) |
| cat_feats = torch.cat(all_feats, dim=1) |
| gate_weights = F.softmax(self.gate(cat_feats), dim=1) |
| stacked = torch.stack(all_logits, dim=1) |
| return (stacked * gate_weights.unsqueeze(-1)).sum(dim=1) |
|
|
|
|
| class StackingFusionModel(nn.Module): |
| def __init__(self, backbone_name, modality_dims, num_classes, hidden_dim=64, proj_dim=0): |
| super().__init__() |
| self.mod_names = list(modality_dims.keys()) |
| self.mod_dims = list(modality_dims.values()) |
| M = len(self.mod_dims) |
| self.projectors = nn.ModuleList() |
| self.backbones = nn.ModuleList() |
| self.classifiers = nn.ModuleList() |
| for dim in self.mod_dims: |
| proj, bb = _make_branch(backbone_name, dim, hidden_dim, proj_dim) |
| self.projectors.append(proj if proj else nn.Identity()) |
| self.backbones.append(bb) |
| self.classifiers.append(nn.Sequential( |
| nn.Dropout(0.5), nn.Linear(bb.output_dim, num_classes), |
| )) |
| self._has_proj = proj_dim > 0 |
| self.meta_learner = nn.Sequential( |
| nn.Linear(M * num_classes, 32), nn.ReLU(), |
| nn.Dropout(0.5), nn.Linear(32, num_classes), |
| ) |
|
|
| def forward(self, x, mask=None): |
| offset = 0 |
| all_logits = [] |
| for i, dim in enumerate(self.mod_dims): |
| x_mod = x[:, :, offset:offset + dim] |
| offset += dim |
| if self._has_proj: |
| x_mod = self.projectors[i](x_mod) |
| feat = self.backbones[i](x_mod, mask) |
| all_logits.append(self.classifiers[i](feat)) |
| cat_logits = torch.cat(all_logits, dim=1) |
| return self.meta_learner(cat_logits) |
|
|
|
|
| class ProductOfExpertsModel(nn.Module): |
| def __init__(self, backbone_name, modality_dims, num_classes, hidden_dim=64, proj_dim=0): |
| super().__init__() |
| self.mod_names = list(modality_dims.keys()) |
| self.mod_dims = list(modality_dims.values()) |
| self.projectors = nn.ModuleList() |
| self.backbones = nn.ModuleList() |
| self.classifiers = nn.ModuleList() |
| for dim in self.mod_dims: |
| proj, bb = _make_branch(backbone_name, dim, hidden_dim, proj_dim) |
| self.projectors.append(proj if proj else nn.Identity()) |
| self.backbones.append(bb) |
| self.classifiers.append(nn.Sequential( |
| nn.Dropout(0.5), nn.Linear(bb.output_dim, num_classes), |
| )) |
| self._has_proj = proj_dim > 0 |
|
|
| def forward(self, x, mask=None): |
| offset = 0 |
| log_probs_sum = None |
| for i, dim in enumerate(self.mod_dims): |
| x_mod = x[:, :, offset:offset + dim] |
| offset += dim |
| if self._has_proj: |
| x_mod = self.projectors[i](x_mod) |
| feat = self.backbones[i](x_mod, mask) |
| logits = self.classifiers[i](feat) |
| log_p = F.log_softmax(logits, dim=1) |
| log_probs_sum = log_p if log_probs_sum is None else log_probs_sum + log_p |
| return log_probs_sum |
|
|
|
|
| class MoEFusionModel(nn.Module): |
| def __init__(self, backbone_name, modality_dims, num_classes, hidden_dim=64, proj_dim=0): |
| super().__init__() |
| self.mod_names = list(modality_dims.keys()) |
| self.mod_dims = list(modality_dims.values()) |
| M = len(self.mod_dims) |
| self.top_k = min(2, M) |
| self.projectors = nn.ModuleList() |
| self.backbones = nn.ModuleList() |
| self.classifiers = nn.ModuleList() |
| total_feat_dim = 0 |
| for dim in self.mod_dims: |
| proj, bb = _make_branch(backbone_name, dim, hidden_dim, proj_dim) |
| self.projectors.append(proj if proj else nn.Identity()) |
| self.backbones.append(bb) |
| total_feat_dim += bb.output_dim |
| self.classifiers.append(nn.Sequential( |
| nn.Dropout(0.5), nn.Linear(bb.output_dim, num_classes), |
| )) |
| self._has_proj = proj_dim > 0 |
| self.router = nn.Linear(total_feat_dim, M) |
|
|
| def forward(self, x, mask=None): |
| offset = 0 |
| all_feats, all_logits = [], [] |
| for i, dim in enumerate(self.mod_dims): |
| x_mod = x[:, :, offset:offset + dim] |
| offset += dim |
| if self._has_proj: |
| x_mod = self.projectors[i](x_mod) |
| feat = self.backbones[i](x_mod, mask) |
| all_feats.append(feat) |
| all_logits.append(self.classifiers[i](feat)) |
| cat_feats = torch.cat(all_feats, dim=1) |
| router_logits = self.router(cat_feats) |
| top_vals, top_idx = router_logits.topk(self.top_k, dim=1) |
| top_weights = F.softmax(top_vals, dim=1) |
| stacked = torch.stack(all_logits, dim=1) |
| top_idx_exp = top_idx.unsqueeze(-1).expand(-1, -1, stacked.size(-1)) |
| selected = stacked.gather(1, top_idx_exp) |
| return (selected * top_weights.unsqueeze(-1)).sum(dim=1) |
|
|
|
|
| class FeatureConcatFusionModel(nn.Module): |
| """Feature-level late fusion: separate backbones, concatenate features, joint classifier.""" |
|
|
| def __init__(self, backbone_name, modality_dims, num_classes, hidden_dim=64, proj_dim=0): |
| super().__init__() |
| self.mod_names = list(modality_dims.keys()) |
| self.mod_dims = list(modality_dims.values()) |
| self.projectors = nn.ModuleList() |
| self.backbones = nn.ModuleList() |
| total_feat_dim = 0 |
| for dim in self.mod_dims: |
| proj, bb = _make_branch(backbone_name, dim, hidden_dim, proj_dim) |
| self.projectors.append(proj if proj else nn.Identity()) |
| self.backbones.append(bb) |
| total_feat_dim += bb.output_dim |
| self._has_proj = proj_dim > 0 |
| self.classifier = nn.Sequential( |
| nn.LayerNorm(total_feat_dim), |
| nn.Dropout(0.5), |
| nn.Linear(total_feat_dim, hidden_dim), |
| nn.ReLU(), |
| nn.Dropout(0.3), |
| nn.Linear(hidden_dim, num_classes), |
| ) |
|
|
| def forward(self, x, mask=None): |
| offset = 0 |
| all_feats = [] |
| for i, dim in enumerate(self.mod_dims): |
| x_mod = x[:, :, offset:offset + dim] |
| offset += dim |
| if self._has_proj: |
| x_mod = self.projectors[i](x_mod) |
| feat = self.backbones[i](x_mod, mask) |
| all_feats.append(feat) |
| cat_feats = torch.cat(all_feats, dim=1) |
| return self.classifier(cat_feats) |
|
|
|
|
| def build_model(backbone_name, fusion, input_dim, modality_dims, num_classes, |
| hidden_dim=128, proj_dim=0, late_agg='mean'): |
| """Factory function. proj_dim=0 means no projection (raw features).""" |
| if fusion == 'early': |
| return SingleModel(backbone_name, input_dim, num_classes, hidden_dim, |
| modality_dims=modality_dims, proj_dim=proj_dim) |
| elif fusion == 'late': |
| return LateFusionModel(backbone_name, modality_dims, num_classes, hidden_dim, |
| proj_dim, late_agg=late_agg) |
| elif fusion == 'attention': |
| return AttentionFusionModel(backbone_name, modality_dims, num_classes, hidden_dim, proj_dim) |
| elif fusion == 'weighted_late': |
| return WeightedLateFusionModel(backbone_name, modality_dims, num_classes, hidden_dim, proj_dim) |
| elif fusion == 'gated_late': |
| return GatedLateFusionModel(backbone_name, modality_dims, num_classes, hidden_dim, proj_dim) |
| elif fusion == 'stacking': |
| return StackingFusionModel(backbone_name, modality_dims, num_classes, hidden_dim, proj_dim) |
| elif fusion == 'product': |
| return ProductOfExpertsModel(backbone_name, modality_dims, num_classes, hidden_dim, proj_dim) |
| elif fusion == 'moe': |
| return MoEFusionModel(backbone_name, modality_dims, num_classes, hidden_dim, proj_dim) |
| elif fusion == 'feat_concat': |
| return FeatureConcatFusionModel(backbone_name, modality_dims, num_classes, hidden_dim, proj_dim) |
| else: |
| raise ValueError(f"Unknown fusion: {fusion}") |
|
|