Spaces:
Running
Running
| import torch | |
| import torch.nn.functional as F | |
| import os | |
| import numpy as np | |
| import pandas as pd | |
| import json | |
| import sys | |
| from bat_detect.detector import models | |
| import bat_detect.detector.compute_features as feats | |
| import bat_detect.detector.post_process as pp | |
| import bat_detect.utils.audio_utils as au | |
| def get_default_bd_args(): | |
| args = {} | |
| args['detection_threshold'] = 0.001 | |
| args['time_expansion_factor'] = 1 | |
| args['audio_dir'] = '' | |
| args['ann_dir'] = '' | |
| args['spec_slices'] = False | |
| args['chunk_size'] = 3 | |
| args['spec_features'] = False | |
| args['cnn_features'] = False | |
| args['quiet'] = True | |
| args['save_preds_if_empty'] = True | |
| args['ann_dir'] = os.path.join(args['ann_dir'], '') | |
| return args | |
| def get_audio_files(ip_dir): | |
| matches = [] | |
| for root, dirnames, filenames in os.walk(ip_dir): | |
| for filename in filenames: | |
| if filename.lower().endswith('.wav'): | |
| matches.append(os.path.join(root, filename)) | |
| return matches | |
| def load_model(model_path, load_weights=True): | |
| # load model | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| if os.path.isfile(model_path): | |
| net_params = torch.load(model_path, map_location=device) | |
| else: | |
| print('Error: model not found.') | |
| sys.exit(1) | |
| params = net_params['params'] | |
| params['device'] = device | |
| if params['model_name'] == 'Net2DFast': | |
| model = models.Net2DFast(params['num_filters'], num_classes=len(params['class_names']), | |
| emb_dim=params['emb_dim'], ip_height=params['ip_height'], | |
| resize_factor=params['resize_factor']) | |
| elif params['model_name'] == 'Net2DFastNoAttn': | |
| model = models.Net2DFastNoAttn(params['num_filters'], num_classes=len(params['class_names']), | |
| emb_dim=params['emb_dim'], ip_height=params['ip_height'], | |
| resize_factor=params['resize_factor']) | |
| elif params['model_name'] == 'Net2DFastNoCoordConv': | |
| model = models.Net2DFastNoCoordConv(params['num_filters'], num_classes=len(params['class_names']), | |
| emb_dim=params['emb_dim'], ip_height=params['ip_height'], | |
| resize_factor=params['resize_factor']) | |
| else: | |
| print('Error: unknown model.') | |
| if load_weights: | |
| model.load_state_dict(net_params['state_dict']) | |
| model = model.to(params['device']) | |
| model.eval() | |
| return model, params | |
| def merge_results(predictions, spec_feats, cnn_feats, spec_slices): | |
| predictions_m = {} | |
| num_preds = np.sum([len(pp['det_probs']) for pp in predictions]) | |
| if num_preds > 0: | |
| for kk in predictions[0].keys(): | |
| predictions_m[kk] = np.hstack([pp[kk] for pp in predictions if pp['det_probs'].shape[0] > 0]) | |
| else: | |
| # hack in case where no detected calls as we need some of the key names in dict | |
| predictions_m = predictions[0] | |
| if len(spec_feats) > 0: | |
| spec_feats = np.vstack(spec_feats) | |
| if len(cnn_feats) > 0: | |
| cnn_feats = np.vstack(cnn_feats) | |
| return predictions_m, spec_feats, cnn_feats, spec_slices | |
| def convert_results(file_id, time_exp, duration, params, predictions, spec_feats, cnn_feats, spec_slices): | |
| # create a single dictionary - this is the format used by the annotation tool | |
| pred_dict = {} | |
| pred_dict['id'] = file_id | |
| pred_dict['annotated'] = False | |
| pred_dict['issues'] = False | |
| pred_dict['notes'] = 'Automatically generated.' | |
| pred_dict['time_exp'] = time_exp | |
| pred_dict['duration'] = round(duration, 4) | |
| pred_dict['annotation'] = [] | |
| class_prob_best = predictions['class_probs'].max(0) | |
| class_ind_best = predictions['class_probs'].argmax(0) | |
| class_overall = pp.overall_class_pred(predictions['det_probs'], predictions['class_probs']) | |
| pred_dict['class_name'] = params['class_names'][np.argmax(class_overall)] | |
| for ii in range(predictions['det_probs'].shape[0]): | |
| res = {} | |
| res['start_time'] = round(float(predictions['start_times'][ii]), 4) | |
| res['end_time'] = round(float(predictions['end_times'][ii]), 4) | |
| res['low_freq'] = int(predictions['low_freqs'][ii]) | |
| res['high_freq'] = int(predictions['high_freqs'][ii]) | |
| res['class'] = str(params['class_names'][int(class_ind_best[ii])]) | |
| res['class_prob'] = round(float(class_prob_best[ii]), 3) | |
| res['det_prob'] = round(float(predictions['det_probs'][ii]), 3) | |
| res['individual'] = '-1' | |
| res['event'] = 'Echolocation' | |
| pred_dict['annotation'].append(res) | |
| # combine into final results dictionary | |
| results = {} | |
| results['pred_dict'] = pred_dict | |
| if len(spec_feats) > 0: | |
| results['spec_feats'] = spec_feats | |
| results['spec_feat_names'] = feats.get_feature_names() | |
| if len(cnn_feats) > 0: | |
| results['cnn_feats'] = cnn_feats | |
| results['cnn_feat_names'] = [str(ii) for ii in range(cnn_feats.shape[1])] | |
| if len(spec_slices) > 0: | |
| results['spec_slices'] = spec_slices | |
| return results | |
| def save_results_to_file(results, op_path): | |
| # make directory if it does not exist | |
| if not os.path.isdir(os.path.dirname(op_path)): | |
| os.makedirs(os.path.dirname(op_path)) | |
| # save csv file - if there are predictions | |
| result_list = [res for res in results['pred_dict']['annotation']] | |
| df = pd.DataFrame(result_list) | |
| df['file_name'] = [results['pred_dict']['id']]*len(result_list) | |
| df.index.name = 'id' | |
| if 'class_prob' in df.columns: | |
| df = df[['det_prob', 'start_time', 'end_time', 'high_freq', | |
| 'low_freq', 'class', 'class_prob']] | |
| df.to_csv(op_path + '.csv', sep=',') | |
| # save features | |
| if 'spec_feats' in results.keys(): | |
| df = pd.DataFrame(results['spec_feats'], columns=results['spec_feat_names']) | |
| df.to_csv(op_path + '_spec_features.csv', sep=',', index=False, float_format='%.5f') | |
| if 'cnn_feats' in results.keys(): | |
| df = pd.DataFrame(results['cnn_feats'], columns=results['cnn_feat_names']) | |
| df.to_csv(op_path + '_cnn_features.csv', sep=',', index=False, float_format='%.5f') | |
| # save json file | |
| with open(op_path + '.json', 'w') as da: | |
| json.dump(results['pred_dict'], da, indent=2, sort_keys=True) | |
| def compute_spectrogram(audio, sampling_rate, params, return_np=False): | |
| # pad audio so it is evenly divisible by downsampling factors | |
| duration = audio.shape[0] / float(sampling_rate) | |
| audio = au.pad_audio(audio, sampling_rate, params['fft_win_length'], | |
| params['fft_overlap'], params['resize_factor'], | |
| params['spec_divide_factor']) | |
| # generate spectrogram | |
| spec, _ = au.generate_spectrogram(audio, sampling_rate, params) | |
| # convert to pytorch | |
| spec = torch.from_numpy(spec).to(params['device']) | |
| spec = spec.unsqueeze(0).unsqueeze(0) | |
| # resize the spec | |
| rs = params['resize_factor'] | |
| spec_op_shape = (int(params['spec_height']*rs), int(spec.shape[-1]*rs)) | |
| spec = F.interpolate(spec, size=spec_op_shape, mode='bilinear', align_corners=False) | |
| if return_np: | |
| spec_np = spec[0,0,:].cpu().data.numpy() | |
| else: | |
| spec_np = None | |
| return duration, spec, spec_np | |
| def process_file(audio_file, model, params, args, time_exp=None, top_n=5, return_raw_preds=False, max_duration=False): | |
| # store temporary results here | |
| predictions = [] | |
| spec_feats = [] | |
| cnn_feats = [] | |
| spec_slices = [] | |
| # get time expansion factor | |
| if time_exp is None: | |
| time_exp = args['time_expansion_factor'] | |
| params['detection_threshold'] = args['detection_threshold'] | |
| # load audio file | |
| sampling_rate, audio_full = au.load_audio_file(audio_file, time_exp, | |
| params['target_samp_rate'], params['scale_raw_audio']) | |
| # clipping maximum duration | |
| if max_duration is not False: | |
| max_duration = np.minimum(int(sampling_rate*max_duration), audio_full.shape[0]) | |
| audio_full = audio_full[:max_duration] | |
| duration_full = audio_full.shape[0] / float(sampling_rate) | |
| return_np_spec = args['spec_features'] or args['spec_slices'] | |
| # loop through larger file and split into chunks | |
| # TODO fix so that it overlaps correctly and takes care of duplicate detections at borders | |
| num_chunks = int(np.ceil(duration_full/args['chunk_size'])) | |
| for chunk_id in range(num_chunks): | |
| # chunk | |
| chunk_time = args['chunk_size']*chunk_id | |
| chunk_length = int(sampling_rate*args['chunk_size']) | |
| start_sample = chunk_id*chunk_length | |
| end_sample = np.minimum((chunk_id+1)*chunk_length, audio_full.shape[0]) | |
| audio = audio_full[start_sample:end_sample] | |
| # load audio file and compute spectrogram | |
| duration, spec, spec_np = compute_spectrogram(audio, sampling_rate, params, return_np_spec) | |
| # evaluate model | |
| with torch.no_grad(): | |
| outputs = model(spec, return_feats=args['cnn_features']) | |
| # run non-max suppression | |
| pred_nms, features = pp.run_nms(outputs, params, np.array([float(sampling_rate)])) | |
| pred_nms = pred_nms[0] | |
| pred_nms['start_times'] += chunk_time | |
| pred_nms['end_times'] += chunk_time | |
| # if we have a background class | |
| if pred_nms['class_probs'].shape[0] > len(params['class_names']): | |
| pred_nms['class_probs'] = pred_nms['class_probs'][:-1, :] | |
| predictions.append(pred_nms) | |
| # extract features - if there are any calls detected | |
| if (pred_nms['det_probs'].shape[0] > 0): | |
| if args['spec_features']: | |
| spec_feats.append(feats.get_feats(spec_np, pred_nms, params)) | |
| if args['cnn_features']: | |
| cnn_feats.append(features[0]) | |
| if args['spec_slices']: | |
| spec_slices.extend(feats.extract_spec_slices(spec_np, pred_nms, params)) | |
| # convert the predictions into output dictionary | |
| file_id = os.path.basename(audio_file) | |
| predictions, spec_feats, cnn_feats, spec_slices =\ | |
| merge_results(predictions, spec_feats, cnn_feats, spec_slices) | |
| results = convert_results(file_id, time_exp, duration_full, params, | |
| predictions, spec_feats, cnn_feats, spec_slices) | |
| # summarize results | |
| if not args['quiet']: | |
| num_detections = len(results['pred_dict']['annotation']) | |
| print('{}'.format(num_detections) + ' call(s) detected above the threshold.') | |
| # print results for top n classes | |
| if not args['quiet'] and (num_detections > 0): | |
| class_overall = pp.overall_class_pred(predictions['det_probs'], predictions['class_probs']) | |
| print('species name'.ljust(30) + 'probablity present') | |
| for cc in np.argsort(class_overall)[::-1][:top_n]: | |
| print(params['class_names'][cc].ljust(30) + str(round(class_overall[cc], 3))) | |
| if return_raw_preds: | |
| return predictions | |
| else: | |
| return results | |