NLTM-NITG commited on
Commit
a8b0e92
·
verified ·
1 Parent(s): 19a4bf3

Update HuggingFace/model.py

Browse files

Sending wave2vec2 path when loading weights

Files changed (1) hide show
  1. HuggingFace/model.py +239 -236
HuggingFace/model.py CHANGED
@@ -1,237 +1,240 @@
1
- import torch
2
- import torch.nn as nn
3
- import numpy as np
4
- import torchaudio
5
- import soundfile as sf
6
- from torch import Tensor
7
-
8
- # Define your device
9
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
-
11
- # Define constants based on the loaded checkpoint
12
- e_dim = 512 # Update with the correct embedding dimension based on your model
13
- n_classes = 2 # Number of language classes, based on your requirement
14
- look_back1 = 30
15
- look_back2 = 60
16
- lan2id = {'MA': 0, 'PU': 1}
17
-
18
- # Function to preprocess input data
19
- def Get_data(X):
20
- if isinstance(X, torch.Tensor):
21
- X = X.cpu().numpy()
22
-
23
- mu = X.mean(axis=0)
24
- std = X.std(axis=0)
25
- np.place(std, std == 0, 1)
26
- X = (X - mu) / std
27
-
28
- Xdata1 = []
29
- Xdata2 = []
30
- for i in range(0, len(X)-look_back1, 1):
31
- a = X[i:(i+look_back1), :]
32
- Xdata1.append(a)
33
- Xdata1 = np.array(Xdata1)
34
-
35
- for i in range(0, len(X)-look_back2, 2):
36
- b = X[i+1:(i+look_back2):3, :]
37
- Xdata2.append(b)
38
- Xdata2 = np.array(Xdata2)
39
-
40
- return Xdata1, Xdata2
41
-
42
-
43
-
44
- class LSTMNet(nn.Module):
45
- def __init__(self):
46
- super(LSTMNet, self).__init__()
47
- self.lstm1 = nn.LSTM(1024, 512, bidirectional=True)
48
- self.lstm2 = nn.LSTM(1024, 256, bidirectional=True)
49
-
50
- self.fc_ha = nn.Linear(e_dim, 256)
51
- self.fc_1 = nn.Linear(256, 1)
52
- self.softmax = nn.Softmax(dim=1)
53
-
54
- def forward(self, x):
55
- x1, _ = self.lstm1(x)
56
- x2, _ = self.lstm2(x1)
57
- ht = x2[-1]
58
- ht = torch.unsqueeze(ht, 0)
59
-
60
- ha = torch.tanh(self.fc_ha(ht))
61
- alp = self.fc_1(ha)
62
- al = self.softmax(alp)
63
-
64
- T = list(ht.shape)[1]
65
- batch_size = list(ht.shape)[0]
66
- D = list(ht.shape)[2]
67
- c = torch.bmm(al.view(batch_size, 1, T), ht.view(batch_size, T, D))
68
- c = torch.squeeze(c, 0)
69
- return c
70
-
71
- class CCSL_Net(nn.Module):
72
- def __init__(self, model1, model2):
73
- super(CCSL_Net, self).__init__()
74
- self.model1 = model1
75
- self.model2 = model2
76
-
77
- self.att1 = nn.Linear(e_dim, 256)
78
- self.att2 = nn.Linear(256, 1)
79
-
80
- self.softmax = nn.Softmax(dim=1)
81
- self.lang_classifier = nn.Linear(e_dim, n_classes, bias=False)
82
-
83
- def forward(self, x1, x2):
84
- e1 = self.model1(x1)
85
- e2 = self.model2(x2)
86
-
87
- ht_e = torch.cat((e1, e2), dim=0)
88
- ht_e = torch.unsqueeze(ht_e, 0)
89
- ha_e = torch.tanh(self.att1(ht_e))
90
- alp = torch.tanh(self.att2(ha_e))
91
- al = self.softmax(alp)
92
- Tb = list(ht_e.shape)[1]
93
- batch_size = list(ht_e.shape)[0]
94
- D = list(ht_e.shape)[2]
95
- u_vec = torch.bmm(al.view(batch_size, 1, Tb), ht_e.view(batch_size, Tb, D))
96
- u_vec = torch.squeeze(u_vec, 0)
97
-
98
- lan_prim = self.lang_classifier(u_vec)
99
-
100
- return lan_prim
101
-
102
- class DID_Model(nn.Module):
103
- def __init__(self):
104
- super(DID_Model, self).__init__()
105
- self.model1 = LSTMNet()
106
- self.model2 = LSTMNet()
107
- self.ccslnet = CCSL_Net(self.model1, self.model2)
108
-
109
- def forward(self, x1, x2):
110
- output = self.ccslnet(x1, x2)
111
- return output
112
-
113
- def load_weights(self, checkpoint_path):
114
- checkpoint = torch.load(checkpoint_path, map_location=device)
115
-
116
- # Load weights for model1
117
- self.model1.lstm1.load_state_dict({
118
- 'weight_ih_l0': checkpoint['model1.lstm1.weight_ih_l0'],
119
- 'weight_hh_l0': checkpoint['model1.lstm1.weight_hh_l0'],
120
- 'bias_ih_l0': checkpoint['model1.lstm1.bias_ih_l0'],
121
- 'bias_hh_l0': checkpoint['model1.lstm1.bias_hh_l0'],
122
- 'weight_ih_l0_reverse': checkpoint['model1.lstm1.weight_ih_l0_reverse'],
123
- 'weight_hh_l0_reverse': checkpoint['model1.lstm1.weight_hh_l0_reverse'],
124
- 'bias_ih_l0_reverse': checkpoint['model1.lstm1.bias_ih_l0_reverse'],
125
- 'bias_hh_l0_reverse': checkpoint['model1.lstm1.bias_hh_l0_reverse']
126
- })
127
- self.model1.lstm2.load_state_dict({
128
- 'weight_ih_l0': checkpoint['model1.lstm2.weight_ih_l0'],
129
- 'weight_hh_l0': checkpoint['model1.lstm2.weight_hh_l0'],
130
- 'bias_ih_l0': checkpoint['model1.lstm2.bias_ih_l0'],
131
- 'bias_hh_l0': checkpoint['model1.lstm2.bias_hh_l0'],
132
- 'weight_ih_l0_reverse': checkpoint['model1.lstm2.weight_ih_l0_reverse'],
133
- 'weight_hh_l0_reverse': checkpoint['model1.lstm2.weight_hh_l0_reverse'],
134
- 'bias_ih_l0_reverse': checkpoint['model1.lstm2.bias_ih_l0_reverse'],
135
- 'bias_hh_l0_reverse': checkpoint['model1.lstm2.bias_hh_l0_reverse']
136
- })
137
- self.model1.fc_ha.load_state_dict({
138
- 'weight': checkpoint['model1.fc_ha.weight'],
139
- 'bias': checkpoint['model1.fc_ha.bias']
140
- })
141
- self.model1.fc_1.load_state_dict({
142
- 'weight': checkpoint['model1.fc_1.weight'],
143
- 'bias': checkpoint['model1.fc_1.bias']
144
- })
145
-
146
- # Load weights for model2
147
- self.model2.lstm1.load_state_dict({
148
- 'weight_ih_l0': checkpoint['model2.lstm1.weight_ih_l0'],
149
- 'weight_hh_l0': checkpoint['model2.lstm1.weight_hh_l0'],
150
- 'bias_ih_l0': checkpoint['model2.lstm1.bias_ih_l0'],
151
- 'bias_hh_l0': checkpoint['model2.lstm1.bias_hh_l0'],
152
- 'weight_ih_l0_reverse': checkpoint['model2.lstm1.weight_ih_l0_reverse'],
153
- 'weight_hh_l0_reverse': checkpoint['model2.lstm1.weight_hh_l0_reverse'],
154
- 'bias_ih_l0_reverse': checkpoint['model2.lstm1.bias_ih_l0_reverse'],
155
- 'bias_hh_l0_reverse': checkpoint['model2.lstm1.bias_hh_l0_reverse']
156
- })
157
- self.model2.lstm2.load_state_dict({
158
- 'weight_ih_l0': checkpoint['model2.lstm2.weight_ih_l0'],
159
- 'weight_hh_l0': checkpoint['model2.lstm2.weight_hh_l0'],
160
- 'bias_ih_l0': checkpoint['model2.lstm2.bias_ih_l0'],
161
- 'bias_hh_l0': checkpoint['model2.lstm2.bias_hh_l0'],
162
- 'weight_ih_l0_reverse': checkpoint['model2.lstm2.weight_ih_l0_reverse'],
163
- 'weight_hh_l0_reverse': checkpoint['model2.lstm2.weight_hh_l0_reverse'],
164
- 'bias_ih_l0_reverse': checkpoint['model2.lstm2.bias_ih_l0_reverse'],
165
- 'bias_hh_l0_reverse': checkpoint['model2.lstm2.bias_hh_l0_reverse']
166
- })
167
- self.model2.fc_ha.load_state_dict({
168
- 'weight': checkpoint['model2.fc_ha.weight'],
169
- 'bias': checkpoint['model2.fc_ha.bias']
170
- })
171
- self.model2.fc_1.load_state_dict({
172
- 'weight': checkpoint['model2.fc_1.weight'],
173
- 'bias': checkpoint['model2.fc_1.bias']
174
- })
175
-
176
- # Load attention weights
177
- self.ccslnet.att1.load_state_dict({
178
- 'weight': checkpoint['att1.weight'],
179
- 'bias': checkpoint['att1.bias']
180
- })
181
- self.ccslnet.att2.load_state_dict({
182
- 'weight': checkpoint['att2.weight'],
183
- 'bias': checkpoint['att2.bias']
184
- })
185
-
186
- # Load language classifier weights
187
- self.ccslnet.lang_classifier.load_state_dict({
188
- 'weight': checkpoint['lang_classifier.weight']
189
- })
190
-
191
- print("Weights loaded successfully!")
192
- print("Dialect Identification Model loaded!")
193
-
194
- def predict_dialect(self, audio_path, wave2vec_model_path):
195
-
196
- input_features = self.extract_wav2vec_features(audio_path, wave2vec_model_path)
197
- X1, X2 = Get_data(input_features)
198
- X1 = np.swapaxes(X1, 0, 1)
199
- X2 = np.swapaxes(X2, 0, 1)
200
-
201
- x1 = torch.from_numpy(X1).to(device)
202
- x2 = torch.from_numpy(X2).to(device)
203
- # Pass inputs through the model
204
- with torch.no_grad():
205
- output = self.forward(x1, x2)
206
-
207
- predicted_value = output.argmax().cpu().item()
208
-
209
- # Convert predicted value to dialect
210
- dialect = next(key for key, value in lan2id.items() if value == predicted_value)
211
- return dialect
212
-
213
- def extract_wav2vec_features(self, audio_path, wave2vec_model_path):
214
-
215
- wave2vec2_bundle = torchaudio.pipelines.WAV2VEC2_ASR_LARGE_960H
216
- wave2vec2_model = wave2vec2_bundle.get_model()
217
-
218
- # Load the state dictionary from the given path
219
- wave2vec2_model.load_state_dict(torch.load(wave2vec_model_path, map_location=device))
220
- wave2vec2_model = wave2vec2_model.to(device)
221
- wave2vec2_model.eval()
222
- print("Wav2Vec 2.0 model loaded!")
223
-
224
- print(f"\n\nLoading audio from {audio_path}.")
225
- X, sample_rate = sf.read(audio_path)
226
- waveform = Tensor(X)
227
- waveform = waveform.unsqueeze(0)
228
-
229
- if sample_rate != wave2vec2_bundle.sample_rate:
230
- waveform = torchaudio.functional.resample(waveform, sample_rate, wave2vec2_bundle.sample_rate)
231
- waveform = waveform.squeeze(-1)
232
-
233
- with torch.inference_mode():
234
- features, _ = wave2vec2_model.extract_features(waveform)
235
-
236
- input_features = torch.squeeze(features[2])
 
 
 
237
  return input_features
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import numpy as np
4
+ import torchaudio
5
+ import soundfile as sf
6
+ from torch import Tensor
7
+
8
+ # Define your device
9
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
+
11
+ # Define constants based on the loaded checkpoint
12
+ e_dim = 512 # Update with the correct embedding dimension based on your model
13
+ n_classes = 2 # Number of language classes, based on your requirement
14
+ look_back1 = 30
15
+ look_back2 = 60
16
+ lan2id = {'MA': 0, 'PU': 1}
17
+
18
+ # Function to preprocess input data
19
+ def Get_data(X):
20
+ if isinstance(X, torch.Tensor):
21
+ X = X.cpu().numpy()
22
+
23
+ mu = X.mean(axis=0)
24
+ std = X.std(axis=0)
25
+ np.place(std, std == 0, 1)
26
+ X = (X - mu) / std
27
+
28
+ Xdata1 = []
29
+ Xdata2 = []
30
+ for i in range(0, len(X)-look_back1, 1):
31
+ a = X[i:(i+look_back1), :]
32
+ Xdata1.append(a)
33
+ Xdata1 = np.array(Xdata1)
34
+
35
+ for i in range(0, len(X)-look_back2, 2):
36
+ b = X[i+1:(i+look_back2):3, :]
37
+ Xdata2.append(b)
38
+ Xdata2 = np.array(Xdata2)
39
+
40
+ return Xdata1, Xdata2
41
+
42
+
43
+
44
+ class LSTMNet(nn.Module):
45
+ def __init__(self):
46
+ super(LSTMNet, self).__init__()
47
+ self.lstm1 = nn.LSTM(1024, 512, bidirectional=True)
48
+ self.lstm2 = nn.LSTM(1024, 256, bidirectional=True)
49
+
50
+ self.fc_ha = nn.Linear(e_dim, 256)
51
+ self.fc_1 = nn.Linear(256, 1)
52
+ self.softmax = nn.Softmax(dim=1)
53
+
54
+ def forward(self, x):
55
+ x1, _ = self.lstm1(x)
56
+ x2, _ = self.lstm2(x1)
57
+ ht = x2[-1]
58
+ ht = torch.unsqueeze(ht, 0)
59
+
60
+ ha = torch.tanh(self.fc_ha(ht))
61
+ alp = self.fc_1(ha)
62
+ al = self.softmax(alp)
63
+
64
+ T = list(ht.shape)[1]
65
+ batch_size = list(ht.shape)[0]
66
+ D = list(ht.shape)[2]
67
+ c = torch.bmm(al.view(batch_size, 1, T), ht.view(batch_size, T, D))
68
+ c = torch.squeeze(c, 0)
69
+ return c
70
+
71
+ class CCSL_Net(nn.Module):
72
+ def __init__(self, model1, model2):
73
+ super(CCSL_Net, self).__init__()
74
+ self.model1 = model1
75
+ self.model2 = model2
76
+
77
+ self.att1 = nn.Linear(e_dim, 256)
78
+ self.att2 = nn.Linear(256, 1)
79
+
80
+ self.softmax = nn.Softmax(dim=1)
81
+ self.lang_classifier = nn.Linear(e_dim, n_classes, bias=False)
82
+
83
+ def forward(self, x1, x2):
84
+ e1 = self.model1(x1)
85
+ e2 = self.model2(x2)
86
+
87
+ ht_e = torch.cat((e1, e2), dim=0)
88
+ ht_e = torch.unsqueeze(ht_e, 0)
89
+ ha_e = torch.tanh(self.att1(ht_e))
90
+ alp = torch.tanh(self.att2(ha_e))
91
+ al = self.softmax(alp)
92
+ Tb = list(ht_e.shape)[1]
93
+ batch_size = list(ht_e.shape)[0]
94
+ D = list(ht_e.shape)[2]
95
+ u_vec = torch.bmm(al.view(batch_size, 1, Tb), ht_e.view(batch_size, Tb, D))
96
+ u_vec = torch.squeeze(u_vec, 0)
97
+
98
+ lan_prim = self.lang_classifier(u_vec)
99
+
100
+ return lan_prim
101
+
102
+ class DID_Model(nn.Module):
103
+ def __init__(self):
104
+ super(DID_Model, self).__init__()
105
+ self.model1 = LSTMNet()
106
+ self.model2 = LSTMNet()
107
+ self.ccslnet = CCSL_Net(self.model1, self.model2)
108
+ self.wave2vec_model_path = ""
109
+
110
+ def forward(self, x1, x2):
111
+ output = self.ccslnet(x1, x2)
112
+ return output
113
+
114
+ def load_weights(self, checkpoint_path, wave2vec_model_path):
115
+ checkpoint = torch.load(checkpoint_path, map_location=device)
116
+ self.wave2vec_model_path = wave2vec_model_path
117
+
118
+ # Load weights for model1
119
+ self.model1.lstm1.load_state_dict({
120
+ 'weight_ih_l0': checkpoint['model1.lstm1.weight_ih_l0'],
121
+ 'weight_hh_l0': checkpoint['model1.lstm1.weight_hh_l0'],
122
+ 'bias_ih_l0': checkpoint['model1.lstm1.bias_ih_l0'],
123
+ 'bias_hh_l0': checkpoint['model1.lstm1.bias_hh_l0'],
124
+ 'weight_ih_l0_reverse': checkpoint['model1.lstm1.weight_ih_l0_reverse'],
125
+ 'weight_hh_l0_reverse': checkpoint['model1.lstm1.weight_hh_l0_reverse'],
126
+ 'bias_ih_l0_reverse': checkpoint['model1.lstm1.bias_ih_l0_reverse'],
127
+ 'bias_hh_l0_reverse': checkpoint['model1.lstm1.bias_hh_l0_reverse']
128
+ })
129
+ self.model1.lstm2.load_state_dict({
130
+ 'weight_ih_l0': checkpoint['model1.lstm2.weight_ih_l0'],
131
+ 'weight_hh_l0': checkpoint['model1.lstm2.weight_hh_l0'],
132
+ 'bias_ih_l0': checkpoint['model1.lstm2.bias_ih_l0'],
133
+ 'bias_hh_l0': checkpoint['model1.lstm2.bias_hh_l0'],
134
+ 'weight_ih_l0_reverse': checkpoint['model1.lstm2.weight_ih_l0_reverse'],
135
+ 'weight_hh_l0_reverse': checkpoint['model1.lstm2.weight_hh_l0_reverse'],
136
+ 'bias_ih_l0_reverse': checkpoint['model1.lstm2.bias_ih_l0_reverse'],
137
+ 'bias_hh_l0_reverse': checkpoint['model1.lstm2.bias_hh_l0_reverse']
138
+ })
139
+ self.model1.fc_ha.load_state_dict({
140
+ 'weight': checkpoint['model1.fc_ha.weight'],
141
+ 'bias': checkpoint['model1.fc_ha.bias']
142
+ })
143
+ self.model1.fc_1.load_state_dict({
144
+ 'weight': checkpoint['model1.fc_1.weight'],
145
+ 'bias': checkpoint['model1.fc_1.bias']
146
+ })
147
+
148
+ # Load weights for model2
149
+ self.model2.lstm1.load_state_dict({
150
+ 'weight_ih_l0': checkpoint['model2.lstm1.weight_ih_l0'],
151
+ 'weight_hh_l0': checkpoint['model2.lstm1.weight_hh_l0'],
152
+ 'bias_ih_l0': checkpoint['model2.lstm1.bias_ih_l0'],
153
+ 'bias_hh_l0': checkpoint['model2.lstm1.bias_hh_l0'],
154
+ 'weight_ih_l0_reverse': checkpoint['model2.lstm1.weight_ih_l0_reverse'],
155
+ 'weight_hh_l0_reverse': checkpoint['model2.lstm1.weight_hh_l0_reverse'],
156
+ 'bias_ih_l0_reverse': checkpoint['model2.lstm1.bias_ih_l0_reverse'],
157
+ 'bias_hh_l0_reverse': checkpoint['model2.lstm1.bias_hh_l0_reverse']
158
+ })
159
+ self.model2.lstm2.load_state_dict({
160
+ 'weight_ih_l0': checkpoint['model2.lstm2.weight_ih_l0'],
161
+ 'weight_hh_l0': checkpoint['model2.lstm2.weight_hh_l0'],
162
+ 'bias_ih_l0': checkpoint['model2.lstm2.bias_ih_l0'],
163
+ 'bias_hh_l0': checkpoint['model2.lstm2.bias_hh_l0'],
164
+ 'weight_ih_l0_reverse': checkpoint['model2.lstm2.weight_ih_l0_reverse'],
165
+ 'weight_hh_l0_reverse': checkpoint['model2.lstm2.weight_hh_l0_reverse'],
166
+ 'bias_ih_l0_reverse': checkpoint['model2.lstm2.bias_ih_l0_reverse'],
167
+ 'bias_hh_l0_reverse': checkpoint['model2.lstm2.bias_hh_l0_reverse']
168
+ })
169
+ self.model2.fc_ha.load_state_dict({
170
+ 'weight': checkpoint['model2.fc_ha.weight'],
171
+ 'bias': checkpoint['model2.fc_ha.bias']
172
+ })
173
+ self.model2.fc_1.load_state_dict({
174
+ 'weight': checkpoint['model2.fc_1.weight'],
175
+ 'bias': checkpoint['model2.fc_1.bias']
176
+ })
177
+
178
+ # Load attention weights
179
+ self.ccslnet.att1.load_state_dict({
180
+ 'weight': checkpoint['att1.weight'],
181
+ 'bias': checkpoint['att1.bias']
182
+ })
183
+ self.ccslnet.att2.load_state_dict({
184
+ 'weight': checkpoint['att2.weight'],
185
+ 'bias': checkpoint['att2.bias']
186
+ })
187
+
188
+ # Load language classifier weights
189
+ self.ccslnet.lang_classifier.load_state_dict({
190
+ 'weight': checkpoint['lang_classifier.weight']
191
+ })
192
+
193
+ print("Weights loaded successfully!")
194
+ print("Dialect Identification Model loaded!")
195
+
196
+ def predict_dialect(self, audio_path):
197
+
198
+ wave2vec_model_path = self.wave2vec_model_path
199
+ input_features = self.extract_wav2vec_features(audio_path, wave2vec_model_path)
200
+ X1, X2 = Get_data(input_features)
201
+ X1 = np.swapaxes(X1, 0, 1)
202
+ X2 = np.swapaxes(X2, 0, 1)
203
+
204
+ x1 = torch.from_numpy(X1).to(device)
205
+ x2 = torch.from_numpy(X2).to(device)
206
+ # Pass inputs through the model
207
+ with torch.no_grad():
208
+ output = self.forward(x1, x2)
209
+
210
+ predicted_value = output.argmax().cpu().item()
211
+
212
+ # Convert predicted value to dialect
213
+ dialect = next(key for key, value in lan2id.items() if value == predicted_value)
214
+ return dialect
215
+
216
+ def extract_wav2vec_features(self, audio_path, wave2vec_model_path):
217
+
218
+ wave2vec2_bundle = torchaudio.pipelines.WAV2VEC2_ASR_LARGE_960H
219
+ wave2vec2_model = wave2vec2_bundle.get_model()
220
+
221
+ # Load the state dictionary from the given path
222
+ wave2vec2_model.load_state_dict(torch.load(wave2vec_model_path, map_location=device))
223
+ wave2vec2_model = wave2vec2_model.to(device)
224
+ wave2vec2_model.eval()
225
+ print("Wav2Vec 2.0 model loaded!")
226
+
227
+ print(f"\n\nLoading audio from {audio_path}.")
228
+ X, sample_rate = sf.read(audio_path)
229
+ waveform = Tensor(X)
230
+ waveform = waveform.unsqueeze(0)
231
+
232
+ if sample_rate != wave2vec2_bundle.sample_rate:
233
+ waveform = torchaudio.functional.resample(waveform, sample_rate, wave2vec2_bundle.sample_rate)
234
+ waveform = waveform.squeeze(-1)
235
+
236
+ with torch.inference_mode():
237
+ features, _ = wave2vec2_model.extract_features(waveform)
238
+
239
+ input_features = torch.squeeze(features[2])
240
  return input_features