davesalvi commited on
Commit
b1fcba0
·
1 Parent(s): afd55da

moe transformer

Browse files
Files changed (3) hide show
  1. .idea/workspace.xml +2 -1
  2. script.py +2 -1
  3. src/moe_model.py +28 -66
.idea/workspace.xml CHANGED
@@ -6,6 +6,7 @@
6
  <component name="ChangeListManager">
7
  <list default="true" id="23565123-73ab-4f40-a9ef-1086e0c9e1ec" name="Changes" comment="">
8
  <change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" />
 
9
  <change beforePath="$PROJECT_DIR$/src/moe_model.py" beforeDir="false" afterPath="$PROJECT_DIR$/src/moe_model.py" afterDir="false" />
10
  </list>
11
  <option name="SHOW_DIALOG" value="false" />
@@ -136,7 +137,7 @@
136
  <workItem from="1743062628099" duration="35000" />
137
  <workItem from="1743063082652" duration="7000" />
138
  <workItem from="1743092790258" duration="1395000" />
139
- <workItem from="1743151940209" duration="1035000" />
140
  </task>
141
  <servers />
142
  </component>
 
6
  <component name="ChangeListManager">
7
  <list default="true" id="23565123-73ab-4f40-a9ef-1086e0c9e1ec" name="Changes" comment="">
8
  <change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" />
9
+ <change beforePath="$PROJECT_DIR$/script.py" beforeDir="false" afterPath="$PROJECT_DIR$/script.py" afterDir="false" />
10
  <change beforePath="$PROJECT_DIR$/src/moe_model.py" beforeDir="false" afterPath="$PROJECT_DIR$/src/moe_model.py" afterDir="false" />
11
  </list>
12
  <option name="SHOW_DIALOG" value="false" />
 
137
  <workItem from="1743062628099" duration="35000" />
138
  <workItem from="1743063082652" duration="7000" />
139
  <workItem from="1743092790258" duration="1395000" />
140
+ <workItem from="1743151940209" duration="1772000" />
141
  </task>
142
  <servers />
143
  </component>
script.py CHANGED
@@ -69,7 +69,8 @@ expert_6 = LCNN(return_emb=True).to(device)
69
  # model_path = './checkpoints/MOE_ULTIMATE.pth'
70
 
71
  model = MOE_attention(experts=[expert_1, expert_2, expert_3, expert_4, expert_5, expert_6], device=device, freezing=True)
72
- model_path = './checkpoints/MOE_ATTENTION.pth'
 
73
 
74
  model = (model).to(device)
75
  model.load_state_dict(torch.load(model_path, map_location=device))
 
69
  # model_path = './checkpoints/MOE_ULTIMATE.pth'
70
 
71
  model = MOE_attention(experts=[expert_1, expert_2, expert_3, expert_4, expert_5, expert_6], device=device, freezing=True)
72
+ # model_path = './checkpoints/MOE_ATTENTION.pth'
73
+ model_path = './checkpoints/MOE_TRANSF.pth'
74
 
75
  model = (model).to(device)
76
  model.load_state_dict(torch.load(model_path, map_location=device))
src/moe_model.py CHANGED
@@ -61,61 +61,34 @@ class UltimateMOE(nn.Module):
61
  class MOE_attention(nn.Module):
62
  def __init__(self, experts, device, input_dim=128, freezing=False):
63
  super(MOE_attention, self).__init__()
64
-
65
  self.threshold = 0.4
66
 
67
  self.device = device
68
  self.experts = nn.ModuleList(experts)
69
  self.num_experts = len(experts)
70
 
71
- self.proc_emb_1 = nn.Sequential(
72
- nn.Linear(128, 128),
73
- nn.BatchNorm1d(128),
74
- nn.GLU(),
75
- nn.Linear(64, 32)
 
 
 
 
 
 
 
76
  )
 
77
 
78
- self.proc_emb_2 = nn.Sequential(
79
- nn.Linear(128, 128),
80
- nn.BatchNorm1d(128),
81
- nn.GLU(),
82
- nn.Linear(64, 32)
83
- )
84
 
85
- self.proc_emb_3 = nn.Sequential(
86
- nn.Linear(128, 128),
87
- nn.BatchNorm1d(128),
88
- nn.GLU(),
89
- nn.Linear(64, 32)
90
- )
91
-
92
- self.proc_emb_4 = nn.Sequential(
93
- nn.Linear(128, 128),
94
- nn.BatchNorm1d(128),
95
- nn.GLU(),
96
- nn.Linear(64, 32)
97
- )
98
-
99
- self.proc_emb_5 = nn.Sequential(
100
- nn.Linear(128, 128),
101
- nn.BatchNorm1d(128),
102
- nn.GLU(),
103
- nn.Linear(64, 32)
104
- )
105
-
106
- self.proc_emb_6 = nn.Sequential(
107
- nn.Linear(128, 128),
108
- nn.BatchNorm1d(128),
109
- nn.GLU(),
110
- nn.Linear(64, 32)
111
- )
112
-
113
- self.MHead_Attn = nn.MultiheadAttention(
114
- embed_dim=32, num_heads=4,
115
- dropout=0.1, batch_first=True
116
- )
117
-
118
- self.query_proj = nn.Linear(input_dim * self.num_experts, 32)
119
  self.softmax = nn.Softmax(dim=1)
120
 
121
  if freezing:
@@ -125,31 +98,20 @@ class MOE_attention(nn.Module):
125
 
126
  def forward(self, x):
127
 
128
- outputs = [expert(x)[0] for expert in self.experts]
129
- embeddings = [expert(x)[1] for expert in self.experts]
 
130
 
131
- proc_emb_1 = self.proc_emb_1(embeddings[0])
132
- proc_emb_2 = self.proc_emb_2(embeddings[1])
133
- proc_emb_3 = self.proc_emb_3(embeddings[2])
134
- proc_emb_4 = self.proc_emb_4(embeddings[3])
135
- proc_emb_5 = self.proc_emb_5(embeddings[4])
136
- proc_emb_6 = self.proc_emb_6(embeddings[5])
137
 
138
- processed_embs = torch.stack([proc_emb_1, proc_emb_2, proc_emb_3, proc_emb_4, proc_emb_5, proc_emb_6], dim=1)
139
 
140
- query_input = torch.cat(embeddings, dim=1)
141
- query = self.query_proj(query_input).unsqueeze(1)
142
 
143
- attn_output, attn_weights = self.MHead_Attn(
144
- query=query,
145
- key=processed_embs,
146
- value=processed_embs
147
- )
148
 
149
- gating_weights = attn_weights.mean(dim=1).unsqueeze(1)
150
  expert_outputs = torch.stack(outputs, dim=1)
151
-
152
  combined_output = torch.bmm(gating_weights, expert_outputs).squeeze(1)
153
- score = self.softmax(combined_output)
154
 
155
- return score
 
61
  class MOE_attention(nn.Module):
62
  def __init__(self, experts, device, input_dim=128, freezing=False):
63
  super(MOE_attention, self).__init__()
64
+
65
  self.threshold = 0.4
66
 
67
  self.device = device
68
  self.experts = nn.ModuleList(experts)
69
  self.num_experts = len(experts)
70
 
71
+ self.proc_emb = nn.ModuleList([
72
+ nn.Sequential(
73
+ nn.Linear(128, 128),
74
+ nn.BatchNorm1d(128),
75
+ nn.GLU(),
76
+ nn.Linear(64, 32)
77
+ ) for _ in range(self.num_experts)
78
+ ])
79
+
80
+ self.TransfEnc = nn.Sequential(
81
+ nn.TransformerEncoderLayer(d_model=32, nhead=4, dropout=0.1, dim_feedforward=512),
82
+ nn.TransformerEncoderLayer(d_model=32, nhead=4, dropout=0.1, dim_feedforward=128)
83
  )
84
+ self.linear_out = nn.Linear(32, 1)
85
 
86
+ # self.MHead_Attn = nn.MultiheadAttention(
87
+ # embed_dim=32, num_heads=4,
88
+ # dropout=0.1, batch_first=True
89
+ # )
 
 
90
 
91
+ # self.query_proj = nn.Linear(input_dim * self.num_experts, 32)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  self.softmax = nn.Softmax(dim=1)
93
 
94
  if freezing:
 
98
 
99
  def forward(self, x):
100
 
101
+ results = [expert(x) for expert in self.experts]
102
+ outputs = [res[0] for res in results]
103
+ embeddings = [res[1] for res in results]
104
 
105
+ processed_embs = torch.stack([proc_emb(emb) for proc_emb, emb in zip(self.proc_emb, embeddings)], dim=1)
 
 
 
 
 
106
 
107
+ # pdb.set_trace()
108
 
109
+ transf_out = self.TransfEnc(processed_embs)
 
110
 
111
+ gating_weights = self.linear_out(transf_out).squeeze(2)
112
+ gating_weights = self.softmax(gating_weights).unsqueeze(1)
 
 
 
113
 
 
114
  expert_outputs = torch.stack(outputs, dim=1)
 
115
  combined_output = torch.bmm(gating_weights, expert_outputs).squeeze(1)
 
116
 
117
+ return combined_output