{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "IZ6SNYq_tVVC" }, "source": [ "# Intent Classification with BERT\n", "\n", "This notebook demonstrates the fine-tuning of BERT to perform intent classification.\n", "Intent classification tries to map given instructions (sentence in natural language) to a set of predefined intents. \n", "\n", "## What you will learn\n", "\n", "- Load data from csv and preprocess it for training and test\n", "- Load a BERT model from TensorFlow Hub\n", "- Build your own model by combining BERT with a classifier\n", "- Train your own model, fine-tuning BERT as part of that\n", "- Save your model and use it to recognize the intend of instructions\n" ] }, { "cell_type": "markdown", "metadata": { "id": "2PHBpLPuQdmK" }, "source": [ "## About BERT\n", "\n", "[BERT](https://arxiv.org/abs/1810.04805) and other Transformer encoder architectures have been shown to be successful on a variety of tasks in NLP (natural language processing). They compute vector-space representations of natural language that are suitable for use in deep learning models. The BERT family of models uses the Transformer encoder architecture to process each token of input text in the full context of all tokens before and after, hence the name: Bidirectional Encoder Representations from Transformers. \n", "\n", "BERT models are usually pre-trained on a large corpus of text, then fine-tuned for specific tasks.\n" ] }, { "cell_type": "markdown", "metadata": { "id": "SCjmX4zTCkRK" }, "source": [ "## Setup\n" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "execution": { "iopub.execute_input": "2021-01-13T03:07:33.561260Z", "iopub.status.busy": "2021-01-13T03:07:33.560567Z", "iopub.status.idle": "2021-01-13T03:07:40.852309Z", "shell.execute_reply": "2021-01-13T03:07:40.851601Z" }, "id": "_XgTpm9ZxoN9" }, "outputs": [], "source": [ "import os\n", "#import shutil\n", "import pandas as pd\n", "\n", "import tensorflow as tf\n", "import tensorflow_hub as hub\n", "import tensorflow_text as text\n", "import seaborn as sns\n", "from pylab import rcParams\n", "\n", "import matplotlib.pyplot as plt\n", "tf.get_logger().setLevel('ERROR')\n", "\n", "sns.set(style='whitegrid', palette='muted', font_scale=1.2)\n", "HAPPY_COLORS_PALETTE = [\"#01BEFE\", \"#FFDD00\", \"#FF7D00\", \"#FF006D\", \"#ADFF02\", \"#8F00FF\"]\n", "sns.set_palette(sns.color_palette(HAPPY_COLORS_PALETTE))\n", "rcParams['figure.figsize'] = 12, 8\n", "import warnings\n", "warnings.filterwarnings(\"ignore\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Data Access\n", "The data contains various user queries categorized into seven intents. It is hosted on [GitHub](https://github.com/snipsco/nlu-benchmark/tree/master/2017-06-custom-intent-engines) and is first presented in [this paper](https://arxiv.org/abs/1805.10190). In the list below the classes and an example for each class is given:\n", "\n", "* `class`: SearchCreativeWork - `example`:*play hell house song*\n", "* `class`: GetWeather - `example`: *is it windy in boston, mb right now*\n", "* `class`: BookRestaurant - `example`: *book a restaurant for eight people in six years*\n", "* `class`: PlayMusic - `example`: *play the song little robin redbreast*\n", "* `class`: AddToPlaylist - `example`: *add step to me to the 50 clásicos playlist*\n", "* `class`: RateBook - `example`: *give 6 stars to of mice and men*\n", "* `class`: SearchScreeningEvent - `example` : *find fish story*" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Data can be downloaded from a Google Drive by applying [gdown](https://pypi.org/project/gdown/). In the following code cells the download is invoked only if the corresponding file, does not yet exist at the corresponding location." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "datafolder=\"\"" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "trainfile=datafolder+\"train.csv\"\n", "testfile=datafolder+\"test.csv\"\n", "validfile=datafolder+\"valid.csv\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next, the downloaded .csv-files for training, validation and test are imported into pandas dataframes:" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [], "source": [ "traindf = pd.read_csv(trainfile)\n", "validdf = pd.read_csv(validfile)\n", "testdf = pd.read_csv(testfile)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
| \n", " | text | \n", "intent | \n", "
|---|---|---|
| 0 | \n", "listen to westbam alumb allergic on google music | \n", "PlayMusic | \n", "
| 1 | \n", "add step to me to the 50 clásicos playlist | \n", "AddToPlaylist | \n", "
| 2 | \n", "i give this current textbook a rating value of... | \n", "RateBook | \n", "
| 3 | \n", "play the song little robin redbreast | \n", "PlayMusic | \n", "
| 4 | \n", "please add iris dement to my playlist this is ... | \n", "AddToPlaylist | \n", "