Spaces:
Running
Running
| import gradio as gr | |
| import os | |
| import matplotlib.pyplot as plt | |
| import pandas as pd | |
| import numpy as np | |
| import bat_detect.utils.detector_utils as du | |
| import bat_detect.utils.audio_utils as au | |
| import bat_detect.utils.plot_utils as viz | |
| # setup the arguments | |
| args = {} | |
| args = du.get_default_bd_args() | |
| args['detection_threshold'] = 0.3 | |
| args['time_expansion_factor'] = 1 | |
| args['model_path'] = 'models/Net2DFast_UK_same.pth.tar' | |
| max_duration = 2.0 | |
| # load the model | |
| model, params = du.load_model(args['model_path']) | |
| df = gr.Dataframe( | |
| headers=["species", "time", "detection_prob", "species_prob"], | |
| datatype=["str", "str", "str", "str"], | |
| row_count=1, | |
| col_count=(4, "fixed"), | |
| label='Predictions' | |
| ) | |
| examples = [['example_data/audio/20170701_213954-MYOMYS-LR_0_0.5.wav', 0.3], | |
| ['example_data/audio/20180530_213516-EPTSER-LR_0_0.5.wav', 0.3], | |
| ['example_data/audio/20180627_215323-RHIFER-LR_0_0.5.wav', 0.3]] | |
| def make_prediction(file_name=None, detection_threshold=0.3): | |
| if file_name is not None: | |
| audio_file = file_name | |
| else: | |
| return "You must provide an input audio file." | |
| if detection_threshold is not None and detection_threshold != '': | |
| args['detection_threshold'] = float(detection_threshold) | |
| # process the file to generate predictions | |
| results = du.process_file(audio_file, model, params, args, max_duration=max_duration) | |
| anns = [ann for ann in results['pred_dict']['annotation']] | |
| clss = [aa['class'] for aa in anns] | |
| st_time = [aa['start_time'] for aa in anns] | |
| cls_prob = [aa['class_prob'] for aa in anns] | |
| det_prob = [aa['det_prob'] for aa in anns] | |
| data = {'species': clss, 'time': st_time, 'detection_prob': det_prob, 'species_prob': cls_prob} | |
| df = pd.DataFrame(data=data) | |
| im = generate_results_image(audio_file, anns) | |
| return [df, im] | |
| def generate_results_image(audio_file, anns): | |
| # load audio | |
| sampling_rate, audio = au.load_audio_file(audio_file, args['time_expansion_factor'], | |
| params['target_samp_rate'], params['scale_raw_audio'], max_duration=max_duration) | |
| duration = audio.shape[0] / sampling_rate | |
| # generate spec | |
| spec, spec_viz = au.generate_spectrogram(audio, sampling_rate, params, True, False) | |
| # create fig | |
| plt.close('all') | |
| fig = plt.figure(1, figsize=(spec.shape[1]/100, spec.shape[0]/100), dpi=100, frameon=False) | |
| spec_duration = au.x_coords_to_time(spec.shape[1], sampling_rate, params['fft_win_length'], params['fft_overlap']) | |
| viz.create_box_image(spec, fig, anns, 0, spec_duration, spec_duration, params, spec.max()*1.1, False, True) | |
| plt.ylabel('Freq - kHz') | |
| plt.xlabel('Time - secs') | |
| plt.tight_layout() | |
| # convert fig to image | |
| fig.canvas.draw() | |
| data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8) | |
| w, h = fig.canvas.get_width_height() | |
| im = data.reshape((int(h), int(w), -1)) | |
| return im | |
| descr_txt = "Demo of BatDetect2 deep learning-based bat echolocation call detection. " \ | |
| "<br>This model is only trained on bat species from the UK. If the input " \ | |
| "file is longer than 2 seconds, only the first 2 seconds will be processed." \ | |
| "<br>Check out the paper [here](https://www.biorxiv.org/content/10.1101/2022.12.14.520490v1)." | |
| gr.Interface( | |
| fn = make_prediction, | |
| inputs = [gr.Audio(sources=["upload"], type="filepath"), | |
| gr.Dropdown([0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9])], | |
| outputs = [df, gr.Image(label="Visualisation")], | |
| theme = "huggingface", | |
| title = "BatDetect2 Demo", | |
| description = descr_txt, | |
| examples = examples, | |
| allow_flagging = 'never', | |
| ).launch() | |