import torch import torch.nn as nn import torch.nn.functional as F class UltimateMOE(nn.Module): def __init__(self, experts): super(UltimateMOE, self).__init__() self.threshold = 0.27 self.experts = nn.ModuleList(experts) num_experts = len(experts) self.lrelu = nn.LeakyReLU() self.bn = nn.BatchNorm1d(32) self.fc1 = nn.Linear(64, 32) self.fc2 = nn.Linear(64, 32) self.fc3 = nn.Linear(64, 32) self.fc4 = nn.Linear(64, 32) self.pooling = nn.Parameter(torch.ones(32)) self.gating_network = nn.Sequential( nn.Linear(32 * (num_experts + 1), 64), nn.Dropout(0.2), nn.BatchNorm1d(64), nn.LeakyReLU(), nn.Linear(64, num_experts), ) self.softmax = nn.Softmax(dim=1) def forward(self, x): outputs = [expert(x)[0] for expert in self.experts] embeddings = [expert(x)[1] for expert in self.experts] emb_1 = self.lrelu(self.bn(self.fc1(embeddings[0]))) emb_2 = self.lrelu(self.bn(self.fc2(embeddings[1]))) emb_3 = self.lrelu(self.bn(self.fc3(embeddings[2]))) emb_4 = self.lrelu(self.bn(self.fc4(embeddings[3]))) combined = emb_1 * emb_2 * emb_3 * emb_4 weighted_combined = combined * self.pooling.unsqueeze(0) concatenated_embeddings = torch.cat((emb_1, emb_2, emb_3, emb_4, weighted_combined), dim=1) gating_weights = self.gating_network(concatenated_embeddings) gating_weights = F.softmax(gating_weights, dim=-1) weighted_logits = torch.stack(outputs, dim=-1) weighted_logits = torch.einsum('bn,bcn->bc', gating_weights, weighted_logits) score = self.softmax(weighted_logits) return score class MOE_attention(nn.Module): def __init__(self, experts, device, input_dim=128, freezing=False): super(MOE_attention, self).__init__() self.threshold = 0.1 self.temperature = 1.2 self.device = device self.experts = nn.ModuleList(experts) self.num_experts = len(experts) # self.proc_emb = nn.ModuleList([ # nn.Sequential( # nn.Linear(input_dim, 128), # nn.BatchNorm1d(128), # nn.GLU(), # nn.Linear(64, 32) # ) for _ in range(self.num_experts) # ]) self.proc_emb = nn.ModuleList([ nn.Sequential( nn.Linear(128, 128), nn.BatchNorm1d(128), nn.GLU(), nn.Linear(64, 32) ), nn.Sequential( nn.Linear(256, 128), nn.BatchNorm1d(128), nn.GLU(), nn.Linear(64, 32) ), nn.Sequential( nn.Linear(256, 128), nn.BatchNorm1d(128), nn.GLU(), nn.Linear(64, 32) ) ]) self.TransfEnc = nn.Sequential( nn.TransformerEncoderLayer(d_model=32, nhead=4, dropout=0.1, dim_feedforward=512), nn.TransformerEncoderLayer(d_model=32, nhead=4, dropout=0.1, dim_feedforward=512) ) self.linear_out = nn.Linear(32, 1) self.softmax = nn.Softmax(dim=1) if freezing: for expert in self.experts: for param in expert.parameters(): param.requires_grad = False def forward(self, x): results = [expert(x) for expert in self.experts] outputs = [res[0] for res in results] embeddings = [res[1] for res in results] processed_embs = torch.stack([proc_emb(emb) for proc_emb, emb in zip(self.proc_emb, embeddings)], dim=1) transf_out = self.TransfEnc(processed_embs) gating_weights = self.linear_out(transf_out) gating_weights = self.softmax(gating_weights / self.temperature) expert_outputs = torch.stack(outputs, dim=1) combined_output = torch.sum(gating_weights * expert_outputs, dim=1) return combined_output class MOE_attention_FS(nn.Module): def __init__(self, experts, device, input_dim=128, freezing=False): super(MOE_attention_FS, self).__init__() self.threshold = 0.5 self.temperature = 1.2 self.device = device self.experts = nn.ModuleList(experts) self.num_experts = len(experts) self.proc_emb = nn.ModuleList([ nn.Sequential( nn.Linear(128, 128), nn.BatchNorm1d(128), nn.GLU(), nn.Linear(64, 32) ) for _ in range(self.num_experts) ]) self.TransfEnc = nn.Sequential( nn.TransformerEncoderLayer(d_model=32, nhead=4, dropout=0.1, dim_feedforward=512), nn.TransformerEncoderLayer(d_model=32, nhead=4, dropout=0.1, dim_feedforward=512) ) self.linear_out = nn.Linear(32, 1) self.softmax = nn.Softmax(dim=1) if freezing: for expert in self.experts: for param in expert.parameters(): param.requires_grad = False def forward(self, x_16, x_22, x_24): results = [self.experts[0](x_16), self.experts[1](x_22), self.experts[2](x_24)] # results = [expert(x) for expert in self.experts] outputs = [res[0] for res in results] embeddings = [res[1] for res in results] processed_embs = torch.stack([proc_emb(emb) for proc_emb, emb in zip(self.proc_emb, embeddings)], dim=1) transf_out = self.TransfEnc(processed_embs) gating_weights = self.linear_out(transf_out) gating_weights = self.softmax(gating_weights / self.temperature) expert_outputs = torch.stack(outputs, dim=1) combined_output = torch.sum(gating_weights * expert_outputs, dim=1) return combined_output