File size: 1,547 Bytes
87e9f1f
1
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: interpretation_component"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio shap transformers torch"]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import shap\n", "from transformers import pipeline\n", "\n", "\n", "sentiment_classifier = pipeline(\"text-classification\", return_all_scores=True)\n", "\n", "def interpretation_function(text):\n", "    explainer = shap.Explainer(sentiment_classifier)\n", "    shap_values = explainer([text])\n", "    scores = list(zip(shap_values.data[0], shap_values.values[0, :, 1]))\n", "    return {\"original\": text, \"interpretation\": scores}\n", "\n", "with gr.Blocks() as demo:\n", "    with gr.Row():\n", "        with gr.Column():\n", "            input_text = gr.Textbox(label=\"Sentiment Analysis\", value=\"Wonderfully terrible\")\n", "            with gr.Row():\n", "                interpret = gr.Button(\"Interpret\")\n", "        with gr.Column():\n", "                interpretation = gr.components.Interpretation(input_text)\n", "\n", "    interpret.click(interpretation_function, input_text, interpretation)\n", "\n", "if __name__ == \"__main__\":\n", "    demo.launch()"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}