moe transformer
Browse files- .idea/workspace.xml +2 -1
- script.py +2 -1
- src/moe_model.py +28 -66
.idea/workspace.xml
CHANGED
|
@@ -6,6 +6,7 @@
|
|
| 6 |
<component name="ChangeListManager">
|
| 7 |
<list default="true" id="23565123-73ab-4f40-a9ef-1086e0c9e1ec" name="Changes" comment="">
|
| 8 |
<change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" />
|
|
|
|
| 9 |
<change beforePath="$PROJECT_DIR$/src/moe_model.py" beforeDir="false" afterPath="$PROJECT_DIR$/src/moe_model.py" afterDir="false" />
|
| 10 |
</list>
|
| 11 |
<option name="SHOW_DIALOG" value="false" />
|
|
@@ -136,7 +137,7 @@
|
|
| 136 |
<workItem from="1743062628099" duration="35000" />
|
| 137 |
<workItem from="1743063082652" duration="7000" />
|
| 138 |
<workItem from="1743092790258" duration="1395000" />
|
| 139 |
-
<workItem from="1743151940209" duration="
|
| 140 |
</task>
|
| 141 |
<servers />
|
| 142 |
</component>
|
|
|
|
| 6 |
<component name="ChangeListManager">
|
| 7 |
<list default="true" id="23565123-73ab-4f40-a9ef-1086e0c9e1ec" name="Changes" comment="">
|
| 8 |
<change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" />
|
| 9 |
+
<change beforePath="$PROJECT_DIR$/script.py" beforeDir="false" afterPath="$PROJECT_DIR$/script.py" afterDir="false" />
|
| 10 |
<change beforePath="$PROJECT_DIR$/src/moe_model.py" beforeDir="false" afterPath="$PROJECT_DIR$/src/moe_model.py" afterDir="false" />
|
| 11 |
</list>
|
| 12 |
<option name="SHOW_DIALOG" value="false" />
|
|
|
|
| 137 |
<workItem from="1743062628099" duration="35000" />
|
| 138 |
<workItem from="1743063082652" duration="7000" />
|
| 139 |
<workItem from="1743092790258" duration="1395000" />
|
| 140 |
+
<workItem from="1743151940209" duration="1772000" />
|
| 141 |
</task>
|
| 142 |
<servers />
|
| 143 |
</component>
|
script.py
CHANGED
|
@@ -69,7 +69,8 @@ expert_6 = LCNN(return_emb=True).to(device)
|
|
| 69 |
# model_path = './checkpoints/MOE_ULTIMATE.pth'
|
| 70 |
|
| 71 |
model = MOE_attention(experts=[expert_1, expert_2, expert_3, expert_4, expert_5, expert_6], device=device, freezing=True)
|
| 72 |
-
model_path = './checkpoints/MOE_ATTENTION.pth'
|
|
|
|
| 73 |
|
| 74 |
model = (model).to(device)
|
| 75 |
model.load_state_dict(torch.load(model_path, map_location=device))
|
|
|
|
| 69 |
# model_path = './checkpoints/MOE_ULTIMATE.pth'
|
| 70 |
|
| 71 |
model = MOE_attention(experts=[expert_1, expert_2, expert_3, expert_4, expert_5, expert_6], device=device, freezing=True)
|
| 72 |
+
# model_path = './checkpoints/MOE_ATTENTION.pth'
|
| 73 |
+
model_path = './checkpoints/MOE_TRANSF.pth'
|
| 74 |
|
| 75 |
model = (model).to(device)
|
| 76 |
model.load_state_dict(torch.load(model_path, map_location=device))
|
src/moe_model.py
CHANGED
|
@@ -61,61 +61,34 @@ class UltimateMOE(nn.Module):
|
|
| 61 |
class MOE_attention(nn.Module):
|
| 62 |
def __init__(self, experts, device, input_dim=128, freezing=False):
|
| 63 |
super(MOE_attention, self).__init__()
|
| 64 |
-
|
| 65 |
self.threshold = 0.4
|
| 66 |
|
| 67 |
self.device = device
|
| 68 |
self.experts = nn.ModuleList(experts)
|
| 69 |
self.num_experts = len(experts)
|
| 70 |
|
| 71 |
-
self.
|
| 72 |
-
nn.
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
)
|
|
|
|
| 77 |
|
| 78 |
-
self.
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
nn.Linear(64, 32)
|
| 83 |
-
)
|
| 84 |
|
| 85 |
-
self.
|
| 86 |
-
nn.Linear(128, 128),
|
| 87 |
-
nn.BatchNorm1d(128),
|
| 88 |
-
nn.GLU(),
|
| 89 |
-
nn.Linear(64, 32)
|
| 90 |
-
)
|
| 91 |
-
|
| 92 |
-
self.proc_emb_4 = nn.Sequential(
|
| 93 |
-
nn.Linear(128, 128),
|
| 94 |
-
nn.BatchNorm1d(128),
|
| 95 |
-
nn.GLU(),
|
| 96 |
-
nn.Linear(64, 32)
|
| 97 |
-
)
|
| 98 |
-
|
| 99 |
-
self.proc_emb_5 = nn.Sequential(
|
| 100 |
-
nn.Linear(128, 128),
|
| 101 |
-
nn.BatchNorm1d(128),
|
| 102 |
-
nn.GLU(),
|
| 103 |
-
nn.Linear(64, 32)
|
| 104 |
-
)
|
| 105 |
-
|
| 106 |
-
self.proc_emb_6 = nn.Sequential(
|
| 107 |
-
nn.Linear(128, 128),
|
| 108 |
-
nn.BatchNorm1d(128),
|
| 109 |
-
nn.GLU(),
|
| 110 |
-
nn.Linear(64, 32)
|
| 111 |
-
)
|
| 112 |
-
|
| 113 |
-
self.MHead_Attn = nn.MultiheadAttention(
|
| 114 |
-
embed_dim=32, num_heads=4,
|
| 115 |
-
dropout=0.1, batch_first=True
|
| 116 |
-
)
|
| 117 |
-
|
| 118 |
-
self.query_proj = nn.Linear(input_dim * self.num_experts, 32)
|
| 119 |
self.softmax = nn.Softmax(dim=1)
|
| 120 |
|
| 121 |
if freezing:
|
|
@@ -125,31 +98,20 @@ class MOE_attention(nn.Module):
|
|
| 125 |
|
| 126 |
def forward(self, x):
|
| 127 |
|
| 128 |
-
|
| 129 |
-
|
|
|
|
| 130 |
|
| 131 |
-
|
| 132 |
-
proc_emb_2 = self.proc_emb_2(embeddings[1])
|
| 133 |
-
proc_emb_3 = self.proc_emb_3(embeddings[2])
|
| 134 |
-
proc_emb_4 = self.proc_emb_4(embeddings[3])
|
| 135 |
-
proc_emb_5 = self.proc_emb_5(embeddings[4])
|
| 136 |
-
proc_emb_6 = self.proc_emb_6(embeddings[5])
|
| 137 |
|
| 138 |
-
|
| 139 |
|
| 140 |
-
|
| 141 |
-
query = self.query_proj(query_input).unsqueeze(1)
|
| 142 |
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
key=processed_embs,
|
| 146 |
-
value=processed_embs
|
| 147 |
-
)
|
| 148 |
|
| 149 |
-
gating_weights = attn_weights.mean(dim=1).unsqueeze(1)
|
| 150 |
expert_outputs = torch.stack(outputs, dim=1)
|
| 151 |
-
|
| 152 |
combined_output = torch.bmm(gating_weights, expert_outputs).squeeze(1)
|
| 153 |
-
score = self.softmax(combined_output)
|
| 154 |
|
| 155 |
-
return
|
|
|
|
| 61 |
class MOE_attention(nn.Module):
|
| 62 |
def __init__(self, experts, device, input_dim=128, freezing=False):
|
| 63 |
super(MOE_attention, self).__init__()
|
| 64 |
+
|
| 65 |
self.threshold = 0.4
|
| 66 |
|
| 67 |
self.device = device
|
| 68 |
self.experts = nn.ModuleList(experts)
|
| 69 |
self.num_experts = len(experts)
|
| 70 |
|
| 71 |
+
self.proc_emb = nn.ModuleList([
|
| 72 |
+
nn.Sequential(
|
| 73 |
+
nn.Linear(128, 128),
|
| 74 |
+
nn.BatchNorm1d(128),
|
| 75 |
+
nn.GLU(),
|
| 76 |
+
nn.Linear(64, 32)
|
| 77 |
+
) for _ in range(self.num_experts)
|
| 78 |
+
])
|
| 79 |
+
|
| 80 |
+
self.TransfEnc = nn.Sequential(
|
| 81 |
+
nn.TransformerEncoderLayer(d_model=32, nhead=4, dropout=0.1, dim_feedforward=512),
|
| 82 |
+
nn.TransformerEncoderLayer(d_model=32, nhead=4, dropout=0.1, dim_feedforward=128)
|
| 83 |
)
|
| 84 |
+
self.linear_out = nn.Linear(32, 1)
|
| 85 |
|
| 86 |
+
# self.MHead_Attn = nn.MultiheadAttention(
|
| 87 |
+
# embed_dim=32, num_heads=4,
|
| 88 |
+
# dropout=0.1, batch_first=True
|
| 89 |
+
# )
|
|
|
|
|
|
|
| 90 |
|
| 91 |
+
# self.query_proj = nn.Linear(input_dim * self.num_experts, 32)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 92 |
self.softmax = nn.Softmax(dim=1)
|
| 93 |
|
| 94 |
if freezing:
|
|
|
|
| 98 |
|
| 99 |
def forward(self, x):
|
| 100 |
|
| 101 |
+
results = [expert(x) for expert in self.experts]
|
| 102 |
+
outputs = [res[0] for res in results]
|
| 103 |
+
embeddings = [res[1] for res in results]
|
| 104 |
|
| 105 |
+
processed_embs = torch.stack([proc_emb(emb) for proc_emb, emb in zip(self.proc_emb, embeddings)], dim=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 106 |
|
| 107 |
+
# pdb.set_trace()
|
| 108 |
|
| 109 |
+
transf_out = self.TransfEnc(processed_embs)
|
|
|
|
| 110 |
|
| 111 |
+
gating_weights = self.linear_out(transf_out).squeeze(2)
|
| 112 |
+
gating_weights = self.softmax(gating_weights).unsqueeze(1)
|
|
|
|
|
|
|
|
|
|
| 113 |
|
|
|
|
| 114 |
expert_outputs = torch.stack(outputs, dim=1)
|
|
|
|
| 115 |
combined_output = torch.bmm(gating_weights, expert_outputs).squeeze(1)
|
|
|
|
| 116 |
|
| 117 |
+
return combined_output
|