Spaces:
Running
Running
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| from matplotlib import patches | |
| from sklearn.svm import LinearSVC | |
| from matplotlib.axes._axes import _log as matplotlib_axes_logger | |
| matplotlib_axes_logger.setLevel('ERROR') | |
| colors = ['#e6194B', '#3cb44b', '#ffe119', '#4363d8', '#f58231', '#911eb4', | |
| '#42d4f4', '#f032e6', '#bfef45', '#fabebe', '#469990', '#e6beff', | |
| '#9A6324', '#fffac8', '#800000', '#aaffc3', '#808000', '#ffd8b1', | |
| '#000075', '#a9a9a9'] | |
| class InteractivePlotter: | |
| def __init__(self, feats_ds, feats, spec_slices, call_info, freq_lims, allow_training): | |
| """ | |
| Plots 2D low dimensional features on left and corresponding spectgrams on | |
| the right. | |
| """ | |
| self.feats_ds = feats_ds | |
| self.feats = feats | |
| self.clf = None | |
| self.spec_slices = spec_slices | |
| self.call_info = call_info | |
| #_, self.labels = np.unique([cc['class'] for cc in call_info], return_inverse=True) | |
| self.labels = np.zeros(len(call_info), dtype=np.int) | |
| self.annotated = np.zeros(self.labels.shape[0], dtype=np.int) # can populate this with 1's where we have labels | |
| self.labels_cols = [colors[self.labels[ii]] for ii in range(len(self.labels))] | |
| self.freq_lims = freq_lims | |
| self.allow_training = allow_training | |
| self.pt_size = 5.0 | |
| self.spec_pad = 0.2 # this much padding has been applied to the spec slices | |
| self.fig_width = 12 | |
| self.fig_height = 8 | |
| self.current_id = 0 | |
| max_ind = np.argmax([ss.shape[1] for ss in self.spec_slices]) | |
| self.max_width = self.spec_slices[max_ind].shape[1] | |
| self.blank_spec = np.zeros((self.spec_slices[0].shape[0], self.max_width)) | |
| def plot(self, fig_id): | |
| self.fig, self.ax = plt.subplots(nrows=1, ncols=2, num=fig_id, figsize=(self.fig_width, self.fig_height), | |
| gridspec_kw={'width_ratios': [2, 1]}) | |
| plt.tight_layout() | |
| # plot 2D TNSE features | |
| self.low_dim_plt = self.ax[0].scatter(self.feats_ds[:, 0], self.feats_ds[:, 1], | |
| c=self.labels_cols, s=self.pt_size, picker=5) | |
| self.ax[0].set_title('TSNE of Call Features') | |
| self.ax[0].set_xticks([]) | |
| self.ax[0].set_yticks([]) | |
| # plot clip from spectrogram | |
| spec_min_max = (0, self.blank_spec.shape[1], self.freq_lims[0], self.freq_lims[1]) | |
| self.ax[1].imshow(self.blank_spec, extent=spec_min_max, cmap='plasma', aspect='auto') | |
| self.spec_im = self.ax[1].get_images()[0] | |
| self.ax[1].set_title('Spectrogram') | |
| self.ax[1].grid(color='w', linewidth=0.5) | |
| self.ax[1].set_xticks([]) | |
| self.ax[1].set_ylabel('kHz') | |
| bbox_orig = patches.Rectangle((0,0),0,0, edgecolor='w', linewidth=0, fill=False) | |
| self.ax[1].add_patch(bbox_orig) | |
| self.annot = self.ax[0].annotate('', xy=(0,0), xytext=(20,20),textcoords='offset points', | |
| bbox=dict(boxstyle='round', fc='w'), arrowprops=dict(arrowstyle='->')) | |
| self.annot.set_visible(False) | |
| self.fig.canvas.mpl_connect('motion_notify_event', self.mouse_hover) | |
| self.fig.canvas.mpl_connect('key_press_event', self.key_press) | |
| def mouse_hover(self, event): | |
| vis = self.annot.get_visible() | |
| if event.inaxes == self.ax[0]: | |
| cont, ind = self.low_dim_plt.contains(event) | |
| if cont: | |
| self.current_id = ind['ind'][0] | |
| # copy spec into full window - probably a better way of doing this | |
| new_spec = self.blank_spec.copy() | |
| w_diff = (self.blank_spec.shape[1] - self.spec_slices[self.current_id].shape[1])//2 | |
| new_spec[:, w_diff:self.spec_slices[self.current_id].shape[1]+w_diff] = self.spec_slices[self.current_id] | |
| self.spec_im.set_data(new_spec) | |
| self.spec_im.set_clim(vmin=0, vmax=new_spec.max()) | |
| # draw bounding box around call | |
| self.ax[1].patches[0].remove() | |
| spec_width_orig = self.spec_slices[self.current_id].shape[1]/(1.0+2.0*self.spec_pad) | |
| xx = w_diff + self.spec_pad*spec_width_orig | |
| ww = spec_width_orig | |
| yy = self.call_info[self.current_id]['low_freq']/1000 | |
| hh = (self.call_info[self.current_id]['high_freq']-self.call_info[self.current_id]['low_freq'])/1000 | |
| bbox = patches.Rectangle((xx,yy),ww,hh, edgecolor='r', linewidth=0.5, fill=False) | |
| self.ax[1].add_patch(bbox) | |
| # update annotation arrow | |
| pos = self.low_dim_plt.get_offsets()[self.current_id] | |
| self.annot.xy = pos | |
| self.annot.set_visible(True) | |
| # write call info | |
| info_str = self.call_info[self.current_id]['file_name'] + ', time=' \ | |
| + str(round(self.call_info[self.current_id]['start_time'],3)) \ | |
| + ', prob=' + str(round(self.call_info[self.current_id]['det_prob'],3)) | |
| self.ax[0].set_xlabel(info_str) | |
| # redraw | |
| self.fig.canvas.draw_idle() | |
| def key_press(self, event): | |
| if event.key.isdigit(): | |
| self.labels_cols[self.current_id] = colors[int(event.key)] | |
| self.labels[self.current_id] = int(event.key) | |
| self.annotated[self.current_id] = 1 | |
| elif event.key == 'enter' and self.allow_training: | |
| self.train_classifier() | |
| elif event.key == 'x' and self.allow_training: | |
| self.get_classifier_params() | |
| self.ax[0].scatter(self.feats_ds[:, 0], self.feats_ds[:, 1], | |
| c=self.labels_cols, s=self.pt_size) | |
| self.fig.canvas.draw_idle() | |
| def train_classifier(self): | |
| # TODO maybe it's better to classify in 2D space - but then can't be linear ... | |
| inds = np.where(self.annotated == 1)[0] | |
| labs_un, labs_inds = np.unique(self.labels[inds], return_inverse=True) | |
| if labs_un.shape[0] > 1: # needs at least 2 classes | |
| self.clf = LinearSVC(C=1.0, penalty='l2', loss='squared_hinge', tol=0.0001, | |
| intercept_scaling=1.0, max_iter=2000) | |
| self.clf.fit(self.feats[inds, :], self.labels[inds]) | |
| # update labels | |
| inds_unlab = np.where(self.annotated == 0)[0] | |
| self.labels[inds_unlab] = self.clf.predict(self.feats[inds_unlab]) | |
| for ii in inds_unlab: | |
| self.labels_cols[ii] = colors[self.labels[ii]] | |
| else: | |
| print('Not enough data - please label more classes.') | |
| def get_classifier_params(self): | |
| res = {} | |
| if self.clf is None: | |
| print('Model not trained!') | |
| else: | |
| res['weights'] = self.clf.coef_.astype(np.float32) | |
| res['biases'] = self.clf.intercept_.astype(np.float32) | |
| return res | |