Chandrasekar A Pasumarthi commited on
Commit
20691fe
·
verified ·
1 Parent(s): 1cfe2e1

Upload LSTMPytorchandLightning (1).ipynb

Browse files
Files changed (1) hide show
  1. LSTMPytorchandLightning (1).ipynb +1349 -0
LSTMPytorchandLightning (1).ipynb ADDED
@@ -0,0 +1,1349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 14,
6
+ "metadata": {
7
+ "colab": {
8
+ "base_uri": "https://localhost:8080/"
9
+ },
10
+ "id": "N4PnG_qEpFB3",
11
+ "outputId": "f7dcdf6a-c1a3-4faa-9675-cc4b4f25c232"
12
+ },
13
+ "outputs": [],
14
+ "source": [
15
+ "import torch\n",
16
+ "import torch.nn as nn\n",
17
+ "import torch.nn.functional as F\n",
18
+ "from torch.optim import Adam\n",
19
+ "\n",
20
+ "import lightning as L\n",
21
+ "from torch.utils.data import TensorDataset, DataLoader"
22
+ ]
23
+ },
24
+ {
25
+ "cell_type": "markdown",
26
+ "metadata": {},
27
+ "source": [
28
+ "LSTM from Scratch:"
29
+ ]
30
+ },
31
+ {
32
+ "cell_type": "code",
33
+ "execution_count": 15,
34
+ "metadata": {
35
+ "id": "c8iFQOAFsOfC"
36
+ },
37
+ "outputs": [],
38
+ "source": [
39
+ "#Outline of an LSTM Class:\n",
40
+ "class LSTMfromScratch(L.LightningModule):\n",
41
+ " def __init__(self):\n",
42
+ " # Initalize weights and biases\n",
43
+ " super().__init__()\n",
44
+ " mean = torch.tensor(0.0)\n",
45
+ " std = torch.tensor(1.0)\n",
46
+ "\n",
47
+ " self.wfp1 = nn.Parameter(torch.normal(mean=mean, std=std), requires_grad=True) # The wf means the weight at the forget gate and the p means this weight is used in the sigmoid later to get the percentage\n",
48
+ " self.wfp2 = nn.Parameter(torch.normal(mean=mean, std=std), requires_grad=True)\n",
49
+ " self.bfp1 = nn.Parameter(torch.tensor(0.0), requires_grad=True) # The bf means the bias at the forget gate and the p means this weight is used in the sigmoid later to get the percentage\n",
50
+ "\n",
51
+ " self.wip1 = nn.Parameter(torch.normal(mean=mean, std=std), requires_grad=True)# The wi means the weight at the intput gate and the p means this weight is used in the sigmoid later to get the percentage\n",
52
+ " self.wip2 = nn.Parameter(torch.normal(mean=mean, std=std), requires_grad=True)\n",
53
+ " self.bip1 = nn.Parameter(torch.tensor(0.0), requires_grad=True)# The bi means the bias at the input gate and the p means this weight is used in the sigmoid later to get the percentage\n",
54
+ "\n",
55
+ " self.wi3 = nn.Parameter(torch.normal(mean=mean, std=std), requires_grad=True) # These do not have p because they are used in tanH actv fucntions to make possible predictions\n",
56
+ " self.wi4 = nn.Parameter(torch.normal(mean=mean, std=std), requires_grad=True)\n",
57
+ " self.bi2 = nn.Parameter(torch.tensor(0.0), requires_grad=True)\n",
58
+ "\n",
59
+ " self.wop1 = nn.Parameter(torch.normal(mean=mean, std=std), requires_grad=True)\n",
60
+ " self.wop2 = nn.Parameter(torch.normal(mean=mean, std=std), requires_grad=True)\n",
61
+ " self.bop1 = nn.Parameter(torch.normal(mean=mean, std=std), requires_grad=True)\n",
62
+ "\n",
63
+ " def lstm_unit(self, input_value, long_mem, short_mem):\n",
64
+ " # This is where the math is done in the lstm\n",
65
+ " long_remem_percent = torch.sigmoid((short_mem*self.wfp1) + (input_value*self.wfp2) + self.bfp1)\n",
66
+ "\n",
67
+ " potenital_long_mem_percent = torch.sigmoid((short_mem*self.wip1) + (input_value*self.wip2) + self.bip1)\n",
68
+ " potential_mem = torch.tanh((short_mem * self.wi3) + (input_value*self.wi4) + self.bi2)\n",
69
+ "\n",
70
+ " updated_long_term_mem = (long_mem * long_remem_percent) + (potential_mem * potenital_long_mem_percent)\n",
71
+ "\n",
72
+ " ouput_percent = torch.sigmoid((short_mem*self.wop1) + (input_value * self.wop2) + self.bop1)\n",
73
+ " updated_short_mem = torch.tanh(updated_long_term_mem) * ouput_percent\n",
74
+ "\n",
75
+ " return [updated_long_term_mem, updated_short_mem]\n",
76
+ "\n",
77
+ " def forward(self, input):\n",
78
+ " # We do forward pass here\n",
79
+ " long_mem = 0\n",
80
+ " short_mem = 0\n",
81
+ " day1 = input[0]\n",
82
+ " day2 = input[1]\n",
83
+ " day3 = input[2]\n",
84
+ " day4 = input[3]\n",
85
+ "\n",
86
+ " long_mem, short_mem = self.lstm_unit(day1, long_mem, short_mem)\n",
87
+ " long_mem, short_mem = self.lstm_unit(day2, long_mem, short_mem)\n",
88
+ " long_mem, short_mem = self.lstm_unit(day3, long_mem, short_mem)\n",
89
+ " long_mem, short_mem = self.lstm_unit(day4, long_mem, short_mem)\n",
90
+ "\n",
91
+ " return short_mem\n",
92
+ "\n",
93
+ " def configure_optimizers(self):\n",
94
+ " # Used to configure the Adam optimizer\n",
95
+ " return Adam(self.parameters())\n",
96
+ " def training_step(self, batch, batch_idx):\n",
97
+ " # Used to calculate loss and log training progress\n",
98
+ " # Logging the loss (or trainging progress) will tell you when to stop training\n",
99
+ " input_i, label_i = batch\n",
100
+ " output_i = self.forward(input_i[0])\n",
101
+ " loss = (output_i - label_i)**2\n",
102
+ "\n",
103
+ " self.log(\"train_loss\", loss) # This is a lightning module that we inherited which is able to make a new directory called lightning_logs which has a file that can log and store our loss\n",
104
+ " # Here we are logging our ouptut based on which company we just predicted (company A is out_0 and company B is out_1), and you don't have to do this since it is only apart of the example\n",
105
+ " if label_i == 0:\n",
106
+ " self.log(\"out_0\", output_i)\n",
107
+ " else:\n",
108
+ " self.log(\"out_1\", output_i)\n",
109
+ "\n",
110
+ " return loss"
111
+ ]
112
+ },
113
+ {
114
+ "cell_type": "code",
115
+ "execution_count": 16,
116
+ "metadata": {
117
+ "colab": {
118
+ "base_uri": "https://localhost:8080/"
119
+ },
120
+ "id": "B9FlbMItJxGA",
121
+ "outputId": "cead131b-3bec-4255-da3e-cde515e44039"
122
+ },
123
+ "outputs": [
124
+ {
125
+ "name": "stdout",
126
+ "output_type": "stream",
127
+ "text": [
128
+ "\n",
129
+ "Comparing actual result with predicted result:\n",
130
+ "Company A: Observed = 0, Predicted = tensor(0.2409)\n"
131
+ ]
132
+ }
133
+ ],
134
+ "source": [
135
+ "model = LSTMfromScratch()\n",
136
+ "print(\"\\nComparing actual result with predicted result:\")\n",
137
+ "print(\"Company A: Observed = 0, Predicted = \", model(torch.tensor([0.0, 0.5, 0.25, 1.0])).detach())"
138
+ ]
139
+ },
140
+ {
141
+ "cell_type": "code",
142
+ "execution_count": 17,
143
+ "metadata": {
144
+ "colab": {
145
+ "base_uri": "https://localhost:8080/"
146
+ },
147
+ "id": "8PC3Y3QiUE-C",
148
+ "outputId": "306887bf-fd30-4452-c7aa-21753e8235f3"
149
+ },
150
+ "outputs": [
151
+ {
152
+ "name": "stdout",
153
+ "output_type": "stream",
154
+ "text": [
155
+ "\n",
156
+ "Comparing actual result with predicted result:\n",
157
+ "Company B: Observed = 1, Predicted = tensor(0.2835)\n"
158
+ ]
159
+ }
160
+ ],
161
+ "source": [
162
+ "print(\"\\nComparing actual result with predicted result:\")\n",
163
+ "print(\"Company B: Observed = 1, Predicted = \", model(torch.tensor([1.0, 0.5, 0.25, 1.0])).detach())"
164
+ ]
165
+ },
166
+ {
167
+ "cell_type": "code",
168
+ "execution_count": 19,
169
+ "metadata": {
170
+ "id": "jXl3QMqfVkej"
171
+ },
172
+ "outputs": [],
173
+ "source": [
174
+ "inputs = torch.tensor([[0.0, 0.5, 0.25, 1.0], [1.0, 0.5, 0.25, 1.0]])\n",
175
+ "labels = torch.tensor([0.0, 1.0])\n",
176
+ "dataset = TensorDataset(inputs, labels)\n",
177
+ "dataloader = DataLoader(dataset)"
178
+ ]
179
+ },
180
+ {
181
+ "cell_type": "code",
182
+ "execution_count": 20,
183
+ "metadata": {
184
+ "colab": {
185
+ "base_uri": "https://localhost:8080/",
186
+ "height": 622,
187
+ "referenced_widgets": [
188
+ "77dbe8524d264453acb912fc76795f6e",
189
+ "acc0959eb3c34f989bd50266a74d9996",
190
+ "be0ec436e90942a881f6ede77350e5ac",
191
+ "9744c13d52e047d9b2e86b07070c3649",
192
+ "bcd37bee4dcb45b89856474c38ea9547",
193
+ "9f11d072c8504d139284954a466157fd",
194
+ "f61e61e97ce5416eb99bf3ee2ad73675",
195
+ "ce3da24e0e4241b99e5d641f3deb18ee",
196
+ "7886fd709c0044cd90edd10626f757c5",
197
+ "4b30f751b5874f5e9e12e0bf4f0d2bd8",
198
+ "4c9b7fd2669048bfab33fce44f9aaa2c"
199
+ ]
200
+ },
201
+ "id": "W2qRF_tjYQBu",
202
+ "outputId": "8ea0e0d3-4a32-44cf-e3e0-ff4b481447ff"
203
+ },
204
+ "outputs": [
205
+ {
206
+ "name": "stderr",
207
+ "output_type": "stream",
208
+ "text": [
209
+ "💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.\n",
210
+ "GPU available: True (mps), used: True\n",
211
+ "TPU available: False, using: 0 TPU cores\n",
212
+ "HPU available: False, using: 0 HPUs\n",
213
+ "\n",
214
+ " | Name | Type | Params | Mode\n",
215
+ "---------------------------------------------\n",
216
+ " | other params | n/a | 12 | n/a \n",
217
+ "---------------------------------------------\n",
218
+ "12 Trainable params\n",
219
+ "0 Non-trainable params\n",
220
+ "12 Total params\n",
221
+ "0.000 Total estimated model params size (MB)\n",
222
+ "0 Modules in train mode\n",
223
+ "0 Modules in eval mode\n"
224
+ ]
225
+ },
226
+ {
227
+ "name": "stdout",
228
+ "output_type": "stream",
229
+ "text": [
230
+ "Epoch 1999: 100%|██████████| 2/2 [00:00<00:00, 76.55it/s, v_num=4]"
231
+ ]
232
+ },
233
+ {
234
+ "name": "stderr",
235
+ "output_type": "stream",
236
+ "text": [
237
+ "`Trainer.fit` stopped: `max_epochs=2000` reached.\n"
238
+ ]
239
+ },
240
+ {
241
+ "name": "stdout",
242
+ "output_type": "stream",
243
+ "text": [
244
+ "Epoch 1999: 100%|██████████| 2/2 [00:00<00:00, 50.74it/s, v_num=4]\n"
245
+ ]
246
+ }
247
+ ],
248
+ "source": [
249
+ "trainer = L.Trainer(max_epochs=2000)\n",
250
+ "trainer.fit(model, train_dataloaders=dataloader)"
251
+ ]
252
+ },
253
+ {
254
+ "cell_type": "code",
255
+ "execution_count": 21,
256
+ "metadata": {
257
+ "colab": {
258
+ "base_uri": "https://localhost:8080/"
259
+ },
260
+ "id": "z6N80MBEau_S",
261
+ "outputId": "2ff21c40-bc65-4c4b-c235-98d81c6db92d"
262
+ },
263
+ "outputs": [
264
+ {
265
+ "name": "stdout",
266
+ "output_type": "stream",
267
+ "text": [
268
+ "\n",
269
+ "Comparing actual result with predicted result:\n",
270
+ "Company A: Observed = 0, Predicted = tensor(0.0005)\n"
271
+ ]
272
+ }
273
+ ],
274
+ "source": [
275
+ "print(\"\\nComparing actual result with predicted result:\")\n",
276
+ "print(\"Company A: Observed = 0, Predicted = \", model(torch.tensor([0.0, 0.5, 0.25, 1.0])).detach())"
277
+ ]
278
+ },
279
+ {
280
+ "cell_type": "code",
281
+ "execution_count": 22,
282
+ "metadata": {
283
+ "colab": {
284
+ "base_uri": "https://localhost:8080/"
285
+ },
286
+ "id": "wS7uONzFbiOY",
287
+ "outputId": "89718d1d-7ab7-4e67-d3fc-f485f87af77c"
288
+ },
289
+ "outputs": [
290
+ {
291
+ "name": "stdout",
292
+ "output_type": "stream",
293
+ "text": [
294
+ "\n",
295
+ "Comparing actual result with predicted result:\n",
296
+ "Company B: Observed = 1, Predicted = tensor(0.9432)\n"
297
+ ]
298
+ }
299
+ ],
300
+ "source": [
301
+ "print(\"\\nComparing actual result with predicted result:\")\n",
302
+ "print(\"Company B: Observed = 1, Predicted = \", model(torch.tensor([1.0, 0.5, 0.25, 1.0])).detach())"
303
+ ]
304
+ },
305
+ {
306
+ "cell_type": "code",
307
+ "execution_count": 23,
308
+ "metadata": {
309
+ "colab": {
310
+ "base_uri": "https://localhost:8080/",
311
+ "height": 711,
312
+ "referenced_widgets": [
313
+ "fc540f64edce4478866aa52d082eff18",
314
+ "abf235757f694da2b9b6955a6563410f",
315
+ "666ea6217b364c1991b19b3e637b3a10",
316
+ "855f977859ad4e4e91fa160a784b9ca7",
317
+ "abae861c431f4b8d88c02a64d1e203b3",
318
+ "bddc8a3b084b441ab982c51f5a6537da",
319
+ "f23a32759a1241ca9ea96ac85b856eb0",
320
+ "ef556255ed294360945f36982cde4a61",
321
+ "775923ba5d78493d9da37eeeffbc0fb5",
322
+ "aff872b1fee04784bd91c09cf4e54df5",
323
+ "bd8ff4fc35de431d8bb7ded4e9c11347"
324
+ ]
325
+ },
326
+ "id": "_wX54WUXbk-S",
327
+ "outputId": "0f7a4ba2-2f29-494a-c1d8-86ab1048f6fb"
328
+ },
329
+ "outputs": [
330
+ {
331
+ "name": "stderr",
332
+ "output_type": "stream",
333
+ "text": [
334
+ "💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.\n"
335
+ ]
336
+ },
337
+ {
338
+ "name": "stderr",
339
+ "output_type": "stream",
340
+ "text": [
341
+ "GPU available: True (mps), used: True\n",
342
+ "TPU available: False, using: 0 TPU cores\n",
343
+ "HPU available: False, using: 0 HPUs\n",
344
+ "Restoring states from the checkpoint path at /Users/adhithyapasumarthi/Downloads/lightning_logs/version_4/checkpoints/epoch=1999-step=4000.ckpt\n",
345
+ "/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:366: The dirpath has changed from '/Users/adhithyapasumarthi/Downloads/lightning_logs/version_4/checkpoints' to '/Users/adhithyapasumarthi/Downloads/lightning_logs/version_5/checkpoints', therefore `best_model_score`, `kth_best_model_path`, `kth_value`, `last_model_path` and `best_k_models` won't be reloaded. Only `best_model_path` will be reloaded.\n",
346
+ "\n",
347
+ " | Name | Type | Params | Mode\n",
348
+ "---------------------------------------------\n",
349
+ " | other params | n/a | 12 | n/a \n",
350
+ "---------------------------------------------\n",
351
+ "12 Trainable params\n",
352
+ "0 Non-trainable params\n",
353
+ "12 Total params\n",
354
+ "0.000 Total estimated model params size (MB)\n",
355
+ "0 Modules in train mode\n",
356
+ "0 Modules in eval mode\n",
357
+ "Restored all states from the checkpoint at /Users/adhithyapasumarthi/Downloads/lightning_logs/version_4/checkpoints/epoch=1999-step=4000.ckpt\n"
358
+ ]
359
+ },
360
+ {
361
+ "name": "stdout",
362
+ "output_type": "stream",
363
+ "text": [
364
+ "Epoch 2999: 100%|██████████| 2/2 [00:00<00:00, 82.56it/s, v_num=5]"
365
+ ]
366
+ },
367
+ {
368
+ "name": "stderr",
369
+ "output_type": "stream",
370
+ "text": [
371
+ "`Trainer.fit` stopped: `max_epochs=3000` reached.\n"
372
+ ]
373
+ },
374
+ {
375
+ "name": "stdout",
376
+ "output_type": "stream",
377
+ "text": [
378
+ "Epoch 2999: 100%|██████████| 2/2 [00:00<00:00, 56.92it/s, v_num=5]\n"
379
+ ]
380
+ }
381
+ ],
382
+ "source": [
383
+ "path_to_best_checkpoint = trainer.checkpoint_callback.best_model_path\n",
384
+ "trainer = L.Trainer(max_epochs=3000)\n",
385
+ "trainer.fit(model, train_dataloaders=dataloader, ckpt_path=path_to_best_checkpoint)"
386
+ ]
387
+ },
388
+ {
389
+ "cell_type": "code",
390
+ "execution_count": 24,
391
+ "metadata": {
392
+ "colab": {
393
+ "base_uri": "https://localhost:8080/"
394
+ },
395
+ "id": "WfS-3YPhhx1i",
396
+ "outputId": "c37dbe81-b836-49b4-cbd7-c2fbbe5b6f03"
397
+ },
398
+ "outputs": [
399
+ {
400
+ "name": "stdout",
401
+ "output_type": "stream",
402
+ "text": [
403
+ "\n",
404
+ "Comparing labeled values with predicted values: \n",
405
+ "Comapny A labeled value: 0, Predicted: tensor(0.0001)\n"
406
+ ]
407
+ }
408
+ ],
409
+ "source": [
410
+ "print(\"\\nComparing labeled values with predicted values: \")\n",
411
+ "print(\"Comapny A labeled value: 0, Predicted: \", model(torch.tensor([0.0, 0.5, 0.25, 1.0])).detach())"
412
+ ]
413
+ },
414
+ {
415
+ "cell_type": "code",
416
+ "execution_count": 25,
417
+ "metadata": {
418
+ "colab": {
419
+ "base_uri": "https://localhost:8080/"
420
+ },
421
+ "id": "ip8NedXekZiO",
422
+ "outputId": "7d4ac7dd-aa80-4680-9bbd-760461d97050"
423
+ },
424
+ "outputs": [
425
+ {
426
+ "name": "stdout",
427
+ "output_type": "stream",
428
+ "text": [
429
+ "\n",
430
+ "Comparing labeled values with predicted values: \n",
431
+ "Comapny B labeled value: 1, Predicted: tensor(0.9687)\n"
432
+ ]
433
+ }
434
+ ],
435
+ "source": [
436
+ "print(\"\\nComparing labeled values with predicted values: \")\n",
437
+ "print(\"Comapny B labeled value: 1, Predicted: \", model(torch.tensor([1.0, 0.5, 0.25, 1.0])).detach())"
438
+ ]
439
+ },
440
+ {
441
+ "cell_type": "markdown",
442
+ "metadata": {},
443
+ "source": [
444
+ "LSTM using the pytorch nn.LSTM():"
445
+ ]
446
+ },
447
+ {
448
+ "cell_type": "code",
449
+ "execution_count": 41,
450
+ "metadata": {
451
+ "id": "n7H_kbHIkf6L"
452
+ },
453
+ "outputs": [],
454
+ "source": [
455
+ "class LightningLSTM(L.LightningModule):\n",
456
+ " def __init__(self):\n",
457
+ " super().__init__()\n",
458
+ " # Input size is the number of features that we feed to the lstm and hidden size means the # of ouput values\n",
459
+ " # It is common to feed the ouput values from the lstm into a neural network so it is possible for the lstm to have more than one ouput value. \n",
460
+ " # Example of having multiple output values: If you were predicting the temperature, wind speed, and other features in the next hour you would need multiple different values from the lstm and pass it into a feed forward neural network to predict and classify the general weather pattern that would happen in the next hour\n",
461
+ " self.lstm = nn.LSTM(input_size=1, hidden_size=1) \n",
462
+ "\n",
463
+ " def forward(self, input):\n",
464
+ " # The .view allows you to transpose the list from being a single row to being len(input) amount of rows and we set the # of columns to 1 as there is only 1 feature\n",
465
+ " input_transpose = input.view(len(input), 1)\n",
466
+ " # The self.lstm() takes in the transposed input and gives out the long and short term memory values (respectivly, lstm_out (short term memory values) and the temp (long term memory values))\n",
467
+ " # The lstm_out has the short term memory values from each lstm unrolled unit and the same from temp\n",
468
+ " lstm_out, temp = self.lstm(input_transpose) \n",
469
+ "\n",
470
+ " #This is why we take the last short term value as that is our prediction when passed through the lstm units\n",
471
+ " pred = lstm_out[-1]\n",
472
+ " return pred\n",
473
+ " def configure_optimizers(self):\n",
474
+ " # Using the Adam optimizer and set the learning rate to 0.1 which is a lot higher than the default 0.001 learning rate\n",
475
+ " return Adam(self.parameters(), lr=0.1)\n",
476
+ " def training_step(self, batch, batch_idx):\n",
477
+ " input_i, label_i = batch\n",
478
+ " output_i = self.forward(input_i[0])\n",
479
+ " loss = (output_i - label_i)**2\n",
480
+ "\n",
481
+ " self.log(\"training_loss\", loss)\n",
482
+ " if label_i == 0:\n",
483
+ " self.log(\"out_0\", output_i)\n",
484
+ " else:\n",
485
+ " self.log(\"out_1\", output_i)\n",
486
+ " return loss"
487
+ ]
488
+ },
489
+ {
490
+ "cell_type": "code",
491
+ "execution_count": 59,
492
+ "metadata": {},
493
+ "outputs": [
494
+ {
495
+ "name": "stdout",
496
+ "output_type": "stream",
497
+ "text": [
498
+ "\n",
499
+ "Comparing label and the predicted values:\n",
500
+ "Label value: 0 and Predicted value: tensor([0.0647])\n"
501
+ ]
502
+ }
503
+ ],
504
+ "source": [
505
+ "model = LightningLSTM()\n",
506
+ "print(\"\\nComparing label and the predicted values:\")\n",
507
+ "print(\"Label value: 0 and Predicted value: \", model(torch.tensor([0.0, 0.5, .25, 1.0])).detach())"
508
+ ]
509
+ },
510
+ {
511
+ "cell_type": "code",
512
+ "execution_count": 60,
513
+ "metadata": {},
514
+ "outputs": [
515
+ {
516
+ "name": "stdout",
517
+ "output_type": "stream",
518
+ "text": [
519
+ "Comparing label and the predicted values:\n",
520
+ "Label value: 1 and Predicted value: tensor([0.0640])\n"
521
+ ]
522
+ }
523
+ ],
524
+ "source": [
525
+ "print(\"Comparing label and the predicted values:\")\n",
526
+ "print(\"Label value: 1 and Predicted value: \", model(torch.tensor([1.0, 0.5, .25, 1.0])).detach())"
527
+ ]
528
+ },
529
+ {
530
+ "cell_type": "code",
531
+ "execution_count": 61,
532
+ "metadata": {},
533
+ "outputs": [
534
+ {
535
+ "name": "stderr",
536
+ "output_type": "stream",
537
+ "text": [
538
+ "💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.\n"
539
+ ]
540
+ },
541
+ {
542
+ "name": "stderr",
543
+ "output_type": "stream",
544
+ "text": [
545
+ "GPU available: True (mps), used: True\n",
546
+ "TPU available: False, using: 0 TPU cores\n",
547
+ "HPU available: False, using: 0 HPUs\n",
548
+ "\n",
549
+ " | Name | Type | Params | Mode \n",
550
+ "--------------------------------------\n",
551
+ "0 | lstm | LSTM | 16 | train\n",
552
+ "--------------------------------------\n",
553
+ "16 Trainable params\n",
554
+ "0 Non-trainable params\n",
555
+ "16 Total params\n",
556
+ "0.000 Total estimated model params size (MB)\n",
557
+ "1 Modules in train mode\n",
558
+ "0 Modules in eval mode\n"
559
+ ]
560
+ },
561
+ {
562
+ "name": "stdout",
563
+ "output_type": "stream",
564
+ "text": [
565
+ "Epoch 299: 100%|██████████| 2/2 [00:00<00:00, 176.08it/s, v_num=11]"
566
+ ]
567
+ },
568
+ {
569
+ "name": "stderr",
570
+ "output_type": "stream",
571
+ "text": [
572
+ "`Trainer.fit` stopped: `max_epochs=300` reached.\n"
573
+ ]
574
+ },
575
+ {
576
+ "name": "stdout",
577
+ "output_type": "stream",
578
+ "text": [
579
+ "Epoch 299: 100%|██████████| 2/2 [00:00<00:00, 125.30it/s, v_num=11]\n"
580
+ ]
581
+ }
582
+ ],
583
+ "source": [
584
+ "# Notice how we changed the # of epochs to 300 instead of 3000 because we set the learning rate to 0.1 instead of using the 0.001 default learning rate\n",
585
+ "# This means our model will take larger steps we doing gradient descent which means it should take less time to find minimum loss\n",
586
+ "trainer = L.Trainer(max_epochs=300, log_every_n_steps=2)\n",
587
+ "trainer.fit(model, dataloader)"
588
+ ]
589
+ },
590
+ {
591
+ "cell_type": "code",
592
+ "execution_count": 62,
593
+ "metadata": {},
594
+ "outputs": [
595
+ {
596
+ "name": "stdout",
597
+ "output_type": "stream",
598
+ "text": [
599
+ "\n",
600
+ "Comparing label and the predicted values:\n",
601
+ "Label value: 0 and Predicted value: tensor([4.9227e-05])\n"
602
+ ]
603
+ }
604
+ ],
605
+ "source": [
606
+ "print(\"\\nComparing label and the predicted values:\")\n",
607
+ "print(\"Label value: 0 and Predicted value: \", model(torch.tensor([0.0, 0.5, .25, 1.0])).detach())"
608
+ ]
609
+ },
610
+ {
611
+ "cell_type": "code",
612
+ "execution_count": 63,
613
+ "metadata": {},
614
+ "outputs": [
615
+ {
616
+ "name": "stdout",
617
+ "output_type": "stream",
618
+ "text": [
619
+ "\n",
620
+ "Comparing label and the predicted values:\n",
621
+ "Label value: 1 and Predicted value: tensor([0.9818])\n"
622
+ ]
623
+ }
624
+ ],
625
+ "source": [
626
+ "print(\"\\nComparing label and the predicted values:\")\n",
627
+ "print(\"Label value: 1 and Predicted value: \", model(torch.tensor([1.0, 0.5, .25, 1.0])).detach())"
628
+ ]
629
+ },
630
+ {
631
+ "cell_type": "code",
632
+ "execution_count": null,
633
+ "metadata": {},
634
+ "outputs": [],
635
+ "source": []
636
+ }
637
+ ],
638
+ "metadata": {
639
+ "colab": {
640
+ "provenance": []
641
+ },
642
+ "kernelspec": {
643
+ "display_name": "Python 3",
644
+ "name": "python3"
645
+ },
646
+ "language_info": {
647
+ "codemirror_mode": {
648
+ "name": "ipython",
649
+ "version": 3
650
+ },
651
+ "file_extension": ".py",
652
+ "mimetype": "text/x-python",
653
+ "name": "python",
654
+ "nbconvert_exporter": "python",
655
+ "pygments_lexer": "ipython3",
656
+ "version": "3.11.4"
657
+ },
658
+ "widgets": {
659
+ "application/vnd.jupyter.widget-state+json": {
660
+ "4b30f751b5874f5e9e12e0bf4f0d2bd8": {
661
+ "model_module": "@jupyter-widgets/base",
662
+ "model_module_version": "1.2.0",
663
+ "model_name": "LayoutModel",
664
+ "state": {
665
+ "_model_module": "@jupyter-widgets/base",
666
+ "_model_module_version": "1.2.0",
667
+ "_model_name": "LayoutModel",
668
+ "_view_count": null,
669
+ "_view_module": "@jupyter-widgets/base",
670
+ "_view_module_version": "1.2.0",
671
+ "_view_name": "LayoutView",
672
+ "align_content": null,
673
+ "align_items": null,
674
+ "align_self": null,
675
+ "border": null,
676
+ "bottom": null,
677
+ "display": null,
678
+ "flex": null,
679
+ "flex_flow": null,
680
+ "grid_area": null,
681
+ "grid_auto_columns": null,
682
+ "grid_auto_flow": null,
683
+ "grid_auto_rows": null,
684
+ "grid_column": null,
685
+ "grid_gap": null,
686
+ "grid_row": null,
687
+ "grid_template_areas": null,
688
+ "grid_template_columns": null,
689
+ "grid_template_rows": null,
690
+ "height": null,
691
+ "justify_content": null,
692
+ "justify_items": null,
693
+ "left": null,
694
+ "margin": null,
695
+ "max_height": null,
696
+ "max_width": null,
697
+ "min_height": null,
698
+ "min_width": null,
699
+ "object_fit": null,
700
+ "object_position": null,
701
+ "order": null,
702
+ "overflow": null,
703
+ "overflow_x": null,
704
+ "overflow_y": null,
705
+ "padding": null,
706
+ "right": null,
707
+ "top": null,
708
+ "visibility": null,
709
+ "width": null
710
+ }
711
+ },
712
+ "4c9b7fd2669048bfab33fce44f9aaa2c": {
713
+ "model_module": "@jupyter-widgets/controls",
714
+ "model_module_version": "1.5.0",
715
+ "model_name": "DescriptionStyleModel",
716
+ "state": {
717
+ "_model_module": "@jupyter-widgets/controls",
718
+ "_model_module_version": "1.5.0",
719
+ "_model_name": "DescriptionStyleModel",
720
+ "_view_count": null,
721
+ "_view_module": "@jupyter-widgets/base",
722
+ "_view_module_version": "1.2.0",
723
+ "_view_name": "StyleView",
724
+ "description_width": ""
725
+ }
726
+ },
727
+ "666ea6217b364c1991b19b3e637b3a10": {
728
+ "model_module": "@jupyter-widgets/controls",
729
+ "model_module_version": "1.5.0",
730
+ "model_name": "FloatProgressModel",
731
+ "state": {
732
+ "_dom_classes": [],
733
+ "_model_module": "@jupyter-widgets/controls",
734
+ "_model_module_version": "1.5.0",
735
+ "_model_name": "FloatProgressModel",
736
+ "_view_count": null,
737
+ "_view_module": "@jupyter-widgets/controls",
738
+ "_view_module_version": "1.5.0",
739
+ "_view_name": "ProgressView",
740
+ "bar_style": "success",
741
+ "description": "",
742
+ "description_tooltip": null,
743
+ "layout": "IPY_MODEL_ef556255ed294360945f36982cde4a61",
744
+ "max": 2,
745
+ "min": 0,
746
+ "orientation": "horizontal",
747
+ "style": "IPY_MODEL_775923ba5d78493d9da37eeeffbc0fb5",
748
+ "value": 2
749
+ }
750
+ },
751
+ "775923ba5d78493d9da37eeeffbc0fb5": {
752
+ "model_module": "@jupyter-widgets/controls",
753
+ "model_module_version": "1.5.0",
754
+ "model_name": "ProgressStyleModel",
755
+ "state": {
756
+ "_model_module": "@jupyter-widgets/controls",
757
+ "_model_module_version": "1.5.0",
758
+ "_model_name": "ProgressStyleModel",
759
+ "_view_count": null,
760
+ "_view_module": "@jupyter-widgets/base",
761
+ "_view_module_version": "1.2.0",
762
+ "_view_name": "StyleView",
763
+ "bar_color": null,
764
+ "description_width": ""
765
+ }
766
+ },
767
+ "77dbe8524d264453acb912fc76795f6e": {
768
+ "model_module": "@jupyter-widgets/controls",
769
+ "model_module_version": "1.5.0",
770
+ "model_name": "HBoxModel",
771
+ "state": {
772
+ "_dom_classes": [],
773
+ "_model_module": "@jupyter-widgets/controls",
774
+ "_model_module_version": "1.5.0",
775
+ "_model_name": "HBoxModel",
776
+ "_view_count": null,
777
+ "_view_module": "@jupyter-widgets/controls",
778
+ "_view_module_version": "1.5.0",
779
+ "_view_name": "HBoxView",
780
+ "box_style": "",
781
+ "children": [
782
+ "IPY_MODEL_acc0959eb3c34f989bd50266a74d9996",
783
+ "IPY_MODEL_be0ec436e90942a881f6ede77350e5ac",
784
+ "IPY_MODEL_9744c13d52e047d9b2e86b07070c3649"
785
+ ],
786
+ "layout": "IPY_MODEL_bcd37bee4dcb45b89856474c38ea9547"
787
+ }
788
+ },
789
+ "7886fd709c0044cd90edd10626f757c5": {
790
+ "model_module": "@jupyter-widgets/controls",
791
+ "model_module_version": "1.5.0",
792
+ "model_name": "ProgressStyleModel",
793
+ "state": {
794
+ "_model_module": "@jupyter-widgets/controls",
795
+ "_model_module_version": "1.5.0",
796
+ "_model_name": "ProgressStyleModel",
797
+ "_view_count": null,
798
+ "_view_module": "@jupyter-widgets/base",
799
+ "_view_module_version": "1.2.0",
800
+ "_view_name": "StyleView",
801
+ "bar_color": null,
802
+ "description_width": ""
803
+ }
804
+ },
805
+ "855f977859ad4e4e91fa160a784b9ca7": {
806
+ "model_module": "@jupyter-widgets/controls",
807
+ "model_module_version": "1.5.0",
808
+ "model_name": "HTMLModel",
809
+ "state": {
810
+ "_dom_classes": [],
811
+ "_model_module": "@jupyter-widgets/controls",
812
+ "_model_module_version": "1.5.0",
813
+ "_model_name": "HTMLModel",
814
+ "_view_count": null,
815
+ "_view_module": "@jupyter-widgets/controls",
816
+ "_view_module_version": "1.5.0",
817
+ "_view_name": "HTMLView",
818
+ "description": "",
819
+ "description_tooltip": null,
820
+ "layout": "IPY_MODEL_aff872b1fee04784bd91c09cf4e54df5",
821
+ "placeholder": "​",
822
+ "style": "IPY_MODEL_bd8ff4fc35de431d8bb7ded4e9c11347",
823
+ "value": " 2/2 [00:00&lt;00:00, 32.32it/s, v_num=3]"
824
+ }
825
+ },
826
+ "9744c13d52e047d9b2e86b07070c3649": {
827
+ "model_module": "@jupyter-widgets/controls",
828
+ "model_module_version": "1.5.0",
829
+ "model_name": "HTMLModel",
830
+ "state": {
831
+ "_dom_classes": [],
832
+ "_model_module": "@jupyter-widgets/controls",
833
+ "_model_module_version": "1.5.0",
834
+ "_model_name": "HTMLModel",
835
+ "_view_count": null,
836
+ "_view_module": "@jupyter-widgets/controls",
837
+ "_view_module_version": "1.5.0",
838
+ "_view_name": "HTMLView",
839
+ "description": "",
840
+ "description_tooltip": null,
841
+ "layout": "IPY_MODEL_4b30f751b5874f5e9e12e0bf4f0d2bd8",
842
+ "placeholder": "​",
843
+ "style": "IPY_MODEL_4c9b7fd2669048bfab33fce44f9aaa2c",
844
+ "value": " 2/2 [00:00&lt;00:00, 41.37it/s, v_num=2]"
845
+ }
846
+ },
847
+ "9f11d072c8504d139284954a466157fd": {
848
+ "model_module": "@jupyter-widgets/base",
849
+ "model_module_version": "1.2.0",
850
+ "model_name": "LayoutModel",
851
+ "state": {
852
+ "_model_module": "@jupyter-widgets/base",
853
+ "_model_module_version": "1.2.0",
854
+ "_model_name": "LayoutModel",
855
+ "_view_count": null,
856
+ "_view_module": "@jupyter-widgets/base",
857
+ "_view_module_version": "1.2.0",
858
+ "_view_name": "LayoutView",
859
+ "align_content": null,
860
+ "align_items": null,
861
+ "align_self": null,
862
+ "border": null,
863
+ "bottom": null,
864
+ "display": null,
865
+ "flex": null,
866
+ "flex_flow": null,
867
+ "grid_area": null,
868
+ "grid_auto_columns": null,
869
+ "grid_auto_flow": null,
870
+ "grid_auto_rows": null,
871
+ "grid_column": null,
872
+ "grid_gap": null,
873
+ "grid_row": null,
874
+ "grid_template_areas": null,
875
+ "grid_template_columns": null,
876
+ "grid_template_rows": null,
877
+ "height": null,
878
+ "justify_content": null,
879
+ "justify_items": null,
880
+ "left": null,
881
+ "margin": null,
882
+ "max_height": null,
883
+ "max_width": null,
884
+ "min_height": null,
885
+ "min_width": null,
886
+ "object_fit": null,
887
+ "object_position": null,
888
+ "order": null,
889
+ "overflow": null,
890
+ "overflow_x": null,
891
+ "overflow_y": null,
892
+ "padding": null,
893
+ "right": null,
894
+ "top": null,
895
+ "visibility": null,
896
+ "width": null
897
+ }
898
+ },
899
+ "abae861c431f4b8d88c02a64d1e203b3": {
900
+ "model_module": "@jupyter-widgets/base",
901
+ "model_module_version": "1.2.0",
902
+ "model_name": "LayoutModel",
903
+ "state": {
904
+ "_model_module": "@jupyter-widgets/base",
905
+ "_model_module_version": "1.2.0",
906
+ "_model_name": "LayoutModel",
907
+ "_view_count": null,
908
+ "_view_module": "@jupyter-widgets/base",
909
+ "_view_module_version": "1.2.0",
910
+ "_view_name": "LayoutView",
911
+ "align_content": null,
912
+ "align_items": null,
913
+ "align_self": null,
914
+ "border": null,
915
+ "bottom": null,
916
+ "display": "inline-flex",
917
+ "flex": null,
918
+ "flex_flow": "row wrap",
919
+ "grid_area": null,
920
+ "grid_auto_columns": null,
921
+ "grid_auto_flow": null,
922
+ "grid_auto_rows": null,
923
+ "grid_column": null,
924
+ "grid_gap": null,
925
+ "grid_row": null,
926
+ "grid_template_areas": null,
927
+ "grid_template_columns": null,
928
+ "grid_template_rows": null,
929
+ "height": null,
930
+ "justify_content": null,
931
+ "justify_items": null,
932
+ "left": null,
933
+ "margin": null,
934
+ "max_height": null,
935
+ "max_width": null,
936
+ "min_height": null,
937
+ "min_width": null,
938
+ "object_fit": null,
939
+ "object_position": null,
940
+ "order": null,
941
+ "overflow": null,
942
+ "overflow_x": null,
943
+ "overflow_y": null,
944
+ "padding": null,
945
+ "right": null,
946
+ "top": null,
947
+ "visibility": null,
948
+ "width": "100%"
949
+ }
950
+ },
951
+ "abf235757f694da2b9b6955a6563410f": {
952
+ "model_module": "@jupyter-widgets/controls",
953
+ "model_module_version": "1.5.0",
954
+ "model_name": "HTMLModel",
955
+ "state": {
956
+ "_dom_classes": [],
957
+ "_model_module": "@jupyter-widgets/controls",
958
+ "_model_module_version": "1.5.0",
959
+ "_model_name": "HTMLModel",
960
+ "_view_count": null,
961
+ "_view_module": "@jupyter-widgets/controls",
962
+ "_view_module_version": "1.5.0",
963
+ "_view_name": "HTMLView",
964
+ "description": "",
965
+ "description_tooltip": null,
966
+ "layout": "IPY_MODEL_bddc8a3b084b441ab982c51f5a6537da",
967
+ "placeholder": "​",
968
+ "style": "IPY_MODEL_f23a32759a1241ca9ea96ac85b856eb0",
969
+ "value": "Epoch 2999: 100%"
970
+ }
971
+ },
972
+ "acc0959eb3c34f989bd50266a74d9996": {
973
+ "model_module": "@jupyter-widgets/controls",
974
+ "model_module_version": "1.5.0",
975
+ "model_name": "HTMLModel",
976
+ "state": {
977
+ "_dom_classes": [],
978
+ "_model_module": "@jupyter-widgets/controls",
979
+ "_model_module_version": "1.5.0",
980
+ "_model_name": "HTMLModel",
981
+ "_view_count": null,
982
+ "_view_module": "@jupyter-widgets/controls",
983
+ "_view_module_version": "1.5.0",
984
+ "_view_name": "HTMLView",
985
+ "description": "",
986
+ "description_tooltip": null,
987
+ "layout": "IPY_MODEL_9f11d072c8504d139284954a466157fd",
988
+ "placeholder": "​",
989
+ "style": "IPY_MODEL_f61e61e97ce5416eb99bf3ee2ad73675",
990
+ "value": "Epoch 1999: 100%"
991
+ }
992
+ },
993
+ "aff872b1fee04784bd91c09cf4e54df5": {
994
+ "model_module": "@jupyter-widgets/base",
995
+ "model_module_version": "1.2.0",
996
+ "model_name": "LayoutModel",
997
+ "state": {
998
+ "_model_module": "@jupyter-widgets/base",
999
+ "_model_module_version": "1.2.0",
1000
+ "_model_name": "LayoutModel",
1001
+ "_view_count": null,
1002
+ "_view_module": "@jupyter-widgets/base",
1003
+ "_view_module_version": "1.2.0",
1004
+ "_view_name": "LayoutView",
1005
+ "align_content": null,
1006
+ "align_items": null,
1007
+ "align_self": null,
1008
+ "border": null,
1009
+ "bottom": null,
1010
+ "display": null,
1011
+ "flex": null,
1012
+ "flex_flow": null,
1013
+ "grid_area": null,
1014
+ "grid_auto_columns": null,
1015
+ "grid_auto_flow": null,
1016
+ "grid_auto_rows": null,
1017
+ "grid_column": null,
1018
+ "grid_gap": null,
1019
+ "grid_row": null,
1020
+ "grid_template_areas": null,
1021
+ "grid_template_columns": null,
1022
+ "grid_template_rows": null,
1023
+ "height": null,
1024
+ "justify_content": null,
1025
+ "justify_items": null,
1026
+ "left": null,
1027
+ "margin": null,
1028
+ "max_height": null,
1029
+ "max_width": null,
1030
+ "min_height": null,
1031
+ "min_width": null,
1032
+ "object_fit": null,
1033
+ "object_position": null,
1034
+ "order": null,
1035
+ "overflow": null,
1036
+ "overflow_x": null,
1037
+ "overflow_y": null,
1038
+ "padding": null,
1039
+ "right": null,
1040
+ "top": null,
1041
+ "visibility": null,
1042
+ "width": null
1043
+ }
1044
+ },
1045
+ "bcd37bee4dcb45b89856474c38ea9547": {
1046
+ "model_module": "@jupyter-widgets/base",
1047
+ "model_module_version": "1.2.0",
1048
+ "model_name": "LayoutModel",
1049
+ "state": {
1050
+ "_model_module": "@jupyter-widgets/base",
1051
+ "_model_module_version": "1.2.0",
1052
+ "_model_name": "LayoutModel",
1053
+ "_view_count": null,
1054
+ "_view_module": "@jupyter-widgets/base",
1055
+ "_view_module_version": "1.2.0",
1056
+ "_view_name": "LayoutView",
1057
+ "align_content": null,
1058
+ "align_items": null,
1059
+ "align_self": null,
1060
+ "border": null,
1061
+ "bottom": null,
1062
+ "display": "inline-flex",
1063
+ "flex": null,
1064
+ "flex_flow": "row wrap",
1065
+ "grid_area": null,
1066
+ "grid_auto_columns": null,
1067
+ "grid_auto_flow": null,
1068
+ "grid_auto_rows": null,
1069
+ "grid_column": null,
1070
+ "grid_gap": null,
1071
+ "grid_row": null,
1072
+ "grid_template_areas": null,
1073
+ "grid_template_columns": null,
1074
+ "grid_template_rows": null,
1075
+ "height": null,
1076
+ "justify_content": null,
1077
+ "justify_items": null,
1078
+ "left": null,
1079
+ "margin": null,
1080
+ "max_height": null,
1081
+ "max_width": null,
1082
+ "min_height": null,
1083
+ "min_width": null,
1084
+ "object_fit": null,
1085
+ "object_position": null,
1086
+ "order": null,
1087
+ "overflow": null,
1088
+ "overflow_x": null,
1089
+ "overflow_y": null,
1090
+ "padding": null,
1091
+ "right": null,
1092
+ "top": null,
1093
+ "visibility": null,
1094
+ "width": "100%"
1095
+ }
1096
+ },
1097
+ "bd8ff4fc35de431d8bb7ded4e9c11347": {
1098
+ "model_module": "@jupyter-widgets/controls",
1099
+ "model_module_version": "1.5.0",
1100
+ "model_name": "DescriptionStyleModel",
1101
+ "state": {
1102
+ "_model_module": "@jupyter-widgets/controls",
1103
+ "_model_module_version": "1.5.0",
1104
+ "_model_name": "DescriptionStyleModel",
1105
+ "_view_count": null,
1106
+ "_view_module": "@jupyter-widgets/base",
1107
+ "_view_module_version": "1.2.0",
1108
+ "_view_name": "StyleView",
1109
+ "description_width": ""
1110
+ }
1111
+ },
1112
+ "bddc8a3b084b441ab982c51f5a6537da": {
1113
+ "model_module": "@jupyter-widgets/base",
1114
+ "model_module_version": "1.2.0",
1115
+ "model_name": "LayoutModel",
1116
+ "state": {
1117
+ "_model_module": "@jupyter-widgets/base",
1118
+ "_model_module_version": "1.2.0",
1119
+ "_model_name": "LayoutModel",
1120
+ "_view_count": null,
1121
+ "_view_module": "@jupyter-widgets/base",
1122
+ "_view_module_version": "1.2.0",
1123
+ "_view_name": "LayoutView",
1124
+ "align_content": null,
1125
+ "align_items": null,
1126
+ "align_self": null,
1127
+ "border": null,
1128
+ "bottom": null,
1129
+ "display": null,
1130
+ "flex": null,
1131
+ "flex_flow": null,
1132
+ "grid_area": null,
1133
+ "grid_auto_columns": null,
1134
+ "grid_auto_flow": null,
1135
+ "grid_auto_rows": null,
1136
+ "grid_column": null,
1137
+ "grid_gap": null,
1138
+ "grid_row": null,
1139
+ "grid_template_areas": null,
1140
+ "grid_template_columns": null,
1141
+ "grid_template_rows": null,
1142
+ "height": null,
1143
+ "justify_content": null,
1144
+ "justify_items": null,
1145
+ "left": null,
1146
+ "margin": null,
1147
+ "max_height": null,
1148
+ "max_width": null,
1149
+ "min_height": null,
1150
+ "min_width": null,
1151
+ "object_fit": null,
1152
+ "object_position": null,
1153
+ "order": null,
1154
+ "overflow": null,
1155
+ "overflow_x": null,
1156
+ "overflow_y": null,
1157
+ "padding": null,
1158
+ "right": null,
1159
+ "top": null,
1160
+ "visibility": null,
1161
+ "width": null
1162
+ }
1163
+ },
1164
+ "be0ec436e90942a881f6ede77350e5ac": {
1165
+ "model_module": "@jupyter-widgets/controls",
1166
+ "model_module_version": "1.5.0",
1167
+ "model_name": "FloatProgressModel",
1168
+ "state": {
1169
+ "_dom_classes": [],
1170
+ "_model_module": "@jupyter-widgets/controls",
1171
+ "_model_module_version": "1.5.0",
1172
+ "_model_name": "FloatProgressModel",
1173
+ "_view_count": null,
1174
+ "_view_module": "@jupyter-widgets/controls",
1175
+ "_view_module_version": "1.5.0",
1176
+ "_view_name": "ProgressView",
1177
+ "bar_style": "success",
1178
+ "description": "",
1179
+ "description_tooltip": null,
1180
+ "layout": "IPY_MODEL_ce3da24e0e4241b99e5d641f3deb18ee",
1181
+ "max": 2,
1182
+ "min": 0,
1183
+ "orientation": "horizontal",
1184
+ "style": "IPY_MODEL_7886fd709c0044cd90edd10626f757c5",
1185
+ "value": 2
1186
+ }
1187
+ },
1188
+ "ce3da24e0e4241b99e5d641f3deb18ee": {
1189
+ "model_module": "@jupyter-widgets/base",
1190
+ "model_module_version": "1.2.0",
1191
+ "model_name": "LayoutModel",
1192
+ "state": {
1193
+ "_model_module": "@jupyter-widgets/base",
1194
+ "_model_module_version": "1.2.0",
1195
+ "_model_name": "LayoutModel",
1196
+ "_view_count": null,
1197
+ "_view_module": "@jupyter-widgets/base",
1198
+ "_view_module_version": "1.2.0",
1199
+ "_view_name": "LayoutView",
1200
+ "align_content": null,
1201
+ "align_items": null,
1202
+ "align_self": null,
1203
+ "border": null,
1204
+ "bottom": null,
1205
+ "display": null,
1206
+ "flex": "2",
1207
+ "flex_flow": null,
1208
+ "grid_area": null,
1209
+ "grid_auto_columns": null,
1210
+ "grid_auto_flow": null,
1211
+ "grid_auto_rows": null,
1212
+ "grid_column": null,
1213
+ "grid_gap": null,
1214
+ "grid_row": null,
1215
+ "grid_template_areas": null,
1216
+ "grid_template_columns": null,
1217
+ "grid_template_rows": null,
1218
+ "height": null,
1219
+ "justify_content": null,
1220
+ "justify_items": null,
1221
+ "left": null,
1222
+ "margin": null,
1223
+ "max_height": null,
1224
+ "max_width": null,
1225
+ "min_height": null,
1226
+ "min_width": null,
1227
+ "object_fit": null,
1228
+ "object_position": null,
1229
+ "order": null,
1230
+ "overflow": null,
1231
+ "overflow_x": null,
1232
+ "overflow_y": null,
1233
+ "padding": null,
1234
+ "right": null,
1235
+ "top": null,
1236
+ "visibility": null,
1237
+ "width": null
1238
+ }
1239
+ },
1240
+ "ef556255ed294360945f36982cde4a61": {
1241
+ "model_module": "@jupyter-widgets/base",
1242
+ "model_module_version": "1.2.0",
1243
+ "model_name": "LayoutModel",
1244
+ "state": {
1245
+ "_model_module": "@jupyter-widgets/base",
1246
+ "_model_module_version": "1.2.0",
1247
+ "_model_name": "LayoutModel",
1248
+ "_view_count": null,
1249
+ "_view_module": "@jupyter-widgets/base",
1250
+ "_view_module_version": "1.2.0",
1251
+ "_view_name": "LayoutView",
1252
+ "align_content": null,
1253
+ "align_items": null,
1254
+ "align_self": null,
1255
+ "border": null,
1256
+ "bottom": null,
1257
+ "display": null,
1258
+ "flex": "2",
1259
+ "flex_flow": null,
1260
+ "grid_area": null,
1261
+ "grid_auto_columns": null,
1262
+ "grid_auto_flow": null,
1263
+ "grid_auto_rows": null,
1264
+ "grid_column": null,
1265
+ "grid_gap": null,
1266
+ "grid_row": null,
1267
+ "grid_template_areas": null,
1268
+ "grid_template_columns": null,
1269
+ "grid_template_rows": null,
1270
+ "height": null,
1271
+ "justify_content": null,
1272
+ "justify_items": null,
1273
+ "left": null,
1274
+ "margin": null,
1275
+ "max_height": null,
1276
+ "max_width": null,
1277
+ "min_height": null,
1278
+ "min_width": null,
1279
+ "object_fit": null,
1280
+ "object_position": null,
1281
+ "order": null,
1282
+ "overflow": null,
1283
+ "overflow_x": null,
1284
+ "overflow_y": null,
1285
+ "padding": null,
1286
+ "right": null,
1287
+ "top": null,
1288
+ "visibility": null,
1289
+ "width": null
1290
+ }
1291
+ },
1292
+ "f23a32759a1241ca9ea96ac85b856eb0": {
1293
+ "model_module": "@jupyter-widgets/controls",
1294
+ "model_module_version": "1.5.0",
1295
+ "model_name": "DescriptionStyleModel",
1296
+ "state": {
1297
+ "_model_module": "@jupyter-widgets/controls",
1298
+ "_model_module_version": "1.5.0",
1299
+ "_model_name": "DescriptionStyleModel",
1300
+ "_view_count": null,
1301
+ "_view_module": "@jupyter-widgets/base",
1302
+ "_view_module_version": "1.2.0",
1303
+ "_view_name": "StyleView",
1304
+ "description_width": ""
1305
+ }
1306
+ },
1307
+ "f61e61e97ce5416eb99bf3ee2ad73675": {
1308
+ "model_module": "@jupyter-widgets/controls",
1309
+ "model_module_version": "1.5.0",
1310
+ "model_name": "DescriptionStyleModel",
1311
+ "state": {
1312
+ "_model_module": "@jupyter-widgets/controls",
1313
+ "_model_module_version": "1.5.0",
1314
+ "_model_name": "DescriptionStyleModel",
1315
+ "_view_count": null,
1316
+ "_view_module": "@jupyter-widgets/base",
1317
+ "_view_module_version": "1.2.0",
1318
+ "_view_name": "StyleView",
1319
+ "description_width": ""
1320
+ }
1321
+ },
1322
+ "fc540f64edce4478866aa52d082eff18": {
1323
+ "model_module": "@jupyter-widgets/controls",
1324
+ "model_module_version": "1.5.0",
1325
+ "model_name": "HBoxModel",
1326
+ "state": {
1327
+ "_dom_classes": [],
1328
+ "_model_module": "@jupyter-widgets/controls",
1329
+ "_model_module_version": "1.5.0",
1330
+ "_model_name": "HBoxModel",
1331
+ "_view_count": null,
1332
+ "_view_module": "@jupyter-widgets/controls",
1333
+ "_view_module_version": "1.5.0",
1334
+ "_view_name": "HBoxView",
1335
+ "box_style": "",
1336
+ "children": [
1337
+ "IPY_MODEL_abf235757f694da2b9b6955a6563410f",
1338
+ "IPY_MODEL_666ea6217b364c1991b19b3e637b3a10",
1339
+ "IPY_MODEL_855f977859ad4e4e91fa160a784b9ca7"
1340
+ ],
1341
+ "layout": "IPY_MODEL_abae861c431f4b8d88c02a64d1e203b3"
1342
+ }
1343
+ }
1344
+ }
1345
+ }
1346
+ },
1347
+ "nbformat": 4,
1348
+ "nbformat_minor": 0
1349
+ }