|
|
import json |
|
|
from typing import List, Optional |
|
|
|
|
|
import gradio as gr |
|
|
from gradio.components import Markdown |
|
|
|
|
|
from src.dto.dto import ExplanationGranularity, ExplanationDto |
|
|
from src.utils.registry import EXPLAINERS, MODELS, PERTURBERS, COMPARATORS |
|
|
from src.utils.segregate import PercentileBasedSegregator |
|
|
from src.utils.visualizer import Visualizer |
|
|
|
|
|
|
|
|
class MockExplainerUI: |
|
|
def __init__( |
|
|
self, |
|
|
logo_path: str, |
|
|
css_path: str, |
|
|
visualizer: Visualizer, |
|
|
window_title: str, |
|
|
title: str, |
|
|
examples: Optional[List[str]] = None, |
|
|
): |
|
|
self.__logo_path = logo_path |
|
|
self.__css_path = css_path |
|
|
self.__examples = examples |
|
|
self.__window_title = window_title |
|
|
self.__title = title |
|
|
self.__visualizer = visualizer |
|
|
|
|
|
self.app: gr.Blocks = self.build_app() |
|
|
|
|
|
def build_app(self): |
|
|
with gr.Blocks( |
|
|
theme=gr.themes.Monochrome().set( |
|
|
button_primary_background_fill="#009374", |
|
|
button_primary_background_fill_hover="#009374C4", |
|
|
checkbox_label_background_fill_selected="#028A6EFF", |
|
|
), |
|
|
css=self.__css_path, |
|
|
title=self.__window_title, |
|
|
) as demo: |
|
|
self.__build_app_title() |
|
|
( |
|
|
qn_choice, |
|
|
user_input, |
|
|
system_response, |
|
|
granularity, |
|
|
upper_percentile, |
|
|
middle_percentile, |
|
|
lower_percentile, |
|
|
explainer_name, |
|
|
model_name, |
|
|
perturber_name, |
|
|
comparator_name, |
|
|
generator_vis, |
|
|
submit_btn, |
|
|
) = self.__build_chat_and_explain() |
|
|
|
|
|
submit_btn.click( |
|
|
fn=self.run, |
|
|
inputs=[ |
|
|
qn_choice, |
|
|
user_input, |
|
|
granularity, |
|
|
upper_percentile, |
|
|
middle_percentile, |
|
|
lower_percentile, |
|
|
explainer_name, |
|
|
model_name, |
|
|
perturber_name, |
|
|
comparator_name, |
|
|
], |
|
|
outputs=[user_input, system_response, generator_vis], |
|
|
) |
|
|
|
|
|
return demo |
|
|
|
|
|
def run( |
|
|
self, |
|
|
qn_choice: str, |
|
|
user_input: str, |
|
|
granularity: ExplanationGranularity, |
|
|
upper_percentile: str, |
|
|
middle_percentile: str, |
|
|
lower_percentile: str, |
|
|
explainer_name: str, |
|
|
model_name: str, |
|
|
perturber_name: str, |
|
|
comparator_name: str, |
|
|
): |
|
|
language = "en" if "EN" in qn_choice else "de" |
|
|
if "1" in qn_choice: |
|
|
q_idx = 1 |
|
|
elif "2" in qn_choice: |
|
|
q_idx = 2 |
|
|
elif "3" in qn_choice: |
|
|
q_idx = 3 |
|
|
elif "4" in qn_choice: |
|
|
q_idx = 4 |
|
|
elif "5" in qn_choice: |
|
|
q_idx = 5 |
|
|
else: |
|
|
q_idx = 1 |
|
|
|
|
|
file_path = f"data/{model_name}/{language}_pert_{perturber_name}_comp_{comparator_name}_exp_dto.json" |
|
|
with open(file_path, "r") as f: |
|
|
data = json.load(f) |
|
|
data = data[q_idx] |
|
|
explanation_dto = ExplanationDto.parse_obj(data) |
|
|
|
|
|
user_input = explanation_dto.input_text |
|
|
system_response = explanation_dto.output_text |
|
|
generator_vis = self.__visualize_explanations( |
|
|
user_input=user_input, |
|
|
system_response=system_response, |
|
|
generator_explanations=explanation_dto, |
|
|
upper_percentile=int(upper_percentile), |
|
|
middle_percentile=int(middle_percentile), |
|
|
lower_percentile=int(lower_percentile), |
|
|
) |
|
|
return user_input, system_response, generator_vis |
|
|
|
|
|
def __build_app_title(self): |
|
|
with gr.Row(): |
|
|
with gr.Column(min_width=50, scale=1): |
|
|
gr.Image( |
|
|
value=self.__logo_path, |
|
|
width=50, |
|
|
height=50, |
|
|
show_download_button=False, |
|
|
show_label=False, |
|
|
show_share_button=False, |
|
|
container=False, |
|
|
) |
|
|
with gr.Column(scale=2): |
|
|
Markdown( |
|
|
f'<p style="text-align: left; font-size:200%; font-weight: bold"' |
|
|
f">{self.__title}" |
|
|
f"</p>" |
|
|
) |
|
|
|
|
|
def __build_chat_and_explain(self): |
|
|
with gr.Row(): |
|
|
with gr.Column(scale=2): |
|
|
qn_choice = gr.Radio( |
|
|
|
|
|
label="Choose from these examples", |
|
|
container=True, |
|
|
choices=[ |
|
|
"EN Example 1", |
|
|
"EN Example 2", |
|
|
"EN Example 3", |
|
|
"EN Example 4", |
|
|
"EN Example 5", |
|
|
"DE Example 1", |
|
|
"DE Example 2", |
|
|
"DE Example 3", |
|
|
"DE Example 4", |
|
|
"DE Example 5" |
|
|
], |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=2): |
|
|
gr.Markdown( |
|
|
value="**Note:** This is a demo version of the tool with " |
|
|
"limited functionalities. For building the full " |
|
|
"version, please visit [here](https://github.com/fraunhofer-iais/explainable-lms/tree/master).", |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=2): |
|
|
user_input = gr.Textbox( |
|
|
placeholder="Choose an example from the list above.", |
|
|
label="Question", |
|
|
container=True, |
|
|
lines=10, |
|
|
interactive=False |
|
|
) |
|
|
with gr.Column(scale=1): |
|
|
granularity = gr.Radio( |
|
|
choices=[e for e in ExplanationGranularity], |
|
|
value=ExplanationGranularity.SENTENCE_LEVEL, |
|
|
label="Explanation Granularity", |
|
|
interactive=False |
|
|
) |
|
|
|
|
|
with gr.Accordion(label="Settings", open=False, elem_id="accordion"): |
|
|
with gr.Row(variant="compact"): |
|
|
explainer_name = gr.Radio( |
|
|
label="Explainer", |
|
|
choices=list(EXPLAINERS.keys()), |
|
|
value=list(EXPLAINERS.keys())[0], |
|
|
container=True, |
|
|
visible=False |
|
|
) |
|
|
with gr.Row(variant="compact"): |
|
|
upper_percentile = gr.Textbox(label="Upper", value="85", container=True) |
|
|
middle_percentile = gr.Textbox( |
|
|
label="Middle", value="75", container=True |
|
|
) |
|
|
lower_percentile = gr.Textbox(label="Lower", value="10", container=True) |
|
|
|
|
|
with gr.Row(variant="compact"): |
|
|
model_name = gr.Radio( |
|
|
label="Model", |
|
|
choices=list(MODELS.keys()), |
|
|
value=list(MODELS.keys())[0], |
|
|
container=True, |
|
|
) |
|
|
with gr.Row(variant="compact"): |
|
|
perturber_name = gr.Radio( |
|
|
label="Perturber", |
|
|
choices=list(PERTURBERS.keys()), |
|
|
value=list(PERTURBERS.keys())[0], |
|
|
container=True, |
|
|
) |
|
|
with gr.Row(variant="compact"): |
|
|
comparator_name = gr.Radio( |
|
|
label="Comparator", |
|
|
choices=list(COMPARATORS.keys()), |
|
|
value=list(COMPARATORS.keys())[0], |
|
|
container=True, |
|
|
) |
|
|
with gr.Row(variant="compact"): |
|
|
|
|
|
|
|
|
submit_btn = gr.Button( |
|
|
value="🛠 Submit", |
|
|
variant="secondary", |
|
|
elem_id="button", |
|
|
interactive=True, |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
generator_vis = gr.HTML(label="Explanations") |
|
|
|
|
|
with gr.Row(): |
|
|
system_response = gr.Textbox( |
|
|
label="System Response", |
|
|
container=True, |
|
|
interactive=False, |
|
|
) |
|
|
|
|
|
return ( |
|
|
qn_choice, |
|
|
user_input, |
|
|
system_response, |
|
|
granularity, |
|
|
upper_percentile, |
|
|
middle_percentile, |
|
|
lower_percentile, |
|
|
explainer_name, |
|
|
model_name, |
|
|
perturber_name, |
|
|
comparator_name, |
|
|
generator_vis, |
|
|
submit_btn, |
|
|
) |
|
|
|
|
|
def __visualize_explanations( |
|
|
self, |
|
|
user_input: str, |
|
|
system_response: Optional[str], |
|
|
generator_explanations: ExplanationDto, |
|
|
upper_percentile: Optional[int], |
|
|
middle_percentile: Optional[int], |
|
|
lower_percentile: Optional[int], |
|
|
) -> str: |
|
|
segregator = PercentileBasedSegregator( |
|
|
upper_bound_percentile=upper_percentile, |
|
|
middle_bound_percentile=middle_percentile, |
|
|
lower_bound_percentile=lower_percentile, |
|
|
) |
|
|
return self.__visualizer.visualize( |
|
|
segregator=segregator, |
|
|
explanations=generator_explanations, |
|
|
output_from_explanations=user_input, |
|
|
) |
|
|
|