File size: 634 Bytes
df6cf36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import torch
import torch.nn as nn
from core.nse.engine import NeuroSynthEngine

class TripleHeadSyncManager(nn.Module):
    def __init__(self, d_model=256):
        super().__init__()
        self.heads = nn.ModuleList([NeuroSynthEngine(d_model) for _ in range(3)])

    def forward(self, x, fe_stats):
        outputs = [head(x, fe_stats) for head in self.heads]
        # Consolidate: Average the reconstructions, Max-pool the LR multipliers
        recons = torch.stack([o[0] for o in outputs]).mean(dim=0)
        lr_multipliers = torch.stack([o[1] for o in outputs]).max(dim=0)[0]
        
        return recons, lr_multipliers