rag-ex / src /ui /explainer_ui.py
Viju Sudhi
feat: adding explainer dtos
7f98908
import json
from typing import List, Optional
import gradio as gr
from gradio.components import Markdown
from src.dto.dto import ExplanationGranularity, ExplanationDto
from src.utils.registry import EXPLAINERS, MODELS, PERTURBERS, COMPARATORS
from src.utils.segregate import PercentileBasedSegregator
from src.utils.visualizer import Visualizer
class MockExplainerUI:
def __init__(
self,
logo_path: str,
css_path: str,
visualizer: Visualizer,
window_title: str,
title: str,
examples: Optional[List[str]] = None,
):
self.__logo_path = logo_path
self.__css_path = css_path
self.__examples = examples
self.__window_title = window_title
self.__title = title
self.__visualizer = visualizer
self.app: gr.Blocks = self.build_app()
def build_app(self):
with gr.Blocks(
theme=gr.themes.Monochrome().set(
button_primary_background_fill="#009374",
button_primary_background_fill_hover="#009374C4",
checkbox_label_background_fill_selected="#028A6EFF",
),
css=self.__css_path,
title=self.__window_title,
) as demo:
self.__build_app_title()
(
qn_choice,
user_input,
system_response,
granularity,
upper_percentile,
middle_percentile,
lower_percentile,
explainer_name,
model_name,
perturber_name,
comparator_name,
generator_vis,
submit_btn,
) = self.__build_chat_and_explain()
submit_btn.click(
fn=self.run,
inputs=[
qn_choice,
user_input,
granularity,
upper_percentile,
middle_percentile,
lower_percentile,
explainer_name,
model_name,
perturber_name,
comparator_name,
],
outputs=[user_input, system_response, generator_vis],
)
return demo
def run(
self,
qn_choice: str,
user_input: str,
granularity: ExplanationGranularity,
upper_percentile: str,
middle_percentile: str,
lower_percentile: str,
explainer_name: str,
model_name: str,
perturber_name: str,
comparator_name: str,
):
language = "en" if "EN" in qn_choice else "de"
if "1" in qn_choice:
q_idx = 1
elif "2" in qn_choice:
q_idx = 2
elif "3" in qn_choice:
q_idx = 3
elif "4" in qn_choice:
q_idx = 4
elif "5" in qn_choice:
q_idx = 5
else:
q_idx = 1
file_path = f"data/{model_name}/{language}_pert_{perturber_name}_comp_{comparator_name}_exp_dto.json"
with open(file_path, "r") as f:
data = json.load(f)
data = data[q_idx]
explanation_dto = ExplanationDto.parse_obj(data)
user_input = explanation_dto.input_text
system_response = explanation_dto.output_text
generator_vis = self.__visualize_explanations(
user_input=user_input,
system_response=system_response,
generator_explanations=explanation_dto,
upper_percentile=int(upper_percentile),
middle_percentile=int(middle_percentile),
lower_percentile=int(lower_percentile),
)
return user_input, system_response, generator_vis
def __build_app_title(self):
with gr.Row():
with gr.Column(min_width=50, scale=1):
gr.Image(
value=self.__logo_path,
width=50,
height=50,
show_download_button=False,
show_label=False,
show_share_button=False,
container=False,
)
with gr.Column(scale=2):
Markdown(
f'<p style="text-align: left; font-size:200%; font-weight: bold"'
f">{self.__title}"
f"</p>"
)
def __build_chat_and_explain(self):
with gr.Row():
with gr.Column(scale=2):
qn_choice = gr.Radio(
# placeholder="Type your question here and press Enter.",
label="Choose from these examples",
container=True,
choices=[
"EN Example 1",
"EN Example 2",
"EN Example 3",
"EN Example 4",
"EN Example 5",
"DE Example 1",
"DE Example 2",
"DE Example 3",
"DE Example 4",
"DE Example 5"
],
)
with gr.Row():
with gr.Column(scale=2):
gr.Markdown(
value="**Note:** This is a demo version of the tool with "
"limited functionalities. For building the full "
"version, please visit [here](https://github.com/fraunhofer-iais/explainable-lms/tree/master).",
)
with gr.Row():
with gr.Column(scale=2):
user_input = gr.Textbox(
placeholder="Choose an example from the list above.",
label="Question",
container=True,
lines=10,
interactive=False
)
with gr.Column(scale=1):
granularity = gr.Radio(
choices=[e for e in ExplanationGranularity],
value=ExplanationGranularity.SENTENCE_LEVEL,
label="Explanation Granularity",
interactive=False
)
with gr.Accordion(label="Settings", open=False, elem_id="accordion"):
with gr.Row(variant="compact"):
explainer_name = gr.Radio(
label="Explainer",
choices=list(EXPLAINERS.keys()),
value=list(EXPLAINERS.keys())[0],
container=True,
visible=False
)
with gr.Row(variant="compact"):
upper_percentile = gr.Textbox(label="Upper", value="85", container=True)
middle_percentile = gr.Textbox(
label="Middle", value="75", container=True
)
lower_percentile = gr.Textbox(label="Lower", value="10", container=True)
with gr.Row(variant="compact"):
model_name = gr.Radio(
label="Model",
choices=list(MODELS.keys()),
value=list(MODELS.keys())[0],
container=True,
)
with gr.Row(variant="compact"):
perturber_name = gr.Radio(
label="Perturber",
choices=list(PERTURBERS.keys()),
value=list(PERTURBERS.keys())[0],
container=True,
)
with gr.Row(variant="compact"):
comparator_name = gr.Radio(
label="Comparator",
choices=list(COMPARATORS.keys()),
value=list(COMPARATORS.keys())[0],
container=True,
)
with gr.Row(variant="compact"):
# passing "elem_id" to use a custom style for the component
# in the CSS passed.
submit_btn = gr.Button(
value="🛠 Submit",
variant="secondary",
elem_id="button",
interactive=True,
)
with gr.Row():
generator_vis = gr.HTML(label="Explanations")
with gr.Row():
system_response = gr.Textbox(
label="System Response",
container=True,
interactive=False,
)
return (
qn_choice,
user_input,
system_response,
granularity,
upper_percentile,
middle_percentile,
lower_percentile,
explainer_name,
model_name,
perturber_name,
comparator_name,
generator_vis,
submit_btn,
)
def __visualize_explanations(
self,
user_input: str,
system_response: Optional[str],
generator_explanations: ExplanationDto,
upper_percentile: Optional[int],
middle_percentile: Optional[int],
lower_percentile: Optional[int],
) -> str:
segregator = PercentileBasedSegregator(
upper_bound_percentile=upper_percentile,
middle_bound_percentile=middle_percentile,
lower_bound_percentile=lower_percentile,
)
return self.__visualizer.visualize(
segregator=segregator,
explanations=generator_explanations,
output_from_explanations=user_input,
)