import json from typing import List, Optional import gradio as gr from gradio.components import Markdown from src.dto.dto import ExplanationGranularity, ExplanationDto from src.utils.registry import EXPLAINERS, MODELS, PERTURBERS, COMPARATORS from src.utils.segregate import PercentileBasedSegregator from src.utils.visualizer import Visualizer class MockExplainerUI: def __init__( self, logo_path: str, css_path: str, visualizer: Visualizer, window_title: str, title: str, examples: Optional[List[str]] = None, ): self.__logo_path = logo_path self.__css_path = css_path self.__examples = examples self.__window_title = window_title self.__title = title self.__visualizer = visualizer self.app: gr.Blocks = self.build_app() def build_app(self): with gr.Blocks( theme=gr.themes.Monochrome().set( button_primary_background_fill="#009374", button_primary_background_fill_hover="#009374C4", checkbox_label_background_fill_selected="#028A6EFF", ), css=self.__css_path, title=self.__window_title, ) as demo: self.__build_app_title() ( qn_choice, user_input, system_response, granularity, upper_percentile, middle_percentile, lower_percentile, explainer_name, model_name, perturber_name, comparator_name, generator_vis, submit_btn, ) = self.__build_chat_and_explain() submit_btn.click( fn=self.run, inputs=[ qn_choice, user_input, granularity, upper_percentile, middle_percentile, lower_percentile, explainer_name, model_name, perturber_name, comparator_name, ], outputs=[user_input, system_response, generator_vis], ) return demo def run( self, qn_choice: str, user_input: str, granularity: ExplanationGranularity, upper_percentile: str, middle_percentile: str, lower_percentile: str, explainer_name: str, model_name: str, perturber_name: str, comparator_name: str, ): language = "en" if "EN" in qn_choice else "de" if "1" in qn_choice: q_idx = 1 elif "2" in qn_choice: q_idx = 2 elif "3" in qn_choice: q_idx = 3 elif "4" in qn_choice: q_idx = 4 elif "5" in qn_choice: q_idx = 5 else: q_idx = 1 file_path = f"data/{model_name}/{language}_pert_{perturber_name}_comp_{comparator_name}_exp_dto.json" with open(file_path, "r") as f: data = json.load(f) data = data[q_idx] explanation_dto = ExplanationDto.parse_obj(data) user_input = explanation_dto.input_text system_response = explanation_dto.output_text generator_vis = self.__visualize_explanations( user_input=user_input, system_response=system_response, generator_explanations=explanation_dto, upper_percentile=int(upper_percentile), middle_percentile=int(middle_percentile), lower_percentile=int(lower_percentile), ) return user_input, system_response, generator_vis def __build_app_title(self): with gr.Row(): with gr.Column(min_width=50, scale=1): gr.Image( value=self.__logo_path, width=50, height=50, show_download_button=False, show_label=False, show_share_button=False, container=False, ) with gr.Column(scale=2): Markdown( f'

{self.__title}" f"

" ) def __build_chat_and_explain(self): with gr.Row(): with gr.Column(scale=2): qn_choice = gr.Radio( # placeholder="Type your question here and press Enter.", label="Choose from these examples", container=True, choices=[ "EN Example 1", "EN Example 2", "EN Example 3", "EN Example 4", "EN Example 5", "DE Example 1", "DE Example 2", "DE Example 3", "DE Example 4", "DE Example 5" ], ) with gr.Row(): with gr.Column(scale=2): gr.Markdown( value="**Note:** This is a demo version of the tool with " "limited functionalities. For building the full " "version, please visit [here](https://github.com/fraunhofer-iais/explainable-lms/tree/master).", ) with gr.Row(): with gr.Column(scale=2): user_input = gr.Textbox( placeholder="Choose an example from the list above.", label="Question", container=True, lines=10, interactive=False ) with gr.Column(scale=1): granularity = gr.Radio( choices=[e for e in ExplanationGranularity], value=ExplanationGranularity.SENTENCE_LEVEL, label="Explanation Granularity", interactive=False ) with gr.Accordion(label="Settings", open=False, elem_id="accordion"): with gr.Row(variant="compact"): explainer_name = gr.Radio( label="Explainer", choices=list(EXPLAINERS.keys()), value=list(EXPLAINERS.keys())[0], container=True, visible=False ) with gr.Row(variant="compact"): upper_percentile = gr.Textbox(label="Upper", value="85", container=True) middle_percentile = gr.Textbox( label="Middle", value="75", container=True ) lower_percentile = gr.Textbox(label="Lower", value="10", container=True) with gr.Row(variant="compact"): model_name = gr.Radio( label="Model", choices=list(MODELS.keys()), value=list(MODELS.keys())[0], container=True, ) with gr.Row(variant="compact"): perturber_name = gr.Radio( label="Perturber", choices=list(PERTURBERS.keys()), value=list(PERTURBERS.keys())[0], container=True, ) with gr.Row(variant="compact"): comparator_name = gr.Radio( label="Comparator", choices=list(COMPARATORS.keys()), value=list(COMPARATORS.keys())[0], container=True, ) with gr.Row(variant="compact"): # passing "elem_id" to use a custom style for the component # in the CSS passed. submit_btn = gr.Button( value="🛠 Submit", variant="secondary", elem_id="button", interactive=True, ) with gr.Row(): generator_vis = gr.HTML(label="Explanations") with gr.Row(): system_response = gr.Textbox( label="System Response", container=True, interactive=False, ) return ( qn_choice, user_input, system_response, granularity, upper_percentile, middle_percentile, lower_percentile, explainer_name, model_name, perturber_name, comparator_name, generator_vis, submit_btn, ) def __visualize_explanations( self, user_input: str, system_response: Optional[str], generator_explanations: ExplanationDto, upper_percentile: Optional[int], middle_percentile: Optional[int], lower_percentile: Optional[int], ) -> str: segregator = PercentileBasedSegregator( upper_bound_percentile=upper_percentile, middle_bound_percentile=middle_percentile, lower_bound_percentile=lower_percentile, ) return self.__visualizer.visualize( segregator=segregator, explanations=generator_explanations, output_from_explanations=user_input, )