Spaces:
Runtime error
Runtime error
| from ctypes import DEFAULT_MODE | |
| import streamlit as st | |
| from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig | |
| from ferret import Benchmark | |
| from torch.nn.functional import softmax | |
| DEFAULT_MODEL = "distilbert-base-uncased-finetuned-sst-2-english" | |
| def get_model(model_name): | |
| return AutoModelForSequenceClassification.from_pretrained(model_name) | |
| def get_config(model_name): | |
| return AutoConfig.from_pretrained(model_name) | |
| def get_tokenizer(tokenizer_name): | |
| return AutoTokenizer.from_pretrained(tokenizer_name, use_fast=True) | |
| def body(): | |
| st.markdown( | |
| """ | |
| # Welcome to the *ferret* showcase | |
| You are working now on the *single instance* mode -- i.e., you will work and | |
| inspect one textual query at a time. | |
| ## Sentiment Analysis | |
| Post-hoc explanation techniques discose the rationale behind a given prediction a model | |
| makes while detecting a sentiment out of a text. In a sense the let you *poke* inside the model. | |
| But **who watches the watchers**? | |
| Let's find out! | |
| Let's choose your favourite sentiment classification mode and let ferret do the rest. | |
| We will: | |
| 1. download your model - if you're impatient, here it is a [cute video](https://www.youtube.com/watch?v=0Xks8t-SWHU) π¦ for you; | |
| 2. explain using *ferret*'s built-in methods βοΈ | |
| 3. evaluate explanations with state-of-the-art **faithfulness metrics** π | |
| """ | |
| ) | |
| col1, col2 = st.columns([3, 1]) | |
| with col1: | |
| model_name = st.text_input("HF Model", DEFAULT_MODEL) | |
| with col2: | |
| target = st.selectbox( | |
| "Target", | |
| options=range(5), | |
| index=1, | |
| help="Positional index of your target class.", | |
| ) | |
| text = st.text_input("Text") | |
| compute = st.button("Compute") | |
| if compute and model_name: | |
| with st.spinner("Preparing the magic. Hang in there..."): | |
| model = get_model(model_name) | |
| tokenizer = get_tokenizer(model_name) | |
| config = get_config(model_name) | |
| bench = Benchmark(model, tokenizer) | |
| st.markdown("### Prediction") | |
| scores = bench.score(text) | |
| scores_str = ", ".join( | |
| [f"{config.id2label[l]}: {s:.2f}" for l, s in enumerate(scores)] | |
| ) | |
| st.text(scores_str) | |
| with st.spinner("Computing Explanations.."): | |
| explanations = bench.explain(text, target=target) | |
| st.markdown("### Explanations") | |
| st.dataframe(bench.show_table(explanations)) | |
| with st.spinner("Evaluating Explanations..."): | |
| evaluations = bench.evaluate_explanations( | |
| explanations, target=target, apply_style=False | |
| ) | |
| st.markdown("### Faithfulness Metrics") | |
| st.dataframe(bench.show_evaluation_table(evaluations)) | |
| st.markdown( | |
| """ | |
| **Legend** | |
| - **AOPC Comprehensiveness** (aopc_compr) measures *comprehensiveness*, i.e., if the | |
| explanation captures | |
| - **AOPC Sufficiency** (aopc_suff) measures *sufficiency*, i.e., | |
| - **Leave-On-Out TAU Correlation** (taucorr_loo) measures | |
| See the paper for details. | |
| """ | |
| ) | |