|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
|
class UltimateMOE(nn.Module): |
|
|
def __init__(self, experts): |
|
|
super(UltimateMOE, self).__init__() |
|
|
|
|
|
self.threshold = 0.27 |
|
|
|
|
|
self.experts = nn.ModuleList(experts) |
|
|
num_experts = len(experts) |
|
|
|
|
|
self.lrelu = nn.LeakyReLU() |
|
|
self.bn = nn.BatchNorm1d(32) |
|
|
|
|
|
self.fc1 = nn.Linear(64, 32) |
|
|
self.fc2 = nn.Linear(64, 32) |
|
|
self.fc3 = nn.Linear(64, 32) |
|
|
self.fc4 = nn.Linear(64, 32) |
|
|
|
|
|
self.pooling = nn.Parameter(torch.ones(32)) |
|
|
|
|
|
self.gating_network = nn.Sequential( |
|
|
nn.Linear(32 * (num_experts + 1), 64), |
|
|
nn.Dropout(0.2), |
|
|
nn.BatchNorm1d(64), |
|
|
nn.LeakyReLU(), |
|
|
nn.Linear(64, num_experts), |
|
|
) |
|
|
|
|
|
self.softmax = nn.Softmax(dim=1) |
|
|
|
|
|
def forward(self, x): |
|
|
|
|
|
outputs = [expert(x)[0] for expert in self.experts] |
|
|
embeddings = [expert(x)[1] for expert in self.experts] |
|
|
|
|
|
emb_1 = self.lrelu(self.bn(self.fc1(embeddings[0]))) |
|
|
emb_2 = self.lrelu(self.bn(self.fc2(embeddings[1]))) |
|
|
emb_3 = self.lrelu(self.bn(self.fc3(embeddings[2]))) |
|
|
emb_4 = self.lrelu(self.bn(self.fc4(embeddings[3]))) |
|
|
|
|
|
combined = emb_1 * emb_2 * emb_3 * emb_4 |
|
|
weighted_combined = combined * self.pooling.unsqueeze(0) |
|
|
|
|
|
concatenated_embeddings = torch.cat((emb_1, emb_2, emb_3, emb_4, weighted_combined), dim=1) |
|
|
gating_weights = self.gating_network(concatenated_embeddings) |
|
|
gating_weights = F.softmax(gating_weights, dim=-1) |
|
|
|
|
|
weighted_logits = torch.stack(outputs, dim=-1) |
|
|
weighted_logits = torch.einsum('bn,bcn->bc', gating_weights, weighted_logits) |
|
|
|
|
|
score = self.softmax(weighted_logits) |
|
|
|
|
|
return score |
|
|
|
|
|
|
|
|
|
|
|
class MOE_attention(nn.Module): |
|
|
def __init__(self, experts, device, input_dim=128, freezing=False): |
|
|
super(MOE_attention, self).__init__() |
|
|
|
|
|
self.threshold = 0.1 |
|
|
self.temperature = 1.2 |
|
|
|
|
|
self.device = device |
|
|
self.experts = nn.ModuleList(experts) |
|
|
self.num_experts = len(experts) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.proc_emb = nn.ModuleList([ |
|
|
nn.Sequential( |
|
|
nn.Linear(128, 128), |
|
|
nn.BatchNorm1d(128), |
|
|
nn.GLU(), |
|
|
nn.Linear(64, 32) |
|
|
), |
|
|
nn.Sequential( |
|
|
nn.Linear(256, 128), |
|
|
nn.BatchNorm1d(128), |
|
|
nn.GLU(), |
|
|
nn.Linear(64, 32) |
|
|
), |
|
|
nn.Sequential( |
|
|
nn.Linear(256, 128), |
|
|
nn.BatchNorm1d(128), |
|
|
nn.GLU(), |
|
|
nn.Linear(64, 32) |
|
|
) |
|
|
]) |
|
|
|
|
|
self.TransfEnc = nn.Sequential( |
|
|
nn.TransformerEncoderLayer(d_model=32, nhead=4, dropout=0.1, dim_feedforward=512), |
|
|
nn.TransformerEncoderLayer(d_model=32, nhead=4, dropout=0.1, dim_feedforward=512) |
|
|
) |
|
|
self.linear_out = nn.Linear(32, 1) |
|
|
|
|
|
self.softmax = nn.Softmax(dim=1) |
|
|
|
|
|
if freezing: |
|
|
for expert in self.experts: |
|
|
for param in expert.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
def forward(self, x): |
|
|
|
|
|
results = [expert(x) for expert in self.experts] |
|
|
outputs = [res[0] for res in results] |
|
|
embeddings = [res[1] for res in results] |
|
|
|
|
|
processed_embs = torch.stack([proc_emb(emb) for proc_emb, emb in zip(self.proc_emb, embeddings)], dim=1) |
|
|
|
|
|
transf_out = self.TransfEnc(processed_embs) |
|
|
|
|
|
gating_weights = self.linear_out(transf_out) |
|
|
gating_weights = self.softmax(gating_weights / self.temperature) |
|
|
|
|
|
expert_outputs = torch.stack(outputs, dim=1) |
|
|
|
|
|
combined_output = torch.sum(gating_weights * expert_outputs, dim=1) |
|
|
|
|
|
return combined_output |
|
|
|
|
|
|
|
|
|
|
|
class MOE_attention_FS(nn.Module): |
|
|
def __init__(self, experts, device, input_dim=128, freezing=False): |
|
|
super(MOE_attention_FS, self).__init__() |
|
|
|
|
|
self.threshold = 0.5 |
|
|
self.temperature = 1.2 |
|
|
|
|
|
self.device = device |
|
|
self.experts = nn.ModuleList(experts) |
|
|
self.num_experts = len(experts) |
|
|
|
|
|
self.proc_emb = nn.ModuleList([ |
|
|
nn.Sequential( |
|
|
nn.Linear(128, 128), |
|
|
nn.BatchNorm1d(128), |
|
|
nn.GLU(), |
|
|
nn.Linear(64, 32) |
|
|
) for _ in range(self.num_experts) |
|
|
]) |
|
|
|
|
|
self.TransfEnc = nn.Sequential( |
|
|
nn.TransformerEncoderLayer(d_model=32, nhead=4, dropout=0.1, dim_feedforward=512), |
|
|
nn.TransformerEncoderLayer(d_model=32, nhead=4, dropout=0.1, dim_feedforward=512) |
|
|
) |
|
|
self.linear_out = nn.Linear(32, 1) |
|
|
self.softmax = nn.Softmax(dim=1) |
|
|
|
|
|
if freezing: |
|
|
for expert in self.experts: |
|
|
for param in expert.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
def forward(self, x_16, x_22, x_24): |
|
|
|
|
|
results = [self.experts[0](x_16), self.experts[1](x_22), self.experts[2](x_24)] |
|
|
|
|
|
outputs = [res[0] for res in results] |
|
|
embeddings = [res[1] for res in results] |
|
|
|
|
|
processed_embs = torch.stack([proc_emb(emb) for proc_emb, emb in zip(self.proc_emb, embeddings)], dim=1) |
|
|
|
|
|
transf_out = self.TransfEnc(processed_embs) |
|
|
|
|
|
gating_weights = self.linear_out(transf_out) |
|
|
gating_weights = self.softmax(gating_weights / self.temperature) |
|
|
|
|
|
expert_outputs = torch.stack(outputs, dim=1) |
|
|
|
|
|
combined_output = torch.sum(gating_weights * expert_outputs, dim=1) |
|
|
|
|
|
return combined_output |