Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from datasets import load_dataset | |
| # + | |
| def get_methods_and_arch(dataset): | |
| columns = dataset.column_names[5:] | |
| methods = [] | |
| archs = [] | |
| for column in columns: | |
| methods.append(column.split('_')[0]) | |
| archs.append('_'.join(column.split('_')[1:-2])) | |
| return list(set(methods)),list(set(archs)) | |
| def get_columns(arch,method): | |
| columns = dataset.column_names[5:] | |
| for col in columns: | |
| if f'{method}_{arch}' in col: | |
| return col | |
| def button_fn(arch,method): | |
| column_heatmap = get_columns(arch,method) | |
| #print("Updated column: ",column_heatmap) | |
| return column_heatmap,index_default,dataset[index_default]["image"],dataset[index_default][column_heatmap] | |
| def func_slider(index,column_textbox): | |
| #global column_heatmap | |
| example = dataset[index] | |
| return example['image'],example[column_textbox] | |
| # - | |
| dataset = load_dataset("GazeLocation/stimuli_heatmaps",split = 'train') | |
| METHODS, ARCHS = get_methods_and_arch(dataset) | |
| index_default = 0 | |
| DEMO = False | |
| if __name__ == '__main__': | |
| demo = gr.Blocks() | |
| with demo: | |
| gr.Markdown("# Heatmap Gaze Location") | |
| with gr.Row(): | |
| dropdown_arch = gr.Dropdown(choices = ARCHS, | |
| value = 'resnet50', | |
| label = 'Model') | |
| dropdown_method = gr.Dropdown(choices = METHODS, | |
| value = 'gradcam', | |
| label = 'Method') | |
| with gr.Row(): | |
| button = gr.Button(label = 'Update Heatmap Model - Method') | |
| with gr.Row(): | |
| hf_slider = gr.Slider(minimum=0, maximum=len(dataset)-1,step = 1) | |
| with gr.Row(): | |
| column_textbox = gr.Textbox(label = 'column name', | |
| value = get_columns(ARCHS[0],METHODS[0]) ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| image_input = gr.Image(label="Input Image",value = dataset[index_default]["image"]) | |
| with gr.Column(): | |
| image_output = gr.Image(label="Output",value = dataset[index_default][get_columns('resnet50','gradcam')]) | |
| button.click(fn = button_fn, | |
| inputs = [dropdown_arch,dropdown_method], | |
| outputs = [column_textbox,hf_slider,image_input,image_output]) | |
| hf_slider.change(func_slider, | |
| inputs = [hf_slider,column_textbox], | |
| outputs = [image_input, image_output]) | |
| if DEMO: | |
| demo.launch(share = True,debug = True) | |
| else: | |
| demo.launch() | |