{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: interpretation_component"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio shap transformers torch"]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import shap\n", "from transformers import pipeline\n", "\n", "\n", "sentiment_classifier = pipeline(\"text-classification\", return_all_scores=True)\n", "\n", "def interpretation_function(text):\n", " explainer = shap.Explainer(sentiment_classifier)\n", " shap_values = explainer([text])\n", " scores = list(zip(shap_values.data[0], shap_values.values[0, :, 1]))\n", " return {\"original\": text, \"interpretation\": scores}\n", "\n", "with gr.Blocks() as demo:\n", " with gr.Row():\n", " with gr.Column():\n", " input_text = gr.Textbox(label=\"Sentiment Analysis\", value=\"Wonderfully terrible\")\n", " with gr.Row():\n", " interpret = gr.Button(\"Interpret\")\n", " with gr.Column():\n", " interpretation = gr.components.Interpretation(input_text)\n", "\n", " interpret.click(interpretation_function, input_text, interpretation)\n", "\n", "if __name__ == \"__main__\":\n", " demo.launch()"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}