MOE debug
Browse files- .idea/workspace.xml +3 -2
- src/moe_model.py +6 -1
.idea/workspace.xml
CHANGED
|
@@ -5,7 +5,8 @@
|
|
| 5 |
</component>
|
| 6 |
<component name="ChangeListManager">
|
| 7 |
<list default="true" id="23565123-73ab-4f40-a9ef-1086e0c9e1ec" name="Changes" comment="">
|
| 8 |
-
<change beforePath="$PROJECT_DIR
|
|
|
|
| 9 |
</list>
|
| 10 |
<option name="SHOW_DIALOG" value="false" />
|
| 11 |
<option name="HIGHLIGHT_CONFLICTS" value="true" />
|
|
@@ -131,7 +132,7 @@
|
|
| 131 |
<workItem from="1742889011586" duration="41000" />
|
| 132 |
<workItem from="1742890414523" duration="4382000" />
|
| 133 |
<workItem from="1742941362366" duration="1283000" />
|
| 134 |
-
<workItem from="1742970904084" duration="
|
| 135 |
</task>
|
| 136 |
<servers />
|
| 137 |
</component>
|
|
|
|
| 5 |
</component>
|
| 6 |
<component name="ChangeListManager">
|
| 7 |
<list default="true" id="23565123-73ab-4f40-a9ef-1086e0c9e1ec" name="Changes" comment="">
|
| 8 |
+
<change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" />
|
| 9 |
+
<change beforePath="$PROJECT_DIR$/src/moe_model.py" beforeDir="false" afterPath="$PROJECT_DIR$/src/moe_model.py" afterDir="false" />
|
| 10 |
</list>
|
| 11 |
<option name="SHOW_DIALOG" value="false" />
|
| 12 |
<option name="HIGHLIGHT_CONFLICTS" value="true" />
|
|
|
|
| 132 |
<workItem from="1742889011586" duration="41000" />
|
| 133 |
<workItem from="1742890414523" duration="4382000" />
|
| 134 |
<workItem from="1742941362366" duration="1283000" />
|
| 135 |
+
<workItem from="1742970904084" duration="2326000" />
|
| 136 |
</task>
|
| 137 |
<servers />
|
| 138 |
</component>
|
src/moe_model.py
CHANGED
|
@@ -7,6 +7,8 @@ class UltimateMOE(nn.Module):
|
|
| 7 |
def __init__(self, experts):
|
| 8 |
super(UltimateMOE, self).__init__()
|
| 9 |
|
|
|
|
|
|
|
| 10 |
self.experts = nn.ModuleList(experts)
|
| 11 |
num_experts = len(experts)
|
| 12 |
|
|
@@ -28,6 +30,7 @@ class UltimateMOE(nn.Module):
|
|
| 28 |
nn.Linear(64, num_experts),
|
| 29 |
)
|
| 30 |
|
|
|
|
| 31 |
|
| 32 |
def forward(self, x):
|
| 33 |
|
|
@@ -49,4 +52,6 @@ class UltimateMOE(nn.Module):
|
|
| 49 |
weighted_logits = torch.stack(outputs, dim=-1)
|
| 50 |
weighted_logits = torch.einsum('bn,bcn->bc', gating_weights, weighted_logits)
|
| 51 |
|
| 52 |
-
|
|
|
|
|
|
|
|
|
| 7 |
def __init__(self, experts):
|
| 8 |
super(UltimateMOE, self).__init__()
|
| 9 |
|
| 10 |
+
self.threshold = 0.3
|
| 11 |
+
|
| 12 |
self.experts = nn.ModuleList(experts)
|
| 13 |
num_experts = len(experts)
|
| 14 |
|
|
|
|
| 30 |
nn.Linear(64, num_experts),
|
| 31 |
)
|
| 32 |
|
| 33 |
+
self.softmax = nn.Softmax(dim=1)
|
| 34 |
|
| 35 |
def forward(self, x):
|
| 36 |
|
|
|
|
| 52 |
weighted_logits = torch.stack(outputs, dim=-1)
|
| 53 |
weighted_logits = torch.einsum('bn,bcn->bc', gating_weights, weighted_logits)
|
| 54 |
|
| 55 |
+
score = self.softmax(weighted_logits)
|
| 56 |
+
|
| 57 |
+
return score
|