davesalvi commited on
Commit
7bd22e4
·
1 Parent(s): 13257bc

MOE debug

Browse files
Files changed (2) hide show
  1. .idea/workspace.xml +3 -2
  2. src/moe_model.py +6 -1
.idea/workspace.xml CHANGED
@@ -5,7 +5,8 @@
5
  </component>
6
  <component name="ChangeListManager">
7
  <list default="true" id="23565123-73ab-4f40-a9ef-1086e0c9e1ec" name="Changes" comment="">
8
- <change beforePath="$PROJECT_DIR$/src/lcnn_model.py" beforeDir="false" afterPath="$PROJECT_DIR$/src/lcnn_model.py" afterDir="false" />
 
9
  </list>
10
  <option name="SHOW_DIALOG" value="false" />
11
  <option name="HIGHLIGHT_CONFLICTS" value="true" />
@@ -131,7 +132,7 @@
131
  <workItem from="1742889011586" duration="41000" />
132
  <workItem from="1742890414523" duration="4382000" />
133
  <workItem from="1742941362366" duration="1283000" />
134
- <workItem from="1742970904084" duration="1716000" />
135
  </task>
136
  <servers />
137
  </component>
 
5
  </component>
6
  <component name="ChangeListManager">
7
  <list default="true" id="23565123-73ab-4f40-a9ef-1086e0c9e1ec" name="Changes" comment="">
8
+ <change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" />
9
+ <change beforePath="$PROJECT_DIR$/src/moe_model.py" beforeDir="false" afterPath="$PROJECT_DIR$/src/moe_model.py" afterDir="false" />
10
  </list>
11
  <option name="SHOW_DIALOG" value="false" />
12
  <option name="HIGHLIGHT_CONFLICTS" value="true" />
 
132
  <workItem from="1742889011586" duration="41000" />
133
  <workItem from="1742890414523" duration="4382000" />
134
  <workItem from="1742941362366" duration="1283000" />
135
+ <workItem from="1742970904084" duration="2326000" />
136
  </task>
137
  <servers />
138
  </component>
src/moe_model.py CHANGED
@@ -7,6 +7,8 @@ class UltimateMOE(nn.Module):
7
  def __init__(self, experts):
8
  super(UltimateMOE, self).__init__()
9
 
 
 
10
  self.experts = nn.ModuleList(experts)
11
  num_experts = len(experts)
12
 
@@ -28,6 +30,7 @@ class UltimateMOE(nn.Module):
28
  nn.Linear(64, num_experts),
29
  )
30
 
 
31
 
32
  def forward(self, x):
33
 
@@ -49,4 +52,6 @@ class UltimateMOE(nn.Module):
49
  weighted_logits = torch.stack(outputs, dim=-1)
50
  weighted_logits = torch.einsum('bn,bcn->bc', gating_weights, weighted_logits)
51
 
52
- return weighted_logits, gating_weights
 
 
 
7
  def __init__(self, experts):
8
  super(UltimateMOE, self).__init__()
9
 
10
+ self.threshold = 0.3
11
+
12
  self.experts = nn.ModuleList(experts)
13
  num_experts = len(experts)
14
 
 
30
  nn.Linear(64, num_experts),
31
  )
32
 
33
+ self.softmax = nn.Softmax(dim=1)
34
 
35
  def forward(self, x):
36
 
 
52
  weighted_logits = torch.stack(outputs, dim=-1)
53
  weighted_logits = torch.einsum('bn,bcn->bc', gating_weights, weighted_logits)
54
 
55
+ score = self.softmax(weighted_logits)
56
+
57
+ return score