| """ |
| Copyright $today.year LY Corporation |
| |
| LY Corporation licenses this file to you under the Apache License, |
| version 2.0 (the "License"); you may not use this file except in compliance |
| with the License. You may obtain a copy of the License at: |
| |
| https://www.apache.org/licenses/LICENSE-2.0 |
| |
| Unless required by applicable law or agreed to in writing, software |
| distributed under the License is distributed on an "AS IS" BASIS, WITHOUT |
| WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the |
| License for the specific language governing permissions and limitations |
| under the License. |
| """ |
| import os |
| import torch |
| import subprocess |
| import gradio as gr |
| import librosa |
| from tqdm import tqdm |
| from lighthouse.models import * |
|
|
| |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| MODEL_NAMES = ['qd_detr'] |
| FEATURES = ['clap'] |
| TOPK_MOMENT = 5 |
| sample_path = "sample_data/1a-ODBWMUAE.wav" |
| sample_query = "Water cascades down from a waterfall." |
|
|
| """ |
| Helper functions |
| """ |
| def load_pretrained_weights(): |
| file_urls = [] |
| for model_name in MODEL_NAMES: |
| for feature in FEATURES: |
| file_urls.append( |
| "https://zenodo.org/records/13961029/files/{}_{}_clotho-moment.ckpt".format(feature, model_name) |
| ) |
| for file_url in tqdm(file_urls): |
| if not os.path.exists('gradio_demo/weights/' + os.path.basename(file_url)): |
| command = 'wget -P gradio_demo/weights/ {}'.format(file_url) |
| subprocess.run(command, shell=True) |
|
|
| return file_urls |
|
|
| def flatten(array2d): |
| list1d = [] |
| for elem in array2d: |
| list1d += elem |
| return list1d |
|
|
| """ |
| Model initialization |
| """ |
| load_pretrained_weights() |
| model = QDDETRPredictor('gradio_demo/weights/clap_qd_detr_clotho-moment.ckpt', device=device, feature_name='clap') |
| loaded_audio = None |
|
|
| """ |
| Gradio functions |
| """ |
| def audio_upload(audio): |
| global loaded_audio |
| if audio is None: |
| loaded_audio = None |
| yield gr.update(value="Removed the audio", visible=True) |
| else: |
| yield gr.update(value="Processing the audio. Wait for a minute...", visible=True) |
| audio_feats = model.encode_audio(audio) |
| loaded_audio = audio_feats |
| yield gr.update(value="Finished audio processing!", visible=True) |
|
|
| def model_load(radio): |
| if radio is not None: |
| yield gr.update(value="Loading new model. Wait for a minute...", visible=True) |
| global model |
| feature, model_name = radio.split('+') |
| feature, model_name = feature.strip(), model_name.strip() |
|
|
| if model_name == 'qd_detr': |
| model_class = QDDETRPredictor |
| else: |
| raise gr.Error("Select from the models") |
|
|
| model = model_class('gradio_demo/weights/{}_{}_clotho-moment.ckpt'.format(feature, model_name), |
| device=device, feature_name='{}'.format(feature)) |
| yield gr.update(value="Model loaded: {}".format(radio), visible=True) |
|
|
| def predict(textbox, line, gallery): |
| global loaded_audio |
| if loaded_audio is None: |
| raise gr.Error('Upload the audio before pushing the `Retrieve moment` button.') |
| else: |
| prediction = model.predict(textbox, loaded_audio) |
| mr_results = prediction['pred_relevant_windows'] |
|
|
| buttons = [] |
| for i, pred in enumerate(mr_results[:TOPK_MOMENT]): |
| buttons.append(gr.Button(value='moment {}: [{}, {}] Score: {}'.format(i+1, pred[0], pred[1], pred[2]), visible=True)) |
|
|
| return buttons |
|
|
|
|
| def show_trimmed_audio(audio, button): |
| s, sr = librosa.load(audio, sr=None) |
| _seconds = button.split(': [')[1].split(']')[0].split(', ') |
| start_sec = float(_seconds[0]) |
| end_sec = float(_seconds[1]) |
| start_frame = int(start_sec * sr) |
| end_frame = int(end_sec * sr) |
|
|
| return gr.Audio((sr, s[start_frame:end_frame]), interactive=False, visible=True) |
|
|
|
|
| def main(): |
| title = """# Audio Moment Retrieval Demo""" |
|
|
| with gr.Blocks(theme=gr.themes.Soft()) as demo: |
| gr.Markdown(title) |
|
|
| with gr.Row(): |
| with gr.Column(): |
| with gr.Group(): |
| gr.Markdown("## Model selection") |
| radio_list = flatten([["{} + {}".format(feature, model_name) for model_name in MODEL_NAMES] for feature in FEATURES]) |
| radio = gr.Radio(radio_list, label="models", value="clap + qd_detr", info="Which model do you want to use?") |
| load_status_text = gr.Textbox(label='Model load status', value='Model loaded: clap + qd_detr') |
|
|
| with gr.Group(): |
| gr.Markdown("## Audio and query") |
| audio_input = gr.Audio(sample_path, type='filepath') |
| output = gr.Textbox(label='Audio processing progress') |
| query_input = gr.Textbox(sample_query, label='query') |
| button = gr.Button("Retrieve moment", variant="primary") |
|
|
| with gr.Column(): |
| with gr.Group(): |
| gr.Markdown("## Retrieved moments") |
| gr.Markdown("Click on the moment button to listen to the trimmed audio.") |
| button_1 = gr.Button(value='moment 1', visible=False, elem_id='result_0') |
| button_2 = gr.Button(value='moment 2', visible=False, elem_id='result_1') |
| button_3 = gr.Button(value='moment 3', visible=False, elem_id='result_2') |
| button_4 = gr.Button(value='moment 4', visible=False, elem_id='result_3') |
| button_5 = gr.Button(value='moment 5', visible=False, elem_id='result_4') |
| result = gr.Audio(None, label='Trimmed audio', interactive=False, visible=False) |
|
|
| button_1.click(show_trimmed_audio, inputs=[audio_input, button_1], outputs=[result]) |
| button_2.click(show_trimmed_audio, inputs=[audio_input, button_2], outputs=[result]) |
| button_3.click(show_trimmed_audio, inputs=[audio_input, button_3], outputs=[result]) |
| button_4.click(show_trimmed_audio, inputs=[audio_input, button_4], outputs=[result]) |
| button_5.click(show_trimmed_audio, inputs=[audio_input, button_5], outputs=[result]) |
|
|
| audio_input.change(audio_upload, inputs=[audio_input], outputs=output) |
| radio.select(model_load, inputs=[radio], outputs=load_status_text) |
| |
| button.click(predict, |
| inputs=[query_input], |
| outputs=[button_1, button_2, button_3, button_4, button_5]) |
| demo.load(audio_upload, inputs=[audio_input], outputs=output) |
|
|
| demo.launch() |
|
|
| if __name__ == "__main__": |
| main() |
|
|