davesalvi commited on
Commit
2137fb1
·
1 Parent(s): 645a259

simplify utils

Browse files
Files changed (3) hide show
  1. .idea/workspace.xml +1 -1
  2. script.py +1 -1
  3. src/utils.py +0 -231
.idea/workspace.xml CHANGED
@@ -5,8 +5,8 @@
5
  </component>
6
  <component name="ChangeListManager">
7
  <list default="true" id="23565123-73ab-4f40-a9ef-1086e0c9e1ec" name="Changes" comment="">
8
- <change beforePath="$PROJECT_DIR$/.idea/workspace.xml" beforeDir="false" afterPath="$PROJECT_DIR$/.idea/workspace.xml" afterDir="false" />
9
  <change beforePath="$PROJECT_DIR$/script.py" beforeDir="false" afterPath="$PROJECT_DIR$/script.py" afterDir="false" />
 
10
  </list>
11
  <option name="SHOW_DIALOG" value="false" />
12
  <option name="HIGHLIGHT_CONFLICTS" value="true" />
 
5
  </component>
6
  <component name="ChangeListManager">
7
  <list default="true" id="23565123-73ab-4f40-a9ef-1086e0c9e1ec" name="Changes" comment="">
 
8
  <change beforePath="$PROJECT_DIR$/script.py" beforeDir="false" afterPath="$PROJECT_DIR$/script.py" afterDir="false" />
9
+ <change beforePath="$PROJECT_DIR$/src/utils.py" beforeDir="false" afterPath="$PROJECT_DIR$/src/utils.py" afterDir="false" />
10
  </list>
11
  <option name="SHOW_DIALOG" value="false" />
12
  <option name="HIGHLIGHT_CONFLICTS" value="true" />
script.py CHANGED
@@ -10,7 +10,7 @@ from preprocess import preprocess
10
  # from pathlib import Path
11
 
12
  # from src.rawnet_model import RawNet
13
- # from src.utils import *
14
 
15
  # Import your model and anything else you want
16
  # You can even install other packages included in your repo
 
10
  # from pathlib import Path
11
 
12
  # from src.rawnet_model import RawNet
13
+ from src.utils import *
14
 
15
  # Import your model and anything else you want
16
  # You can even install other packages included in your repo
src/utils.py CHANGED
@@ -1,74 +1,4 @@
1
- import os
2
- import torch
3
- import random
4
- import GPUtil
5
  import yaml
6
- import matplotlib.pyplot as plt
7
- import numpy as np
8
- from sklearn.metrics import roc_curve, auc, confusion_matrix
9
- import pandas as pd
10
- import torch.nn as nn
11
-
12
-
13
- def set_gpu(id=-1):
14
- """
15
- Set GPU device or select the one with the lowest memory usage (None for CPU-only)
16
-
17
- :param id: if specified, corresponds to the GPU index desired.
18
- """
19
- if id is None:
20
- # CPU only
21
- print('GPU not selected')
22
- os.environ["CUDA_VISIBLE_DEVICES"] = str(-1)
23
- else:
24
- # -1 for automatic choice
25
- device = id if id != -1 else GPUtil.getFirstAvailable(order='memory')[0]
26
- try:
27
- name = GPUtil.getGPUs()[device].name
28
- except IndexError:
29
- print('The selected GPU does not exist. Switching to the most available one.')
30
- device = GPUtil.getFirstAvailable(order='memory')[0]
31
- name = GPUtil.getGPUs()[device].name
32
- print('GPU selected: %d - %s' % (device, name))
33
- os.environ["CUDA_VISIBLE_DEVICES"] = str(device)
34
- return device
35
-
36
-
37
- def prepare_asvspoof_data(config):
38
-
39
- data_dir_2019 = '/nas/public/dataset/asvspoof2019/LA/ASVspoof2019_LA_cm_protocols'
40
- data_eval_2021 = '/nas/public/dataset/asvspoof2021/DF_cm_eval_labels.txt'
41
- files = [os.path.join(data_dir_2019, 'ASVspoof2019.LA.cm.train.trn.txt'),
42
- os.path.join(data_dir_2019, 'ASVspoof2019.LA.cm.dev.trl.txt'), data_eval_2021]
43
-
44
- audio_dir_2019 = '/nas/public/dataset/asvspoof2019/LA'
45
- audio_dir_2021 = '/nas/public/dataset/asvspoof2021/ASVspoof2021_DF_eval/flac/'
46
- set_dirs = [os.path.join(audio_dir_2019, 'ASVspoof2019_LA_train/flac/'),
47
- os.path.join(audio_dir_2019, 'ASVspoof2019_LA_dev/flac/'), audio_dir_2021]
48
-
49
- save_paths = [config['df_train_path'], config['df_dev_path'], config['df_eval_path']]
50
-
51
- for file_path, set_dir, save_path in zip(files, set_dirs, save_paths):
52
-
53
- txt_file = pd.read_csv(file_path, sep=' ', header=None)
54
- txt_file = txt_file.replace({'bonafide': 0, 'spoof': 1})
55
-
56
- txt_file.iloc[:,1] = set_dir + txt_file.iloc[:,1].astype(str) + '.flac'
57
-
58
- if not file_path == data_eval_2021:
59
- df = txt_file[[1, 4]]
60
- df = df.rename({1: 'path', 4: 'label'}, axis='columns')
61
- else:
62
- df = txt_file[[1, 5]]
63
- df = df.rename({1: 'path', 5: 'label'}, axis='columns')
64
-
65
- df.to_csv(save_path)
66
-
67
-
68
- def init_weights(module):
69
- if isinstance(module, nn.Linear):
70
- torch.nn.init.xavier_uniform_(module.weight)
71
- module.bias.data.fill_(0.01)
72
 
73
 
74
  def read_yaml(config_path):
@@ -84,164 +14,3 @@ def read_yaml(config_path):
84
  config = yaml.safe_load(f)
85
  return config
86
 
87
-
88
- def sigmoid(x, factor=1):
89
- """
90
- Compute sigmoid function.
91
-
92
- :param x: input signal
93
- :param factor: sigmoid parameter
94
- :return: sigmoid(x)
95
- :rtype np.array
96
- """
97
- z = 1 / (1 + np.exp(-factor*x))
98
- return z
99
-
100
-
101
- def plot_roc_curve(labels, pred, legend=None):
102
- """
103
- Plot ROC curve.
104
-
105
- :param labels: groundtruth labels
106
- :type labels: list
107
- :param pred: predicted score
108
- :type pred: list
109
- :param legend: if True, add legend to the plot
110
- :type legend: bool
111
- :return:
112
- """
113
- # labels and pred bust be given in (N, ) shape
114
-
115
- def tpr5(y_true, y_pred):
116
- fpr, tpr, thr = roc_curve(y_true, y_pred)
117
- fp_sort = sorted(fpr)
118
- tp_sort = sorted(tpr)
119
- tpr_ind = [i for (i, val) in enumerate(fp_sort) if val >= 0.1][0]
120
- tpr01 = tp_sort[tpr_ind]
121
- return tpr01
122
-
123
- lw = 3
124
-
125
-
126
- fpr, tpr, thres = roc_curve(labels, pred)
127
- rocauc = auc(fpr, tpr)
128
- fnr = 1 - tpr
129
- eer = fpr[np.nanargmin(np.absolute(fnr - fpr))]
130
- optimal_index = np.argmax(tpr - fpr)
131
- optimal_threshold = thres[optimal_index]
132
-
133
- print('TPR5 = {:.3f}'.format(tpr5(labels, pred)))
134
- print('AUC = {:.3f}'.format(rocauc))
135
- print('EER = {:.3f}'.format(eer))
136
- print('Best Thres. = {:.3f}'.format(optimal_threshold))
137
- print()
138
- if legend:
139
- plt.plot(fpr, tpr, lw=lw, label='$\mathrm{' + legend + ' - AUC = %0.2f}$' % rocauc)
140
- else:
141
- plt.plot(fpr, tpr, lw=lw, label='$\mathrm{AUC = %0.2f}$' % rocauc)
142
- plt.plot([0, 1], [0, 1], color='black', lw=lw, linestyle='--')
143
- plt.xlim([-0.02, 1.0])
144
- plt.ylim([0.0, 1.03])
145
- plt.xlabel(r'$\mathrm{False\;Positive\;Rate}$', fontsize=18)
146
- plt.ylabel(r'$\mathrm{True\;Positive\;Rate}$', fontsize=18)
147
- plt.legend(loc="lower right", fontsize=15)
148
- plt.xticks(fontsize=15)
149
- plt.yticks(fontsize=15)
150
- plt.grid(True)
151
- # plt.show()
152
-
153
- return optimal_threshold
154
-
155
- def plot_confusion_matrix(y_true, y_pred, normalize=False, cmap=plt.cm.Blues):
156
- """
157
- Plot confusion matrix.
158
-
159
- :param y_true: ground-truth labels
160
- :type y_true: list
161
- :param y_pred: predicted labels
162
- :type y_pred: list
163
- :param normalize: if set to True, normalise the confusion matrix.
164
- :type normalize: bool
165
- :param cmap: matplotlib cmap to be used for plot
166
- :type cmap:
167
- :return:
168
- """
169
- cm = confusion_matrix(y_true, y_pred)
170
- # Only use the labels that appear in the data
171
- # classes = classes[unique_labels(y_true, y_pred)]
172
- classes = ['$\it{Real}$','$\it{Fake}$']
173
- if normalize:
174
- cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
175
- print(cm)
176
-
177
- fsize = 25 # fontsize
178
- fig, ax = plt.subplots()
179
- im = ax.imshow(cm, interpolation='nearest', cmap=cmap, clim=(0,1))
180
- cbar = ax.figure.colorbar(im, ax=ax)
181
- cbar.ax.tick_params(labelsize=fsize)
182
- ax.set(xticks=np.arange(cm.shape[1]),
183
- yticks=np.arange(cm.shape[0]),
184
- )
185
- ax.set_xlabel('$\mathrm{True\;label}$', fontsize=fsize)
186
- ax.set_ylabel('$\mathrm{Predicted\;label}$', fontsize=fsize)
187
- ax.set_xticklabels(classes, fontsize=fsize)
188
- ax.set_yticklabels(classes, fontsize=fsize)
189
- # Rotate the tick labels and set their alignment.
190
- # plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
191
- # rotation_mode="anchor")
192
- # Loop over data dimensions and create text annotations.
193
- fmt = '.3f' if normalize else 'd'
194
- thresh = cm.max() / 2.
195
- for i in range(cm.shape[0]):
196
- for j in range(cm.shape[1]):
197
- ax.text(j, i, format('$\mathrm{' + str(format(cm[i, j], fmt)) + '}$'),
198
- ha="center", va="center",
199
- fontsize=fsize,
200
- color="white" if np.array(cm[i, j]) > thresh else "black")
201
- fig.tight_layout()
202
- # plt.show()
203
-
204
- return ax
205
-
206
-
207
- def reconstruct_from_pred(pred_array, win_len, hop_size, fs=16000):
208
- """
209
- Create a score array with length equal to the original signal length starting from predictions aggregated on
210
- rectangular windows.
211
-
212
- :param pred_array: aggregated prediction array
213
- :type pred_array: list
214
- :param win_len: length of the window used for aggregation
215
- :type win_len: int
216
- :param hop_size: length of the hop used for aggregation
217
- :type hop_size: int
218
- :param fs: sampling frequency
219
- :type fs: int
220
- :return: reconstructed array
221
- """
222
-
223
- pred_array = np.array(pred_array)
224
- audio_shape = (len(pred_array)-1) * hop_size * fs + win_len * fs
225
-
226
- window_pred = np.zeros((len(pred_array), int(audio_shape)))
227
- for idx, pred in enumerate(pred_array):
228
- window_pred[idx, int(idx*hop_size*fs):int((idx*hop_size+win_len)*fs)] = pred
229
-
230
- window_pred = np.nanmean(np.where(window_pred != 0, window_pred, np.nan), 0)
231
-
232
- return window_pred
233
-
234
-
235
- def seed_everything(seed: int):
236
- """
237
- Set seed for everything.
238
- :param seed: seed value
239
- :type seed: int
240
- """
241
- random.seed(seed)
242
- os.environ['PYTHONHASHSEED'] = str(seed)
243
- np.random.seed(seed)
244
- torch.manual_seed(seed)
245
- torch.cuda.manual_seed(seed)
246
- torch.backends.cudnn.deterministic = True
247
- torch.backends.cudnn.benchmark = True
 
 
 
 
 
1
  import yaml
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
 
4
  def read_yaml(config_path):
 
14
  config = yaml.safe_load(f)
15
  return config
16