Baldezo313 commited on
Commit
8ee57c7
·
1 Parent(s): bbfce04

First model version

Browse files
datasets/shakespeare.txt ADDED
The diff for this file is too large to render. See raw diff
 
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ pywin32
2
+ m2-base
src/text_generation.ipynb ADDED
@@ -0,0 +1,1839 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import numpy as np\n",
10
+ "import pandas as pd\n",
11
+ "import matplotlib.pyplot as plt\n",
12
+ "import tensorflow as tf"
13
+ ]
14
+ },
15
+ {
16
+ "cell_type": "code",
17
+ "execution_count": 11,
18
+ "metadata": {},
19
+ "outputs": [],
20
+ "source": [
21
+ "path_to_file = \"C:/Users/balde/Desktop/DSTI/Msc Applied Data Science & AI/Deep Learning/NLP/NPL-Text_Generation/datasets/shakespeare.txt\""
22
+ ]
23
+ },
24
+ {
25
+ "cell_type": "code",
26
+ "execution_count": 12,
27
+ "metadata": {},
28
+ "outputs": [],
29
+ "source": [
30
+ "text = open(path_to_file, 'r').read()"
31
+ ]
32
+ },
33
+ {
34
+ "cell_type": "code",
35
+ "execution_count": 17,
36
+ "metadata": {},
37
+ "outputs": [
38
+ {
39
+ "name": "stdout",
40
+ "output_type": "stream",
41
+ "text": [
42
+ "reward.\n",
43
+ " HELENA. Inspired merit so by breath is barr'd.\n",
44
+ " It is not so with Him that all things knows,\n",
45
+ " As 'tis with us that square our guess by shows;\n",
46
+ " But most it is presumption in us when\n",
47
+ " The help of heaven we count the act of men.\n",
48
+ " Dear sir, to my endeavours give consent;\n",
49
+ " Of heaven, not me, make an experiment.\n",
50
+ " I am not an impostor, that proclaim \n",
51
+ " Myself against the level of mine aim;\n",
52
+ " But know I think, and think I know most sure,\n",
53
+ " My art is not past power nor you past cure.\n",
54
+ " KING. Art thou so confident? Within what space\n",
55
+ " Hop'st thou my cure?\n",
56
+ " HELENA. The greatest Grace lending grace.\n",
57
+ " Ere twice the horses of the sun shall bring\n",
58
+ " Their fiery torcher his diurnal ring,\n",
59
+ " Ere twice in murk and occidental damp\n",
60
+ " Moist Hesperus hath quench'd his sleepy lamp,\n",
61
+ " Or four and twenty times the pilot's glass\n",
62
+ " Hath told the thievish minutes how they pass,\n",
63
+ " What is infirm from your sound parts shall fly,\n",
64
+ " Health shall live free, and s\n"
65
+ ]
66
+ }
67
+ ],
68
+ "source": [
69
+ "# print(text[:500])\n",
70
+ "print(text[140500:141500])"
71
+ ]
72
+ },
73
+ {
74
+ "cell_type": "code",
75
+ "execution_count": 20,
76
+ "metadata": {},
77
+ "outputs": [
78
+ {
79
+ "data": {
80
+ "text/plain": [
81
+ "['\\n',\n",
82
+ " ' ',\n",
83
+ " '!',\n",
84
+ " '\"',\n",
85
+ " '&',\n",
86
+ " \"'\",\n",
87
+ " '(',\n",
88
+ " ')',\n",
89
+ " ',',\n",
90
+ " '-',\n",
91
+ " '.',\n",
92
+ " '0',\n",
93
+ " '1',\n",
94
+ " '2',\n",
95
+ " '3',\n",
96
+ " '4',\n",
97
+ " '5',\n",
98
+ " '6',\n",
99
+ " '7',\n",
100
+ " '8',\n",
101
+ " '9',\n",
102
+ " ':',\n",
103
+ " ';',\n",
104
+ " '<',\n",
105
+ " '>',\n",
106
+ " '?',\n",
107
+ " 'A',\n",
108
+ " 'B',\n",
109
+ " 'C',\n",
110
+ " 'D',\n",
111
+ " 'E',\n",
112
+ " 'F',\n",
113
+ " 'G',\n",
114
+ " 'H',\n",
115
+ " 'I',\n",
116
+ " 'J',\n",
117
+ " 'K',\n",
118
+ " 'L',\n",
119
+ " 'M',\n",
120
+ " 'N',\n",
121
+ " 'O',\n",
122
+ " 'P',\n",
123
+ " 'Q',\n",
124
+ " 'R',\n",
125
+ " 'S',\n",
126
+ " 'T',\n",
127
+ " 'U',\n",
128
+ " 'V',\n",
129
+ " 'W',\n",
130
+ " 'X',\n",
131
+ " 'Y',\n",
132
+ " 'Z',\n",
133
+ " '[',\n",
134
+ " ']',\n",
135
+ " '_',\n",
136
+ " '`',\n",
137
+ " 'a',\n",
138
+ " 'b',\n",
139
+ " 'c',\n",
140
+ " 'd',\n",
141
+ " 'e',\n",
142
+ " 'f',\n",
143
+ " 'g',\n",
144
+ " 'h',\n",
145
+ " 'i',\n",
146
+ " 'j',\n",
147
+ " 'k',\n",
148
+ " 'l',\n",
149
+ " 'm',\n",
150
+ " 'n',\n",
151
+ " 'o',\n",
152
+ " 'p',\n",
153
+ " 'q',\n",
154
+ " 'r',\n",
155
+ " 's',\n",
156
+ " 't',\n",
157
+ " 'u',\n",
158
+ " 'v',\n",
159
+ " 'w',\n",
160
+ " 'x',\n",
161
+ " 'y',\n",
162
+ " 'z',\n",
163
+ " '|',\n",
164
+ " '}']"
165
+ ]
166
+ },
167
+ "execution_count": 20,
168
+ "metadata": {},
169
+ "output_type": "execute_result"
170
+ }
171
+ ],
172
+ "source": [
173
+ "vocab = sorted(set(text))\n",
174
+ "vocab"
175
+ ]
176
+ },
177
+ {
178
+ "cell_type": "code",
179
+ "execution_count": 21,
180
+ "metadata": {},
181
+ "outputs": [
182
+ {
183
+ "data": {
184
+ "text/plain": [
185
+ "84"
186
+ ]
187
+ },
188
+ "execution_count": 21,
189
+ "metadata": {},
190
+ "output_type": "execute_result"
191
+ }
192
+ ],
193
+ "source": [
194
+ "len(vocab)"
195
+ ]
196
+ },
197
+ {
198
+ "cell_type": "code",
199
+ "execution_count": 22,
200
+ "metadata": {},
201
+ "outputs": [
202
+ {
203
+ "name": "stdout",
204
+ "output_type": "stream",
205
+ "text": [
206
+ "(0, '\\n')\n",
207
+ "(1, ' ')\n",
208
+ "(2, '!')\n",
209
+ "(3, '\"')\n",
210
+ "(4, '&')\n",
211
+ "(5, \"'\")\n",
212
+ "(6, '(')\n",
213
+ "(7, ')')\n",
214
+ "(8, ',')\n",
215
+ "(9, '-')\n",
216
+ "(10, '.')\n",
217
+ "(11, '0')\n",
218
+ "(12, '1')\n",
219
+ "(13, '2')\n",
220
+ "(14, '3')\n",
221
+ "(15, '4')\n",
222
+ "(16, '5')\n",
223
+ "(17, '6')\n",
224
+ "(18, '7')\n",
225
+ "(19, '8')\n",
226
+ "(20, '9')\n",
227
+ "(21, ':')\n",
228
+ "(22, ';')\n",
229
+ "(23, '<')\n",
230
+ "(24, '>')\n",
231
+ "(25, '?')\n",
232
+ "(26, 'A')\n",
233
+ "(27, 'B')\n",
234
+ "(28, 'C')\n",
235
+ "(29, 'D')\n",
236
+ "(30, 'E')\n",
237
+ "(31, 'F')\n",
238
+ "(32, 'G')\n",
239
+ "(33, 'H')\n",
240
+ "(34, 'I')\n",
241
+ "(35, 'J')\n",
242
+ "(36, 'K')\n",
243
+ "(37, 'L')\n",
244
+ "(38, 'M')\n",
245
+ "(39, 'N')\n",
246
+ "(40, 'O')\n",
247
+ "(41, 'P')\n",
248
+ "(42, 'Q')\n",
249
+ "(43, 'R')\n",
250
+ "(44, 'S')\n",
251
+ "(45, 'T')\n",
252
+ "(46, 'U')\n",
253
+ "(47, 'V')\n",
254
+ "(48, 'W')\n",
255
+ "(49, 'X')\n",
256
+ "(50, 'Y')\n",
257
+ "(51, 'Z')\n",
258
+ "(52, '[')\n",
259
+ "(53, ']')\n",
260
+ "(54, '_')\n",
261
+ "(55, '`')\n",
262
+ "(56, 'a')\n",
263
+ "(57, 'b')\n",
264
+ "(58, 'c')\n",
265
+ "(59, 'd')\n",
266
+ "(60, 'e')\n",
267
+ "(61, 'f')\n",
268
+ "(62, 'g')\n",
269
+ "(63, 'h')\n",
270
+ "(64, 'i')\n",
271
+ "(65, 'j')\n",
272
+ "(66, 'k')\n",
273
+ "(67, 'l')\n",
274
+ "(68, 'm')\n",
275
+ "(69, 'n')\n",
276
+ "(70, 'o')\n",
277
+ "(71, 'p')\n",
278
+ "(72, 'q')\n",
279
+ "(73, 'r')\n",
280
+ "(74, 's')\n",
281
+ "(75, 't')\n",
282
+ "(76, 'u')\n",
283
+ "(77, 'v')\n",
284
+ "(78, 'w')\n",
285
+ "(79, 'x')\n",
286
+ "(80, 'y')\n",
287
+ "(81, 'z')\n",
288
+ "(82, '|')\n",
289
+ "(83, '}')\n"
290
+ ]
291
+ }
292
+ ],
293
+ "source": [
294
+ "for pair in enumerate(vocab):\n",
295
+ " print(pair)"
296
+ ]
297
+ },
298
+ {
299
+ "cell_type": "code",
300
+ "execution_count": 23,
301
+ "metadata": {},
302
+ "outputs": [],
303
+ "source": [
304
+ "char_to_ind = {char:ind for ind, char in enumerate(vocab)}"
305
+ ]
306
+ },
307
+ {
308
+ "cell_type": "code",
309
+ "execution_count": 24,
310
+ "metadata": {},
311
+ "outputs": [
312
+ {
313
+ "data": {
314
+ "text/plain": [
315
+ "{'\\n': 0,\n",
316
+ " ' ': 1,\n",
317
+ " '!': 2,\n",
318
+ " '\"': 3,\n",
319
+ " '&': 4,\n",
320
+ " \"'\": 5,\n",
321
+ " '(': 6,\n",
322
+ " ')': 7,\n",
323
+ " ',': 8,\n",
324
+ " '-': 9,\n",
325
+ " '.': 10,\n",
326
+ " '0': 11,\n",
327
+ " '1': 12,\n",
328
+ " '2': 13,\n",
329
+ " '3': 14,\n",
330
+ " '4': 15,\n",
331
+ " '5': 16,\n",
332
+ " '6': 17,\n",
333
+ " '7': 18,\n",
334
+ " '8': 19,\n",
335
+ " '9': 20,\n",
336
+ " ':': 21,\n",
337
+ " ';': 22,\n",
338
+ " '<': 23,\n",
339
+ " '>': 24,\n",
340
+ " '?': 25,\n",
341
+ " 'A': 26,\n",
342
+ " 'B': 27,\n",
343
+ " 'C': 28,\n",
344
+ " 'D': 29,\n",
345
+ " 'E': 30,\n",
346
+ " 'F': 31,\n",
347
+ " 'G': 32,\n",
348
+ " 'H': 33,\n",
349
+ " 'I': 34,\n",
350
+ " 'J': 35,\n",
351
+ " 'K': 36,\n",
352
+ " 'L': 37,\n",
353
+ " 'M': 38,\n",
354
+ " 'N': 39,\n",
355
+ " 'O': 40,\n",
356
+ " 'P': 41,\n",
357
+ " 'Q': 42,\n",
358
+ " 'R': 43,\n",
359
+ " 'S': 44,\n",
360
+ " 'T': 45,\n",
361
+ " 'U': 46,\n",
362
+ " 'V': 47,\n",
363
+ " 'W': 48,\n",
364
+ " 'X': 49,\n",
365
+ " 'Y': 50,\n",
366
+ " 'Z': 51,\n",
367
+ " '[': 52,\n",
368
+ " ']': 53,\n",
369
+ " '_': 54,\n",
370
+ " '`': 55,\n",
371
+ " 'a': 56,\n",
372
+ " 'b': 57,\n",
373
+ " 'c': 58,\n",
374
+ " 'd': 59,\n",
375
+ " 'e': 60,\n",
376
+ " 'f': 61,\n",
377
+ " 'g': 62,\n",
378
+ " 'h': 63,\n",
379
+ " 'i': 64,\n",
380
+ " 'j': 65,\n",
381
+ " 'k': 66,\n",
382
+ " 'l': 67,\n",
383
+ " 'm': 68,\n",
384
+ " 'n': 69,\n",
385
+ " 'o': 70,\n",
386
+ " 'p': 71,\n",
387
+ " 'q': 72,\n",
388
+ " 'r': 73,\n",
389
+ " 's': 74,\n",
390
+ " 't': 75,\n",
391
+ " 'u': 76,\n",
392
+ " 'v': 77,\n",
393
+ " 'w': 78,\n",
394
+ " 'x': 79,\n",
395
+ " 'y': 80,\n",
396
+ " 'z': 81,\n",
397
+ " '|': 82,\n",
398
+ " '}': 83}"
399
+ ]
400
+ },
401
+ "execution_count": 24,
402
+ "metadata": {},
403
+ "output_type": "execute_result"
404
+ }
405
+ ],
406
+ "source": [
407
+ "char_to_ind"
408
+ ]
409
+ },
410
+ {
411
+ "cell_type": "code",
412
+ "execution_count": 25,
413
+ "metadata": {},
414
+ "outputs": [],
415
+ "source": [
416
+ "ind_to_char = np.array(vocab)"
417
+ ]
418
+ },
419
+ {
420
+ "cell_type": "code",
421
+ "execution_count": 26,
422
+ "metadata": {},
423
+ "outputs": [
424
+ {
425
+ "data": {
426
+ "text/plain": [
427
+ "33"
428
+ ]
429
+ },
430
+ "execution_count": 26,
431
+ "metadata": {},
432
+ "output_type": "execute_result"
433
+ }
434
+ ],
435
+ "source": [
436
+ "char_to_ind['H']"
437
+ ]
438
+ },
439
+ {
440
+ "cell_type": "code",
441
+ "execution_count": 27,
442
+ "metadata": {},
443
+ "outputs": [
444
+ {
445
+ "data": {
446
+ "text/plain": [
447
+ "'H'"
448
+ ]
449
+ },
450
+ "execution_count": 27,
451
+ "metadata": {},
452
+ "output_type": "execute_result"
453
+ }
454
+ ],
455
+ "source": [
456
+ "ind_to_char[33]"
457
+ ]
458
+ },
459
+ {
460
+ "cell_type": "code",
461
+ "execution_count": 28,
462
+ "metadata": {},
463
+ "outputs": [],
464
+ "source": [
465
+ "encoded_text = np.array([char_to_ind[c] for c in text])"
466
+ ]
467
+ },
468
+ {
469
+ "cell_type": "code",
470
+ "execution_count": 29,
471
+ "metadata": {},
472
+ "outputs": [
473
+ {
474
+ "data": {
475
+ "text/plain": [
476
+ "array([ 0, 1, 1, ..., 30, 39, 29])"
477
+ ]
478
+ },
479
+ "execution_count": 29,
480
+ "metadata": {},
481
+ "output_type": "execute_result"
482
+ }
483
+ ],
484
+ "source": [
485
+ "encoded_text"
486
+ ]
487
+ },
488
+ {
489
+ "cell_type": "code",
490
+ "execution_count": 30,
491
+ "metadata": {},
492
+ "outputs": [
493
+ {
494
+ "data": {
495
+ "text/plain": [
496
+ "(5445609,)"
497
+ ]
498
+ },
499
+ "execution_count": 30,
500
+ "metadata": {},
501
+ "output_type": "execute_result"
502
+ }
503
+ ],
504
+ "source": [
505
+ "encoded_text.shape"
506
+ ]
507
+ },
508
+ {
509
+ "cell_type": "code",
510
+ "execution_count": 32,
511
+ "metadata": {},
512
+ "outputs": [
513
+ {
514
+ "name": "stdout",
515
+ "output_type": "stream",
516
+ "text": [
517
+ "\n",
518
+ " 1\n",
519
+ " From fairest creatures we desire increase,\n",
520
+ " That thereby beauty's rose might never die,\n",
521
+ " But as the riper should by time decease,\n",
522
+ " His tender heir might bear his memory:\n",
523
+ " But thou contracted to thine own bright eyes,\n",
524
+ " Feed'st thy light's flame with self-substantial fuel,\n",
525
+ " Making a famine where abundance lies,\n",
526
+ " Thy self thy foe, to thy sweet self too cruel:\n",
527
+ " Thou that art now the world's fresh ornament,\n",
528
+ " And only herald to the gaudy spring,\n",
529
+ " Within thine own bu\n"
530
+ ]
531
+ }
532
+ ],
533
+ "source": [
534
+ "sample = text[:500]\n",
535
+ "print(sample)"
536
+ ]
537
+ },
538
+ {
539
+ "cell_type": "code",
540
+ "execution_count": 33,
541
+ "metadata": {},
542
+ "outputs": [
543
+ {
544
+ "data": {
545
+ "text/plain": [
546
+ "array([ 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,\n",
547
+ " 1, 1, 1, 1, 1, 12, 0, 1, 1, 31, 73, 70, 68, 1, 61, 56, 64,\n",
548
+ " 73, 60, 74, 75, 1, 58, 73, 60, 56, 75, 76, 73, 60, 74, 1, 78, 60,\n",
549
+ " 1, 59, 60, 74, 64, 73, 60, 1, 64, 69, 58, 73, 60, 56, 74, 60, 8,\n",
550
+ " 0, 1, 1, 45, 63, 56, 75, 1, 75, 63, 60, 73, 60, 57, 80, 1, 57,\n",
551
+ " 60, 56, 76, 75, 80, 5, 74, 1, 73, 70, 74, 60, 1, 68, 64, 62, 63,\n",
552
+ " 75, 1, 69, 60, 77, 60, 73, 1, 59, 64, 60, 8, 0, 1, 1, 27, 76,\n",
553
+ " 75, 1, 56, 74, 1, 75, 63, 60, 1, 73, 64, 71, 60, 73, 1, 74, 63,\n",
554
+ " 70, 76, 67, 59, 1, 57, 80, 1, 75, 64, 68, 60, 1, 59, 60, 58, 60,\n",
555
+ " 56, 74, 60, 8, 0, 1, 1, 33, 64, 74, 1, 75, 60, 69, 59, 60, 73,\n",
556
+ " 1, 63, 60, 64, 73, 1, 68, 64, 62, 63, 75, 1, 57, 60, 56, 73, 1,\n",
557
+ " 63, 64, 74, 1, 68, 60, 68, 70, 73, 80, 21, 0, 1, 1, 27, 76, 75,\n",
558
+ " 1, 75, 63, 70, 76, 1, 58, 70, 69, 75, 73, 56, 58, 75, 60, 59, 1,\n",
559
+ " 75, 70, 1, 75, 63, 64, 69, 60, 1, 70, 78, 69, 1, 57, 73, 64, 62,\n",
560
+ " 63, 75, 1, 60, 80, 60, 74, 8, 0, 1, 1, 31, 60, 60, 59, 5, 74,\n",
561
+ " 75, 1, 75, 63, 80, 1, 67, 64, 62, 63, 75, 5, 74, 1, 61, 67, 56,\n",
562
+ " 68, 60, 1, 78, 64, 75, 63, 1, 74, 60, 67, 61, 9, 74, 76, 57, 74,\n",
563
+ " 75, 56, 69, 75, 64, 56, 67, 1, 61, 76, 60, 67, 8, 0, 1, 1, 38,\n",
564
+ " 56, 66, 64, 69, 62, 1, 56, 1, 61, 56, 68, 64, 69, 60, 1, 78, 63,\n",
565
+ " 60, 73, 60, 1, 56, 57, 76, 69, 59, 56, 69, 58, 60, 1, 67, 64, 60,\n",
566
+ " 74, 8, 0, 1, 1, 45, 63, 80, 1, 74, 60, 67, 61, 1, 75, 63, 80,\n",
567
+ " 1, 61, 70, 60, 8, 1, 75, 70, 1, 75, 63, 80, 1, 74, 78, 60, 60,\n",
568
+ " 75, 1, 74, 60, 67, 61, 1, 75, 70, 70, 1, 58, 73, 76, 60, 67, 21,\n",
569
+ " 0, 1, 1, 45, 63, 70, 76, 1, 75, 63, 56, 75, 1, 56, 73, 75, 1,\n",
570
+ " 69, 70, 78, 1, 75, 63, 60, 1, 78, 70, 73, 67, 59, 5, 74, 1, 61,\n",
571
+ " 73, 60, 74, 63, 1, 70, 73, 69, 56, 68, 60, 69, 75, 8, 0, 1, 1,\n",
572
+ " 26, 69, 59, 1, 70, 69, 67, 80, 1, 63, 60, 73, 56, 67, 59, 1, 75,\n",
573
+ " 70, 1, 75, 63, 60, 1, 62, 56, 76, 59, 80, 1, 74, 71, 73, 64, 69,\n",
574
+ " 62, 8, 0, 1, 1, 48, 64, 75, 63, 64, 69, 1, 75, 63, 64, 69, 60,\n",
575
+ " 1, 70, 78, 69, 1, 57, 76])"
576
+ ]
577
+ },
578
+ "execution_count": 33,
579
+ "metadata": {},
580
+ "output_type": "execute_result"
581
+ }
582
+ ],
583
+ "source": [
584
+ "encoded_text[:500]"
585
+ ]
586
+ },
587
+ {
588
+ "cell_type": "code",
589
+ "execution_count": 34,
590
+ "metadata": {},
591
+ "outputs": [
592
+ {
593
+ "data": {
594
+ "text/plain": [
595
+ "42"
596
+ ]
597
+ },
598
+ "execution_count": 34,
599
+ "metadata": {},
600
+ "output_type": "execute_result"
601
+ }
602
+ ],
603
+ "source": [
604
+ "line = \"From fairest creatures we desire increase,\"\n",
605
+ "len(line)"
606
+ ]
607
+ },
608
+ {
609
+ "cell_type": "code",
610
+ "execution_count": 36,
611
+ "metadata": {},
612
+ "outputs": [
613
+ {
614
+ "data": {
615
+ "text/plain": [
616
+ "133"
617
+ ]
618
+ },
619
+ "execution_count": 36,
620
+ "metadata": {},
621
+ "output_type": "execute_result"
622
+ }
623
+ ],
624
+ "source": [
625
+ "lines = '''\n",
626
+ "From fairest creatures we desire increase,\n",
627
+ " That thereby beauty's rose might never die,\n",
628
+ " But as the riper should by time decease,\n",
629
+ "'''\n",
630
+ "\n",
631
+ "len(lines)"
632
+ ]
633
+ },
634
+ {
635
+ "cell_type": "code",
636
+ "execution_count": null,
637
+ "metadata": {},
638
+ "outputs": [],
639
+ "source": []
640
+ },
641
+ {
642
+ "cell_type": "code",
643
+ "execution_count": 37,
644
+ "metadata": {},
645
+ "outputs": [],
646
+ "source": [
647
+ "seq_len = 120"
648
+ ]
649
+ },
650
+ {
651
+ "cell_type": "code",
652
+ "execution_count": 38,
653
+ "metadata": {},
654
+ "outputs": [
655
+ {
656
+ "data": {
657
+ "text/plain": [
658
+ "45005"
659
+ ]
660
+ },
661
+ "execution_count": 38,
662
+ "metadata": {},
663
+ "output_type": "execute_result"
664
+ }
665
+ ],
666
+ "source": [
667
+ "total_num_seq = len(text) // (seq_len + 1)\n",
668
+ "total_num_seq"
669
+ ]
670
+ },
671
+ {
672
+ "cell_type": "code",
673
+ "execution_count": 39,
674
+ "metadata": {},
675
+ "outputs": [],
676
+ "source": [
677
+ "char_dataset = tf.data.Dataset.from_tensor_slices(encoded_text)"
678
+ ]
679
+ },
680
+ {
681
+ "cell_type": "code",
682
+ "execution_count": 40,
683
+ "metadata": {},
684
+ "outputs": [
685
+ {
686
+ "data": {
687
+ "text/plain": [
688
+ "tensorflow.python.data.ops.from_tensor_slices_op._TensorSliceDataset"
689
+ ]
690
+ },
691
+ "execution_count": 40,
692
+ "metadata": {},
693
+ "output_type": "execute_result"
694
+ }
695
+ ],
696
+ "source": [
697
+ "type(char_dataset)"
698
+ ]
699
+ },
700
+ {
701
+ "cell_type": "code",
702
+ "execution_count": 42,
703
+ "metadata": {},
704
+ "outputs": [
705
+ {
706
+ "name": "stdout",
707
+ "output_type": "stream",
708
+ "text": [
709
+ "\n",
710
+ "\n",
711
+ " \n",
712
+ " \n",
713
+ " \n",
714
+ " \n",
715
+ " \n",
716
+ " \n",
717
+ " \n",
718
+ " \n",
719
+ " \n",
720
+ " \n",
721
+ " \n",
722
+ " \n",
723
+ " \n",
724
+ " \n",
725
+ " \n",
726
+ " \n",
727
+ " \n",
728
+ " \n",
729
+ " \n",
730
+ " \n",
731
+ " \n",
732
+ "1\n",
733
+ "\n",
734
+ "\n",
735
+ " \n",
736
+ " \n",
737
+ "F\n",
738
+ "r\n",
739
+ "o\n",
740
+ "m\n",
741
+ " \n",
742
+ "f\n",
743
+ "a\n",
744
+ "i\n",
745
+ "r\n",
746
+ "e\n",
747
+ "s\n",
748
+ "t\n",
749
+ " \n",
750
+ "c\n",
751
+ "r\n",
752
+ "e\n",
753
+ "a\n",
754
+ "t\n",
755
+ "u\n",
756
+ "r\n",
757
+ "e\n",
758
+ "s\n",
759
+ " \n",
760
+ "w\n",
761
+ "e\n",
762
+ " \n",
763
+ "d\n",
764
+ "e\n",
765
+ "s\n",
766
+ "i\n",
767
+ "r\n",
768
+ "e\n",
769
+ " \n",
770
+ "i\n",
771
+ "n\n",
772
+ "c\n",
773
+ "r\n",
774
+ "e\n",
775
+ "a\n",
776
+ "s\n",
777
+ "e\n",
778
+ ",\n",
779
+ "\n",
780
+ "\n",
781
+ " \n",
782
+ " \n",
783
+ "T\n",
784
+ "h\n",
785
+ "a\n",
786
+ "t\n",
787
+ " \n",
788
+ "t\n",
789
+ "h\n",
790
+ "e\n",
791
+ "r\n",
792
+ "e\n",
793
+ "b\n",
794
+ "y\n",
795
+ " \n",
796
+ "b\n",
797
+ "e\n",
798
+ "a\n",
799
+ "u\n",
800
+ "t\n",
801
+ "y\n",
802
+ "'\n",
803
+ "s\n",
804
+ " \n",
805
+ "r\n",
806
+ "o\n",
807
+ "s\n",
808
+ "e\n",
809
+ " \n",
810
+ "m\n",
811
+ "i\n",
812
+ "g\n",
813
+ "h\n",
814
+ "t\n",
815
+ " \n",
816
+ "n\n",
817
+ "e\n",
818
+ "v\n",
819
+ "e\n",
820
+ "r\n",
821
+ " \n",
822
+ "d\n",
823
+ "i\n",
824
+ "e\n",
825
+ ",\n",
826
+ "\n",
827
+ "\n",
828
+ " \n",
829
+ " \n",
830
+ "B\n",
831
+ "u\n",
832
+ "t\n",
833
+ " \n",
834
+ "a\n",
835
+ "s\n",
836
+ " \n",
837
+ "t\n",
838
+ "h\n",
839
+ "e\n",
840
+ " \n",
841
+ "r\n",
842
+ "i\n",
843
+ "p\n",
844
+ "e\n",
845
+ "r\n",
846
+ " \n",
847
+ "s\n",
848
+ "h\n",
849
+ "o\n",
850
+ "u\n",
851
+ "l\n",
852
+ "d\n",
853
+ " \n",
854
+ "b\n",
855
+ "y\n",
856
+ " \n",
857
+ "t\n",
858
+ "i\n",
859
+ "m\n",
860
+ "e\n",
861
+ " \n",
862
+ "d\n",
863
+ "e\n",
864
+ "c\n",
865
+ "e\n",
866
+ "a\n",
867
+ "s\n",
868
+ "e\n",
869
+ ",\n",
870
+ "\n",
871
+ "\n",
872
+ " \n",
873
+ " \n",
874
+ "H\n",
875
+ "i\n",
876
+ "s\n",
877
+ " \n",
878
+ "t\n",
879
+ "e\n",
880
+ "n\n",
881
+ "d\n",
882
+ "e\n",
883
+ "r\n",
884
+ " \n",
885
+ "h\n",
886
+ "e\n",
887
+ "i\n",
888
+ "r\n",
889
+ " \n",
890
+ "m\n",
891
+ "i\n",
892
+ "g\n",
893
+ "h\n",
894
+ "t\n",
895
+ " \n",
896
+ "b\n",
897
+ "e\n",
898
+ "a\n",
899
+ "r\n",
900
+ " \n",
901
+ "h\n",
902
+ "i\n",
903
+ "s\n",
904
+ " \n",
905
+ "m\n",
906
+ "e\n",
907
+ "m\n",
908
+ "o\n",
909
+ "r\n",
910
+ "y\n",
911
+ ":\n",
912
+ "\n",
913
+ "\n",
914
+ " \n",
915
+ " \n",
916
+ "B\n",
917
+ "u\n",
918
+ "t\n",
919
+ " \n",
920
+ "t\n",
921
+ "h\n",
922
+ "o\n",
923
+ "u\n",
924
+ " \n",
925
+ "c\n",
926
+ "o\n",
927
+ "n\n",
928
+ "t\n",
929
+ "r\n",
930
+ "a\n",
931
+ "c\n",
932
+ "t\n",
933
+ "e\n",
934
+ "d\n",
935
+ " \n",
936
+ "t\n",
937
+ "o\n",
938
+ " \n",
939
+ "t\n",
940
+ "h\n",
941
+ "i\n",
942
+ "n\n",
943
+ "e\n",
944
+ " \n",
945
+ "o\n",
946
+ "w\n",
947
+ "n\n",
948
+ " \n",
949
+ "b\n",
950
+ "r\n",
951
+ "i\n",
952
+ "g\n",
953
+ "h\n",
954
+ "t\n",
955
+ " \n",
956
+ "e\n",
957
+ "y\n",
958
+ "e\n",
959
+ "s\n",
960
+ ",\n",
961
+ "\n",
962
+ "\n",
963
+ " \n",
964
+ " \n",
965
+ "F\n",
966
+ "e\n",
967
+ "e\n",
968
+ "d\n",
969
+ "'\n",
970
+ "s\n",
971
+ "t\n",
972
+ " \n",
973
+ "t\n",
974
+ "h\n",
975
+ "y\n",
976
+ " \n",
977
+ "l\n",
978
+ "i\n",
979
+ "g\n",
980
+ "h\n",
981
+ "t\n",
982
+ "'\n",
983
+ "s\n",
984
+ " \n",
985
+ "f\n",
986
+ "l\n",
987
+ "a\n",
988
+ "m\n",
989
+ "e\n",
990
+ " \n",
991
+ "w\n",
992
+ "i\n",
993
+ "t\n",
994
+ "h\n",
995
+ " \n",
996
+ "s\n",
997
+ "e\n",
998
+ "l\n",
999
+ "f\n",
1000
+ "-\n",
1001
+ "s\n",
1002
+ "u\n",
1003
+ "b\n",
1004
+ "s\n",
1005
+ "t\n",
1006
+ "a\n",
1007
+ "n\n",
1008
+ "t\n",
1009
+ "i\n",
1010
+ "a\n",
1011
+ "l\n",
1012
+ " \n",
1013
+ "f\n",
1014
+ "u\n",
1015
+ "e\n",
1016
+ "l\n",
1017
+ ",\n",
1018
+ "\n",
1019
+ "\n",
1020
+ " \n",
1021
+ " \n",
1022
+ "M\n",
1023
+ "a\n",
1024
+ "k\n",
1025
+ "i\n",
1026
+ "n\n",
1027
+ "g\n",
1028
+ " \n",
1029
+ "a\n",
1030
+ " \n",
1031
+ "f\n",
1032
+ "a\n",
1033
+ "m\n",
1034
+ "i\n",
1035
+ "n\n",
1036
+ "e\n",
1037
+ " \n",
1038
+ "w\n",
1039
+ "h\n",
1040
+ "e\n",
1041
+ "r\n",
1042
+ "e\n",
1043
+ " \n",
1044
+ "a\n",
1045
+ "b\n",
1046
+ "u\n",
1047
+ "n\n",
1048
+ "d\n",
1049
+ "a\n",
1050
+ "n\n",
1051
+ "c\n",
1052
+ "e\n",
1053
+ " \n",
1054
+ "l\n",
1055
+ "i\n",
1056
+ "e\n",
1057
+ "s\n",
1058
+ ",\n",
1059
+ "\n",
1060
+ "\n",
1061
+ " \n",
1062
+ " \n",
1063
+ "T\n",
1064
+ "h\n",
1065
+ "y\n",
1066
+ " \n",
1067
+ "s\n",
1068
+ "e\n",
1069
+ "l\n",
1070
+ "f\n",
1071
+ " \n",
1072
+ "t\n",
1073
+ "h\n",
1074
+ "y\n",
1075
+ " \n",
1076
+ "f\n",
1077
+ "o\n",
1078
+ "e\n",
1079
+ ",\n",
1080
+ " \n",
1081
+ "t\n",
1082
+ "o\n",
1083
+ " \n",
1084
+ "t\n",
1085
+ "h\n",
1086
+ "y\n",
1087
+ " \n",
1088
+ "s\n",
1089
+ "w\n",
1090
+ "e\n",
1091
+ "e\n",
1092
+ "t\n",
1093
+ " \n",
1094
+ "s\n",
1095
+ "e\n",
1096
+ "l\n",
1097
+ "f\n",
1098
+ " \n",
1099
+ "t\n",
1100
+ "o\n",
1101
+ "o\n",
1102
+ " \n",
1103
+ "c\n",
1104
+ "r\n",
1105
+ "u\n",
1106
+ "e\n",
1107
+ "l\n",
1108
+ ":\n",
1109
+ "\n",
1110
+ "\n",
1111
+ " \n",
1112
+ " \n",
1113
+ "T\n",
1114
+ "h\n",
1115
+ "o\n",
1116
+ "u\n",
1117
+ " \n",
1118
+ "t\n",
1119
+ "h\n",
1120
+ "a\n",
1121
+ "t\n",
1122
+ " \n",
1123
+ "a\n",
1124
+ "r\n",
1125
+ "t\n",
1126
+ " \n",
1127
+ "n\n",
1128
+ "o\n",
1129
+ "w\n",
1130
+ " \n",
1131
+ "t\n",
1132
+ "h\n",
1133
+ "e\n",
1134
+ " \n",
1135
+ "w\n",
1136
+ "o\n",
1137
+ "r\n",
1138
+ "l\n",
1139
+ "d\n",
1140
+ "'\n",
1141
+ "s\n",
1142
+ " \n",
1143
+ "f\n",
1144
+ "r\n",
1145
+ "e\n",
1146
+ "s\n",
1147
+ "h\n",
1148
+ " \n",
1149
+ "o\n",
1150
+ "r\n",
1151
+ "n\n",
1152
+ "a\n",
1153
+ "m\n",
1154
+ "e\n",
1155
+ "n\n",
1156
+ "t\n",
1157
+ ",\n",
1158
+ "\n",
1159
+ "\n",
1160
+ " \n",
1161
+ " \n",
1162
+ "A\n",
1163
+ "n\n",
1164
+ "d\n",
1165
+ " \n",
1166
+ "o\n",
1167
+ "n\n",
1168
+ "l\n",
1169
+ "y\n",
1170
+ " \n",
1171
+ "h\n",
1172
+ "e\n",
1173
+ "r\n",
1174
+ "a\n",
1175
+ "l\n",
1176
+ "d\n",
1177
+ " \n",
1178
+ "t\n",
1179
+ "o\n",
1180
+ " \n",
1181
+ "t\n",
1182
+ "h\n",
1183
+ "e\n",
1184
+ " \n",
1185
+ "g\n",
1186
+ "a\n",
1187
+ "u\n",
1188
+ "d\n",
1189
+ "y\n",
1190
+ " \n",
1191
+ "s\n",
1192
+ "p\n",
1193
+ "r\n",
1194
+ "i\n",
1195
+ "n\n",
1196
+ "g\n",
1197
+ ",\n",
1198
+ "\n",
1199
+ "\n",
1200
+ " \n",
1201
+ " \n",
1202
+ "W\n",
1203
+ "i\n",
1204
+ "t\n",
1205
+ "h\n",
1206
+ "i\n",
1207
+ "n\n",
1208
+ " \n",
1209
+ "t\n",
1210
+ "h\n",
1211
+ "i\n",
1212
+ "n\n",
1213
+ "e\n",
1214
+ " \n",
1215
+ "o\n",
1216
+ "w\n",
1217
+ "n\n",
1218
+ " \n",
1219
+ "b\n",
1220
+ "u\n"
1221
+ ]
1222
+ }
1223
+ ],
1224
+ "source": [
1225
+ "for item in char_dataset.take(500):\n",
1226
+ " print(ind_to_char[item.numpy()])"
1227
+ ]
1228
+ },
1229
+ {
1230
+ "cell_type": "code",
1231
+ "execution_count": 43,
1232
+ "metadata": {},
1233
+ "outputs": [],
1234
+ "source": [
1235
+ "sequences = char_dataset.batch(seq_len+1, drop_remainder=True)"
1236
+ ]
1237
+ },
1238
+ {
1239
+ "cell_type": "code",
1240
+ "execution_count": 44,
1241
+ "metadata": {},
1242
+ "outputs": [],
1243
+ "source": [
1244
+ "def create_seq_targets(seq):\n",
1245
+ " input_txt = seq[:-1]\n",
1246
+ " target_txt = seq[1:]\n",
1247
+ " return input_txt, target_txt"
1248
+ ]
1249
+ },
1250
+ {
1251
+ "cell_type": "code",
1252
+ "execution_count": 45,
1253
+ "metadata": {},
1254
+ "outputs": [],
1255
+ "source": [
1256
+ "dataset = sequences.map(create_seq_targets)"
1257
+ ]
1258
+ },
1259
+ {
1260
+ "cell_type": "code",
1261
+ "execution_count": 46,
1262
+ "metadata": {},
1263
+ "outputs": [
1264
+ {
1265
+ "name": "stdout",
1266
+ "output_type": "stream",
1267
+ "text": [
1268
+ "[ 0 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 12 0\n",
1269
+ " 1 1 31 73 70 68 1 61 56 64 73 60 74 75 1 58 73 60 56 75 76 73 60 74\n",
1270
+ " 1 78 60 1 59 60 74 64 73 60 1 64 69 58 73 60 56 74 60 8 0 1 1 45\n",
1271
+ " 63 56 75 1 75 63 60 73 60 57 80 1 57 60 56 76 75 80 5 74 1 73 70 74\n",
1272
+ " 60 1 68 64 62 63 75 1 69 60 77 60 73 1 59 64 60 8 0 1 1 27 76 75]\n",
1273
+ "\n",
1274
+ " 1\n",
1275
+ " From fairest creatures we desire increase,\n",
1276
+ " That thereby beauty's rose might never die,\n",
1277
+ " But\n",
1278
+ "[ 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 12 0 1\n",
1279
+ " 1 31 73 70 68 1 61 56 64 73 60 74 75 1 58 73 60 56 75 76 73 60 74 1\n",
1280
+ " 78 60 1 59 60 74 64 73 60 1 64 69 58 73 60 56 74 60 8 0 1 1 45 63\n",
1281
+ " 56 75 1 75 63 60 73 60 57 80 1 57 60 56 76 75 80 5 74 1 73 70 74 60\n",
1282
+ " 1 68 64 62 63 75 1 69 60 77 60 73 1 59 64 60 8 0 1 1 27 76 75 1]\n",
1283
+ " 1\n",
1284
+ " From fairest creatures we desire increase,\n",
1285
+ " That thereby beauty's rose might never die,\n",
1286
+ " But \n"
1287
+ ]
1288
+ }
1289
+ ],
1290
+ "source": [
1291
+ "for input_txt, target_txt in dataset.take(1):\n",
1292
+ " print(input_txt.numpy())\n",
1293
+ " print(''.join(ind_to_char[input_txt.numpy()]))\n",
1294
+ " print(target_txt.numpy())\n",
1295
+ " print(''.join(ind_to_char[target_txt.numpy()]))"
1296
+ ]
1297
+ },
1298
+ {
1299
+ "cell_type": "code",
1300
+ "execution_count": 47,
1301
+ "metadata": {},
1302
+ "outputs": [],
1303
+ "source": [
1304
+ "batch_size = 128"
1305
+ ]
1306
+ },
1307
+ {
1308
+ "cell_type": "code",
1309
+ "execution_count": 48,
1310
+ "metadata": {},
1311
+ "outputs": [],
1312
+ "source": [
1313
+ "buffer_size = 10000"
1314
+ ]
1315
+ },
1316
+ {
1317
+ "cell_type": "code",
1318
+ "execution_count": 49,
1319
+ "metadata": {},
1320
+ "outputs": [
1321
+ {
1322
+ "data": {
1323
+ "text/plain": [
1324
+ "<_BatchDataset element_spec=(TensorSpec(shape=(128, 120), dtype=tf.int32, name=None), TensorSpec(shape=(128, 120), dtype=tf.int32, name=None))>"
1325
+ ]
1326
+ },
1327
+ "execution_count": 49,
1328
+ "metadata": {},
1329
+ "output_type": "execute_result"
1330
+ }
1331
+ ],
1332
+ "source": [
1333
+ "dataset = dataset.shuffle(buffer_size).batch(batch_size, drop_remainder=True)\n",
1334
+ "dataset"
1335
+ ]
1336
+ },
1337
+ {
1338
+ "cell_type": "code",
1339
+ "execution_count": 50,
1340
+ "metadata": {},
1341
+ "outputs": [
1342
+ {
1343
+ "data": {
1344
+ "text/plain": [
1345
+ "84"
1346
+ ]
1347
+ },
1348
+ "execution_count": 50,
1349
+ "metadata": {},
1350
+ "output_type": "execute_result"
1351
+ }
1352
+ ],
1353
+ "source": [
1354
+ "vocab_size = len(vocab)\n",
1355
+ "vocab_size"
1356
+ ]
1357
+ },
1358
+ {
1359
+ "cell_type": "code",
1360
+ "execution_count": 51,
1361
+ "metadata": {},
1362
+ "outputs": [],
1363
+ "source": [
1364
+ "embed_dim = 64"
1365
+ ]
1366
+ },
1367
+ {
1368
+ "cell_type": "code",
1369
+ "execution_count": 52,
1370
+ "metadata": {},
1371
+ "outputs": [],
1372
+ "source": [
1373
+ "rnn_neurons = 1026"
1374
+ ]
1375
+ },
1376
+ {
1377
+ "cell_type": "code",
1378
+ "execution_count": 53,
1379
+ "metadata": {},
1380
+ "outputs": [],
1381
+ "source": [
1382
+ "from tensorflow.keras.losses import sparse_categorical_crossentropy"
1383
+ ]
1384
+ },
1385
+ {
1386
+ "cell_type": "code",
1387
+ "execution_count": 54,
1388
+ "metadata": {},
1389
+ "outputs": [],
1390
+ "source": [
1391
+ "def sparse_cat_loss(y_true, y_pred):\n",
1392
+ " return sparse_categorical_crossentropy(y_true, y_pred, from_logits=True)"
1393
+ ]
1394
+ },
1395
+ {
1396
+ "cell_type": "code",
1397
+ "execution_count": 55,
1398
+ "metadata": {},
1399
+ "outputs": [],
1400
+ "source": [
1401
+ "from tensorflow.keras.models import Sequential\n",
1402
+ "from tensorflow.keras.layers import Embedding, GRU, Dense"
1403
+ ]
1404
+ },
1405
+ {
1406
+ "cell_type": "code",
1407
+ "execution_count": 56,
1408
+ "metadata": {},
1409
+ "outputs": [],
1410
+ "source": [
1411
+ "def create_model(vocab_size, embed_dim, rnn_neurons, batch_size):\n",
1412
+ " model = Sequential()\n",
1413
+ "\n",
1414
+ " model.add(Embedding(vocab_size, embed_dim, batch_input_shape=[batch_size, None]))\n",
1415
+ "\n",
1416
+ " model.add(GRU(rnn_neurons, return_sequences=True, stateful=True, recurrent_initializer='glorot_uniform'))\n",
1417
+ "\n",
1418
+ " model.add(Dense(vocab_size))\n",
1419
+ "\n",
1420
+ " model.compile(optimizer='adam', loss=sparse_cat_loss)\n",
1421
+ "\n",
1422
+ " return model"
1423
+ ]
1424
+ },
1425
+ {
1426
+ "cell_type": "code",
1427
+ "execution_count": 57,
1428
+ "metadata": {},
1429
+ "outputs": [
1430
+ {
1431
+ "name": "stdout",
1432
+ "output_type": "stream",
1433
+ "text": [
1434
+ "Model: \"sequential\"\n",
1435
+ "_________________________________________________________________\n",
1436
+ " Layer (type) Output Shape Param # \n",
1437
+ "=================================================================\n",
1438
+ " embedding (Embedding) (128, None, 64) 5376 \n",
1439
+ " \n",
1440
+ " gru (GRU) (128, None, 1026) 3361176 \n",
1441
+ " \n",
1442
+ " dense (Dense) (128, None, 84) 86268 \n",
1443
+ " \n",
1444
+ "=================================================================\n",
1445
+ "Total params: 3452820 (13.17 MB)\n",
1446
+ "Trainable params: 3452820 (13.17 MB)\n",
1447
+ "Non-trainable params: 0 (0.00 Byte)\n",
1448
+ "_________________________________________________________________\n"
1449
+ ]
1450
+ }
1451
+ ],
1452
+ "source": [
1453
+ "model = create_model(vocab_size=vocab_size, \n",
1454
+ " embed_dim=embed_dim, \n",
1455
+ " rnn_neurons=rnn_neurons,\n",
1456
+ " batch_size=batch_size)\n",
1457
+ "\n",
1458
+ "model.summary()"
1459
+ ]
1460
+ },
1461
+ {
1462
+ "cell_type": "code",
1463
+ "execution_count": 58,
1464
+ "metadata": {},
1465
+ "outputs": [],
1466
+ "source": [
1467
+ "for input_example_batch, target_example_batch in dataset.take(1):\n",
1468
+ " input_example_predictions = model(input_example_batch)"
1469
+ ]
1470
+ },
1471
+ {
1472
+ "cell_type": "code",
1473
+ "execution_count": 59,
1474
+ "metadata": {},
1475
+ "outputs": [
1476
+ {
1477
+ "data": {
1478
+ "text/plain": [
1479
+ "<tf.Tensor: shape=(128, 120, 84), dtype=float32, numpy=\n",
1480
+ "array([[[-1.26658962e-03, -1.02281375e-02, -5.01847127e-03, ...,\n",
1481
+ " -3.02844611e-03, 6.41892198e-03, -3.31320823e-03],\n",
1482
+ " [-1.21301366e-03, 3.18358769e-04, -9.04226955e-03, ...,\n",
1483
+ " -7.06026657e-03, -4.07771766e-03, 3.90136405e-03],\n",
1484
+ " [ 2.12129857e-03, 3.73512739e-05, 1.03873445e-03, ...,\n",
1485
+ " -5.67088660e-04, -3.33711831e-03, 7.81264342e-03],\n",
1486
+ " ...,\n",
1487
+ " [-1.05437764e-03, 6.62492588e-03, -2.61027599e-04, ...,\n",
1488
+ " -1.16620697e-02, -3.73046333e-03, 4.27998928e-03],\n",
1489
+ " [-4.87042870e-03, 8.28131475e-03, -3.26290075e-03, ...,\n",
1490
+ " -1.36158746e-02, -6.28873426e-03, 4.68202401e-03],\n",
1491
+ " [ 3.73602519e-03, 7.49139953e-03, 2.62855785e-03, ...,\n",
1492
+ " -9.11762752e-03, -2.22274661e-03, -1.46359345e-03]],\n",
1493
+ "\n",
1494
+ " [[ 1.65691471e-03, 2.78515508e-05, -1.75164896e-05, ...,\n",
1495
+ " -1.03196641e-02, -1.14688405e-03, 7.50818709e-03],\n",
1496
+ " [ 4.01926955e-04, 8.56293645e-03, -2.48706527e-03, ...,\n",
1497
+ " -7.24339113e-03, -8.28842982e-04, -3.51517042e-03],\n",
1498
+ " [ 2.97096046e-03, 3.00624478e-03, 5.37311751e-03, ...,\n",
1499
+ " -7.55489280e-04, -2.63659190e-03, 3.83156352e-03],\n",
1500
+ " ...,\n",
1501
+ " [-6.00246480e-04, -9.83457896e-04, -3.51762777e-04, ...,\n",
1502
+ " 6.29320042e-04, -9.89628583e-03, 8.98226909e-03],\n",
1503
+ " [-4.09764796e-03, 5.64620737e-03, 1.21265789e-03, ...,\n",
1504
+ " -1.10058172e-03, -4.23033535e-03, 7.76559464e-04],\n",
1505
+ " [-1.02544213e-02, -4.39250330e-03, 4.08628071e-03, ...,\n",
1506
+ " -1.39716011e-03, -7.45914457e-03, -8.94208997e-03]],\n",
1507
+ "\n",
1508
+ " [[-7.31695490e-03, 7.70223187e-03, -3.47627047e-03, ...,\n",
1509
+ " -3.14399763e-03, -4.83559561e-05, -1.66273641e-03],\n",
1510
+ " [ 3.19275842e-03, 7.04296818e-03, 4.52343049e-03, ...,\n",
1511
+ " -3.37669649e-03, 6.39380887e-05, -3.89098749e-03],\n",
1512
+ " [ 2.26991810e-03, 7.66665256e-03, -3.74295679e-03, ...,\n",
1513
+ " -6.55478938e-03, -8.11303221e-03, 4.04081633e-03],\n",
1514
+ " ...,\n",
1515
+ " [-1.24336348e-03, 3.24544404e-03, -2.19549867e-03, ...,\n",
1516
+ " -1.23574454e-02, -7.09445961e-03, 1.07077677e-02],\n",
1517
+ " [-2.10441439e-03, -3.01999808e-03, 4.96061705e-03, ...,\n",
1518
+ " -1.12426355e-02, -1.87711930e-03, 1.15814880e-02],\n",
1519
+ " [-2.55338964e-03, 2.37546698e-03, 7.44714448e-03, ...,\n",
1520
+ " -1.10822935e-02, -4.92575718e-03, 7.61651807e-03]],\n",
1521
+ "\n",
1522
+ " ...,\n",
1523
+ "\n",
1524
+ " [[-3.04947514e-03, -4.48098173e-04, -4.28649737e-03, ...,\n",
1525
+ " -3.67468246e-03, -6.06621569e-03, 4.82408609e-03],\n",
1526
+ " [ 3.38326930e-03, 2.39760685e-03, -4.43490362e-03, ...,\n",
1527
+ " -6.41413406e-03, -2.42703035e-03, 8.07784870e-03],\n",
1528
+ " [ 3.51318740e-03, 1.01979310e-03, -2.10099528e-03, ...,\n",
1529
+ " -1.37671335e-02, -3.03332042e-03, 1.16969068e-02],\n",
1530
+ " ...,\n",
1531
+ " [-2.51456769e-03, 3.08659696e-03, -4.09391802e-03, ...,\n",
1532
+ " -3.73627315e-03, -8.52109678e-03, 2.93066399e-03],\n",
1533
+ " [-3.25050764e-03, -4.07967670e-03, 8.40548746e-05, ...,\n",
1534
+ " -1.96301658e-03, 5.53767779e-04, 1.12879812e-03],\n",
1535
+ " [-3.65182478e-03, -1.62036659e-03, 6.76601776e-04, ...,\n",
1536
+ " 1.74052105e-03, -4.68559912e-04, -2.90953903e-06]],\n",
1537
+ "\n",
1538
+ " [[ 2.71482952e-03, -6.13818120e-04, -4.99741035e-03, ...,\n",
1539
+ " 1.30506267e-03, 1.26352720e-03, -3.16191092e-03],\n",
1540
+ " [-2.35046493e-03, 1.47341855e-03, 1.92235212e-03, ...,\n",
1541
+ " 4.15714085e-03, 7.70311628e-04, -3.43234511e-03],\n",
1542
+ " [-2.23669480e-03, -1.88404694e-03, 3.89048969e-03, ...,\n",
1543
+ " -1.00363541e-04, -4.45647072e-03, -1.20103837e-03],\n",
1544
+ " ...,\n",
1545
+ " [ 1.74792449e-03, -1.49176281e-04, 2.31775595e-03, ...,\n",
1546
+ " -2.45657354e-03, -5.49030956e-03, 1.05382213e-02],\n",
1547
+ " [-2.09265738e-03, 7.97000830e-04, 3.14242928e-03, ...,\n",
1548
+ " 1.77775964e-03, -3.20373825e-03, 2.90514552e-03],\n",
1549
+ " [-1.35548890e-03, -2.83561996e-03, 2.86235218e-03, ...,\n",
1550
+ " -1.70795049e-03, -6.73092064e-03, 1.74498581e-03]],\n",
1551
+ "\n",
1552
+ " [[ 5.28363232e-03, 1.80142722e-03, 3.59522342e-03, ...,\n",
1553
+ " 1.53130700e-03, 9.43017891e-04, 9.18017875e-04],\n",
1554
+ " [-4.08536335e-03, -4.53328993e-03, 4.62532882e-03, ...,\n",
1555
+ " 1.57068600e-04, -6.36877678e-03, -9.60698817e-03],\n",
1556
+ " [ 4.26323013e-03, 3.35310609e-03, 3.84865189e-03, ...,\n",
1557
+ " 5.65373246e-03, -3.62121896e-03, -3.28896893e-03],\n",
1558
+ " ...,\n",
1559
+ " [-2.07619974e-03, -3.01369117e-03, -2.72819865e-03, ...,\n",
1560
+ " -5.44426683e-03, -2.23890383e-04, 7.91562069e-03],\n",
1561
+ " [-2.60874163e-03, 6.16421178e-03, -4.43145679e-03, ...,\n",
1562
+ " -4.63165995e-03, 8.77770421e-04, -3.59299872e-03],\n",
1563
+ " [-3.71062336e-03, -3.75270541e-03, -2.04372234e-04, ...,\n",
1564
+ " -2.80867866e-03, 5.48168877e-03, -2.82955472e-03]]],\n",
1565
+ " dtype=float32)>"
1566
+ ]
1567
+ },
1568
+ "execution_count": 59,
1569
+ "metadata": {},
1570
+ "output_type": "execute_result"
1571
+ }
1572
+ ],
1573
+ "source": [
1574
+ "input_example_predictions"
1575
+ ]
1576
+ },
1577
+ {
1578
+ "cell_type": "code",
1579
+ "execution_count": 60,
1580
+ "metadata": {},
1581
+ "outputs": [],
1582
+ "source": [
1583
+ "sampled_indices = tf.random.categorical(input_example_predictions[0], num_samples=1)"
1584
+ ]
1585
+ },
1586
+ {
1587
+ "cell_type": "code",
1588
+ "execution_count": 61,
1589
+ "metadata": {},
1590
+ "outputs": [
1591
+ {
1592
+ "data": {
1593
+ "text/plain": [
1594
+ "<tf.Tensor: shape=(120, 1), dtype=int64, numpy=\n",
1595
+ "array([[16],\n",
1596
+ " [31],\n",
1597
+ " [40],\n",
1598
+ " [44],\n",
1599
+ " [81],\n",
1600
+ " [61],\n",
1601
+ " [42],\n",
1602
+ " [31],\n",
1603
+ " [38],\n",
1604
+ " [73],\n",
1605
+ " [ 0],\n",
1606
+ " [57],\n",
1607
+ " [20],\n",
1608
+ " [32],\n",
1609
+ " [41],\n",
1610
+ " [ 9],\n",
1611
+ " [44],\n",
1612
+ " [ 6],\n",
1613
+ " [78],\n",
1614
+ " [72],\n",
1615
+ " [12],\n",
1616
+ " [77],\n",
1617
+ " [48],\n",
1618
+ " [37],\n",
1619
+ " [ 6],\n",
1620
+ " [73],\n",
1621
+ " [52],\n",
1622
+ " [72],\n",
1623
+ " [16],\n",
1624
+ " [44],\n",
1625
+ " [10],\n",
1626
+ " [72],\n",
1627
+ " [45],\n",
1628
+ " [63],\n",
1629
+ " [29],\n",
1630
+ " [57],\n",
1631
+ " [44],\n",
1632
+ " [35],\n",
1633
+ " [33],\n",
1634
+ " [50],\n",
1635
+ " [78],\n",
1636
+ " [33],\n",
1637
+ " [44],\n",
1638
+ " [24],\n",
1639
+ " [46],\n",
1640
+ " [17],\n",
1641
+ " [34],\n",
1642
+ " [22],\n",
1643
+ " [74],\n",
1644
+ " [61],\n",
1645
+ " [51],\n",
1646
+ " [26],\n",
1647
+ " [17],\n",
1648
+ " [24],\n",
1649
+ " [16],\n",
1650
+ " [38],\n",
1651
+ " [61],\n",
1652
+ " [58],\n",
1653
+ " [42],\n",
1654
+ " [66],\n",
1655
+ " [17],\n",
1656
+ " [44],\n",
1657
+ " [24],\n",
1658
+ " [42],\n",
1659
+ " [44],\n",
1660
+ " [54],\n",
1661
+ " [53],\n",
1662
+ " [30],\n",
1663
+ " [50],\n",
1664
+ " [17],\n",
1665
+ " [73],\n",
1666
+ " [21],\n",
1667
+ " [21],\n",
1668
+ " [31],\n",
1669
+ " [35],\n",
1670
+ " [52],\n",
1671
+ " [24],\n",
1672
+ " [67],\n",
1673
+ " [44],\n",
1674
+ " [30],\n",
1675
+ " [21],\n",
1676
+ " [40],\n",
1677
+ " [ 3],\n",
1678
+ " [11],\n",
1679
+ " [54],\n",
1680
+ " [ 1],\n",
1681
+ " [56],\n",
1682
+ " [79],\n",
1683
+ " [80],\n",
1684
+ " [22],\n",
1685
+ " [82],\n",
1686
+ " [58],\n",
1687
+ " [21],\n",
1688
+ " [44],\n",
1689
+ " [30],\n",
1690
+ " [64],\n",
1691
+ " [59],\n",
1692
+ " [53],\n",
1693
+ " [40],\n",
1694
+ " [48],\n",
1695
+ " [35],\n",
1696
+ " [83],\n",
1697
+ " [15],\n",
1698
+ " [70],\n",
1699
+ " [20],\n",
1700
+ " [16],\n",
1701
+ " [75],\n",
1702
+ " [61],\n",
1703
+ " [81],\n",
1704
+ " [25],\n",
1705
+ " [ 0],\n",
1706
+ " [62],\n",
1707
+ " [16],\n",
1708
+ " [57],\n",
1709
+ " [23],\n",
1710
+ " [43],\n",
1711
+ " [47],\n",
1712
+ " [48],\n",
1713
+ " [ 0],\n",
1714
+ " [67]], dtype=int64)>"
1715
+ ]
1716
+ },
1717
+ "execution_count": 61,
1718
+ "metadata": {},
1719
+ "output_type": "execute_result"
1720
+ }
1721
+ ],
1722
+ "source": [
1723
+ "sampled_indices"
1724
+ ]
1725
+ },
1726
+ {
1727
+ "cell_type": "code",
1728
+ "execution_count": 62,
1729
+ "metadata": {},
1730
+ "outputs": [],
1731
+ "source": [
1732
+ "sampled_indices = tf.squeeze(sampled_indices, axis=1).numpy()"
1733
+ ]
1734
+ },
1735
+ {
1736
+ "cell_type": "code",
1737
+ "execution_count": 63,
1738
+ "metadata": {},
1739
+ "outputs": [
1740
+ {
1741
+ "data": {
1742
+ "text/plain": [
1743
+ "array([16, 31, 40, 44, 81, 61, 42, 31, 38, 73, 0, 57, 20, 32, 41, 9, 44,\n",
1744
+ " 6, 78, 72, 12, 77, 48, 37, 6, 73, 52, 72, 16, 44, 10, 72, 45, 63,\n",
1745
+ " 29, 57, 44, 35, 33, 50, 78, 33, 44, 24, 46, 17, 34, 22, 74, 61, 51,\n",
1746
+ " 26, 17, 24, 16, 38, 61, 58, 42, 66, 17, 44, 24, 42, 44, 54, 53, 30,\n",
1747
+ " 50, 17, 73, 21, 21, 31, 35, 52, 24, 67, 44, 30, 21, 40, 3, 11, 54,\n",
1748
+ " 1, 56, 79, 80, 22, 82, 58, 21, 44, 30, 64, 59, 53, 40, 48, 35, 83,\n",
1749
+ " 15, 70, 20, 16, 75, 61, 81, 25, 0, 62, 16, 57, 23, 43, 47, 48, 0,\n",
1750
+ " 67], dtype=int64)"
1751
+ ]
1752
+ },
1753
+ "execution_count": 63,
1754
+ "metadata": {},
1755
+ "output_type": "execute_result"
1756
+ }
1757
+ ],
1758
+ "source": [
1759
+ "sampled_indices"
1760
+ ]
1761
+ },
1762
+ {
1763
+ "cell_type": "code",
1764
+ "execution_count": 64,
1765
+ "metadata": {},
1766
+ "outputs": [
1767
+ {
1768
+ "data": {
1769
+ "text/plain": [
1770
+ "array(['5', 'F', 'O', 'S', 'z', 'f', 'Q', 'F', 'M', 'r', '\\n', 'b', '9',\n",
1771
+ " 'G', 'P', '-', 'S', '(', 'w', 'q', '1', 'v', 'W', 'L', '(', 'r',\n",
1772
+ " '[', 'q', '5', 'S', '.', 'q', 'T', 'h', 'D', 'b', 'S', 'J', 'H',\n",
1773
+ " 'Y', 'w', 'H', 'S', '>', 'U', '6', 'I', ';', 's', 'f', 'Z', 'A',\n",
1774
+ " '6', '>', '5', 'M', 'f', 'c', 'Q', 'k', '6', 'S', '>', 'Q', 'S',\n",
1775
+ " '_', ']', 'E', 'Y', '6', 'r', ':', ':', 'F', 'J', '[', '>', 'l',\n",
1776
+ " 'S', 'E', ':', 'O', '\"', '0', '_', ' ', 'a', 'x', 'y', ';', '|',\n",
1777
+ " 'c', ':', 'S', 'E', 'i', 'd', ']', 'O', 'W', 'J', '}', '4', 'o',\n",
1778
+ " '9', '5', 't', 'f', 'z', '?', '\\n', 'g', '5', 'b', '<', 'R', 'V',\n",
1779
+ " 'W', '\\n', 'l'], dtype='<U1')"
1780
+ ]
1781
+ },
1782
+ "execution_count": 64,
1783
+ "metadata": {},
1784
+ "output_type": "execute_result"
1785
+ }
1786
+ ],
1787
+ "source": [
1788
+ "ind_to_char[sampled_indices]"
1789
+ ]
1790
+ },
1791
+ {
1792
+ "cell_type": "code",
1793
+ "execution_count": 65,
1794
+ "metadata": {},
1795
+ "outputs": [],
1796
+ "source": [
1797
+ "epochs = 30\n"
1798
+ ]
1799
+ },
1800
+ {
1801
+ "cell_type": "code",
1802
+ "execution_count": null,
1803
+ "metadata": {},
1804
+ "outputs": [],
1805
+ "source": [
1806
+ "model.fit(dataset, epochs=epochs)"
1807
+ ]
1808
+ },
1809
+ {
1810
+ "cell_type": "code",
1811
+ "execution_count": null,
1812
+ "metadata": {},
1813
+ "outputs": [],
1814
+ "source": []
1815
+ }
1816
+ ],
1817
+ "metadata": {
1818
+ "kernelspec": {
1819
+ "display_name": "env_nlp",
1820
+ "language": "python",
1821
+ "name": "python3"
1822
+ },
1823
+ "language_info": {
1824
+ "codemirror_mode": {
1825
+ "name": "ipython",
1826
+ "version": 3
1827
+ },
1828
+ "file_extension": ".py",
1829
+ "mimetype": "text/x-python",
1830
+ "name": "python",
1831
+ "nbconvert_exporter": "python",
1832
+ "pygments_lexer": "ipython3",
1833
+ "version": "3.11.5"
1834
+ },
1835
+ "orig_nbformat": 4
1836
+ },
1837
+ "nbformat": 4,
1838
+ "nbformat_minor": 2
1839
+ }