Spaces:
Runtime error
Runtime error
| from libs import * | |
| import configVars | |
| import ecg_plot | |
| def remove_baseline_filter(sample_rate): | |
| fc = 0.8 # [Hz], cutoff frequency | |
| fst = 0.2 # [Hz], rejection band | |
| rp = 0.5 # [dB], ripple in passband | |
| rs = 40 # [dB], attenuation in rejection band | |
| wn = fc / (sample_rate / 2) | |
| wst = fst / (sample_rate / 2) | |
| filterorder, aux = sgn.ellipord(wn, wst, rp, rs) | |
| sos = sgn.iirfilter(filterorder, wn, rp, rs, btype='high', ftype='ellip', output='sos') | |
| return sos | |
| reduced_leads = ['DI', 'DII', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6'] | |
| all_leads = ['DI', 'DII', 'DIII', 'AVR', 'AVL', 'AVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6'] | |
| def preprocess_ecg(ecg, sample_rate, leads, scale=1, | |
| use_all_leads=True, remove_baseline=False): | |
| # Remove baseline | |
| if remove_baseline: | |
| sos = remove_baseline_filter(sample_rate) | |
| ecg_nobaseline = sgn.sosfiltfilt(sos, ecg, padtype='constant', axis=-1) | |
| else: | |
| ecg_nobaseline = ecg | |
| # Rescale | |
| ecg_rescaled = scale * ecg_nobaseline | |
| # Resample | |
| if sample_rate != 500: | |
| ecg_resampled = sgn.resample_poly(ecg_rescaled, up=500, down=sample_rate, axis=-1) | |
| else: | |
| ecg_resampled = ecg_rescaled | |
| length = len(ecg_resampled[0]) | |
| # Add leads if needed | |
| target_leads = all_leads if use_all_leads else reduced_leads | |
| n_leads_target = len(target_leads) | |
| l2p = dict(zip(target_leads, range(n_leads_target))) | |
| ecg_targetleads = np.zeros([n_leads_target, length]) | |
| ecg_targetleads = ecg_rescaled | |
| if n_leads_target >= leads and use_all_leads: | |
| ecg_targetleads[l2p['DIII'], :] = ecg_targetleads[l2p['DII'], :] - ecg_targetleads[l2p['DI'], :] | |
| ecg_targetleads[l2p['AVR'], :] = -(ecg_targetleads[l2p['DI'], :] + ecg_targetleads[l2p['DII'], :]) / 2 | |
| ecg_targetleads[l2p['AVL'], :] = (ecg_targetleads[l2p['DI'], :] - ecg_targetleads[l2p['DIII'], :]) / 2 | |
| ecg_targetleads[l2p['AVF'], :] = (ecg_targetleads[l2p['DII'], :] + ecg_targetleads[l2p['DIII'], :]) / 2 | |
| return ecg_targetleads | |
| def generateH5(input_file,out_file,new_freq=None,new_len=None,scale=1,sample_rate=None): | |
| n = len(input_file) # Get length | |
| try: | |
| h5f = h5py.File(f"{configVars.pathCasos}{out_file}", 'r+') | |
| h5f.clear() | |
| except: | |
| h5f = h5py.File(f"{configVars.pathCasos}{out_file}", 'w') | |
| # Resample | |
| if new_freq is not None: | |
| ecg_resampled = sgn.resample_poly(input_file, up=new_freq, down=sample_rate, axis=-1) | |
| else: | |
| ecg_resampled = input_file | |
| new_freq = sample_rate | |
| n_leads, length = ecg_resampled.shape | |
| # Rescale | |
| ecg_rescaled = scale * ecg_resampled | |
| # Reshape | |
| if new_len is None or new_len == length: | |
| ecg_reshaped = ecg_rescaled | |
| elif new_len > length: | |
| ecg_reshaped = np.zeros([n_leads, new_len]) | |
| pad = (new_len - length) // 2 | |
| ecg_reshaped[..., pad:length+pad] = ecg_rescaled | |
| else: | |
| extra = (length - new_len) // 2 | |
| ecg_reshaped = ecg_rescaled[:, extra:new_len + extra] | |
| n_leads, n_samples = ecg_reshaped.shape | |
| x = h5f.create_dataset('tracings', (1, n_samples, n_leads), dtype='f8') | |
| x[0, :, :] = ecg_reshaped.T | |
| h5f.close() | |
| def LightX3ECG( | |
| train_loaders, | |
| config, | |
| save_ckp_dir, | |
| ): | |
| model = torch.load(f"{save_ckp_dir}/best.ptl", map_location='cpu') | |
| #model = torch.load(f"{save_ckp_dir}/best.ptl", map_location = "cuda") | |
| model.to(torch.device('cpu')) | |
| with torch.no_grad(): | |
| model.eval() | |
| running_preds = [] | |
| for ecgs in train_loaders["pred"]: | |
| ecgs = ecgs.cpu() | |
| logits = model(ecgs) | |
| preds = list(torch.max(logits, 1)[1].detach().cpu().numpy()) if not config["is_multilabel"] else list(torch.sigmoid(logits).detach().cpu().numpy()) | |
| running_preds.extend(preds) | |
| if config["is_multilabel"]: | |
| running_preds = np.array(running_preds) | |
| optimal_thresholds = pd.read_csv(f"{configVars.pathThresholds}CPSC-2018/optimal_thresholds_best.csv") | |
| preds = optimal_thresholds[optimal_thresholds["Threshold"]<=running_preds[0]] | |
| preds = preds['Pred'].values.tolist() | |
| else: | |
| enfermedades = ['AFIB','GSVT','SB','SR'] | |
| running_preds = np.array(running_preds) | |
| #running_preds=np.reshape(running_preds, (len(running_preds),-1)) | |
| preds = enfermedades[running_preds[0]] | |
| return preds | |
| def ecgPlot(source,sample): | |
| data = np.load(source) | |
| #print(data) | |
| xml_leads = ['DI', 'DII', 'DIII', 'AVR', 'AVL', 'AVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6'] | |
| ecg_plot.plot_12(data, sample_rate= sample,lead_index=xml_leads, title="Muestra") | |
| ecg_plot.save_as_png("ecg") |