Spaces:
Paused
Paused
| from gyraudio.default_locations import EXPERIMENT_STORAGE_ROOT | |
| from gyraudio.audio_separation.parser import shared_parser | |
| from gyraudio.audio_separation.infer import launch_infer, RECORD_KEYS, SNR_OUT, SNR_IN, NBATCH, SAVE_IDX | |
| from gyraudio.audio_separation.properties import TEST, NAME, SHORT_NAME, CURRENT_EPOCH, SNR_FILTER | |
| import sys | |
| import os | |
| from dash import Dash, html, dcc, callback, Output, Input, dash_table | |
| import plotly.express as px | |
| import plotly.graph_objects as go | |
| from plotly.subplots import make_subplots | |
| import pandas as pd | |
| from typing import List | |
| import torch | |
| from pathlib import Path | |
| DIFF_SNR = 'SNR out - SNR in' | |
| def get_app(record_row_dfs : pd.DataFrame, eval_dfs : List[pd.DataFrame]) : | |
| app = Dash(__name__) | |
| # names_options = [{'label' : f"{record[SHORT_NAME]} - {record[NAME]} epoch {record[CURRENT_EPOCH]:04d}", 'value' : record[NAME] } for idx, record in record_row_dfs.iterrows()] | |
| app.layout = html.Div([ | |
| html.H1(children='Inference results', style={'textAlign':'center'}), | |
| # dcc.Dropdown(names_options, names_options[0]['value'], id='exp-selection'), | |
| # dcc.RadioItems(['scatter', 'box'], 'box', inline=True, id='radio-plot-type'), | |
| dcc.RadioItems([SNR_OUT, DIFF_SNR], DIFF_SNR, inline=True, id='radio-plot-out'), | |
| dcc.Graph(id='graph-content') | |
| ]) | |
| def update_graph(radio_plot_out) : | |
| fig = make_subplots(rows = 2, cols = 1) | |
| colors = px.colors.qualitative.Plotly | |
| for id, record in record_row_dfs.iterrows() : | |
| color = colors[id % len(colors)] | |
| eval_df = eval_dfs[id].sort_values(by=SNR_IN) | |
| eval_df[DIFF_SNR] = eval_df[SNR_OUT] - eval_df[SNR_IN] | |
| legend = f'{record[SHORT_NAME]}_{record[NAME]}' | |
| fig.add_trace( | |
| go.Scatter( | |
| x=eval_df[SNR_IN], | |
| y=eval_df[radio_plot_out], | |
| mode="markers", marker={'color' : color}, | |
| name=legend, | |
| hovertemplate = 'File : %{text}'+ | |
| '<br>%{y}<br>', | |
| text = [f"{eval[SAVE_IDX]:.0f}" for idx, eval in eval_df.iterrows()] | |
| ), | |
| row = 1, col = 1 | |
| ) | |
| eval_df_bins = eval_df | |
| eval_df_bins[SNR_IN] = eval_df_bins[SNR_IN].apply(lambda snr : round(snr)) | |
| fig.add_trace( | |
| go.Box( | |
| x=eval_df[SNR_IN], | |
| y=eval_df[radio_plot_out], | |
| fillcolor = color, | |
| marker={'color' : color}, | |
| name = legend | |
| ), | |
| row = 2, col = 1 | |
| ) | |
| title = f"SNR performances" | |
| fig.update_layout( | |
| title=title, | |
| xaxis2_title = SNR_IN, | |
| yaxis_title = radio_plot_out, | |
| hovermode='x unified' | |
| ) | |
| return fig | |
| return app | |
| def main(argv): | |
| default_device = "cuda" if torch.cuda.is_available() else "cpu" | |
| parser_def = shared_parser(help="Launch training \nCheck results at: https://wandb.ai/balthazarneveu/audio-sep" | |
| + ("\n<<<Cuda available>>>" if default_device == "cuda" else "")) | |
| parser_def.add_argument("-i", "--input-dir", type=str, default=EXPERIMENT_STORAGE_ROOT) | |
| parser_def.add_argument("-o", "--output-dir", type=str, default=EXPERIMENT_STORAGE_ROOT) | |
| parser_def.add_argument("-d", "--device", type=str, default=default_device, | |
| help="Training device", choices=["cpu", "cuda"]) | |
| parser_def.add_argument("-b", "--nb-batch", type=int, default=None, | |
| help="Number of batches to process") | |
| parser_def.add_argument("-s", "--snr-filter", type=float, nargs="+", default=None, | |
| help="SNR filters on the inference dataloader") | |
| args = parser_def.parse_args(argv) | |
| record_row_dfs = pd.DataFrame(columns = RECORD_KEYS) | |
| eval_dfs = [] | |
| for exp in args.experiments: | |
| record_row_df, evaluation_path = launch_infer( | |
| exp, | |
| model_dir=Path(args.input_dir), | |
| output_dir=Path(args.output_dir), | |
| device=args.device, | |
| max_batches=args.nb_batch, | |
| snr_filter=args.snr_filter | |
| ) | |
| eval_df = pd.read_csv(evaluation_path) | |
| # Careful, list order for concat is important for index matching eval_dfs list | |
| record_row_dfs = pd.concat([record_row_dfs.loc[:], record_row_df], ignore_index=True) | |
| eval_dfs.append(eval_df) | |
| app = get_app(record_row_dfs, eval_dfs) | |
| app.run(debug=True) | |
| if __name__ == '__main__': | |
| os.environ["KMP_DUPLICATE_LIB_OK"] = "True" | |
| main(sys.argv[1:]) |