{ "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "import tensorflow as tf" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [], "source": [ "path_to_file = \"C:/Users/balde/Desktop/DSTI/Msc Applied Data Science & AI/Deep Learning/NLP/NPL-Text_Generation/datasets/shakespeare.txt\"" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [], "source": [ "text = open(path_to_file, 'r').read()" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "reward.\n", " HELENA. Inspired merit so by breath is barr'd.\n", " It is not so with Him that all things knows,\n", " As 'tis with us that square our guess by shows;\n", " But most it is presumption in us when\n", " The help of heaven we count the act of men.\n", " Dear sir, to my endeavours give consent;\n", " Of heaven, not me, make an experiment.\n", " I am not an impostor, that proclaim \n", " Myself against the level of mine aim;\n", " But know I think, and think I know most sure,\n", " My art is not past power nor you past cure.\n", " KING. Art thou so confident? Within what space\n", " Hop'st thou my cure?\n", " HELENA. The greatest Grace lending grace.\n", " Ere twice the horses of the sun shall bring\n", " Their fiery torcher his diurnal ring,\n", " Ere twice in murk and occidental damp\n", " Moist Hesperus hath quench'd his sleepy lamp,\n", " Or four and twenty times the pilot's glass\n", " Hath told the thievish minutes how they pass,\n", " What is infirm from your sound parts shall fly,\n", " Health shall live free, and s\n" ] } ], "source": [ "# print(text[:500])\n", "print(text[140500:141500])" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "['\\n',\n", " ' ',\n", " '!',\n", " '\"',\n", " '&',\n", " \"'\",\n", " '(',\n", " ')',\n", " ',',\n", " '-',\n", " '.',\n", " '0',\n", " '1',\n", " '2',\n", " '3',\n", " '4',\n", " '5',\n", " '6',\n", " '7',\n", " '8',\n", " '9',\n", " ':',\n", " ';',\n", " '<',\n", " '>',\n", " '?',\n", " 'A',\n", " 'B',\n", " 'C',\n", " 'D',\n", " 'E',\n", " 'F',\n", " 'G',\n", " 'H',\n", " 'I',\n", " 'J',\n", " 'K',\n", " 'L',\n", " 'M',\n", " 'N',\n", " 'O',\n", " 'P',\n", " 'Q',\n", " 'R',\n", " 'S',\n", " 'T',\n", " 'U',\n", " 'V',\n", " 'W',\n", " 'X',\n", " 'Y',\n", " 'Z',\n", " '[',\n", " ']',\n", " '_',\n", " '`',\n", " 'a',\n", " 'b',\n", " 'c',\n", " 'd',\n", " 'e',\n", " 'f',\n", " 'g',\n", " 'h',\n", " 'i',\n", " 'j',\n", " 'k',\n", " 'l',\n", " 'm',\n", " 'n',\n", " 'o',\n", " 'p',\n", " 'q',\n", " 'r',\n", " 's',\n", " 't',\n", " 'u',\n", " 'v',\n", " 'w',\n", " 'x',\n", " 'y',\n", " 'z',\n", " '|',\n", " '}']" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "vocab = sorted(set(text))\n", "vocab" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "84" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(vocab)" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(0, '\\n')\n", "(1, ' ')\n", "(2, '!')\n", "(3, '\"')\n", "(4, '&')\n", "(5, \"'\")\n", "(6, '(')\n", "(7, ')')\n", "(8, ',')\n", "(9, '-')\n", "(10, '.')\n", "(11, '0')\n", "(12, '1')\n", "(13, '2')\n", "(14, '3')\n", "(15, '4')\n", "(16, '5')\n", "(17, '6')\n", "(18, '7')\n", "(19, '8')\n", "(20, '9')\n", "(21, ':')\n", "(22, ';')\n", "(23, '<')\n", "(24, '>')\n", "(25, '?')\n", "(26, 'A')\n", "(27, 'B')\n", "(28, 'C')\n", "(29, 'D')\n", "(30, 'E')\n", "(31, 'F')\n", "(32, 'G')\n", "(33, 'H')\n", "(34, 'I')\n", "(35, 'J')\n", "(36, 'K')\n", "(37, 'L')\n", "(38, 'M')\n", "(39, 'N')\n", "(40, 'O')\n", "(41, 'P')\n", "(42, 'Q')\n", "(43, 'R')\n", "(44, 'S')\n", "(45, 'T')\n", "(46, 'U')\n", "(47, 'V')\n", "(48, 'W')\n", "(49, 'X')\n", "(50, 'Y')\n", "(51, 'Z')\n", "(52, '[')\n", "(53, ']')\n", "(54, '_')\n", "(55, '`')\n", "(56, 'a')\n", "(57, 'b')\n", "(58, 'c')\n", "(59, 'd')\n", "(60, 'e')\n", "(61, 'f')\n", "(62, 'g')\n", "(63, 'h')\n", "(64, 'i')\n", "(65, 'j')\n", "(66, 'k')\n", "(67, 'l')\n", "(68, 'm')\n", "(69, 'n')\n", "(70, 'o')\n", "(71, 'p')\n", "(72, 'q')\n", "(73, 'r')\n", "(74, 's')\n", "(75, 't')\n", "(76, 'u')\n", "(77, 'v')\n", "(78, 'w')\n", "(79, 'x')\n", "(80, 'y')\n", "(81, 'z')\n", "(82, '|')\n", "(83, '}')\n" ] } ], "source": [ "for pair in enumerate(vocab):\n", " print(pair)" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [], "source": [ "char_to_ind = {char:ind for ind, char in enumerate(vocab)}" ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'\\n': 0,\n", " ' ': 1,\n", " '!': 2,\n", " '\"': 3,\n", " '&': 4,\n", " \"'\": 5,\n", " '(': 6,\n", " ')': 7,\n", " ',': 8,\n", " '-': 9,\n", " '.': 10,\n", " '0': 11,\n", " '1': 12,\n", " '2': 13,\n", " '3': 14,\n", " '4': 15,\n", " '5': 16,\n", " '6': 17,\n", " '7': 18,\n", " '8': 19,\n", " '9': 20,\n", " ':': 21,\n", " ';': 22,\n", " '<': 23,\n", " '>': 24,\n", " '?': 25,\n", " 'A': 26,\n", " 'B': 27,\n", " 'C': 28,\n", " 'D': 29,\n", " 'E': 30,\n", " 'F': 31,\n", " 'G': 32,\n", " 'H': 33,\n", " 'I': 34,\n", " 'J': 35,\n", " 'K': 36,\n", " 'L': 37,\n", " 'M': 38,\n", " 'N': 39,\n", " 'O': 40,\n", " 'P': 41,\n", " 'Q': 42,\n", " 'R': 43,\n", " 'S': 44,\n", " 'T': 45,\n", " 'U': 46,\n", " 'V': 47,\n", " 'W': 48,\n", " 'X': 49,\n", " 'Y': 50,\n", " 'Z': 51,\n", " '[': 52,\n", " ']': 53,\n", " '_': 54,\n", " '`': 55,\n", " 'a': 56,\n", " 'b': 57,\n", " 'c': 58,\n", " 'd': 59,\n", " 'e': 60,\n", " 'f': 61,\n", " 'g': 62,\n", " 'h': 63,\n", " 'i': 64,\n", " 'j': 65,\n", " 'k': 66,\n", " 'l': 67,\n", " 'm': 68,\n", " 'n': 69,\n", " 'o': 70,\n", " 'p': 71,\n", " 'q': 72,\n", " 'r': 73,\n", " 's': 74,\n", " 't': 75,\n", " 'u': 76,\n", " 'v': 77,\n", " 'w': 78,\n", " 'x': 79,\n", " 'y': 80,\n", " 'z': 81,\n", " '|': 82,\n", " '}': 83}" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "char_to_ind" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [], "source": [ "ind_to_char = np.array(vocab)" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "33" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "char_to_ind['H']" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "'H'" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ind_to_char[33]" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [], "source": [ "encoded_text = np.array([char_to_ind[c] for c in text])" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([ 0, 1, 1, ..., 30, 39, 29])" ] }, "execution_count": 29, "metadata": {}, "output_type": "execute_result" } ], "source": [ "encoded_text" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(5445609,)" ] }, "execution_count": 30, "metadata": {}, "output_type": "execute_result" } ], "source": [ "encoded_text.shape" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", " 1\n", " From fairest creatures we desire increase,\n", " That thereby beauty's rose might never die,\n", " But as the riper should by time decease,\n", " His tender heir might bear his memory:\n", " But thou contracted to thine own bright eyes,\n", " Feed'st thy light's flame with self-substantial fuel,\n", " Making a famine where abundance lies,\n", " Thy self thy foe, to thy sweet self too cruel:\n", " Thou that art now the world's fresh ornament,\n", " And only herald to the gaudy spring,\n", " Within thine own bu\n" ] } ], "source": [ "sample = text[:500]\n", "print(sample)" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([ 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n", " 1, 1, 1, 1, 1, 12, 0, 1, 1, 31, 73, 70, 68, 1, 61, 56, 64,\n", " 73, 60, 74, 75, 1, 58, 73, 60, 56, 75, 76, 73, 60, 74, 1, 78, 60,\n", " 1, 59, 60, 74, 64, 73, 60, 1, 64, 69, 58, 73, 60, 56, 74, 60, 8,\n", " 0, 1, 1, 45, 63, 56, 75, 1, 75, 63, 60, 73, 60, 57, 80, 1, 57,\n", " 60, 56, 76, 75, 80, 5, 74, 1, 73, 70, 74, 60, 1, 68, 64, 62, 63,\n", " 75, 1, 69, 60, 77, 60, 73, 1, 59, 64, 60, 8, 0, 1, 1, 27, 76,\n", " 75, 1, 56, 74, 1, 75, 63, 60, 1, 73, 64, 71, 60, 73, 1, 74, 63,\n", " 70, 76, 67, 59, 1, 57, 80, 1, 75, 64, 68, 60, 1, 59, 60, 58, 60,\n", " 56, 74, 60, 8, 0, 1, 1, 33, 64, 74, 1, 75, 60, 69, 59, 60, 73,\n", " 1, 63, 60, 64, 73, 1, 68, 64, 62, 63, 75, 1, 57, 60, 56, 73, 1,\n", " 63, 64, 74, 1, 68, 60, 68, 70, 73, 80, 21, 0, 1, 1, 27, 76, 75,\n", " 1, 75, 63, 70, 76, 1, 58, 70, 69, 75, 73, 56, 58, 75, 60, 59, 1,\n", " 75, 70, 1, 75, 63, 64, 69, 60, 1, 70, 78, 69, 1, 57, 73, 64, 62,\n", " 63, 75, 1, 60, 80, 60, 74, 8, 0, 1, 1, 31, 60, 60, 59, 5, 74,\n", " 75, 1, 75, 63, 80, 1, 67, 64, 62, 63, 75, 5, 74, 1, 61, 67, 56,\n", " 68, 60, 1, 78, 64, 75, 63, 1, 74, 60, 67, 61, 9, 74, 76, 57, 74,\n", " 75, 56, 69, 75, 64, 56, 67, 1, 61, 76, 60, 67, 8, 0, 1, 1, 38,\n", " 56, 66, 64, 69, 62, 1, 56, 1, 61, 56, 68, 64, 69, 60, 1, 78, 63,\n", " 60, 73, 60, 1, 56, 57, 76, 69, 59, 56, 69, 58, 60, 1, 67, 64, 60,\n", " 74, 8, 0, 1, 1, 45, 63, 80, 1, 74, 60, 67, 61, 1, 75, 63, 80,\n", " 1, 61, 70, 60, 8, 1, 75, 70, 1, 75, 63, 80, 1, 74, 78, 60, 60,\n", " 75, 1, 74, 60, 67, 61, 1, 75, 70, 70, 1, 58, 73, 76, 60, 67, 21,\n", " 0, 1, 1, 45, 63, 70, 76, 1, 75, 63, 56, 75, 1, 56, 73, 75, 1,\n", " 69, 70, 78, 1, 75, 63, 60, 1, 78, 70, 73, 67, 59, 5, 74, 1, 61,\n", " 73, 60, 74, 63, 1, 70, 73, 69, 56, 68, 60, 69, 75, 8, 0, 1, 1,\n", " 26, 69, 59, 1, 70, 69, 67, 80, 1, 63, 60, 73, 56, 67, 59, 1, 75,\n", " 70, 1, 75, 63, 60, 1, 62, 56, 76, 59, 80, 1, 74, 71, 73, 64, 69,\n", " 62, 8, 0, 1, 1, 48, 64, 75, 63, 64, 69, 1, 75, 63, 64, 69, 60,\n", " 1, 70, 78, 69, 1, 57, 76])" ] }, "execution_count": 33, "metadata": {}, "output_type": "execute_result" } ], "source": [ "encoded_text[:500]" ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "42" ] }, "execution_count": 34, "metadata": {}, "output_type": "execute_result" } ], "source": [ "line = \"From fairest creatures we desire increase,\"\n", "len(line)" ] }, { "cell_type": "code", "execution_count": 36, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "133" ] }, "execution_count": 36, "metadata": {}, "output_type": "execute_result" } ], "source": [ "lines = '''\n", "From fairest creatures we desire increase,\n", " That thereby beauty's rose might never die,\n", " But as the riper should by time decease,\n", "'''\n", "\n", "len(lines)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 37, "metadata": {}, "outputs": [], "source": [ "seq_len = 120" ] }, { "cell_type": "code", "execution_count": 38, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "45005" ] }, "execution_count": 38, "metadata": {}, "output_type": "execute_result" } ], "source": [ "total_num_seq = len(text) // (seq_len + 1)\n", "total_num_seq" ] }, { "cell_type": "code", "execution_count": 39, "metadata": {}, "outputs": [], "source": [ "char_dataset = tf.data.Dataset.from_tensor_slices(encoded_text)" ] }, { "cell_type": "code", "execution_count": 40, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "tensorflow.python.data.ops.from_tensor_slices_op._TensorSliceDataset" ] }, "execution_count": 40, "metadata": {}, "output_type": "execute_result" } ], "source": [ "type(char_dataset)" ] }, { "cell_type": "code", "execution_count": 42, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "1\n", "\n", "\n", " \n", " \n", "F\n", "r\n", "o\n", "m\n", " \n", "f\n", "a\n", "i\n", "r\n", "e\n", "s\n", "t\n", " \n", "c\n", "r\n", "e\n", "a\n", "t\n", "u\n", "r\n", "e\n", "s\n", " \n", "w\n", "e\n", " \n", "d\n", "e\n", "s\n", "i\n", "r\n", "e\n", " \n", "i\n", "n\n", "c\n", "r\n", "e\n", "a\n", "s\n", "e\n", ",\n", "\n", "\n", " \n", " \n", "T\n", "h\n", "a\n", "t\n", " \n", "t\n", "h\n", "e\n", "r\n", "e\n", "b\n", "y\n", " \n", "b\n", "e\n", "a\n", "u\n", "t\n", "y\n", "'\n", "s\n", " \n", "r\n", "o\n", "s\n", "e\n", " \n", "m\n", "i\n", "g\n", "h\n", "t\n", " \n", "n\n", "e\n", "v\n", "e\n", "r\n", " \n", "d\n", "i\n", "e\n", ",\n", "\n", "\n", " \n", " \n", "B\n", "u\n", "t\n", " \n", "a\n", "s\n", " \n", "t\n", "h\n", "e\n", " \n", "r\n", "i\n", "p\n", "e\n", "r\n", " \n", "s\n", "h\n", "o\n", "u\n", "l\n", "d\n", " \n", "b\n", "y\n", " \n", "t\n", "i\n", "m\n", "e\n", " \n", "d\n", "e\n", "c\n", "e\n", "a\n", "s\n", "e\n", ",\n", "\n", "\n", " \n", " \n", "H\n", "i\n", "s\n", " \n", "t\n", "e\n", "n\n", "d\n", "e\n", "r\n", " \n", "h\n", "e\n", "i\n", "r\n", " \n", "m\n", "i\n", "g\n", "h\n", "t\n", " \n", "b\n", "e\n", "a\n", "r\n", " \n", "h\n", "i\n", "s\n", " \n", "m\n", "e\n", "m\n", "o\n", "r\n", "y\n", ":\n", "\n", "\n", " \n", " \n", "B\n", "u\n", "t\n", " \n", "t\n", "h\n", "o\n", "u\n", " \n", "c\n", "o\n", "n\n", "t\n", "r\n", "a\n", "c\n", "t\n", "e\n", "d\n", " \n", "t\n", "o\n", " \n", "t\n", "h\n", "i\n", "n\n", "e\n", " \n", "o\n", "w\n", "n\n", " \n", "b\n", "r\n", "i\n", "g\n", "h\n", "t\n", " \n", "e\n", "y\n", "e\n", "s\n", ",\n", "\n", "\n", " \n", " \n", "F\n", "e\n", "e\n", "d\n", "'\n", "s\n", "t\n", " \n", "t\n", "h\n", "y\n", " \n", "l\n", "i\n", "g\n", "h\n", "t\n", "'\n", "s\n", " \n", "f\n", "l\n", "a\n", "m\n", "e\n", " \n", "w\n", "i\n", "t\n", "h\n", " \n", "s\n", "e\n", "l\n", "f\n", "-\n", "s\n", "u\n", "b\n", "s\n", "t\n", "a\n", "n\n", "t\n", "i\n", "a\n", "l\n", " \n", "f\n", "u\n", "e\n", "l\n", ",\n", "\n", "\n", " \n", " \n", "M\n", "a\n", "k\n", "i\n", "n\n", "g\n", " \n", "a\n", " \n", "f\n", "a\n", "m\n", "i\n", "n\n", "e\n", " \n", "w\n", "h\n", "e\n", "r\n", "e\n", " \n", "a\n", "b\n", "u\n", "n\n", "d\n", "a\n", "n\n", "c\n", "e\n", " \n", "l\n", "i\n", "e\n", "s\n", ",\n", "\n", "\n", " \n", " \n", "T\n", "h\n", "y\n", " \n", "s\n", "e\n", "l\n", "f\n", " \n", "t\n", "h\n", "y\n", " \n", "f\n", "o\n", "e\n", ",\n", " \n", "t\n", "o\n", " \n", "t\n", "h\n", "y\n", " \n", "s\n", "w\n", "e\n", "e\n", "t\n", " \n", "s\n", "e\n", "l\n", "f\n", " \n", "t\n", "o\n", "o\n", " \n", "c\n", "r\n", "u\n", "e\n", "l\n", ":\n", "\n", "\n", " \n", " \n", "T\n", "h\n", "o\n", "u\n", " \n", "t\n", "h\n", "a\n", "t\n", " \n", "a\n", "r\n", "t\n", " \n", "n\n", "o\n", "w\n", " \n", "t\n", "h\n", "e\n", " \n", "w\n", "o\n", "r\n", "l\n", "d\n", "'\n", "s\n", " \n", "f\n", "r\n", "e\n", "s\n", "h\n", " \n", "o\n", "r\n", "n\n", "a\n", "m\n", "e\n", "n\n", "t\n", ",\n", "\n", "\n", " \n", " \n", "A\n", "n\n", "d\n", " \n", "o\n", "n\n", "l\n", "y\n", " \n", "h\n", "e\n", "r\n", "a\n", "l\n", "d\n", " \n", "t\n", "o\n", " \n", "t\n", "h\n", "e\n", " \n", "g\n", "a\n", "u\n", "d\n", "y\n", " \n", "s\n", "p\n", "r\n", "i\n", "n\n", "g\n", ",\n", "\n", "\n", " \n", " \n", "W\n", "i\n", "t\n", "h\n", "i\n", "n\n", " \n", "t\n", "h\n", "i\n", "n\n", "e\n", " \n", "o\n", "w\n", "n\n", " \n", "b\n", "u\n" ] } ], "source": [ "for item in char_dataset.take(500):\n", " print(ind_to_char[item.numpy()])" ] }, { "cell_type": "code", "execution_count": 43, "metadata": {}, "outputs": [], "source": [ "sequences = char_dataset.batch(seq_len+1, drop_remainder=True)" ] }, { "cell_type": "code", "execution_count": 44, "metadata": {}, "outputs": [], "source": [ "def create_seq_targets(seq):\n", " input_txt = seq[:-1]\n", " target_txt = seq[1:]\n", " return input_txt, target_txt" ] }, { "cell_type": "code", "execution_count": 45, "metadata": {}, "outputs": [], "source": [ "dataset = sequences.map(create_seq_targets)" ] }, { "cell_type": "code", "execution_count": 46, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "[ 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 12 0\n", " 1 1 31 73 70 68 1 61 56 64 73 60 74 75 1 58 73 60 56 75 76 73 60 74\n", " 1 78 60 1 59 60 74 64 73 60 1 64 69 58 73 60 56 74 60 8 0 1 1 45\n", " 63 56 75 1 75 63 60 73 60 57 80 1 57 60 56 76 75 80 5 74 1 73 70 74\n", " 60 1 68 64 62 63 75 1 69 60 77 60 73 1 59 64 60 8 0 1 1 27 76 75]\n", "\n", " 1\n", " From fairest creatures we desire increase,\n", " That thereby beauty's rose might never die,\n", " But\n", "[ 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 12 0 1\n", " 1 31 73 70 68 1 61 56 64 73 60 74 75 1 58 73 60 56 75 76 73 60 74 1\n", " 78 60 1 59 60 74 64 73 60 1 64 69 58 73 60 56 74 60 8 0 1 1 45 63\n", " 56 75 1 75 63 60 73 60 57 80 1 57 60 56 76 75 80 5 74 1 73 70 74 60\n", " 1 68 64 62 63 75 1 69 60 77 60 73 1 59 64 60 8 0 1 1 27 76 75 1]\n", " 1\n", " From fairest creatures we desire increase,\n", " That thereby beauty's rose might never die,\n", " But \n" ] } ], "source": [ "for input_txt, target_txt in dataset.take(1):\n", " print(input_txt.numpy())\n", " print(''.join(ind_to_char[input_txt.numpy()]))\n", " print(target_txt.numpy())\n", " print(''.join(ind_to_char[target_txt.numpy()]))" ] }, { "cell_type": "code", "execution_count": 47, "metadata": {}, "outputs": [], "source": [ "batch_size = 128" ] }, { "cell_type": "code", "execution_count": 48, "metadata": {}, "outputs": [], "source": [ "buffer_size = 10000" ] }, { "cell_type": "code", "execution_count": 49, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "<_BatchDataset element_spec=(TensorSpec(shape=(128, 120), dtype=tf.int32, name=None), TensorSpec(shape=(128, 120), dtype=tf.int32, name=None))>" ] }, "execution_count": 49, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dataset = dataset.shuffle(buffer_size).batch(batch_size, drop_remainder=True)\n", "dataset" ] }, { "cell_type": "code", "execution_count": 50, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "84" ] }, "execution_count": 50, "metadata": {}, "output_type": "execute_result" } ], "source": [ "vocab_size = len(vocab)\n", "vocab_size" ] }, { "cell_type": "code", "execution_count": 51, "metadata": {}, "outputs": [], "source": [ "embed_dim = 64" ] }, { "cell_type": "code", "execution_count": 52, "metadata": {}, "outputs": [], "source": [ "rnn_neurons = 1026" ] }, { "cell_type": "code", "execution_count": 53, "metadata": {}, "outputs": [], "source": [ "from tensorflow.keras.losses import sparse_categorical_crossentropy" ] }, { "cell_type": "code", "execution_count": 54, "metadata": {}, "outputs": [], "source": [ "def sparse_cat_loss(y_true, y_pred):\n", " return sparse_categorical_crossentropy(y_true, y_pred, from_logits=True)" ] }, { "cell_type": "code", "execution_count": 55, "metadata": {}, "outputs": [], "source": [ "from tensorflow.keras.models import Sequential\n", "from tensorflow.keras.layers import Embedding, GRU, Dense" ] }, { "cell_type": "code", "execution_count": 56, "metadata": {}, "outputs": [], "source": [ "def create_model(vocab_size, embed_dim, rnn_neurons, batch_size):\n", " model = Sequential()\n", "\n", " model.add(Embedding(vocab_size, embed_dim, batch_input_shape=[batch_size, None]))\n", "\n", " model.add(GRU(rnn_neurons, return_sequences=True, stateful=True, recurrent_initializer='glorot_uniform'))\n", "\n", " model.add(Dense(vocab_size))\n", "\n", " model.compile(optimizer='adam', loss=sparse_cat_loss)\n", "\n", " return model" ] }, { "cell_type": "code", "execution_count": 57, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Model: \"sequential\"\n", "_________________________________________________________________\n", " Layer (type) Output Shape Param # \n", "=================================================================\n", " embedding (Embedding) (128, None, 64) 5376 \n", " \n", " gru (GRU) (128, None, 1026) 3361176 \n", " \n", " dense (Dense) (128, None, 84) 86268 \n", " \n", "=================================================================\n", "Total params: 3452820 (13.17 MB)\n", "Trainable params: 3452820 (13.17 MB)\n", "Non-trainable params: 0 (0.00 Byte)\n", "_________________________________________________________________\n" ] } ], "source": [ "model = create_model(vocab_size=vocab_size, \n", " embed_dim=embed_dim, \n", " rnn_neurons=rnn_neurons,\n", " batch_size=batch_size)\n", "\n", "model.summary()" ] }, { "cell_type": "code", "execution_count": 58, "metadata": {}, "outputs": [], "source": [ "for input_example_batch, target_example_batch in dataset.take(1):\n", " input_example_predictions = model(input_example_batch)" ] }, { "cell_type": "code", "execution_count": 59, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 59, "metadata": {}, "output_type": "execute_result" } ], "source": [ "input_example_predictions" ] }, { "cell_type": "code", "execution_count": 60, "metadata": {}, "outputs": [], "source": [ "sampled_indices = tf.random.categorical(input_example_predictions[0], num_samples=1)" ] }, { "cell_type": "code", "execution_count": 61, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 61, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sampled_indices" ] }, { "cell_type": "code", "execution_count": 62, "metadata": {}, "outputs": [], "source": [ "sampled_indices = tf.squeeze(sampled_indices, axis=1).numpy()" ] }, { "cell_type": "code", "execution_count": 63, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array([16, 31, 40, 44, 81, 61, 42, 31, 38, 73, 0, 57, 20, 32, 41, 9, 44,\n", " 6, 78, 72, 12, 77, 48, 37, 6, 73, 52, 72, 16, 44, 10, 72, 45, 63,\n", " 29, 57, 44, 35, 33, 50, 78, 33, 44, 24, 46, 17, 34, 22, 74, 61, 51,\n", " 26, 17, 24, 16, 38, 61, 58, 42, 66, 17, 44, 24, 42, 44, 54, 53, 30,\n", " 50, 17, 73, 21, 21, 31, 35, 52, 24, 67, 44, 30, 21, 40, 3, 11, 54,\n", " 1, 56, 79, 80, 22, 82, 58, 21, 44, 30, 64, 59, 53, 40, 48, 35, 83,\n", " 15, 70, 20, 16, 75, 61, 81, 25, 0, 62, 16, 57, 23, 43, 47, 48, 0,\n", " 67], dtype=int64)" ] }, "execution_count": 63, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sampled_indices" ] }, { "cell_type": "code", "execution_count": 64, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "array(['5', 'F', 'O', 'S', 'z', 'f', 'Q', 'F', 'M', 'r', '\\n', 'b', '9',\n", " 'G', 'P', '-', 'S', '(', 'w', 'q', '1', 'v', 'W', 'L', '(', 'r',\n", " '[', 'q', '5', 'S', '.', 'q', 'T', 'h', 'D', 'b', 'S', 'J', 'H',\n", " 'Y', 'w', 'H', 'S', '>', 'U', '6', 'I', ';', 's', 'f', 'Z', 'A',\n", " '6', '>', '5', 'M', 'f', 'c', 'Q', 'k', '6', 'S', '>', 'Q', 'S',\n", " '_', ']', 'E', 'Y', '6', 'r', ':', ':', 'F', 'J', '[', '>', 'l',\n", " 'S', 'E', ':', 'O', '\"', '0', '_', ' ', 'a', 'x', 'y', ';', '|',\n", " 'c', ':', 'S', 'E', 'i', 'd', ']', 'O', 'W', 'J', '}', '4', 'o',\n", " '9', '5', 't', 'f', 'z', '?', '\\n', 'g', '5', 'b', '<', 'R', 'V',\n", " 'W', '\\n', 'l'], dtype='