File size: 2,759 Bytes
e25d079
1
{"cells": [{"cell_type": "markdown", "id": "302934307671667531413257853548643485645", "metadata": {}, "source": ["# Gradio Demo: blocks_interpretation"]}, {"cell_type": "code", "execution_count": null, "id": "272996653310673477252411125948039410165", "metadata": {}, "outputs": [], "source": ["!pip install -q gradio shap matplotlib transformers torch"]}, {"cell_type": "code", "execution_count": null, "id": "288918539441861185822528903084949547379", "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "import shap\n", "from transformers import pipeline\n", "import matplotlib.pyplot as plt\n", "\n", "\n", "sentiment_classifier = pipeline(\"text-classification\", return_all_scores=True)\n", "\n", "\n", "def classifier(text):\n", "    pred = sentiment_classifier(text)\n", "    return {p[\"label\"]: p[\"score\"] for p in pred[0]}\n", "\n", "\n", "def interpretation_function(text):\n", "    explainer = shap.Explainer(sentiment_classifier)\n", "    shap_values = explainer([text])\n", "    # Dimensions are (batch size, text size, number of classes)\n", "    # Since we care about positive sentiment, use index 1\n", "    scores = list(zip(shap_values.data[0], shap_values.values[0, :, 1]))\n", "\n", "    scores_desc = sorted(scores, key=lambda t: t[1])[::-1]\n", "\n", "    # Filter out empty string added by shap\n", "    scores_desc = [t for t in scores_desc if t[0] != \"\"]\n", "\n", "    fig_m = plt.figure()\n", "    plt.bar(x=[s[0] for s in scores_desc[:5]],\n", "            height=[s[1] for s in scores_desc[:5]])\n", "    plt.title(\"Top words contributing to positive sentiment\")\n", "    plt.ylabel(\"Shap Value\")\n", "    plt.xlabel(\"Word\")\n", "    return {\"original\": text, \"interpretation\": scores}, fig_m\n", "\n", "\n", "with gr.Blocks() as demo:\n", "    with gr.Row():\n", "        with gr.Column():\n", "            input_text = gr.Textbox(label=\"Input Text\")\n", "            with gr.Row():\n", "                classify = gr.Button(\"Classify Sentiment\")\n", "                interpret = gr.Button(\"Interpret\")\n", "        with gr.Column():\n", "            label = gr.Label(label=\"Predicted Sentiment\")\n", "        with gr.Column():\n", "            with gr.Tab(\"Display interpretation with built-in component\"):\n", "                interpretation = gr.components.Interpretation(input_text)\n", "            with gr.Tab(\"Display interpretation with plot\"):\n", "                interpretation_plot = gr.Plot()\n", "\n", "    classify.click(classifier, input_text, label)\n", "    interpret.click(interpretation_function, input_text, [interpretation, interpretation_plot])\n", "\n", "if __name__ == \"__main__\":\n", "    demo.launch()"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}