ispl_safe / src /moe_model.py
davesalvi's picture
change thres
5e70df1
import torch
import torch.nn as nn
import torch.nn.functional as F
class UltimateMOE(nn.Module):
def __init__(self, experts):
super(UltimateMOE, self).__init__()
self.threshold = 0.27
self.experts = nn.ModuleList(experts)
num_experts = len(experts)
self.lrelu = nn.LeakyReLU()
self.bn = nn.BatchNorm1d(32)
self.fc1 = nn.Linear(64, 32)
self.fc2 = nn.Linear(64, 32)
self.fc3 = nn.Linear(64, 32)
self.fc4 = nn.Linear(64, 32)
self.pooling = nn.Parameter(torch.ones(32))
self.gating_network = nn.Sequential(
nn.Linear(32 * (num_experts + 1), 64),
nn.Dropout(0.2),
nn.BatchNorm1d(64),
nn.LeakyReLU(),
nn.Linear(64, num_experts),
)
self.softmax = nn.Softmax(dim=1)
def forward(self, x):
outputs = [expert(x)[0] for expert in self.experts]
embeddings = [expert(x)[1] for expert in self.experts]
emb_1 = self.lrelu(self.bn(self.fc1(embeddings[0])))
emb_2 = self.lrelu(self.bn(self.fc2(embeddings[1])))
emb_3 = self.lrelu(self.bn(self.fc3(embeddings[2])))
emb_4 = self.lrelu(self.bn(self.fc4(embeddings[3])))
combined = emb_1 * emb_2 * emb_3 * emb_4
weighted_combined = combined * self.pooling.unsqueeze(0)
concatenated_embeddings = torch.cat((emb_1, emb_2, emb_3, emb_4, weighted_combined), dim=1)
gating_weights = self.gating_network(concatenated_embeddings)
gating_weights = F.softmax(gating_weights, dim=-1)
weighted_logits = torch.stack(outputs, dim=-1)
weighted_logits = torch.einsum('bn,bcn->bc', gating_weights, weighted_logits)
score = self.softmax(weighted_logits)
return score
class MOE_attention(nn.Module):
def __init__(self, experts, device, input_dim=128, freezing=False):
super(MOE_attention, self).__init__()
self.threshold = 0.1
self.temperature = 1.2
self.device = device
self.experts = nn.ModuleList(experts)
self.num_experts = len(experts)
# self.proc_emb = nn.ModuleList([
# nn.Sequential(
# nn.Linear(input_dim, 128),
# nn.BatchNorm1d(128),
# nn.GLU(),
# nn.Linear(64, 32)
# ) for _ in range(self.num_experts)
# ])
self.proc_emb = nn.ModuleList([
nn.Sequential(
nn.Linear(128, 128),
nn.BatchNorm1d(128),
nn.GLU(),
nn.Linear(64, 32)
),
nn.Sequential(
nn.Linear(256, 128),
nn.BatchNorm1d(128),
nn.GLU(),
nn.Linear(64, 32)
),
nn.Sequential(
nn.Linear(256, 128),
nn.BatchNorm1d(128),
nn.GLU(),
nn.Linear(64, 32)
)
])
self.TransfEnc = nn.Sequential(
nn.TransformerEncoderLayer(d_model=32, nhead=4, dropout=0.1, dim_feedforward=512),
nn.TransformerEncoderLayer(d_model=32, nhead=4, dropout=0.1, dim_feedforward=512)
)
self.linear_out = nn.Linear(32, 1)
self.softmax = nn.Softmax(dim=1)
if freezing:
for expert in self.experts:
for param in expert.parameters():
param.requires_grad = False
def forward(self, x):
results = [expert(x) for expert in self.experts]
outputs = [res[0] for res in results]
embeddings = [res[1] for res in results]
processed_embs = torch.stack([proc_emb(emb) for proc_emb, emb in zip(self.proc_emb, embeddings)], dim=1)
transf_out = self.TransfEnc(processed_embs)
gating_weights = self.linear_out(transf_out)
gating_weights = self.softmax(gating_weights / self.temperature)
expert_outputs = torch.stack(outputs, dim=1)
combined_output = torch.sum(gating_weights * expert_outputs, dim=1)
return combined_output
class MOE_attention_FS(nn.Module):
def __init__(self, experts, device, input_dim=128, freezing=False):
super(MOE_attention_FS, self).__init__()
self.threshold = 0.5
self.temperature = 1.2
self.device = device
self.experts = nn.ModuleList(experts)
self.num_experts = len(experts)
self.proc_emb = nn.ModuleList([
nn.Sequential(
nn.Linear(128, 128),
nn.BatchNorm1d(128),
nn.GLU(),
nn.Linear(64, 32)
) for _ in range(self.num_experts)
])
self.TransfEnc = nn.Sequential(
nn.TransformerEncoderLayer(d_model=32, nhead=4, dropout=0.1, dim_feedforward=512),
nn.TransformerEncoderLayer(d_model=32, nhead=4, dropout=0.1, dim_feedforward=512)
)
self.linear_out = nn.Linear(32, 1)
self.softmax = nn.Softmax(dim=1)
if freezing:
for expert in self.experts:
for param in expert.parameters():
param.requires_grad = False
def forward(self, x_16, x_22, x_24):
results = [self.experts[0](x_16), self.experts[1](x_22), self.experts[2](x_24)]
# results = [expert(x) for expert in self.experts]
outputs = [res[0] for res in results]
embeddings = [res[1] for res in results]
processed_embs = torch.stack([proc_emb(emb) for proc_emb, emb in zip(self.proc_emb, embeddings)], dim=1)
transf_out = self.TransfEnc(processed_embs)
gating_weights = self.linear_out(transf_out)
gating_weights = self.softmax(gating_weights / self.temperature)
expert_outputs = torch.stack(outputs, dim=1)
combined_output = torch.sum(gating_weights * expert_outputs, dim=1)
return combined_output