Spaces:
Runtime error
Runtime error
Commit ·
229e14c
1
Parent(s): 11174d4
feat/fix: fixing code issues, adding plotting functions
Browse files- .gitignore +1 -0
- backend/controller.py +3 -3
- explanation/interpret_captum.py +0 -40
- explanation/interpret_shap.py +0 -72
- explanation/markup.py +12 -12
- explanation/plotting.py +0 -0
- explanation/visualize.py +0 -52
- explanation/visualize_att.py +0 -0
- model/mistral.py +10 -10
- requirements.txt +1 -3
.gitignore
CHANGED
|
@@ -2,3 +2,4 @@
|
|
| 2 |
__pycache__/
|
| 3 |
/start-venv.sh
|
| 4 |
/components/iframe/dist/
|
|
|
|
|
|
| 2 |
__pycache__/
|
| 3 |
/start-venv.sh
|
| 4 |
/components/iframe/dist/
|
| 5 |
+
.venv
|
backend/controller.py
CHANGED
|
@@ -10,7 +10,7 @@ from model import mistral
|
|
| 10 |
from explanation import (
|
| 11 |
interpret_shap as shap_int,
|
| 12 |
interpret_captum as cpt_int,
|
| 13 |
-
|
| 14 |
)
|
| 15 |
|
| 16 |
|
|
@@ -33,10 +33,10 @@ def interference(
|
|
| 33 |
|
| 34 |
if model_selection.lower() == "mistral":
|
| 35 |
model = mistral
|
| 36 |
-
print("
|
| 37 |
else:
|
| 38 |
model = godel
|
| 39 |
-
print("
|
| 40 |
|
| 41 |
# if a XAI approach is selected, grab the XAI module instance
|
| 42 |
if xai_selection in ("SHAP", "Attention"):
|
|
|
|
| 10 |
from explanation import (
|
| 11 |
interpret_shap as shap_int,
|
| 12 |
interpret_captum as cpt_int,
|
| 13 |
+
visualize_att as viz,
|
| 14 |
)
|
| 15 |
|
| 16 |
|
|
|
|
| 33 |
|
| 34 |
if model_selection.lower() == "mistral":
|
| 35 |
model = mistral
|
| 36 |
+
print("Indentified model as Mistral")
|
| 37 |
else:
|
| 38 |
model = godel
|
| 39 |
+
print("Indentified model as GODEL")
|
| 40 |
|
| 41 |
# if a XAI approach is selected, grab the XAI module instance
|
| 42 |
if xai_selection in ("SHAP", "Attention"):
|
explanation/interpret_captum.py
CHANGED
|
@@ -1,40 +0,0 @@
|
|
| 1 |
-
# external imports
|
| 2 |
-
from captum.attr import LLMAttribution, TextTokenInput, KernelShap
|
| 3 |
-
import torch
|
| 4 |
-
|
| 5 |
-
# internal imports
|
| 6 |
-
from utils import formatting as fmt
|
| 7 |
-
from .markup import markup_text
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
# main explain function that returns a chat with explanations
|
| 11 |
-
def chat_explained(model, prompt):
|
| 12 |
-
model.set_config({})
|
| 13 |
-
|
| 14 |
-
# creating llm attribution class with KernelSHAP and Mistal Model, Tokenizer
|
| 15 |
-
llm_attribution = LLMAttribution(KernelShap(model.MODEL), model.TOKENIZER)
|
| 16 |
-
|
| 17 |
-
# generation attribution
|
| 18 |
-
attribution_input = TextTokenInput(prompt, model.TOKENIZER)
|
| 19 |
-
attribution_result = llm_attribution.attribute(
|
| 20 |
-
attribution_input, gen_args=model.CONFIG.to_dict()
|
| 21 |
-
)
|
| 22 |
-
|
| 23 |
-
# extracting values and input tokens
|
| 24 |
-
values = attribution_result.seq_attr.to(torch.device("cpu")).numpy()
|
| 25 |
-
input_tokens = fmt.format_tokens(attribution_result.input_tokens)
|
| 26 |
-
|
| 27 |
-
# raising error if mismatch occurs
|
| 28 |
-
if len(attribution_result.input_tokens) != len(values):
|
| 29 |
-
raise RuntimeError("values and input len mismatch")
|
| 30 |
-
|
| 31 |
-
# getting response text, graphic placeholder and marked text object
|
| 32 |
-
response_text = fmt.format_output_text(attribution_result.output_tokens)
|
| 33 |
-
graphic = (
|
| 34 |
-
"<div style='text-align: center; font-family:arial;'><h4>Attention"
|
| 35 |
-
"Intepretation with Captum doesn't support an interactive graphic.</h4></div>"
|
| 36 |
-
)
|
| 37 |
-
marked_text = markup_text(input_tokens, values, variant="captum")
|
| 38 |
-
|
| 39 |
-
# return response, graphic and marked_text array
|
| 40 |
-
return response_text, graphic, marked_text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
explanation/interpret_shap.py
CHANGED
|
@@ -1,72 +0,0 @@
|
|
| 1 |
-
# interpret module that implements the interpretability method
|
| 2 |
-
|
| 3 |
-
# external imports
|
| 4 |
-
from shap import models, maskers, plots, PartitionExplainer
|
| 5 |
-
import torch
|
| 6 |
-
|
| 7 |
-
# internal imports
|
| 8 |
-
from utils import formatting as fmt
|
| 9 |
-
from .markup import markup_text
|
| 10 |
-
|
| 11 |
-
# global variables
|
| 12 |
-
TEACHER_FORCING = None
|
| 13 |
-
TEXT_MASKER = None
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
# main explain function that returns a chat with explanations
|
| 17 |
-
def chat_explained(model, prompt):
|
| 18 |
-
model.set_config({})
|
| 19 |
-
|
| 20 |
-
# create the shap explainer
|
| 21 |
-
shap_explainer = PartitionExplainer(model.MODEL, model.TOKENIZER)
|
| 22 |
-
|
| 23 |
-
# get the shap values for the prompt
|
| 24 |
-
shap_values = shap_explainer([prompt])
|
| 25 |
-
|
| 26 |
-
# create the explanation graphic and marked text array
|
| 27 |
-
graphic = create_graphic(shap_values)
|
| 28 |
-
marked_text = markup_text(
|
| 29 |
-
shap_values.data[0], shap_values.values[0], variant="shap"
|
| 30 |
-
)
|
| 31 |
-
|
| 32 |
-
# create the response text
|
| 33 |
-
response_text = fmt.format_output_text(shap_values.output_names)
|
| 34 |
-
|
| 35 |
-
# return response, graphic and marked_text array
|
| 36 |
-
return response_text, graphic, marked_text
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
# function used to wrap the model with a shap model
|
| 40 |
-
def wrap_shap(model):
|
| 41 |
-
# calling global variants
|
| 42 |
-
global TEXT_MASKER, TEACHER_FORCING
|
| 43 |
-
|
| 44 |
-
# set the device to cuda if gpu is available
|
| 45 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 46 |
-
|
| 47 |
-
# updating the model settings
|
| 48 |
-
model.set_config()
|
| 49 |
-
|
| 50 |
-
# (re)initialize the shap models and masker
|
| 51 |
-
# creating a shap text_generation model
|
| 52 |
-
text_generation = models.TextGeneration(model.MODEL, model.TOKENIZER)
|
| 53 |
-
# wrapping the text generation model in a teacher forcing model
|
| 54 |
-
TEACHER_FORCING = models.TeacherForcing(
|
| 55 |
-
text_generation,
|
| 56 |
-
model.TOKENIZER,
|
| 57 |
-
device=str(device),
|
| 58 |
-
similarity_model=model.MODEL,
|
| 59 |
-
similarity_tokenizer=model.TOKENIZER,
|
| 60 |
-
)
|
| 61 |
-
# setting the text masker as an empty string
|
| 62 |
-
TEXT_MASKER = maskers.Text(model.TOKENIZER, " ", collapse_mask_token=True)
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
# graphic plotting function that creates a html graphic (as string) for the explanation
|
| 66 |
-
def create_graphic(shap_values):
|
| 67 |
-
|
| 68 |
-
# create the html graphic using shap text plot function
|
| 69 |
-
graphic_html = plots.text(shap_values, display=False)
|
| 70 |
-
|
| 71 |
-
# return the html graphic as string to display in iFrame
|
| 72 |
-
return str(graphic_html)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
explanation/markup.py
CHANGED
|
@@ -66,16 +66,16 @@ def color_codes():
|
|
| 66 |
return {
|
| 67 |
# -5 to -1: Strong Light Sky Blue to Lighter Sky Blue
|
| 68 |
# 0: white (assuming default light mode)
|
| 69 |
-
# +1 to +5 light pink to
|
| 70 |
-
"-5": "#
|
| 71 |
-
"-4": "#
|
| 72 |
-
"-3": "#
|
| 73 |
-
"-2": "#
|
| 74 |
-
"-1": "#
|
| 75 |
-
"0": "#
|
| 76 |
-
"
|
| 77 |
-
"
|
| 78 |
-
"
|
| 79 |
-
"
|
| 80 |
-
"
|
| 81 |
}
|
|
|
|
| 66 |
return {
|
| 67 |
# -5 to -1: Strong Light Sky Blue to Lighter Sky Blue
|
| 68 |
# 0: white (assuming default light mode)
|
| 69 |
+
# +1 to +5 light pink to strng magenta
|
| 70 |
+
"-5": "#008bfb",
|
| 71 |
+
"-4": "#68a1fd",
|
| 72 |
+
"-3": "#96b7fe",
|
| 73 |
+
"-2": "#bcceff",
|
| 74 |
+
"-1:": "#dee6ff",
|
| 75 |
+
"0": "#ffffff",
|
| 76 |
+
"1": "#ffd9d9",
|
| 77 |
+
"2": "#ffb3b5",
|
| 78 |
+
"3": "#ff8b92",
|
| 79 |
+
"4": "#ff5c71",
|
| 80 |
+
"5": "#ff0051",
|
| 81 |
}
|
explanation/plotting.py
ADDED
|
File without changes
|
explanation/visualize.py
DELETED
|
@@ -1,52 +0,0 @@
|
|
| 1 |
-
# visualization module that creates an attention visualization
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
# internal imports
|
| 5 |
-
from utils import formatting as fmt
|
| 6 |
-
from .markup import markup_text
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
# chat function that returns an answer
|
| 10 |
-
# and marked text based on attention
|
| 11 |
-
def chat_explained(model, prompt):
|
| 12 |
-
|
| 13 |
-
# get encoded input
|
| 14 |
-
encoder_input_ids = model.TOKENIZER(
|
| 15 |
-
prompt, return_tensors="pt", add_special_tokens=True
|
| 16 |
-
).input_ids
|
| 17 |
-
# generate output together with attentions of the model
|
| 18 |
-
decoder_input_ids = model.MODEL.generate(
|
| 19 |
-
encoder_input_ids, output_attentions=True, **model.CONFIG
|
| 20 |
-
)
|
| 21 |
-
|
| 22 |
-
# get input and output text as list of strings
|
| 23 |
-
encoder_text = fmt.format_tokens(
|
| 24 |
-
model.TOKENIZER.convert_ids_to_tokens(encoder_input_ids[0])
|
| 25 |
-
)
|
| 26 |
-
decoder_text = fmt.format_tokens(
|
| 27 |
-
model.TOKENIZER.convert_ids_to_tokens(decoder_input_ids[0])
|
| 28 |
-
)
|
| 29 |
-
|
| 30 |
-
# get attention values for the input and output vectors
|
| 31 |
-
# using already generated input and output
|
| 32 |
-
attention_output = model.MODEL(
|
| 33 |
-
input_ids=encoder_input_ids,
|
| 34 |
-
decoder_input_ids=decoder_input_ids,
|
| 35 |
-
output_attentions=True,
|
| 36 |
-
)
|
| 37 |
-
|
| 38 |
-
# averaging attention across layers
|
| 39 |
-
averaged_attention = fmt.avg_attention(attention_output)
|
| 40 |
-
|
| 41 |
-
# format response text for clean output
|
| 42 |
-
response_text = fmt.format_output_text(decoder_text)
|
| 43 |
-
# setting placeholder for iFrame graphic
|
| 44 |
-
graphic = (
|
| 45 |
-
"<div style='text-align: center; font-family:arial;'><h4>Attention"
|
| 46 |
-
" Visualization doesn't support an interactive graphic.</h4></div>"
|
| 47 |
-
)
|
| 48 |
-
# creating marked text using markup_text function and attention
|
| 49 |
-
marked_text = markup_text(encoder_text, averaged_attention, variant="visualizer")
|
| 50 |
-
|
| 51 |
-
# returning response, graphic and marked text array
|
| 52 |
-
return response_text, graphic, marked_text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
explanation/visualize_att.py
ADDED
|
File without changes
|
model/mistral.py
CHANGED
|
@@ -41,13 +41,11 @@ CONFIG.update(**{
|
|
| 41 |
|
| 42 |
|
| 43 |
# function to (re) set config
|
| 44 |
-
def set_config(
|
| 45 |
|
| 46 |
-
# if config dict is given,
|
| 47 |
-
if
|
| 48 |
-
|
| 49 |
-
else:
|
| 50 |
-
CONFIG.update(**{
|
| 51 |
"temperature": 0.7,
|
| 52 |
"max_new_tokens": 50,
|
| 53 |
"max_length": 50,
|
|
@@ -55,7 +53,9 @@ def set_config(config: dict):
|
|
| 55 |
"repetition_penalty": 1.2,
|
| 56 |
"do_sample": True,
|
| 57 |
"seed": 42,
|
| 58 |
-
}
|
|
|
|
|
|
|
| 59 |
|
| 60 |
|
| 61 |
# advanced formatting function that takes into a account a conversation history
|
|
@@ -77,9 +77,9 @@ def format_prompt(message: str, history: list, system_prompt: str, knowledge: st
|
|
| 77 |
"""
|
| 78 |
else:
|
| 79 |
# takes the very first exchange and the system prompt as base
|
| 80 |
-
prompt =
|
| 81 |
-
|
| 82 |
-
|
| 83 |
|
| 84 |
# adds conversation history to the prompt
|
| 85 |
for conversation in history[1:]:
|
|
|
|
| 41 |
|
| 42 |
|
| 43 |
# function to (re) set config
|
| 44 |
+
def set_config(config_dict: dict):
|
| 45 |
|
| 46 |
+
# if config dict is not given, set to default
|
| 47 |
+
if config_dict == {}:
|
| 48 |
+
config_dict = {
|
|
|
|
|
|
|
| 49 |
"temperature": 0.7,
|
| 50 |
"max_new_tokens": 50,
|
| 51 |
"max_length": 50,
|
|
|
|
| 53 |
"repetition_penalty": 1.2,
|
| 54 |
"do_sample": True,
|
| 55 |
"seed": 42,
|
| 56 |
+
}
|
| 57 |
+
|
| 58 |
+
CONFIG.update(**dict)
|
| 59 |
|
| 60 |
|
| 61 |
# advanced formatting function that takes into a account a conversation history
|
|
|
|
| 77 |
"""
|
| 78 |
else:
|
| 79 |
# takes the very first exchange and the system prompt as base
|
| 80 |
+
prompt = f"""
|
| 81 |
+
<s>[INST] {system_prompt} {history[0][0]} [/INST] {history[0][1]}</s>
|
| 82 |
+
"""
|
| 83 |
|
| 84 |
# adds conversation history to the prompt
|
| 85 |
for conversation in history[1:]:
|
requirements.txt
CHANGED
|
@@ -2,7 +2,7 @@ gradio~=4.7.1
|
|
| 2 |
transformers~=4.35.2
|
| 3 |
torch~=2.1.1
|
| 4 |
shap
|
| 5 |
-
captum
|
| 6 |
bertviz~=1.4.0
|
| 7 |
accelerate~=0.24.1
|
| 8 |
bitsandbytes
|
|
@@ -13,9 +13,7 @@ uvicorn~=0.24.0
|
|
| 13 |
tinydb~=4.8.0
|
| 14 |
black~=23.12.0
|
| 15 |
pylint~=3.0.0
|
| 16 |
-
seaborn~=0.13.0
|
| 17 |
numpy
|
| 18 |
matplotlib
|
| 19 |
pre-commit
|
| 20 |
-
ipython
|
| 21 |
gradio-iframe~=0.0.10
|
|
|
|
| 2 |
transformers~=4.35.2
|
| 3 |
torch~=2.1.1
|
| 4 |
shap
|
| 5 |
+
captum @ git+https://github.com/LennardZuendorf/thesis-captum.git
|
| 6 |
bertviz~=1.4.0
|
| 7 |
accelerate~=0.24.1
|
| 8 |
bitsandbytes
|
|
|
|
| 13 |
tinydb~=4.8.0
|
| 14 |
black~=23.12.0
|
| 15 |
pylint~=3.0.0
|
|
|
|
| 16 |
numpy
|
| 17 |
matplotlib
|
| 18 |
pre-commit
|
|
|
|
| 19 |
gradio-iframe~=0.0.10
|