Spaces:
Sleeping
Sleeping
| from __future__ import division | |
| import matplotlib | |
| matplotlib.use('agg') | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import os | |
| from data_reader import DataConfig | |
| from detect_peaks import detect_peaks | |
| import logging | |
| class EMA(object): | |
| def __init__(self, alpha): | |
| self.alpha = alpha | |
| self.x = 0. | |
| self.count = 0 | |
| def value(self): | |
| return self.x | |
| def __call__(self, x): | |
| if self.count == 0: | |
| self.x = x | |
| else: | |
| self.x = self.alpha * self.x + (1 - self.alpha) * x | |
| self.count += 1 | |
| return self.x | |
| class LMA(object): | |
| def __init__(self): | |
| self.x = 0. | |
| self.count = 0 | |
| def value(self): | |
| return self.x | |
| def __call__(self, x): | |
| if self.count == 0: | |
| self.x = x | |
| else: | |
| self.x += (x - self.x)/(self.count+1) | |
| self.count += 1 | |
| return self.x | |
| def detect_peaks_thread(i, pred, fname=None, result_dir=None, args=None): | |
| if args is None: | |
| itp, prob_p = detect_peaks(pred[i,:,0,1], mph=0.5, mpd=0.5/DataConfig().dt, show=False) | |
| its, prob_s = detect_peaks(pred[i,:,0,2], mph=0.5, mpd=0.5/DataConfig().dt, show=False) | |
| else: | |
| itp, prob_p = detect_peaks(pred[i,:,0,1], mph=args.tp_prob, mpd=0.5/DataConfig().dt, show=False) | |
| its, prob_s = detect_peaks(pred[i,:,0,2], mph=args.ts_prob, mpd=0.5/DataConfig().dt, show=False) | |
| if (fname is not None) and (result_dir is not None): | |
| # np.savez(os.path.join(result_dir, fname[i].decode().split('/')[-1]), pred=pred[i], itp=itp, its=its, prob_p=prob_p, prob_s=prob_s) | |
| try: | |
| np.savez(os.path.join(result_dir, fname[i].decode()), pred=pred[i], itp=itp, its=its, prob_p=prob_p, prob_s=prob_s) | |
| except FileNotFoundError: | |
| #if not os.path.exists(os.path.dirname(os.path.join(result_dir, fname[i].decode()))): | |
| os.makedirs(os.path.dirname(os.path.join(result_dir, fname[i].decode())), exist_ok=True) | |
| np.savez(os.path.join(result_dir, fname[i].decode()), pred=pred[i], itp=itp, its=its, prob_p=prob_p, prob_s=prob_s) | |
| return [(itp, prob_p), (its, prob_s)] | |
| def plot_result_thread(i, pred, X, Y=None, itp=None, its=None, | |
| itp_pred=None, its_pred=None, fname=None, figure_dir=None): | |
| dt = DataConfig().dt | |
| t = np.arange(0, pred.shape[1]) * dt | |
| box = dict(boxstyle='round', facecolor='white', alpha=1) | |
| text_loc = [0.05, 0.77] | |
| plt.figure(i) | |
| plt.clf() | |
| # fig_size = plt.gcf().get_size_inches() | |
| # plt.gcf().set_size_inches(fig_size*[1, 1.2]) | |
| plt.subplot(411) | |
| plt.plot(t, X[i, :, 0, 0], 'k', label='E', linewidth=0.5) | |
| plt.autoscale(enable=True, axis='x', tight=True) | |
| tmp_min = np.min(X[i, :, 0, 0]) | |
| tmp_max = np.max(X[i, :, 0, 0]) | |
| if (itp is not None) and (its is not None): | |
| for j in range(len(itp[i])): | |
| if j == 0: | |
| plt.plot([itp[i][j]*dt, itp[i][j]*dt], [tmp_min, tmp_max], 'b', label='P', linewidth=0.5) | |
| else: | |
| plt.plot([itp[i][j]*dt, itp[i][j]*dt], [tmp_min, tmp_max], 'b', linewidth=0.5) | |
| for j in range(len(its[i])): | |
| if j == 0: | |
| plt.plot([its[i][j]*dt, its[i][j]*dt], [tmp_min, tmp_max], 'r', label='S', linewidth=0.5) | |
| else: | |
| plt.plot([its[i][j]*dt, its[i][j]*dt], [tmp_min, tmp_max], 'r', linewidth=0.5) | |
| plt.ylabel('Amplitude') | |
| plt.legend(loc='upper right', fontsize='small') | |
| plt.gca().set_xticklabels([]) | |
| plt.text(text_loc[0], text_loc[1], '(i)', horizontalalignment='center', | |
| transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box) | |
| plt.subplot(412) | |
| plt.plot(t, X[i, :, 0, 1], 'k', label='N', linewidth=0.5) | |
| plt.autoscale(enable=True, axis='x', tight=True) | |
| tmp_min = np.min(X[i, :, 0, 1]) | |
| tmp_max = np.max(X[i, :, 0, 1]) | |
| if (itp is not None) and (its is not None): | |
| for j in range(len(itp[i])): | |
| plt.plot([itp[i][j]*dt, itp[i][j]*dt], [tmp_min, tmp_max], 'b', linewidth=0.5) | |
| for j in range(len(its[i])): | |
| plt.plot([its[i][j]*dt, its[i][j]*dt], [tmp_min, tmp_max], 'r', linewidth=0.5) | |
| plt.ylabel('Amplitude') | |
| plt.legend(loc='upper right', fontsize='small') | |
| plt.gca().set_xticklabels([]) | |
| plt.text(text_loc[0], text_loc[1], '(ii)', horizontalalignment='center', | |
| transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box) | |
| plt.subplot(413) | |
| plt.plot(t, X[i, :, 0, 2], 'k', label='Z', linewidth=0.5) | |
| plt.autoscale(enable=True, axis='x', tight=True) | |
| tmp_min = np.min(X[i, :, 0, 2]) | |
| tmp_max = np.max(X[i, :, 0, 2]) | |
| if (itp is not None) and (its is not None): | |
| for j in range(len(itp[i])): | |
| plt.plot([itp[i][j]*dt, itp[i][j]*dt], [tmp_min, tmp_max], 'b', linewidth=0.5) | |
| for j in range(len(its[i])): | |
| plt.plot([its[i][j]*dt, its[i][j]*dt], [tmp_min, tmp_max], 'r', linewidth=0.5) | |
| plt.ylabel('Amplitude') | |
| plt.legend(loc='upper right', fontsize='small') | |
| plt.gca().set_xticklabels([]) | |
| plt.text(text_loc[0], text_loc[1], '(iii)', horizontalalignment='center', | |
| transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box) | |
| plt.subplot(414) | |
| if Y is not None: | |
| plt.plot(t, Y[i, :, 0, 1], 'b', label='P', linewidth=0.5) | |
| plt.plot(t, Y[i, :, 0, 2], 'r', label='S', linewidth=0.5) | |
| plt.plot(t, pred[i, :, 0, 1], '--g', label='$\hat{P}$', linewidth=0.5) | |
| plt.plot(t, pred[i, :, 0, 2], '-.m', label='$\hat{S}$', linewidth=0.5) | |
| plt.autoscale(enable=True, axis='x', tight=True) | |
| if (itp_pred is not None) and (its_pred is not None): | |
| for j in range(len(itp_pred)): | |
| plt.plot([itp_pred[j]*dt, itp_pred[j]*dt], [-0.1, 1.1], '--g', linewidth=0.5) | |
| for j in range(len(its_pred)): | |
| plt.plot([its_pred[j]*dt, its_pred[j]*dt], [-0.1, 1.1], '-.m', linewidth=0.5) | |
| plt.ylim([-0.05, 1.05]) | |
| plt.text(text_loc[0], text_loc[1], '(iv)', horizontalalignment='center', | |
| transform=plt.gca().transAxes, fontsize="small", fontweight="normal", bbox=box) | |
| plt.legend(loc='upper right', fontsize='small') | |
| plt.xlabel('Time (s)') | |
| plt.ylabel('Probability') | |
| plt.tight_layout() | |
| plt.gcf().align_labels() | |
| try: | |
| plt.savefig(os.path.join(figure_dir, | |
| fname[i].decode().rstrip('.npz')+'.png'), | |
| bbox_inches='tight') | |
| except FileNotFoundError: | |
| #if not os.path.exists(os.path.dirname(os.path.join(figure_dir, fname[i].decode()))): | |
| os.makedirs(os.path.dirname(os.path.join(figure_dir, fname[i].decode())), exist_ok=True) | |
| plt.savefig(os.path.join(figure_dir, | |
| fname[i].decode().rstrip('.npz')+'.png'), | |
| bbox_inches='tight') | |
| #plt.savefig(os.path.join(figure_dir, | |
| # fname[i].decode().split('/')[-1].rstrip('.npz')+'.png'), | |
| # bbox_inches='tight') | |
| # plt.savefig(os.path.join(figure_dir, | |
| # fname[i].decode().split('/')[-1].rstrip('.npz')+'.pdf'), | |
| # bbox_inches='tight') | |
| plt.close(i) | |
| return 0 | |
| def postprocessing_thread(i, pred, X, Y=None, itp=None, its=None, fname=None, result_dir=None, figure_dir=None, args=None): | |
| (itp_pred, prob_p), (its_pred, prob_s) = detect_peaks_thread(i, pred, fname, result_dir, args) | |
| if (fname is not None) and (figure_dir is not None): | |
| plot_result_thread(i, pred, X, Y, itp, its, itp_pred, its_pred, fname, figure_dir) | |
| return [(itp_pred, prob_p), (its_pred, prob_s)] | |
| def clean_queue(picks): | |
| clean = [] | |
| for i in range(len(picks)): | |
| tmp = [] | |
| for j in picks[i]: | |
| if j != 0: | |
| tmp.append(j) | |
| clean.append(tmp) | |
| return clean | |
| def clean_queue_thread(picks): | |
| tmp = [] | |
| for j in picks: | |
| if j != 0: | |
| tmp.append(j) | |
| return tmp | |
| def metrics(TP, nP, nT): | |
| ''' | |
| TP: true positive | |
| nP: number of positive picks | |
| nT: number of true picks | |
| ''' | |
| precision = TP / nP | |
| recall = TP / nT | |
| F1 = 2* precision * recall / (precision + recall) | |
| return [precision, recall, F1] | |
| def correct_picks(picks, true_p, true_s, tol): | |
| dt = DataConfig().dt | |
| if len(true_p) != len(true_s): | |
| print("The length of true P and S pickers are not the same") | |
| num = len(true_p) | |
| TP_p = 0; TP_s = 0; nP_p = 0; nP_s = 0; nT_p = 0; nT_s = 0 | |
| diff_p = []; diff_s = [] | |
| for i in range(num): | |
| nT_p += len(true_p[i]) | |
| nT_s += len(true_s[i]) | |
| nP_p += len(picks[i][0][0]) | |
| nP_s += len(picks[i][1][0]) | |
| if len(true_p[i]) > 1 or len(true_s[i]) > 1: | |
| print(i, picks[i], true_p[i], true_s[i]) | |
| tmp_p = np.array(picks[i][0][0]) - np.array(true_p[i])[:,np.newaxis] | |
| tmp_s = np.array(picks[i][1][0]) - np.array(true_s[i])[:,np.newaxis] | |
| TP_p += np.sum(np.abs(tmp_p) < tol/dt) | |
| TP_s += np.sum(np.abs(tmp_s) < tol/dt) | |
| diff_p.append(tmp_p[np.abs(tmp_p) < 0.5/dt]) | |
| diff_s.append(tmp_s[np.abs(tmp_s) < 0.5/dt]) | |
| return [TP_p, TP_s, nP_p, nP_s, nT_p, nT_s, diff_p, diff_s] | |
| def calculate_metrics(picks, itp, its, tol=0.1): | |
| TP_p, TP_s, nP_p, nP_s, nT_p, nT_s, diff_p, diff_s = correct_picks(picks, itp, its, tol) | |
| precision_p, recall_p, f1_p = metrics(TP_p, nP_p, nT_p) | |
| precision_s, recall_s, f1_s = metrics(TP_s, nP_s, nT_s) | |
| logging.info("Total records: {}".format(len(picks))) | |
| logging.info("P-phase:") | |
| logging.info("True={}, Predict={}, TruePositive={}".format(nT_p, nP_p, TP_p)) | |
| logging.info("Precision={:.3f}, Recall={:.3f}, F1={:.3f}".format(precision_p, recall_p, f1_p)) | |
| logging.info("S-phase:") | |
| logging.info("True={}, Predict={}, TruePositive={}".format(nT_s, nP_s, TP_s)) | |
| logging.info("Precision={:.3f}, Recall={:.3f}, F1={:.3f}".format(precision_s, recall_s, f1_s)) | |
| return [precision_p, recall_p, f1_p], [precision_s, recall_s, f1_s] | |