File size: 634 Bytes
df6cf36 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 | import torch
import torch.nn as nn
from core.nse.engine import NeuroSynthEngine
class TripleHeadSyncManager(nn.Module):
def __init__(self, d_model=256):
super().__init__()
self.heads = nn.ModuleList([NeuroSynthEngine(d_model) for _ in range(3)])
def forward(self, x, fe_stats):
outputs = [head(x, fe_stats) for head in self.heads]
# Consolidate: Average the reconstructions, Max-pool the LR multipliers
recons = torch.stack([o[0] for o in outputs]).mean(dim=0)
lr_multipliers = torch.stack([o[1] for o in outputs]).max(dim=0)[0]
return recons, lr_multipliers
|