attention moe
Browse files- .idea/workspace.xml +3 -2
- script.py +20 -14
- src/moe_model.py +97 -1
.idea/workspace.xml
CHANGED
|
@@ -6,7 +6,8 @@
|
|
| 6 |
<component name="ChangeListManager">
|
| 7 |
<list default="true" id="23565123-73ab-4f40-a9ef-1086e0c9e1ec" name="Changes" comment="">
|
| 8 |
<change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" />
|
| 9 |
-
<change beforePath="$PROJECT_DIR$/
|
|
|
|
| 10 |
</list>
|
| 11 |
<option name="SHOW_DIALOG" value="false" />
|
| 12 |
<option name="HIGHLIGHT_CONFLICTS" value="true" />
|
|
@@ -136,7 +137,7 @@
|
|
| 136 |
<workItem from="1743062628099" duration="35000" />
|
| 137 |
<workItem from="1743063082652" duration="7000" />
|
| 138 |
<workItem from="1743092790258" duration="1395000" />
|
| 139 |
-
<workItem from="1743151940209" duration="
|
| 140 |
</task>
|
| 141 |
<servers />
|
| 142 |
</component>
|
|
|
|
| 6 |
<component name="ChangeListManager">
|
| 7 |
<list default="true" id="23565123-73ab-4f40-a9ef-1086e0c9e1ec" name="Changes" comment="">
|
| 8 |
<change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" />
|
| 9 |
+
<change beforePath="$PROJECT_DIR$/script.py" beforeDir="false" afterPath="$PROJECT_DIR$/script.py" afterDir="false" />
|
| 10 |
+
<change beforePath="$PROJECT_DIR$/src/moe_model.py" beforeDir="false" afterPath="$PROJECT_DIR$/src/moe_model.py" afterDir="false" />
|
| 11 |
</list>
|
| 12 |
<option name="SHOW_DIALOG" value="false" />
|
| 13 |
<option name="HIGHLIGHT_CONFLICTS" value="true" />
|
|
|
|
| 137 |
<workItem from="1743062628099" duration="35000" />
|
| 138 |
<workItem from="1743063082652" duration="7000" />
|
| 139 |
<workItem from="1743092790258" duration="1395000" />
|
| 140 |
+
<workItem from="1743151940209" duration="996000" />
|
| 141 |
</task>
|
| 142 |
<servers />
|
| 143 |
</component>
|
script.py
CHANGED
|
@@ -14,7 +14,7 @@ from preprocess import preprocess
|
|
| 14 |
|
| 15 |
from src.rawnet_model import RawNet
|
| 16 |
from src.lcnn_model import LCNN
|
| 17 |
-
from src.moe_model import UltimateMOE
|
| 18 |
|
| 19 |
# os.environ['OMP_NUM_THREADS'] = '1'
|
| 20 |
# os.environ['OPENBLAS_NUM_THREADS'] = '1'
|
|
@@ -51,22 +51,28 @@ print('Define Model')
|
|
| 51 |
# model_path = './checkpoints/RAWNET_ASVSPOOF_FOR_INTHEWILD_PURDUE.pth'
|
| 52 |
# model.load_state_dict(torch.load(model_path, map_location=device))
|
| 53 |
|
| 54 |
-
# LCNN MODEL
|
| 55 |
-
model = LCNN(return_emb=False).to(device)
|
| 56 |
-
# model_path = './checkpoints/LCNN_ASVSPOOF_FOR_INTHEWILD_PURDUE.pth'
|
| 57 |
-
model_path = './checkpoints/LCNN_ALL_DATA.pth'
|
| 58 |
-
model.load_state_dict(torch.load(model_path, map_location=device))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
|
| 60 |
-
# # MOE MODEL
|
| 61 |
-
# expert_1 = LCNN(return_emb=True).to(device)
|
| 62 |
-
# expert_2 = LCNN(return_emb=True).to(device)
|
| 63 |
-
# expert_3 = LCNN(return_emb=True).to(device)
|
| 64 |
-
# expert_4 = LCNN(return_emb=True).to(device)
|
| 65 |
-
#
|
| 66 |
# model = UltimateMOE(experts=[expert_1, expert_2, expert_3, expert_4])
|
| 67 |
-
# model = (model).to(device)
|
| 68 |
# model_path = './checkpoints/MOE_ULTIMATE.pth'
|
| 69 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 70 |
|
| 71 |
model.eval()
|
| 72 |
|
|
|
|
| 14 |
|
| 15 |
from src.rawnet_model import RawNet
|
| 16 |
from src.lcnn_model import LCNN
|
| 17 |
+
from src.moe_model import UltimateMOE, MOE_attention
|
| 18 |
|
| 19 |
# os.environ['OMP_NUM_THREADS'] = '1'
|
| 20 |
# os.environ['OPENBLAS_NUM_THREADS'] = '1'
|
|
|
|
| 51 |
# model_path = './checkpoints/RAWNET_ASVSPOOF_FOR_INTHEWILD_PURDUE.pth'
|
| 52 |
# model.load_state_dict(torch.load(model_path, map_location=device))
|
| 53 |
|
| 54 |
+
# # LCNN MODEL
|
| 55 |
+
# model = LCNN(return_emb=False).to(device)
|
| 56 |
+
# # model_path = './checkpoints/LCNN_ASVSPOOF_FOR_INTHEWILD_PURDUE.pth'
|
| 57 |
+
# model_path = './checkpoints/LCNN_ALL_DATA.pth'
|
| 58 |
+
# model.load_state_dict(torch.load(model_path, map_location=device))
|
| 59 |
+
|
| 60 |
+
# MOE MODEL
|
| 61 |
+
expert_1 = LCNN(return_emb=True).to(device)
|
| 62 |
+
expert_2 = LCNN(return_emb=True).to(device)
|
| 63 |
+
expert_3 = LCNN(return_emb=True).to(device)
|
| 64 |
+
expert_4 = LCNN(return_emb=True).to(device)
|
| 65 |
+
expert_5 = LCNN(return_emb=True).to(device)
|
| 66 |
+
expert_6 = LCNN(return_emb=True).to(device)
|
| 67 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 68 |
# model = UltimateMOE(experts=[expert_1, expert_2, expert_3, expert_4])
|
|
|
|
| 69 |
# model_path = './checkpoints/MOE_ULTIMATE.pth'
|
| 70 |
+
|
| 71 |
+
model = MOE_attention(experts=[expert_1, expert_2, expert_3, expert_4, expert_5, expert_6], device=device, freezing=True)
|
| 72 |
+
model_path = './checkpoints/MOE_ATTENTION.pth'
|
| 73 |
+
|
| 74 |
+
model = (model).to(device)
|
| 75 |
+
model.load_state_dict(torch.load(model_path, map_location=device))
|
| 76 |
|
| 77 |
model.eval()
|
| 78 |
|
src/moe_model.py
CHANGED
|
@@ -54,4 +54,100 @@ class UltimateMOE(nn.Module):
|
|
| 54 |
|
| 55 |
score = self.softmax(weighted_logits)
|
| 56 |
|
| 57 |
-
return score
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 54 |
|
| 55 |
score = self.softmax(weighted_logits)
|
| 56 |
|
| 57 |
+
return score
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class MOE_attention(nn.Module):
|
| 62 |
+
def __init__(self, experts, device, input_dim=128, freezing=False):
|
| 63 |
+
super(MOE_attention, self).__init__()
|
| 64 |
+
|
| 65 |
+
self.device = device
|
| 66 |
+
self.experts = nn.ModuleList(experts)
|
| 67 |
+
self.num_experts = len(experts)
|
| 68 |
+
|
| 69 |
+
self.proc_emb_1 = nn.Sequential(
|
| 70 |
+
nn.Linear(128, 128),
|
| 71 |
+
nn.BatchNorm1d(128),
|
| 72 |
+
nn.GLU(),
|
| 73 |
+
nn.Linear(64, 32)
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
self.proc_emb_2 = nn.Sequential(
|
| 77 |
+
nn.Linear(128, 128),
|
| 78 |
+
nn.BatchNorm1d(128),
|
| 79 |
+
nn.GLU(),
|
| 80 |
+
nn.Linear(64, 32)
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
self.proc_emb_3 = nn.Sequential(
|
| 84 |
+
nn.Linear(128, 128),
|
| 85 |
+
nn.BatchNorm1d(128),
|
| 86 |
+
nn.GLU(),
|
| 87 |
+
nn.Linear(64, 32)
|
| 88 |
+
)
|
| 89 |
+
|
| 90 |
+
self.proc_emb_4 = nn.Sequential(
|
| 91 |
+
nn.Linear(128, 128),
|
| 92 |
+
nn.BatchNorm1d(128),
|
| 93 |
+
nn.GLU(),
|
| 94 |
+
nn.Linear(64, 32)
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
self.proc_emb_5 = nn.Sequential(
|
| 98 |
+
nn.Linear(128, 128),
|
| 99 |
+
nn.BatchNorm1d(128),
|
| 100 |
+
nn.GLU(),
|
| 101 |
+
nn.Linear(64, 32)
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
self.proc_emb_6 = nn.Sequential(
|
| 105 |
+
nn.Linear(128, 128),
|
| 106 |
+
nn.BatchNorm1d(128),
|
| 107 |
+
nn.GLU(),
|
| 108 |
+
nn.Linear(64, 32)
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
self.MHead_Attn = nn.MultiheadAttention(
|
| 112 |
+
embed_dim=32, num_heads=4,
|
| 113 |
+
dropout=0.1, batch_first=True
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
self.query_proj = nn.Linear(input_dim * self.num_experts, 32)
|
| 117 |
+
self.softmax = nn.Softmax(dim=1)
|
| 118 |
+
|
| 119 |
+
if freezing:
|
| 120 |
+
for expert in self.experts:
|
| 121 |
+
for param in expert.parameters():
|
| 122 |
+
param.requires_grad = False
|
| 123 |
+
|
| 124 |
+
def forward(self, x):
|
| 125 |
+
|
| 126 |
+
outputs = [expert(x)[0] for expert in self.experts]
|
| 127 |
+
embeddings = [expert(x)[1] for expert in self.experts]
|
| 128 |
+
|
| 129 |
+
proc_emb_1 = self.proc_emb_1(embeddings[0])
|
| 130 |
+
proc_emb_2 = self.proc_emb_2(embeddings[1])
|
| 131 |
+
proc_emb_3 = self.proc_emb_3(embeddings[2])
|
| 132 |
+
proc_emb_4 = self.proc_emb_4(embeddings[3])
|
| 133 |
+
proc_emb_5 = self.proc_emb_5(embeddings[4])
|
| 134 |
+
proc_emb_6 = self.proc_emb_6(embeddings[5])
|
| 135 |
+
|
| 136 |
+
processed_embs = torch.stack([proc_emb_1, proc_emb_2, proc_emb_3, proc_emb_4, proc_emb_5, proc_emb_6], dim=1)
|
| 137 |
+
|
| 138 |
+
query_input = torch.cat(embeddings, dim=1)
|
| 139 |
+
query = self.query_proj(query_input).unsqueeze(1)
|
| 140 |
+
|
| 141 |
+
attn_output, attn_weights = self.MHead_Attn(
|
| 142 |
+
query=query,
|
| 143 |
+
key=processed_embs,
|
| 144 |
+
value=processed_embs
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
gating_weights = attn_weights.mean(dim=1).unsqueeze(1)
|
| 148 |
+
expert_outputs = torch.stack(outputs, dim=1)
|
| 149 |
+
|
| 150 |
+
combined_output = torch.bmm(gating_weights, expert_outputs).squeeze(1)
|
| 151 |
+
score = self.softmax(combined_output)
|
| 152 |
+
|
| 153 |
+
return score
|