davesalvi commited on
Commit
d1fc27c
·
1 Parent(s): bdcf32d

new moe model

Browse files
Files changed (2) hide show
  1. script.py +13 -6
  2. src/moe_model.py +24 -3
script.py CHANGED
@@ -56,8 +56,8 @@ print('Define Model')
56
  # model = ResNet_LogSpec(sample_rate=24000, return_emb=False).to(device)
57
  # model_path = './checkpoints/RESNET_LOGSPEC_ALL_DATA_FS_24000.pth'
58
 
59
- model = ResNet_MelSpec(sample_rate=24000, return_emb=False).to(device)
60
- model_path = './checkpoints/RESNET_MELSPEC_ALL_DATA_FS_24000.pth'
61
 
62
  ## LCNN MODEL
63
  # model = LCNN(return_emb=False, fs=24000).to(device)
@@ -72,9 +72,16 @@ model_path = './checkpoints/RESNET_MELSPEC_ALL_DATA_FS_24000.pth'
72
  # model_path = './checkpoints/LCNN_ALL_DATA_FS_22050.pth'
73
  # model_path = './checkpoints/LCNN_ALL_DATA_FS_24000.pth'
74
 
75
- model.load_state_dict(torch.load(model_path, map_location=device))
76
 
77
  # # MOE MODEL
 
 
 
 
 
 
 
78
  # expert_1 = LCNN(return_emb=True, fs=16000).to(device)
79
  # expert_2 = LCNN(return_emb=True, fs=22050).to(device)
80
  # expert_3 = LCNN(return_emb=True, fs=24000).to(device)
@@ -100,9 +107,9 @@ model.load_state_dict(torch.load(model_path, map_location=device))
100
  # # model_path = './checkpoints/MOE_TRANSF_7EXP_AUG_NO_FREEZE.pth'
101
  # # model_path = './checkpoints/MOE_TRANSF_8EXP_AUG.pth'
102
  # model_path = './checkpoints/MOE_TRANSF_8EXP_AUG_NO_FREEZE.pth'
103
- #
104
- # model = (model).to(device)
105
- # model.load_state_dict(torch.load(model_path, map_location=device))
106
 
107
  model.eval()
108
 
 
56
  # model = ResNet_LogSpec(sample_rate=24000, return_emb=False).to(device)
57
  # model_path = './checkpoints/RESNET_LOGSPEC_ALL_DATA_FS_24000.pth'
58
 
59
+ # model = ResNet_MelSpec(sample_rate=24000, return_emb=False).to(device)
60
+ # model_path = './checkpoints/RESNET_MELSPEC_ALL_DATA_FS_24000.pth'
61
 
62
  ## LCNN MODEL
63
  # model = LCNN(return_emb=False, fs=24000).to(device)
 
72
  # model_path = './checkpoints/LCNN_ALL_DATA_FS_22050.pth'
73
  # model_path = './checkpoints/LCNN_ALL_DATA_FS_24000.pth'
74
 
75
+ # model.load_state_dict(torch.load(model_path, map_location=device))
76
 
77
  # # MOE MODEL
78
+ expert_1 = LCNN(return_emb=True, fs=24000)
79
+ expert_2 = ResNet_LogSpec(return_emb=True, sample_rate=24000)
80
+ expert_3 = ResNet_MelSpec(return_emb=True, sample_rate=24000)
81
+
82
+ model = MOE_attention(experts=[expert_1, expert_2, expert_3], device=device)
83
+ model_path = './checkpoints/MOE_TRANSF_3EXP_MODELS_AUG.pth'
84
+
85
  # expert_1 = LCNN(return_emb=True, fs=16000).to(device)
86
  # expert_2 = LCNN(return_emb=True, fs=22050).to(device)
87
  # expert_3 = LCNN(return_emb=True, fs=24000).to(device)
 
107
  # # model_path = './checkpoints/MOE_TRANSF_7EXP_AUG_NO_FREEZE.pth'
108
  # # model_path = './checkpoints/MOE_TRANSF_8EXP_AUG.pth'
109
  # model_path = './checkpoints/MOE_TRANSF_8EXP_AUG_NO_FREEZE.pth'
110
+
111
+ model = (model).to(device)
112
+ model.load_state_dict(torch.load(model_path, map_location=device))
113
 
114
  model.eval()
115
 
src/moe_model.py CHANGED
@@ -62,20 +62,41 @@ class MOE_attention(nn.Module):
62
  def __init__(self, experts, device, input_dim=128, freezing=False):
63
  super(MOE_attention, self).__init__()
64
 
65
- self.threshold = 0.25
66
  self.temperature = 1.2
67
 
68
  self.device = device
69
  self.experts = nn.ModuleList(experts)
70
  self.num_experts = len(experts)
71
 
 
 
 
 
 
 
 
 
 
72
  self.proc_emb = nn.ModuleList([
73
  nn.Sequential(
74
- nn.Linear(input_dim, 128),
75
  nn.BatchNorm1d(128),
76
  nn.GLU(),
77
  nn.Linear(64, 32)
78
- ) for _ in range(self.num_experts)
 
 
 
 
 
 
 
 
 
 
 
 
79
  ])
80
 
81
  self.TransfEnc = nn.Sequential(
 
62
  def __init__(self, experts, device, input_dim=128, freezing=False):
63
  super(MOE_attention, self).__init__()
64
 
65
+ self.threshold = 0.5
66
  self.temperature = 1.2
67
 
68
  self.device = device
69
  self.experts = nn.ModuleList(experts)
70
  self.num_experts = len(experts)
71
 
72
+ # self.proc_emb = nn.ModuleList([
73
+ # nn.Sequential(
74
+ # nn.Linear(input_dim, 128),
75
+ # nn.BatchNorm1d(128),
76
+ # nn.GLU(),
77
+ # nn.Linear(64, 32)
78
+ # ) for _ in range(self.num_experts)
79
+ # ])
80
+
81
  self.proc_emb = nn.ModuleList([
82
  nn.Sequential(
83
+ nn.Linear(128, 128),
84
  nn.BatchNorm1d(128),
85
  nn.GLU(),
86
  nn.Linear(64, 32)
87
+ ),
88
+ nn.Sequential(
89
+ nn.Linear(256, 128),
90
+ nn.BatchNorm1d(128),
91
+ nn.GLU(),
92
+ nn.Linear(64, 32)
93
+ ),
94
+ nn.Sequential(
95
+ nn.Linear(256, 128),
96
+ nn.BatchNorm1d(128),
97
+ nn.GLU(),
98
+ nn.Linear(64, 32)
99
+ )
100
  ])
101
 
102
  self.TransfEnc = nn.Sequential(