davesalvi commited on
Commit
1e0c4e0
·
1 Parent(s): 9d78131

attention moe

Browse files
Files changed (3) hide show
  1. .idea/workspace.xml +3 -2
  2. script.py +20 -14
  3. src/moe_model.py +97 -1
.idea/workspace.xml CHANGED
@@ -6,7 +6,8 @@
6
  <component name="ChangeListManager">
7
  <list default="true" id="23565123-73ab-4f40-a9ef-1086e0c9e1ec" name="Changes" comment="">
8
  <change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" />
9
- <change beforePath="$PROJECT_DIR$/src/lcnn_model.py" beforeDir="false" afterPath="$PROJECT_DIR$/src/lcnn_model.py" afterDir="false" />
 
10
  </list>
11
  <option name="SHOW_DIALOG" value="false" />
12
  <option name="HIGHLIGHT_CONFLICTS" value="true" />
@@ -136,7 +137,7 @@
136
  <workItem from="1743062628099" duration="35000" />
137
  <workItem from="1743063082652" duration="7000" />
138
  <workItem from="1743092790258" duration="1395000" />
139
- <workItem from="1743151940209" duration="31000" />
140
  </task>
141
  <servers />
142
  </component>
 
6
  <component name="ChangeListManager">
7
  <list default="true" id="23565123-73ab-4f40-a9ef-1086e0c9e1ec" name="Changes" comment="">
8
  <change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" />
9
+ <change beforePath="$PROJECT_DIR$/script.py" beforeDir="false" afterPath="$PROJECT_DIR$/script.py" afterDir="false" />
10
+ <change beforePath="$PROJECT_DIR$/src/moe_model.py" beforeDir="false" afterPath="$PROJECT_DIR$/src/moe_model.py" afterDir="false" />
11
  </list>
12
  <option name="SHOW_DIALOG" value="false" />
13
  <option name="HIGHLIGHT_CONFLICTS" value="true" />
 
137
  <workItem from="1743062628099" duration="35000" />
138
  <workItem from="1743063082652" duration="7000" />
139
  <workItem from="1743092790258" duration="1395000" />
140
+ <workItem from="1743151940209" duration="996000" />
141
  </task>
142
  <servers />
143
  </component>
script.py CHANGED
@@ -14,7 +14,7 @@ from preprocess import preprocess
14
 
15
  from src.rawnet_model import RawNet
16
  from src.lcnn_model import LCNN
17
- from src.moe_model import UltimateMOE
18
 
19
  # os.environ['OMP_NUM_THREADS'] = '1'
20
  # os.environ['OPENBLAS_NUM_THREADS'] = '1'
@@ -51,22 +51,28 @@ print('Define Model')
51
  # model_path = './checkpoints/RAWNET_ASVSPOOF_FOR_INTHEWILD_PURDUE.pth'
52
  # model.load_state_dict(torch.load(model_path, map_location=device))
53
 
54
- # LCNN MODEL
55
- model = LCNN(return_emb=False).to(device)
56
- # model_path = './checkpoints/LCNN_ASVSPOOF_FOR_INTHEWILD_PURDUE.pth'
57
- model_path = './checkpoints/LCNN_ALL_DATA.pth'
58
- model.load_state_dict(torch.load(model_path, map_location=device))
 
 
 
 
 
 
 
 
59
 
60
- # # MOE MODEL
61
- # expert_1 = LCNN(return_emb=True).to(device)
62
- # expert_2 = LCNN(return_emb=True).to(device)
63
- # expert_3 = LCNN(return_emb=True).to(device)
64
- # expert_4 = LCNN(return_emb=True).to(device)
65
- #
66
  # model = UltimateMOE(experts=[expert_1, expert_2, expert_3, expert_4])
67
- # model = (model).to(device)
68
  # model_path = './checkpoints/MOE_ULTIMATE.pth'
69
- # model.load_state_dict(torch.load(model_path, map_location=device))
 
 
 
 
 
70
 
71
  model.eval()
72
 
 
14
 
15
  from src.rawnet_model import RawNet
16
  from src.lcnn_model import LCNN
17
+ from src.moe_model import UltimateMOE, MOE_attention
18
 
19
  # os.environ['OMP_NUM_THREADS'] = '1'
20
  # os.environ['OPENBLAS_NUM_THREADS'] = '1'
 
51
  # model_path = './checkpoints/RAWNET_ASVSPOOF_FOR_INTHEWILD_PURDUE.pth'
52
  # model.load_state_dict(torch.load(model_path, map_location=device))
53
 
54
+ # # LCNN MODEL
55
+ # model = LCNN(return_emb=False).to(device)
56
+ # # model_path = './checkpoints/LCNN_ASVSPOOF_FOR_INTHEWILD_PURDUE.pth'
57
+ # model_path = './checkpoints/LCNN_ALL_DATA.pth'
58
+ # model.load_state_dict(torch.load(model_path, map_location=device))
59
+
60
+ # MOE MODEL
61
+ expert_1 = LCNN(return_emb=True).to(device)
62
+ expert_2 = LCNN(return_emb=True).to(device)
63
+ expert_3 = LCNN(return_emb=True).to(device)
64
+ expert_4 = LCNN(return_emb=True).to(device)
65
+ expert_5 = LCNN(return_emb=True).to(device)
66
+ expert_6 = LCNN(return_emb=True).to(device)
67
 
 
 
 
 
 
 
68
  # model = UltimateMOE(experts=[expert_1, expert_2, expert_3, expert_4])
 
69
  # model_path = './checkpoints/MOE_ULTIMATE.pth'
70
+
71
+ model = MOE_attention(experts=[expert_1, expert_2, expert_3, expert_4, expert_5, expert_6], device=device, freezing=True)
72
+ model_path = './checkpoints/MOE_ATTENTION.pth'
73
+
74
+ model = (model).to(device)
75
+ model.load_state_dict(torch.load(model_path, map_location=device))
76
 
77
  model.eval()
78
 
src/moe_model.py CHANGED
@@ -54,4 +54,100 @@ class UltimateMOE(nn.Module):
54
 
55
  score = self.softmax(weighted_logits)
56
 
57
- return score
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
54
 
55
  score = self.softmax(weighted_logits)
56
 
57
+ return score
58
+
59
+
60
+
61
+ class MOE_attention(nn.Module):
62
+ def __init__(self, experts, device, input_dim=128, freezing=False):
63
+ super(MOE_attention, self).__init__()
64
+
65
+ self.device = device
66
+ self.experts = nn.ModuleList(experts)
67
+ self.num_experts = len(experts)
68
+
69
+ self.proc_emb_1 = nn.Sequential(
70
+ nn.Linear(128, 128),
71
+ nn.BatchNorm1d(128),
72
+ nn.GLU(),
73
+ nn.Linear(64, 32)
74
+ )
75
+
76
+ self.proc_emb_2 = nn.Sequential(
77
+ nn.Linear(128, 128),
78
+ nn.BatchNorm1d(128),
79
+ nn.GLU(),
80
+ nn.Linear(64, 32)
81
+ )
82
+
83
+ self.proc_emb_3 = nn.Sequential(
84
+ nn.Linear(128, 128),
85
+ nn.BatchNorm1d(128),
86
+ nn.GLU(),
87
+ nn.Linear(64, 32)
88
+ )
89
+
90
+ self.proc_emb_4 = nn.Sequential(
91
+ nn.Linear(128, 128),
92
+ nn.BatchNorm1d(128),
93
+ nn.GLU(),
94
+ nn.Linear(64, 32)
95
+ )
96
+
97
+ self.proc_emb_5 = nn.Sequential(
98
+ nn.Linear(128, 128),
99
+ nn.BatchNorm1d(128),
100
+ nn.GLU(),
101
+ nn.Linear(64, 32)
102
+ )
103
+
104
+ self.proc_emb_6 = nn.Sequential(
105
+ nn.Linear(128, 128),
106
+ nn.BatchNorm1d(128),
107
+ nn.GLU(),
108
+ nn.Linear(64, 32)
109
+ )
110
+
111
+ self.MHead_Attn = nn.MultiheadAttention(
112
+ embed_dim=32, num_heads=4,
113
+ dropout=0.1, batch_first=True
114
+ )
115
+
116
+ self.query_proj = nn.Linear(input_dim * self.num_experts, 32)
117
+ self.softmax = nn.Softmax(dim=1)
118
+
119
+ if freezing:
120
+ for expert in self.experts:
121
+ for param in expert.parameters():
122
+ param.requires_grad = False
123
+
124
+ def forward(self, x):
125
+
126
+ outputs = [expert(x)[0] for expert in self.experts]
127
+ embeddings = [expert(x)[1] for expert in self.experts]
128
+
129
+ proc_emb_1 = self.proc_emb_1(embeddings[0])
130
+ proc_emb_2 = self.proc_emb_2(embeddings[1])
131
+ proc_emb_3 = self.proc_emb_3(embeddings[2])
132
+ proc_emb_4 = self.proc_emb_4(embeddings[3])
133
+ proc_emb_5 = self.proc_emb_5(embeddings[4])
134
+ proc_emb_6 = self.proc_emb_6(embeddings[5])
135
+
136
+ processed_embs = torch.stack([proc_emb_1, proc_emb_2, proc_emb_3, proc_emb_4, proc_emb_5, proc_emb_6], dim=1)
137
+
138
+ query_input = torch.cat(embeddings, dim=1)
139
+ query = self.query_proj(query_input).unsqueeze(1)
140
+
141
+ attn_output, attn_weights = self.MHead_Attn(
142
+ query=query,
143
+ key=processed_embs,
144
+ value=processed_embs
145
+ )
146
+
147
+ gating_weights = attn_weights.mean(dim=1).unsqueeze(1)
148
+ expert_outputs = torch.stack(outputs, dim=1)
149
+
150
+ combined_output = torch.bmm(gating_weights, expert_outputs).squeeze(1)
151
+ score = self.softmax(combined_output)
152
+
153
+ return score