Ioana-Gidiuta commited on
Commit
f4e5875
·
verified ·
1 Parent(s): d72f861

Upload Run_ensemble_2 (1).ipynb

Browse files
Files changed (1) hide show
  1. Run_ensemble_2 (1).ipynb +1847 -0
Run_ensemble_2 (1).ipynb ADDED
@@ -0,0 +1,1847 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "provenance": [],
7
+ "collapsed_sections": [
8
+ "GirPusJtYPsP",
9
+ "C-7gft4ddTzo"
10
+ ],
11
+ "gpuType": "T4"
12
+ },
13
+ "kernelspec": {
14
+ "name": "python3",
15
+ "display_name": "Python 3"
16
+ },
17
+ "language_info": {
18
+ "name": "python"
19
+ },
20
+ "accelerator": "GPU",
21
+ "widgets": {
22
+ "application/vnd.jupyter.widget-state+json": {
23
+ "c46dc091acd34be2887c59bf95838529": {
24
+ "model_module": "@jupyter-widgets/controls",
25
+ "model_name": "HBoxModel",
26
+ "model_module_version": "1.5.0",
27
+ "state": {
28
+ "_dom_classes": [],
29
+ "_model_module": "@jupyter-widgets/controls",
30
+ "_model_module_version": "1.5.0",
31
+ "_model_name": "HBoxModel",
32
+ "_view_count": null,
33
+ "_view_module": "@jupyter-widgets/controls",
34
+ "_view_module_version": "1.5.0",
35
+ "_view_name": "HBoxView",
36
+ "box_style": "",
37
+ "children": [
38
+ "IPY_MODEL_af6cce709f8a478a87c4c89222193d8b",
39
+ "IPY_MODEL_4c32916fe36e4466b9c8a96bbd7db71b",
40
+ "IPY_MODEL_48d80beb4cc646328036336431b01278"
41
+ ],
42
+ "layout": "IPY_MODEL_755b9a30f525493382f771840e4b04f4"
43
+ }
44
+ },
45
+ "af6cce709f8a478a87c4c89222193d8b": {
46
+ "model_module": "@jupyter-widgets/controls",
47
+ "model_name": "HTMLModel",
48
+ "model_module_version": "1.5.0",
49
+ "state": {
50
+ "_dom_classes": [],
51
+ "_model_module": "@jupyter-widgets/controls",
52
+ "_model_module_version": "1.5.0",
53
+ "_model_name": "HTMLModel",
54
+ "_view_count": null,
55
+ "_view_module": "@jupyter-widgets/controls",
56
+ "_view_module_version": "1.5.0",
57
+ "_view_name": "HTMLView",
58
+ "description": "",
59
+ "description_tooltip": null,
60
+ "layout": "IPY_MODEL_481a5ffb8c394dfe88d9df74b3edd372",
61
+ "placeholder": "​",
62
+ "style": "IPY_MODEL_7ddb099532f44b8fbc77bfd94232ff8f",
63
+ "value": "README.md: 100%"
64
+ }
65
+ },
66
+ "4c32916fe36e4466b9c8a96bbd7db71b": {
67
+ "model_module": "@jupyter-widgets/controls",
68
+ "model_name": "FloatProgressModel",
69
+ "model_module_version": "1.5.0",
70
+ "state": {
71
+ "_dom_classes": [],
72
+ "_model_module": "@jupyter-widgets/controls",
73
+ "_model_module_version": "1.5.0",
74
+ "_model_name": "FloatProgressModel",
75
+ "_view_count": null,
76
+ "_view_module": "@jupyter-widgets/controls",
77
+ "_view_module_version": "1.5.0",
78
+ "_view_name": "ProgressView",
79
+ "bar_style": "success",
80
+ "description": "",
81
+ "description_tooltip": null,
82
+ "layout": "IPY_MODEL_ba7928e2ec0e459e870e7f8420fc26f6",
83
+ "max": 360,
84
+ "min": 0,
85
+ "orientation": "horizontal",
86
+ "style": "IPY_MODEL_2422979a81b0425fa82df01a1b4b170a",
87
+ "value": 360
88
+ }
89
+ },
90
+ "48d80beb4cc646328036336431b01278": {
91
+ "model_module": "@jupyter-widgets/controls",
92
+ "model_name": "HTMLModel",
93
+ "model_module_version": "1.5.0",
94
+ "state": {
95
+ "_dom_classes": [],
96
+ "_model_module": "@jupyter-widgets/controls",
97
+ "_model_module_version": "1.5.0",
98
+ "_model_name": "HTMLModel",
99
+ "_view_count": null,
100
+ "_view_module": "@jupyter-widgets/controls",
101
+ "_view_module_version": "1.5.0",
102
+ "_view_name": "HTMLView",
103
+ "description": "",
104
+ "description_tooltip": null,
105
+ "layout": "IPY_MODEL_93694c9083bc4a19a51b854829180dc9",
106
+ "placeholder": "​",
107
+ "style": "IPY_MODEL_8d3e47bf8478457d9db128a9927fc568",
108
+ "value": " 360/360 [00:00<00:00, 16.6kB/s]"
109
+ }
110
+ },
111
+ "755b9a30f525493382f771840e4b04f4": {
112
+ "model_module": "@jupyter-widgets/base",
113
+ "model_name": "LayoutModel",
114
+ "model_module_version": "1.2.0",
115
+ "state": {
116
+ "_model_module": "@jupyter-widgets/base",
117
+ "_model_module_version": "1.2.0",
118
+ "_model_name": "LayoutModel",
119
+ "_view_count": null,
120
+ "_view_module": "@jupyter-widgets/base",
121
+ "_view_module_version": "1.2.0",
122
+ "_view_name": "LayoutView",
123
+ "align_content": null,
124
+ "align_items": null,
125
+ "align_self": null,
126
+ "border": null,
127
+ "bottom": null,
128
+ "display": null,
129
+ "flex": null,
130
+ "flex_flow": null,
131
+ "grid_area": null,
132
+ "grid_auto_columns": null,
133
+ "grid_auto_flow": null,
134
+ "grid_auto_rows": null,
135
+ "grid_column": null,
136
+ "grid_gap": null,
137
+ "grid_row": null,
138
+ "grid_template_areas": null,
139
+ "grid_template_columns": null,
140
+ "grid_template_rows": null,
141
+ "height": null,
142
+ "justify_content": null,
143
+ "justify_items": null,
144
+ "left": null,
145
+ "margin": null,
146
+ "max_height": null,
147
+ "max_width": null,
148
+ "min_height": null,
149
+ "min_width": null,
150
+ "object_fit": null,
151
+ "object_position": null,
152
+ "order": null,
153
+ "overflow": null,
154
+ "overflow_x": null,
155
+ "overflow_y": null,
156
+ "padding": null,
157
+ "right": null,
158
+ "top": null,
159
+ "visibility": null,
160
+ "width": null
161
+ }
162
+ },
163
+ "481a5ffb8c394dfe88d9df74b3edd372": {
164
+ "model_module": "@jupyter-widgets/base",
165
+ "model_name": "LayoutModel",
166
+ "model_module_version": "1.2.0",
167
+ "state": {
168
+ "_model_module": "@jupyter-widgets/base",
169
+ "_model_module_version": "1.2.0",
170
+ "_model_name": "LayoutModel",
171
+ "_view_count": null,
172
+ "_view_module": "@jupyter-widgets/base",
173
+ "_view_module_version": "1.2.0",
174
+ "_view_name": "LayoutView",
175
+ "align_content": null,
176
+ "align_items": null,
177
+ "align_self": null,
178
+ "border": null,
179
+ "bottom": null,
180
+ "display": null,
181
+ "flex": null,
182
+ "flex_flow": null,
183
+ "grid_area": null,
184
+ "grid_auto_columns": null,
185
+ "grid_auto_flow": null,
186
+ "grid_auto_rows": null,
187
+ "grid_column": null,
188
+ "grid_gap": null,
189
+ "grid_row": null,
190
+ "grid_template_areas": null,
191
+ "grid_template_columns": null,
192
+ "grid_template_rows": null,
193
+ "height": null,
194
+ "justify_content": null,
195
+ "justify_items": null,
196
+ "left": null,
197
+ "margin": null,
198
+ "max_height": null,
199
+ "max_width": null,
200
+ "min_height": null,
201
+ "min_width": null,
202
+ "object_fit": null,
203
+ "object_position": null,
204
+ "order": null,
205
+ "overflow": null,
206
+ "overflow_x": null,
207
+ "overflow_y": null,
208
+ "padding": null,
209
+ "right": null,
210
+ "top": null,
211
+ "visibility": null,
212
+ "width": null
213
+ }
214
+ },
215
+ "7ddb099532f44b8fbc77bfd94232ff8f": {
216
+ "model_module": "@jupyter-widgets/controls",
217
+ "model_name": "DescriptionStyleModel",
218
+ "model_module_version": "1.5.0",
219
+ "state": {
220
+ "_model_module": "@jupyter-widgets/controls",
221
+ "_model_module_version": "1.5.0",
222
+ "_model_name": "DescriptionStyleModel",
223
+ "_view_count": null,
224
+ "_view_module": "@jupyter-widgets/base",
225
+ "_view_module_version": "1.2.0",
226
+ "_view_name": "StyleView",
227
+ "description_width": ""
228
+ }
229
+ },
230
+ "ba7928e2ec0e459e870e7f8420fc26f6": {
231
+ "model_module": "@jupyter-widgets/base",
232
+ "model_name": "LayoutModel",
233
+ "model_module_version": "1.2.0",
234
+ "state": {
235
+ "_model_module": "@jupyter-widgets/base",
236
+ "_model_module_version": "1.2.0",
237
+ "_model_name": "LayoutModel",
238
+ "_view_count": null,
239
+ "_view_module": "@jupyter-widgets/base",
240
+ "_view_module_version": "1.2.0",
241
+ "_view_name": "LayoutView",
242
+ "align_content": null,
243
+ "align_items": null,
244
+ "align_self": null,
245
+ "border": null,
246
+ "bottom": null,
247
+ "display": null,
248
+ "flex": null,
249
+ "flex_flow": null,
250
+ "grid_area": null,
251
+ "grid_auto_columns": null,
252
+ "grid_auto_flow": null,
253
+ "grid_auto_rows": null,
254
+ "grid_column": null,
255
+ "grid_gap": null,
256
+ "grid_row": null,
257
+ "grid_template_areas": null,
258
+ "grid_template_columns": null,
259
+ "grid_template_rows": null,
260
+ "height": null,
261
+ "justify_content": null,
262
+ "justify_items": null,
263
+ "left": null,
264
+ "margin": null,
265
+ "max_height": null,
266
+ "max_width": null,
267
+ "min_height": null,
268
+ "min_width": null,
269
+ "object_fit": null,
270
+ "object_position": null,
271
+ "order": null,
272
+ "overflow": null,
273
+ "overflow_x": null,
274
+ "overflow_y": null,
275
+ "padding": null,
276
+ "right": null,
277
+ "top": null,
278
+ "visibility": null,
279
+ "width": null
280
+ }
281
+ },
282
+ "2422979a81b0425fa82df01a1b4b170a": {
283
+ "model_module": "@jupyter-widgets/controls",
284
+ "model_name": "ProgressStyleModel",
285
+ "model_module_version": "1.5.0",
286
+ "state": {
287
+ "_model_module": "@jupyter-widgets/controls",
288
+ "_model_module_version": "1.5.0",
289
+ "_model_name": "ProgressStyleModel",
290
+ "_view_count": null,
291
+ "_view_module": "@jupyter-widgets/base",
292
+ "_view_module_version": "1.2.0",
293
+ "_view_name": "StyleView",
294
+ "bar_color": null,
295
+ "description_width": ""
296
+ }
297
+ },
298
+ "93694c9083bc4a19a51b854829180dc9": {
299
+ "model_module": "@jupyter-widgets/base",
300
+ "model_name": "LayoutModel",
301
+ "model_module_version": "1.2.0",
302
+ "state": {
303
+ "_model_module": "@jupyter-widgets/base",
304
+ "_model_module_version": "1.2.0",
305
+ "_model_name": "LayoutModel",
306
+ "_view_count": null,
307
+ "_view_module": "@jupyter-widgets/base",
308
+ "_view_module_version": "1.2.0",
309
+ "_view_name": "LayoutView",
310
+ "align_content": null,
311
+ "align_items": null,
312
+ "align_self": null,
313
+ "border": null,
314
+ "bottom": null,
315
+ "display": null,
316
+ "flex": null,
317
+ "flex_flow": null,
318
+ "grid_area": null,
319
+ "grid_auto_columns": null,
320
+ "grid_auto_flow": null,
321
+ "grid_auto_rows": null,
322
+ "grid_column": null,
323
+ "grid_gap": null,
324
+ "grid_row": null,
325
+ "grid_template_areas": null,
326
+ "grid_template_columns": null,
327
+ "grid_template_rows": null,
328
+ "height": null,
329
+ "justify_content": null,
330
+ "justify_items": null,
331
+ "left": null,
332
+ "margin": null,
333
+ "max_height": null,
334
+ "max_width": null,
335
+ "min_height": null,
336
+ "min_width": null,
337
+ "object_fit": null,
338
+ "object_position": null,
339
+ "order": null,
340
+ "overflow": null,
341
+ "overflow_x": null,
342
+ "overflow_y": null,
343
+ "padding": null,
344
+ "right": null,
345
+ "top": null,
346
+ "visibility": null,
347
+ "width": null
348
+ }
349
+ },
350
+ "8d3e47bf8478457d9db128a9927fc568": {
351
+ "model_module": "@jupyter-widgets/controls",
352
+ "model_name": "DescriptionStyleModel",
353
+ "model_module_version": "1.5.0",
354
+ "state": {
355
+ "_model_module": "@jupyter-widgets/controls",
356
+ "_model_module_version": "1.5.0",
357
+ "_model_name": "DescriptionStyleModel",
358
+ "_view_count": null,
359
+ "_view_module": "@jupyter-widgets/base",
360
+ "_view_module_version": "1.2.0",
361
+ "_view_name": "StyleView",
362
+ "description_width": ""
363
+ }
364
+ },
365
+ "6d93d4f16f7f409d895c928f4c091619": {
366
+ "model_module": "@jupyter-widgets/controls",
367
+ "model_name": "HBoxModel",
368
+ "model_module_version": "1.5.0",
369
+ "state": {
370
+ "_dom_classes": [],
371
+ "_model_module": "@jupyter-widgets/controls",
372
+ "_model_module_version": "1.5.0",
373
+ "_model_name": "HBoxModel",
374
+ "_view_count": null,
375
+ "_view_module": "@jupyter-widgets/controls",
376
+ "_view_module_version": "1.5.0",
377
+ "_view_name": "HBoxView",
378
+ "box_style": "",
379
+ "children": [
380
+ "IPY_MODEL_bdd1de9927bf4183a39c4eb417b4ee65",
381
+ "IPY_MODEL_ed4e47b387a347f98567e87f4dce2dff",
382
+ "IPY_MODEL_95b6f6c28acc4ff3a7da7d1ac5d1fc2d"
383
+ ],
384
+ "layout": "IPY_MODEL_dcda5b897d8c482f8ce32387af5fdb2b"
385
+ }
386
+ },
387
+ "bdd1de9927bf4183a39c4eb417b4ee65": {
388
+ "model_module": "@jupyter-widgets/controls",
389
+ "model_name": "HTMLModel",
390
+ "model_module_version": "1.5.0",
391
+ "state": {
392
+ "_dom_classes": [],
393
+ "_model_module": "@jupyter-widgets/controls",
394
+ "_model_module_version": "1.5.0",
395
+ "_model_name": "HTMLModel",
396
+ "_view_count": null,
397
+ "_view_module": "@jupyter-widgets/controls",
398
+ "_view_module_version": "1.5.0",
399
+ "_view_name": "HTMLView",
400
+ "description": "",
401
+ "description_tooltip": null,
402
+ "layout": "IPY_MODEL_b2da78bd47b144f49a8202289cc6745a",
403
+ "placeholder": "​",
404
+ "style": "IPY_MODEL_1ed7259217474bcfa1e5f80071fb708e",
405
+ "value": "train-00000-of-00001.parquet: 100%"
406
+ }
407
+ },
408
+ "ed4e47b387a347f98567e87f4dce2dff": {
409
+ "model_module": "@jupyter-widgets/controls",
410
+ "model_name": "FloatProgressModel",
411
+ "model_module_version": "1.5.0",
412
+ "state": {
413
+ "_dom_classes": [],
414
+ "_model_module": "@jupyter-widgets/controls",
415
+ "_model_module_version": "1.5.0",
416
+ "_model_name": "FloatProgressModel",
417
+ "_view_count": null,
418
+ "_view_module": "@jupyter-widgets/controls",
419
+ "_view_module_version": "1.5.0",
420
+ "_view_name": "ProgressView",
421
+ "bar_style": "success",
422
+ "description": "",
423
+ "description_tooltip": null,
424
+ "layout": "IPY_MODEL_2589e545cc864d4095becc8d1f75f263",
425
+ "max": 306697640,
426
+ "min": 0,
427
+ "orientation": "horizontal",
428
+ "style": "IPY_MODEL_c1584c81502c471da4c9d89c3e922813",
429
+ "value": 306697640
430
+ }
431
+ },
432
+ "95b6f6c28acc4ff3a7da7d1ac5d1fc2d": {
433
+ "model_module": "@jupyter-widgets/controls",
434
+ "model_name": "HTMLModel",
435
+ "model_module_version": "1.5.0",
436
+ "state": {
437
+ "_dom_classes": [],
438
+ "_model_module": "@jupyter-widgets/controls",
439
+ "_model_module_version": "1.5.0",
440
+ "_model_name": "HTMLModel",
441
+ "_view_count": null,
442
+ "_view_module": "@jupyter-widgets/controls",
443
+ "_view_module_version": "1.5.0",
444
+ "_view_name": "HTMLView",
445
+ "description": "",
446
+ "description_tooltip": null,
447
+ "layout": "IPY_MODEL_fd455eaad05b4614acaee95e03a44fa0",
448
+ "placeholder": "​",
449
+ "style": "IPY_MODEL_2df5051a2bc743258ef138f14173ccc2",
450
+ "value": " 307M/307M [00:12<00:00, 22.8MB/s]"
451
+ }
452
+ },
453
+ "dcda5b897d8c482f8ce32387af5fdb2b": {
454
+ "model_module": "@jupyter-widgets/base",
455
+ "model_name": "LayoutModel",
456
+ "model_module_version": "1.2.0",
457
+ "state": {
458
+ "_model_module": "@jupyter-widgets/base",
459
+ "_model_module_version": "1.2.0",
460
+ "_model_name": "LayoutModel",
461
+ "_view_count": null,
462
+ "_view_module": "@jupyter-widgets/base",
463
+ "_view_module_version": "1.2.0",
464
+ "_view_name": "LayoutView",
465
+ "align_content": null,
466
+ "align_items": null,
467
+ "align_self": null,
468
+ "border": null,
469
+ "bottom": null,
470
+ "display": null,
471
+ "flex": null,
472
+ "flex_flow": null,
473
+ "grid_area": null,
474
+ "grid_auto_columns": null,
475
+ "grid_auto_flow": null,
476
+ "grid_auto_rows": null,
477
+ "grid_column": null,
478
+ "grid_gap": null,
479
+ "grid_row": null,
480
+ "grid_template_areas": null,
481
+ "grid_template_columns": null,
482
+ "grid_template_rows": null,
483
+ "height": null,
484
+ "justify_content": null,
485
+ "justify_items": null,
486
+ "left": null,
487
+ "margin": null,
488
+ "max_height": null,
489
+ "max_width": null,
490
+ "min_height": null,
491
+ "min_width": null,
492
+ "object_fit": null,
493
+ "object_position": null,
494
+ "order": null,
495
+ "overflow": null,
496
+ "overflow_x": null,
497
+ "overflow_y": null,
498
+ "padding": null,
499
+ "right": null,
500
+ "top": null,
501
+ "visibility": null,
502
+ "width": null
503
+ }
504
+ },
505
+ "b2da78bd47b144f49a8202289cc6745a": {
506
+ "model_module": "@jupyter-widgets/base",
507
+ "model_name": "LayoutModel",
508
+ "model_module_version": "1.2.0",
509
+ "state": {
510
+ "_model_module": "@jupyter-widgets/base",
511
+ "_model_module_version": "1.2.0",
512
+ "_model_name": "LayoutModel",
513
+ "_view_count": null,
514
+ "_view_module": "@jupyter-widgets/base",
515
+ "_view_module_version": "1.2.0",
516
+ "_view_name": "LayoutView",
517
+ "align_content": null,
518
+ "align_items": null,
519
+ "align_self": null,
520
+ "border": null,
521
+ "bottom": null,
522
+ "display": null,
523
+ "flex": null,
524
+ "flex_flow": null,
525
+ "grid_area": null,
526
+ "grid_auto_columns": null,
527
+ "grid_auto_flow": null,
528
+ "grid_auto_rows": null,
529
+ "grid_column": null,
530
+ "grid_gap": null,
531
+ "grid_row": null,
532
+ "grid_template_areas": null,
533
+ "grid_template_columns": null,
534
+ "grid_template_rows": null,
535
+ "height": null,
536
+ "justify_content": null,
537
+ "justify_items": null,
538
+ "left": null,
539
+ "margin": null,
540
+ "max_height": null,
541
+ "max_width": null,
542
+ "min_height": null,
543
+ "min_width": null,
544
+ "object_fit": null,
545
+ "object_position": null,
546
+ "order": null,
547
+ "overflow": null,
548
+ "overflow_x": null,
549
+ "overflow_y": null,
550
+ "padding": null,
551
+ "right": null,
552
+ "top": null,
553
+ "visibility": null,
554
+ "width": null
555
+ }
556
+ },
557
+ "1ed7259217474bcfa1e5f80071fb708e": {
558
+ "model_module": "@jupyter-widgets/controls",
559
+ "model_name": "DescriptionStyleModel",
560
+ "model_module_version": "1.5.0",
561
+ "state": {
562
+ "_model_module": "@jupyter-widgets/controls",
563
+ "_model_module_version": "1.5.0",
564
+ "_model_name": "DescriptionStyleModel",
565
+ "_view_count": null,
566
+ "_view_module": "@jupyter-widgets/base",
567
+ "_view_module_version": "1.2.0",
568
+ "_view_name": "StyleView",
569
+ "description_width": ""
570
+ }
571
+ },
572
+ "2589e545cc864d4095becc8d1f75f263": {
573
+ "model_module": "@jupyter-widgets/base",
574
+ "model_name": "LayoutModel",
575
+ "model_module_version": "1.2.0",
576
+ "state": {
577
+ "_model_module": "@jupyter-widgets/base",
578
+ "_model_module_version": "1.2.0",
579
+ "_model_name": "LayoutModel",
580
+ "_view_count": null,
581
+ "_view_module": "@jupyter-widgets/base",
582
+ "_view_module_version": "1.2.0",
583
+ "_view_name": "LayoutView",
584
+ "align_content": null,
585
+ "align_items": null,
586
+ "align_self": null,
587
+ "border": null,
588
+ "bottom": null,
589
+ "display": null,
590
+ "flex": null,
591
+ "flex_flow": null,
592
+ "grid_area": null,
593
+ "grid_auto_columns": null,
594
+ "grid_auto_flow": null,
595
+ "grid_auto_rows": null,
596
+ "grid_column": null,
597
+ "grid_gap": null,
598
+ "grid_row": null,
599
+ "grid_template_areas": null,
600
+ "grid_template_columns": null,
601
+ "grid_template_rows": null,
602
+ "height": null,
603
+ "justify_content": null,
604
+ "justify_items": null,
605
+ "left": null,
606
+ "margin": null,
607
+ "max_height": null,
608
+ "max_width": null,
609
+ "min_height": null,
610
+ "min_width": null,
611
+ "object_fit": null,
612
+ "object_position": null,
613
+ "order": null,
614
+ "overflow": null,
615
+ "overflow_x": null,
616
+ "overflow_y": null,
617
+ "padding": null,
618
+ "right": null,
619
+ "top": null,
620
+ "visibility": null,
621
+ "width": null
622
+ }
623
+ },
624
+ "c1584c81502c471da4c9d89c3e922813": {
625
+ "model_module": "@jupyter-widgets/controls",
626
+ "model_name": "ProgressStyleModel",
627
+ "model_module_version": "1.5.0",
628
+ "state": {
629
+ "_model_module": "@jupyter-widgets/controls",
630
+ "_model_module_version": "1.5.0",
631
+ "_model_name": "ProgressStyleModel",
632
+ "_view_count": null,
633
+ "_view_module": "@jupyter-widgets/base",
634
+ "_view_module_version": "1.2.0",
635
+ "_view_name": "StyleView",
636
+ "bar_color": null,
637
+ "description_width": ""
638
+ }
639
+ },
640
+ "fd455eaad05b4614acaee95e03a44fa0": {
641
+ "model_module": "@jupyter-widgets/base",
642
+ "model_name": "LayoutModel",
643
+ "model_module_version": "1.2.0",
644
+ "state": {
645
+ "_model_module": "@jupyter-widgets/base",
646
+ "_model_module_version": "1.2.0",
647
+ "_model_name": "LayoutModel",
648
+ "_view_count": null,
649
+ "_view_module": "@jupyter-widgets/base",
650
+ "_view_module_version": "1.2.0",
651
+ "_view_name": "LayoutView",
652
+ "align_content": null,
653
+ "align_items": null,
654
+ "align_self": null,
655
+ "border": null,
656
+ "bottom": null,
657
+ "display": null,
658
+ "flex": null,
659
+ "flex_flow": null,
660
+ "grid_area": null,
661
+ "grid_auto_columns": null,
662
+ "grid_auto_flow": null,
663
+ "grid_auto_rows": null,
664
+ "grid_column": null,
665
+ "grid_gap": null,
666
+ "grid_row": null,
667
+ "grid_template_areas": null,
668
+ "grid_template_columns": null,
669
+ "grid_template_rows": null,
670
+ "height": null,
671
+ "justify_content": null,
672
+ "justify_items": null,
673
+ "left": null,
674
+ "margin": null,
675
+ "max_height": null,
676
+ "max_width": null,
677
+ "min_height": null,
678
+ "min_width": null,
679
+ "object_fit": null,
680
+ "object_position": null,
681
+ "order": null,
682
+ "overflow": null,
683
+ "overflow_x": null,
684
+ "overflow_y": null,
685
+ "padding": null,
686
+ "right": null,
687
+ "top": null,
688
+ "visibility": null,
689
+ "width": null
690
+ }
691
+ },
692
+ "2df5051a2bc743258ef138f14173ccc2": {
693
+ "model_module": "@jupyter-widgets/controls",
694
+ "model_name": "DescriptionStyleModel",
695
+ "model_module_version": "1.5.0",
696
+ "state": {
697
+ "_model_module": "@jupyter-widgets/controls",
698
+ "_model_module_version": "1.5.0",
699
+ "_model_name": "DescriptionStyleModel",
700
+ "_view_count": null,
701
+ "_view_module": "@jupyter-widgets/base",
702
+ "_view_module_version": "1.2.0",
703
+ "_view_name": "StyleView",
704
+ "description_width": ""
705
+ }
706
+ },
707
+ "e2ef3cf0e3ff4ea3a8a0dff3dd73a5f1": {
708
+ "model_module": "@jupyter-widgets/controls",
709
+ "model_name": "HBoxModel",
710
+ "model_module_version": "1.5.0",
711
+ "state": {
712
+ "_dom_classes": [],
713
+ "_model_module": "@jupyter-widgets/controls",
714
+ "_model_module_version": "1.5.0",
715
+ "_model_name": "HBoxModel",
716
+ "_view_count": null,
717
+ "_view_module": "@jupyter-widgets/controls",
718
+ "_view_module_version": "1.5.0",
719
+ "_view_name": "HBoxView",
720
+ "box_style": "",
721
+ "children": [
722
+ "IPY_MODEL_7bac50c73a644c9f9e3369b763cb5db7",
723
+ "IPY_MODEL_24206922f4c64c8aadbaec122804aadf",
724
+ "IPY_MODEL_c1f30aa01b434b0d8f9799503d9601f9"
725
+ ],
726
+ "layout": "IPY_MODEL_7b671494c6754864931f43c546578dcb"
727
+ }
728
+ },
729
+ "7bac50c73a644c9f9e3369b763cb5db7": {
730
+ "model_module": "@jupyter-widgets/controls",
731
+ "model_name": "HTMLModel",
732
+ "model_module_version": "1.5.0",
733
+ "state": {
734
+ "_dom_classes": [],
735
+ "_model_module": "@jupyter-widgets/controls",
736
+ "_model_module_version": "1.5.0",
737
+ "_model_name": "HTMLModel",
738
+ "_view_count": null,
739
+ "_view_module": "@jupyter-widgets/controls",
740
+ "_view_module_version": "1.5.0",
741
+ "_view_name": "HTMLView",
742
+ "description": "",
743
+ "description_tooltip": null,
744
+ "layout": "IPY_MODEL_698b1bbf0fdc47e389e9d8eb5aca93d6",
745
+ "placeholder": "​",
746
+ "style": "IPY_MODEL_82a61b5594dd49f2b1ca5dea552b8d87",
747
+ "value": "Generating train split: 100%"
748
+ }
749
+ },
750
+ "24206922f4c64c8aadbaec122804aadf": {
751
+ "model_module": "@jupyter-widgets/controls",
752
+ "model_name": "FloatProgressModel",
753
+ "model_module_version": "1.5.0",
754
+ "state": {
755
+ "_dom_classes": [],
756
+ "_model_module": "@jupyter-widgets/controls",
757
+ "_model_module_version": "1.5.0",
758
+ "_model_name": "FloatProgressModel",
759
+ "_view_count": null,
760
+ "_view_module": "@jupyter-widgets/controls",
761
+ "_view_module_version": "1.5.0",
762
+ "_view_name": "ProgressView",
763
+ "bar_style": "success",
764
+ "description": "",
765
+ "description_tooltip": null,
766
+ "layout": "IPY_MODEL_be75f5be99c246d5a01186a17181a3c3",
767
+ "max": 100,
768
+ "min": 0,
769
+ "orientation": "horizontal",
770
+ "style": "IPY_MODEL_80300b7040c349ed92ecefb4d3402a7b",
771
+ "value": 100
772
+ }
773
+ },
774
+ "c1f30aa01b434b0d8f9799503d9601f9": {
775
+ "model_module": "@jupyter-widgets/controls",
776
+ "model_name": "HTMLModel",
777
+ "model_module_version": "1.5.0",
778
+ "state": {
779
+ "_dom_classes": [],
780
+ "_model_module": "@jupyter-widgets/controls",
781
+ "_model_module_version": "1.5.0",
782
+ "_model_name": "HTMLModel",
783
+ "_view_count": null,
784
+ "_view_module": "@jupyter-widgets/controls",
785
+ "_view_module_version": "1.5.0",
786
+ "_view_name": "HTMLView",
787
+ "description": "",
788
+ "description_tooltip": null,
789
+ "layout": "IPY_MODEL_0a1996f8fe29482aa0b972c07040d97d",
790
+ "placeholder": "​",
791
+ "style": "IPY_MODEL_fa4c605349df4638a8c71e7aa52db1ad",
792
+ "value": " 100/100 [00:01<00:00, 71.49 examples/s]"
793
+ }
794
+ },
795
+ "7b671494c6754864931f43c546578dcb": {
796
+ "model_module": "@jupyter-widgets/base",
797
+ "model_name": "LayoutModel",
798
+ "model_module_version": "1.2.0",
799
+ "state": {
800
+ "_model_module": "@jupyter-widgets/base",
801
+ "_model_module_version": "1.2.0",
802
+ "_model_name": "LayoutModel",
803
+ "_view_count": null,
804
+ "_view_module": "@jupyter-widgets/base",
805
+ "_view_module_version": "1.2.0",
806
+ "_view_name": "LayoutView",
807
+ "align_content": null,
808
+ "align_items": null,
809
+ "align_self": null,
810
+ "border": null,
811
+ "bottom": null,
812
+ "display": null,
813
+ "flex": null,
814
+ "flex_flow": null,
815
+ "grid_area": null,
816
+ "grid_auto_columns": null,
817
+ "grid_auto_flow": null,
818
+ "grid_auto_rows": null,
819
+ "grid_column": null,
820
+ "grid_gap": null,
821
+ "grid_row": null,
822
+ "grid_template_areas": null,
823
+ "grid_template_columns": null,
824
+ "grid_template_rows": null,
825
+ "height": null,
826
+ "justify_content": null,
827
+ "justify_items": null,
828
+ "left": null,
829
+ "margin": null,
830
+ "max_height": null,
831
+ "max_width": null,
832
+ "min_height": null,
833
+ "min_width": null,
834
+ "object_fit": null,
835
+ "object_position": null,
836
+ "order": null,
837
+ "overflow": null,
838
+ "overflow_x": null,
839
+ "overflow_y": null,
840
+ "padding": null,
841
+ "right": null,
842
+ "top": null,
843
+ "visibility": null,
844
+ "width": null
845
+ }
846
+ },
847
+ "698b1bbf0fdc47e389e9d8eb5aca93d6": {
848
+ "model_module": "@jupyter-widgets/base",
849
+ "model_name": "LayoutModel",
850
+ "model_module_version": "1.2.0",
851
+ "state": {
852
+ "_model_module": "@jupyter-widgets/base",
853
+ "_model_module_version": "1.2.0",
854
+ "_model_name": "LayoutModel",
855
+ "_view_count": null,
856
+ "_view_module": "@jupyter-widgets/base",
857
+ "_view_module_version": "1.2.0",
858
+ "_view_name": "LayoutView",
859
+ "align_content": null,
860
+ "align_items": null,
861
+ "align_self": null,
862
+ "border": null,
863
+ "bottom": null,
864
+ "display": null,
865
+ "flex": null,
866
+ "flex_flow": null,
867
+ "grid_area": null,
868
+ "grid_auto_columns": null,
869
+ "grid_auto_flow": null,
870
+ "grid_auto_rows": null,
871
+ "grid_column": null,
872
+ "grid_gap": null,
873
+ "grid_row": null,
874
+ "grid_template_areas": null,
875
+ "grid_template_columns": null,
876
+ "grid_template_rows": null,
877
+ "height": null,
878
+ "justify_content": null,
879
+ "justify_items": null,
880
+ "left": null,
881
+ "margin": null,
882
+ "max_height": null,
883
+ "max_width": null,
884
+ "min_height": null,
885
+ "min_width": null,
886
+ "object_fit": null,
887
+ "object_position": null,
888
+ "order": null,
889
+ "overflow": null,
890
+ "overflow_x": null,
891
+ "overflow_y": null,
892
+ "padding": null,
893
+ "right": null,
894
+ "top": null,
895
+ "visibility": null,
896
+ "width": null
897
+ }
898
+ },
899
+ "82a61b5594dd49f2b1ca5dea552b8d87": {
900
+ "model_module": "@jupyter-widgets/controls",
901
+ "model_name": "DescriptionStyleModel",
902
+ "model_module_version": "1.5.0",
903
+ "state": {
904
+ "_model_module": "@jupyter-widgets/controls",
905
+ "_model_module_version": "1.5.0",
906
+ "_model_name": "DescriptionStyleModel",
907
+ "_view_count": null,
908
+ "_view_module": "@jupyter-widgets/base",
909
+ "_view_module_version": "1.2.0",
910
+ "_view_name": "StyleView",
911
+ "description_width": ""
912
+ }
913
+ },
914
+ "be75f5be99c246d5a01186a17181a3c3": {
915
+ "model_module": "@jupyter-widgets/base",
916
+ "model_name": "LayoutModel",
917
+ "model_module_version": "1.2.0",
918
+ "state": {
919
+ "_model_module": "@jupyter-widgets/base",
920
+ "_model_module_version": "1.2.0",
921
+ "_model_name": "LayoutModel",
922
+ "_view_count": null,
923
+ "_view_module": "@jupyter-widgets/base",
924
+ "_view_module_version": "1.2.0",
925
+ "_view_name": "LayoutView",
926
+ "align_content": null,
927
+ "align_items": null,
928
+ "align_self": null,
929
+ "border": null,
930
+ "bottom": null,
931
+ "display": null,
932
+ "flex": null,
933
+ "flex_flow": null,
934
+ "grid_area": null,
935
+ "grid_auto_columns": null,
936
+ "grid_auto_flow": null,
937
+ "grid_auto_rows": null,
938
+ "grid_column": null,
939
+ "grid_gap": null,
940
+ "grid_row": null,
941
+ "grid_template_areas": null,
942
+ "grid_template_columns": null,
943
+ "grid_template_rows": null,
944
+ "height": null,
945
+ "justify_content": null,
946
+ "justify_items": null,
947
+ "left": null,
948
+ "margin": null,
949
+ "max_height": null,
950
+ "max_width": null,
951
+ "min_height": null,
952
+ "min_width": null,
953
+ "object_fit": null,
954
+ "object_position": null,
955
+ "order": null,
956
+ "overflow": null,
957
+ "overflow_x": null,
958
+ "overflow_y": null,
959
+ "padding": null,
960
+ "right": null,
961
+ "top": null,
962
+ "visibility": null,
963
+ "width": null
964
+ }
965
+ },
966
+ "80300b7040c349ed92ecefb4d3402a7b": {
967
+ "model_module": "@jupyter-widgets/controls",
968
+ "model_name": "ProgressStyleModel",
969
+ "model_module_version": "1.5.0",
970
+ "state": {
971
+ "_model_module": "@jupyter-widgets/controls",
972
+ "_model_module_version": "1.5.0",
973
+ "_model_name": "ProgressStyleModel",
974
+ "_view_count": null,
975
+ "_view_module": "@jupyter-widgets/base",
976
+ "_view_module_version": "1.2.0",
977
+ "_view_name": "StyleView",
978
+ "bar_color": null,
979
+ "description_width": ""
980
+ }
981
+ },
982
+ "0a1996f8fe29482aa0b972c07040d97d": {
983
+ "model_module": "@jupyter-widgets/base",
984
+ "model_name": "LayoutModel",
985
+ "model_module_version": "1.2.0",
986
+ "state": {
987
+ "_model_module": "@jupyter-widgets/base",
988
+ "_model_module_version": "1.2.0",
989
+ "_model_name": "LayoutModel",
990
+ "_view_count": null,
991
+ "_view_module": "@jupyter-widgets/base",
992
+ "_view_module_version": "1.2.0",
993
+ "_view_name": "LayoutView",
994
+ "align_content": null,
995
+ "align_items": null,
996
+ "align_self": null,
997
+ "border": null,
998
+ "bottom": null,
999
+ "display": null,
1000
+ "flex": null,
1001
+ "flex_flow": null,
1002
+ "grid_area": null,
1003
+ "grid_auto_columns": null,
1004
+ "grid_auto_flow": null,
1005
+ "grid_auto_rows": null,
1006
+ "grid_column": null,
1007
+ "grid_gap": null,
1008
+ "grid_row": null,
1009
+ "grid_template_areas": null,
1010
+ "grid_template_columns": null,
1011
+ "grid_template_rows": null,
1012
+ "height": null,
1013
+ "justify_content": null,
1014
+ "justify_items": null,
1015
+ "left": null,
1016
+ "margin": null,
1017
+ "max_height": null,
1018
+ "max_width": null,
1019
+ "min_height": null,
1020
+ "min_width": null,
1021
+ "object_fit": null,
1022
+ "object_position": null,
1023
+ "order": null,
1024
+ "overflow": null,
1025
+ "overflow_x": null,
1026
+ "overflow_y": null,
1027
+ "padding": null,
1028
+ "right": null,
1029
+ "top": null,
1030
+ "visibility": null,
1031
+ "width": null
1032
+ }
1033
+ },
1034
+ "fa4c605349df4638a8c71e7aa52db1ad": {
1035
+ "model_module": "@jupyter-widgets/controls",
1036
+ "model_name": "DescriptionStyleModel",
1037
+ "model_module_version": "1.5.0",
1038
+ "state": {
1039
+ "_model_module": "@jupyter-widgets/controls",
1040
+ "_model_module_version": "1.5.0",
1041
+ "_model_name": "DescriptionStyleModel",
1042
+ "_view_count": null,
1043
+ "_view_module": "@jupyter-widgets/base",
1044
+ "_view_module_version": "1.2.0",
1045
+ "_view_name": "StyleView",
1046
+ "description_width": ""
1047
+ }
1048
+ }
1049
+ }
1050
+ }
1051
+ },
1052
+ "cells": [
1053
+ {
1054
+ "cell_type": "markdown",
1055
+ "source": [
1056
+ "# Imports and Hugging Face Login"
1057
+ ],
1058
+ "metadata": {
1059
+ "id": "GirPusJtYPsP"
1060
+ }
1061
+ },
1062
+ {
1063
+ "cell_type": "code",
1064
+ "execution_count": 1,
1065
+ "metadata": {
1066
+ "id": "hKdN-6CXXV12",
1067
+ "colab": {
1068
+ "base_uri": "https://localhost:8080/"
1069
+ },
1070
+ "outputId": "a1d9a131-8a76-436c-ac24-c3467b5dcc01"
1071
+ },
1072
+ "outputs": [
1073
+ {
1074
+ "output_type": "stream",
1075
+ "name": "stdout",
1076
+ "text": [
1077
+ "Requirement already satisfied: huggingface-hub in /usr/local/lib/python3.10/dist-packages (0.26.5)\n",
1078
+ "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from huggingface-hub) (3.16.1)\n",
1079
+ "Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub) (2024.10.0)\n",
1080
+ "Requirement already satisfied: packaging>=20.9 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub) (24.2)\n",
1081
+ "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub) (6.0.2)\n",
1082
+ "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from huggingface-hub) (2.32.3)\n",
1083
+ "Requirement already satisfied: tqdm>=4.42.1 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub) (4.66.6)\n",
1084
+ "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub) (4.12.2)\n",
1085
+ "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub) (3.4.0)\n",
1086
+ "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub) (3.10)\n",
1087
+ "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub) (2.2.3)\n",
1088
+ "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface-hub) (2024.8.30)\n",
1089
+ "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
1090
+ "gcsfs 2024.10.0 requires fsspec==2024.10.0, but you have fsspec 2024.9.0 which is incompatible.\u001b[0m\u001b[31m\n",
1091
+ "\u001b[0m"
1092
+ ]
1093
+ }
1094
+ ],
1095
+ "source": [
1096
+ "!pip install huggingface-hub\n",
1097
+ "!pip install datasets > delete.txt"
1098
+ ]
1099
+ },
1100
+ {
1101
+ "cell_type": "code",
1102
+ "source": [
1103
+ "import torch\n",
1104
+ "import pickle\n",
1105
+ "from huggingface_hub import hf_hub_download\n",
1106
+ "from datasets import load_dataset, Image\n",
1107
+ "import torch\n",
1108
+ "from torch import nn, optim\n",
1109
+ "from torch.utils.data import DataLoader, Dataset\n",
1110
+ "import numpy as np\n",
1111
+ "from geopy.distance import geodesic\n",
1112
+ "import matplotlib.pyplot as plt\n",
1113
+ "from torchvision import models"
1114
+ ],
1115
+ "metadata": {
1116
+ "id": "SPzgZOzxYYiT"
1117
+ },
1118
+ "execution_count": 2,
1119
+ "outputs": []
1120
+ },
1121
+ {
1122
+ "cell_type": "code",
1123
+ "source": [
1124
+ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
1125
+ "print(device)"
1126
+ ],
1127
+ "metadata": {
1128
+ "id": "PJquO0g1YaMU",
1129
+ "colab": {
1130
+ "base_uri": "https://localhost:8080/"
1131
+ },
1132
+ "outputId": "d82f4fdc-32ee-4f91-e6ce-558ad3e3c837"
1133
+ },
1134
+ "execution_count": 3,
1135
+ "outputs": [
1136
+ {
1137
+ "output_type": "stream",
1138
+ "name": "stdout",
1139
+ "text": [
1140
+ "cuda\n"
1141
+ ]
1142
+ }
1143
+ ]
1144
+ },
1145
+ {
1146
+ "cell_type": "code",
1147
+ "source": [
1148
+ "!huggingface-cli login\n",
1149
+ "# use appropiate token"
1150
+ ],
1151
+ "metadata": {
1152
+ "id": "IcGfZSsoZgau",
1153
+ "colab": {
1154
+ "base_uri": "https://localhost:8080/"
1155
+ },
1156
+ "outputId": "436dcc6f-a924-4be8-e9a8-39c197e5e1e1"
1157
+ },
1158
+ "execution_count": 4,
1159
+ "outputs": [
1160
+ {
1161
+ "output_type": "stream",
1162
+ "name": "stdout",
1163
+ "text": [
1164
+ "\n",
1165
+ " _| _| _| _| _|_|_| _|_|_| _|_|_| _| _| _|_|_| _|_|_|_| _|_| _|_|_| _|_|_|_|\n",
1166
+ " _| _| _| _| _| _| _| _|_| _| _| _| _| _| _| _|\n",
1167
+ " _|_|_|_| _| _| _| _|_| _| _|_| _| _| _| _| _| _|_| _|_|_| _|_|_|_| _| _|_|_|\n",
1168
+ " _| _| _| _| _| _| _| _| _| _| _|_| _| _| _| _| _| _| _|\n",
1169
+ " _| _| _|_| _|_|_| _|_|_| _|_|_| _| _| _|_|_| _| _| _| _|_|_| _|_|_|_|\n",
1170
+ "\n",
1171
+ " To log in, `huggingface_hub` requires a token generated from https://huggingface.co/settings/tokens .\n",
1172
+ "Enter your token (input will not be visible): \n",
1173
+ "Add token as git credential? (Y/n) y\n",
1174
+ "Token is valid (permission: fineGrained).\n",
1175
+ "The token `CIS 5190 Project 3` has been saved to /root/.cache/huggingface/stored_tokens\n",
1176
+ "\u001b[1m\u001b[31mCannot authenticate through git-credential as no helper is defined on your machine.\n",
1177
+ "You might have to re-authenticate when pushing to the Hugging Face Hub.\n",
1178
+ "Run the following command in your terminal in case you want to set the 'store' credential helper as default.\n",
1179
+ "\n",
1180
+ "git config --global credential.helper store\n",
1181
+ "\n",
1182
+ "Read https://git-scm.com/book/en/v2/Git-Tools-Credential-Storage for more details.\u001b[0m\n",
1183
+ "Token has not been saved to git credential helper.\n",
1184
+ "Your token has been saved to /root/.cache/huggingface/token\n",
1185
+ "Login successful.\n",
1186
+ "The current active token is: `CIS 5190 Project 3`\n"
1187
+ ]
1188
+ }
1189
+ ]
1190
+ },
1191
+ {
1192
+ "cell_type": "markdown",
1193
+ "source": [
1194
+ "# Models and Classes"
1195
+ ],
1196
+ "metadata": {
1197
+ "id": "LplsJ-PXXbtm"
1198
+ }
1199
+ },
1200
+ {
1201
+ "cell_type": "code",
1202
+ "source": [
1203
+ "class EnsembleModel(nn.Module):\n",
1204
+ " def __init__(self, models, num_models):\n",
1205
+ " super(EnsembleModel, self).__init__()\n",
1206
+ " self.models = nn.ModuleList(models)\n",
1207
+ " self.weights = nn.Parameter(torch.ones(num_models) / num_models)\n",
1208
+ "\n",
1209
+ " def forward(self, x):\n",
1210
+ " outputs = torch.stack([model(x) for model in self.models], dim=-1)\n",
1211
+ " weighted_output = torch.einsum('bij,j->bi', outputs, self.weights)\n",
1212
+ " return weighted_output"
1213
+ ],
1214
+ "metadata": {
1215
+ "id": "ofOTpLIPcylC"
1216
+ },
1217
+ "execution_count": 9,
1218
+ "outputs": []
1219
+ },
1220
+ {
1221
+ "cell_type": "code",
1222
+ "source": [
1223
+ "class Model1(nn.Module):\n",
1224
+ " def __init__(self, dropout):\n",
1225
+ " super(Model1, self).__init__()\n",
1226
+ " self.features = nn.Sequential(\n",
1227
+ " nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2),\n",
1228
+ " nn.ReLU(inplace=True),\n",
1229
+ " nn.MaxPool2d(kernel_size=3, stride=2),\n",
1230
+ " nn.Conv2d(64, 192, kernel_size=5, padding=2),\n",
1231
+ " nn.ReLU(inplace=True),\n",
1232
+ " nn.MaxPool2d(kernel_size=3, stride=2),\n",
1233
+ " nn.Conv2d(192, 384, kernel_size=3, padding=1),\n",
1234
+ " nn.ReLU(inplace=True),\n",
1235
+ " nn.Conv2d(384, 256, kernel_size=3, padding=1),\n",
1236
+ " nn.ReLU(inplace=True),\n",
1237
+ " nn.Conv2d(256, 256, kernel_size=3, padding=1),\n",
1238
+ " nn.ReLU(inplace=True),\n",
1239
+ " nn.MaxPool2d(kernel_size=3, stride=2),\n",
1240
+ " )\n",
1241
+ " self.classifier = nn.Sequential(\n",
1242
+ " nn.Dropout(p=dropout),\n",
1243
+ " nn.Linear(256 * 6 * 6, 1024),\n",
1244
+ " nn.ReLU(inplace=True),\n",
1245
+ " nn.Dropout(p=dropout),\n",
1246
+ " nn.Linear(1024, 512),\n",
1247
+ " nn.ReLU(inplace=True),\n",
1248
+ " nn.Linear(512, 2),\n",
1249
+ " )\n",
1250
+ "\n",
1251
+ " def forward(self, x):\n",
1252
+ " x = self.features(x)\n",
1253
+ " x = torch.flatten(x, 1)\n",
1254
+ " x = self.classifier(x)\n",
1255
+ " return x\n",
1256
+ "\n",
1257
+ "\n",
1258
+ "def model_fn(dropout):\n",
1259
+ " return Model1(dropout)"
1260
+ ],
1261
+ "metadata": {
1262
+ "id": "fbtZvQrlYGfU"
1263
+ },
1264
+ "execution_count": 10,
1265
+ "outputs": []
1266
+ },
1267
+ {
1268
+ "cell_type": "code",
1269
+ "source": [
1270
+ "class Model2(nn.Module):\n",
1271
+ " def __init__(self, num_blocks=3, dropout_rate=0.5):\n",
1272
+ " super(Model2, self).__init__()\n",
1273
+ "\n",
1274
+ " resnet = models.resnet34(pretrained=True)\n",
1275
+ "\n",
1276
+ " for param in list(resnet.parameters())[:num_blocks]:\n",
1277
+ " param.requires_grad = False\n",
1278
+ "\n",
1279
+ " self.features = nn.Sequential(*list(resnet.children())[:-2])\n",
1280
+ " self.avgpool = nn.AdaptiveAvgPool2d((1, 1))\n",
1281
+ "\n",
1282
+ " self.classifier = nn.Sequential(\n",
1283
+ " nn.Flatten(),\n",
1284
+ " nn.Dropout(p=dropout_rate),\n",
1285
+ " nn.Linear(resnet.fc.in_features, 512),\n",
1286
+ " nn.ReLU(inplace=True),\n",
1287
+ " nn.Dropout(p=dropout_rate),\n",
1288
+ " nn.Linear(512, 2)\n",
1289
+ " )\n",
1290
+ "\n",
1291
+ " def forward(self, x):\n",
1292
+ " x = self.features(x)\n",
1293
+ " x = self.avgpool(x)\n",
1294
+ " x = self.classifier(x)\n",
1295
+ " return x"
1296
+ ],
1297
+ "metadata": {
1298
+ "id": "iBssHEtGXdWi"
1299
+ },
1300
+ "execution_count": 11,
1301
+ "outputs": []
1302
+ },
1303
+ {
1304
+ "cell_type": "code",
1305
+ "source": [
1306
+ "class InceptionModule(nn.Module):\n",
1307
+ " def __init__(self, in_channels, ch1x1, ch3x3_reduce, ch3x3, ch5x5_reduce, ch5x5, pool_proj):\n",
1308
+ " super(InceptionModule, self).__init__()\n",
1309
+ "\n",
1310
+ " self.branch1 = nn.Sequential(\n",
1311
+ " nn.Conv2d(in_channels, ch1x1, kernel_size=1),\n",
1312
+ " nn.ReLU(inplace=True)\n",
1313
+ " )\n",
1314
+ " self.branch2 = nn.Sequential(\n",
1315
+ " nn.Conv2d(in_channels, ch3x3_reduce, kernel_size=1),\n",
1316
+ " nn.ReLU(inplace=True),\n",
1317
+ " nn.Conv2d(ch3x3_reduce, ch3x3, kernel_size=3, padding=1),\n",
1318
+ " nn.ReLU(inplace=True)\n",
1319
+ " )\n",
1320
+ "\n",
1321
+ " self.branch3 = nn.Sequential(\n",
1322
+ " nn.Conv2d(in_channels, ch5x5_reduce, kernel_size=1),\n",
1323
+ " nn.ReLU(inplace=True),\n",
1324
+ " nn.Conv2d(ch5x5_reduce, ch5x5, kernel_size=5, padding=2),\n",
1325
+ " nn.ReLU(inplace=True)\n",
1326
+ " )\n",
1327
+ "\n",
1328
+ " self.branch4 = nn.Sequential(\n",
1329
+ " nn.MaxPool2d(kernel_size=3, stride=1, padding=1),\n",
1330
+ " nn.Conv2d(in_channels, pool_proj, kernel_size=1),\n",
1331
+ " nn.ReLU(inplace=True)\n",
1332
+ " )\n",
1333
+ "\n",
1334
+ " def forward(self, x):\n",
1335
+ " branch1 = self.branch1(x)\n",
1336
+ " branch2 = self.branch2(x)\n",
1337
+ " branch3 = self.branch3(x)\n",
1338
+ " branch4 = self.branch4(x)\n",
1339
+ " outputs = torch.cat([branch1, branch2, branch3, branch4], 1)\n",
1340
+ " return outputs\n",
1341
+ "\n",
1342
+ "class Model4(nn.Module):\n",
1343
+ " def __init__(self, dropout_rate=0.5):\n",
1344
+ " super(Model4, self).__init__()\n",
1345
+ "\n",
1346
+ " self.pre_layers = nn.Sequential(\n",
1347
+ " nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3),\n",
1348
+ " nn.ReLU(inplace=True),\n",
1349
+ " nn.MaxPool2d(kernel_size=3, stride=2, padding=1),\n",
1350
+ " nn.Conv2d(64, 192, kernel_size=3, padding=1),\n",
1351
+ " nn.ReLU(inplace=True),\n",
1352
+ " nn.MaxPool2d(kernel_size=3, stride=2, padding=1)\n",
1353
+ " )\n",
1354
+ "\n",
1355
+ "\n",
1356
+ " self.inception1 = InceptionModule(192, 64, 96, 128, 16, 32, 32)\n",
1357
+ " self.inception2 = InceptionModule(256, 128, 128, 192, 32, 96, 64)\n",
1358
+ "\n",
1359
+ " self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)\n",
1360
+ "\n",
1361
+ " self.inception3 = InceptionModule(480, 192, 96, 208, 16, 48, 64)\n",
1362
+ " self.inception4 = InceptionModule(512, 160, 112, 224, 24, 64, 64)\n",
1363
+ "\n",
1364
+ " self.avgpool = nn.AdaptiveAvgPool2d((1, 1))\n",
1365
+ " self.classifier = nn.Sequential(\n",
1366
+ " nn.Flatten(),\n",
1367
+ " nn.Dropout(p=dropout_rate),\n",
1368
+ " nn.Linear(512, 1024),\n",
1369
+ " nn.ReLU(inplace=True),\n",
1370
+ " nn.Dropout(p=dropout_rate),\n",
1371
+ " nn.Linear(1024, 512),\n",
1372
+ " nn.ReLU(inplace=True),\n",
1373
+ " nn.Linear(512, 2)\n",
1374
+ " )\n",
1375
+ "\n",
1376
+ " def forward(self, x):\n",
1377
+ " x = self.pre_layers(x)\n",
1378
+ " x = self.inception1(x)\n",
1379
+ " x = self.inception2(x)\n",
1380
+ " x = self.maxpool(x)\n",
1381
+ " x = self.inception3(x)\n",
1382
+ " x = self.inception4(x)\n",
1383
+ " x = self.avgpool(x)\n",
1384
+ " x = self.classifier(x)\n",
1385
+ " return x"
1386
+ ],
1387
+ "metadata": {
1388
+ "id": "c4y6R0A3XjcI"
1389
+ },
1390
+ "execution_count": 12,
1391
+ "outputs": []
1392
+ },
1393
+ {
1394
+ "cell_type": "markdown",
1395
+ "source": [
1396
+ "# Load Test Dataset"
1397
+ ],
1398
+ "metadata": {
1399
+ "id": "ybwRXm3zYg_I"
1400
+ }
1401
+ },
1402
+ {
1403
+ "cell_type": "code",
1404
+ "source": [
1405
+ "from torch.utils.data import Dataset\n",
1406
+ "class GPSImageDataset(Dataset):\n",
1407
+ " def __init__(self, hf_dataset, transform, lat_mean=None, lat_std=None, lon_mean=None, lon_std=None):\n",
1408
+ " self.hf_dataset = hf_dataset\n",
1409
+ " self.transform = transform\n",
1410
+ "\n",
1411
+ " # Normalize the latitude and longitude\n",
1412
+ " self.latitudes = np.array(hf_dataset['Latitude'])\n",
1413
+ " self.longitudes = np.array(hf_dataset['Longitude'])\n",
1414
+ " self.latitude_mean = lat_mean if lat_mean is not None else self.latitudes.mean()\n",
1415
+ " self.latitude_std = lat_std if lat_std is not None else self.latitudes.std()\n",
1416
+ " self.longitude_mean = lon_mean if lon_mean is not None else self.longitudes.mean()\n",
1417
+ " self.longitude_std = lon_std if lon_std is not None else self.longitudes.std()\n",
1418
+ "\n",
1419
+ " self.normalized_latitudes = (self.latitudes - self.latitude_mean) / self.latitude_std\n",
1420
+ " self.normalized_longitudes = (self.longitudes - self.longitude_mean) / self.longitude_std\n",
1421
+ "\n",
1422
+ " def __len__(self):\n",
1423
+ " return len(self.hf_dataset)\n",
1424
+ "\n",
1425
+ " def __getitem__(self, idx):\n",
1426
+ " image = self.hf_dataset[idx]['image']\n",
1427
+ " latitude = self.normalized_latitudes[idx]\n",
1428
+ " longitude = self.normalized_longitudes[idx]\n",
1429
+ "\n",
1430
+ " if self.transform:\n",
1431
+ " image = self.transform(image)\n",
1432
+ "\n",
1433
+ " return image, torch.tensor([latitude, longitude], dtype=torch.float)"
1434
+ ],
1435
+ "metadata": {
1436
+ "id": "EfCxgZxMY7b6"
1437
+ },
1438
+ "execution_count": 14,
1439
+ "outputs": []
1440
+ },
1441
+ {
1442
+ "cell_type": "code",
1443
+ "source": [
1444
+ "from torchvision import transforms, models\n",
1445
+ "transform = transforms.Compose([\n",
1446
+ " transforms.RandomResizedCrop(224),\n",
1447
+ " transforms.RandomHorizontalFlip(),\n",
1448
+ " transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),\n",
1449
+ " transforms.ToTensor(),\n",
1450
+ " transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n",
1451
+ "])\n",
1452
+ "\n",
1453
+ "inference_transform = transforms.Compose([\n",
1454
+ " transforms.Resize((224, 224)),\n",
1455
+ " transforms.ToTensor(),\n",
1456
+ " transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])\n",
1457
+ "])"
1458
+ ],
1459
+ "metadata": {
1460
+ "id": "P4Gx6KLQXz4E"
1461
+ },
1462
+ "execution_count": 15,
1463
+ "outputs": []
1464
+ },
1465
+ {
1466
+ "cell_type": "code",
1467
+ "source": [
1468
+ "dataset_test = load_dataset(\"gydou/released_img\")"
1469
+ ],
1470
+ "metadata": {
1471
+ "id": "NTFvFWpRYgcM",
1472
+ "colab": {
1473
+ "base_uri": "https://localhost:8080/",
1474
+ "height": 217,
1475
+ "referenced_widgets": [
1476
+ "c46dc091acd34be2887c59bf95838529",
1477
+ "af6cce709f8a478a87c4c89222193d8b",
1478
+ "4c32916fe36e4466b9c8a96bbd7db71b",
1479
+ "48d80beb4cc646328036336431b01278",
1480
+ "755b9a30f525493382f771840e4b04f4",
1481
+ "481a5ffb8c394dfe88d9df74b3edd372",
1482
+ "7ddb099532f44b8fbc77bfd94232ff8f",
1483
+ "ba7928e2ec0e459e870e7f8420fc26f6",
1484
+ "2422979a81b0425fa82df01a1b4b170a",
1485
+ "93694c9083bc4a19a51b854829180dc9",
1486
+ "8d3e47bf8478457d9db128a9927fc568",
1487
+ "6d93d4f16f7f409d895c928f4c091619",
1488
+ "bdd1de9927bf4183a39c4eb417b4ee65",
1489
+ "ed4e47b387a347f98567e87f4dce2dff",
1490
+ "95b6f6c28acc4ff3a7da7d1ac5d1fc2d",
1491
+ "dcda5b897d8c482f8ce32387af5fdb2b",
1492
+ "b2da78bd47b144f49a8202289cc6745a",
1493
+ "1ed7259217474bcfa1e5f80071fb708e",
1494
+ "2589e545cc864d4095becc8d1f75f263",
1495
+ "c1584c81502c471da4c9d89c3e922813",
1496
+ "fd455eaad05b4614acaee95e03a44fa0",
1497
+ "2df5051a2bc743258ef138f14173ccc2",
1498
+ "e2ef3cf0e3ff4ea3a8a0dff3dd73a5f1",
1499
+ "7bac50c73a644c9f9e3369b763cb5db7",
1500
+ "24206922f4c64c8aadbaec122804aadf",
1501
+ "c1f30aa01b434b0d8f9799503d9601f9",
1502
+ "7b671494c6754864931f43c546578dcb",
1503
+ "698b1bbf0fdc47e389e9d8eb5aca93d6",
1504
+ "82a61b5594dd49f2b1ca5dea552b8d87",
1505
+ "be75f5be99c246d5a01186a17181a3c3",
1506
+ "80300b7040c349ed92ecefb4d3402a7b",
1507
+ "0a1996f8fe29482aa0b972c07040d97d",
1508
+ "fa4c605349df4638a8c71e7aa52db1ad"
1509
+ ]
1510
+ },
1511
+ "outputId": "877c8003-7541-4eb2-bfd5-92540f2d2381"
1512
+ },
1513
+ "execution_count": 16,
1514
+ "outputs": [
1515
+ {
1516
+ "output_type": "stream",
1517
+ "name": "stderr",
1518
+ "text": [
1519
+ "/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_auth.py:94: UserWarning: \n",
1520
+ "The secret `HF_TOKEN` does not exist in your Colab secrets.\n",
1521
+ "To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.\n",
1522
+ "You will be able to reuse this secret in all of your notebooks.\n",
1523
+ "Please note that authentication is recommended but still optional to access public models or datasets.\n",
1524
+ " warnings.warn(\n"
1525
+ ]
1526
+ },
1527
+ {
1528
+ "output_type": "display_data",
1529
+ "data": {
1530
+ "text/plain": [
1531
+ "README.md: 0%| | 0.00/360 [00:00<?, ?B/s]"
1532
+ ],
1533
+ "application/vnd.jupyter.widget-view+json": {
1534
+ "version_major": 2,
1535
+ "version_minor": 0,
1536
+ "model_id": "c46dc091acd34be2887c59bf95838529"
1537
+ }
1538
+ },
1539
+ "metadata": {}
1540
+ },
1541
+ {
1542
+ "output_type": "display_data",
1543
+ "data": {
1544
+ "text/plain": [
1545
+ "train-00000-of-00001.parquet: 0%| | 0.00/307M [00:00<?, ?B/s]"
1546
+ ],
1547
+ "application/vnd.jupyter.widget-view+json": {
1548
+ "version_major": 2,
1549
+ "version_minor": 0,
1550
+ "model_id": "6d93d4f16f7f409d895c928f4c091619"
1551
+ }
1552
+ },
1553
+ "metadata": {}
1554
+ },
1555
+ {
1556
+ "output_type": "display_data",
1557
+ "data": {
1558
+ "text/plain": [
1559
+ "Generating train split: 0%| | 0/100 [00:00<?, ? examples/s]"
1560
+ ],
1561
+ "application/vnd.jupyter.widget-view+json": {
1562
+ "version_major": 2,
1563
+ "version_minor": 0,
1564
+ "model_id": "e2ef3cf0e3ff4ea3a8a0dff3dd73a5f1"
1565
+ }
1566
+ },
1567
+ "metadata": {}
1568
+ }
1569
+ ]
1570
+ },
1571
+ {
1572
+ "cell_type": "code",
1573
+ "source": [
1574
+ "lat_mean = 39.9517411499467\n",
1575
+ "lat_std = 0.0006914493505038013\n",
1576
+ "lon_mean = -75.19143213125122\n",
1577
+ "lon_std = 0.0006539239061573955\n",
1578
+ "\n",
1579
+ "test_dataset = GPSImageDataset(\n",
1580
+ " hf_dataset=dataset_test['train'],\n",
1581
+ " transform=inference_transform,\n",
1582
+ " lat_mean=lat_mean,\n",
1583
+ " lat_std=lat_std,\n",
1584
+ " lon_mean=lon_mean,\n",
1585
+ " lon_std=lon_std\n",
1586
+ ")\n",
1587
+ "\n",
1588
+ "test_dataloader = DataLoader(\n",
1589
+ " test_dataset,\n",
1590
+ " batch_size=32,\n",
1591
+ " shuffle=False,\n",
1592
+ " num_workers=4\n",
1593
+ ")"
1594
+ ],
1595
+ "metadata": {
1596
+ "id": "S2nsXhmOZTiS"
1597
+ },
1598
+ "execution_count": 41,
1599
+ "outputs": []
1600
+ },
1601
+ {
1602
+ "cell_type": "markdown",
1603
+ "source": [
1604
+ "# Loading Our Model from Pickle File"
1605
+ ],
1606
+ "metadata": {
1607
+ "id": "VOYuBGqYZUKR"
1608
+ }
1609
+ },
1610
+ {
1611
+ "cell_type": "code",
1612
+ "source": [
1613
+ "pickle_file_path = hf_hub_download(repo_id= \"CIS-5190-CIA/Ensemble_Version_2\", filename=\"ensemble_model_ver2.pkl\")"
1614
+ ],
1615
+ "metadata": {
1616
+ "id": "ELSgBmAGZaUJ"
1617
+ },
1618
+ "execution_count": 34,
1619
+ "outputs": []
1620
+ },
1621
+ {
1622
+ "cell_type": "code",
1623
+ "source": [
1624
+ "def load_ensemble(file_name, model_classes, device=\"cpu\"):\n",
1625
+ " \"\"\"\n",
1626
+ " Load the ensemble model and individual model weights from a pickle file.\n",
1627
+ "\n",
1628
+ " Args:\n",
1629
+ " file_name: Path to the saved pickle file.\n",
1630
+ " model_classes: A dictionary mapping model names to their classes (e.g., {\"Model1\": Model1, ...}).\n",
1631
+ " device: Device to load the models onto (default is \"cpu\").\n",
1632
+ "\n",
1633
+ " Returns:\n",
1634
+ " trained_models: A dictionary of reloaded models (key -> list of models for each type).\n",
1635
+ " ensemble_weights: Numpy array of ensemble weights.\n",
1636
+ " \"\"\"\n",
1637
+ " # Load the pickle file\n",
1638
+ " with open(file_name, \"rb\") as f:\n",
1639
+ " ensemble_data = pickle.load(f)\n",
1640
+ "\n",
1641
+ " # Extract the ensemble weights\n",
1642
+ " ensemble_weights = ensemble_data[\"ensemble_weights\"]\n",
1643
+ "\n",
1644
+ " # Reload the individual models\n",
1645
+ " trained_models = {}\n",
1646
+ " for model_name, state_dicts in ensemble_data[\"models\"].items():\n",
1647
+ " trained_models[model_name] = []\n",
1648
+ " for state_dict in state_dicts:\n",
1649
+ " model = model_classes[model_name]()\n",
1650
+ " model.load_state_dict(state_dict)\n",
1651
+ " model = model.to(device)\n",
1652
+ " trained_models[model_name].append(model)\n",
1653
+ "\n",
1654
+ " return trained_models, ensemble_weights"
1655
+ ],
1656
+ "metadata": {
1657
+ "id": "1PygE9aMZ4xm"
1658
+ },
1659
+ "execution_count": 43,
1660
+ "outputs": []
1661
+ },
1662
+ {
1663
+ "cell_type": "code",
1664
+ "source": [
1665
+ "model_classes = {\n",
1666
+ " \"Model1\": lambda: Model1(dropout=0.5),\n",
1667
+ " \"Model2\": lambda: Model2(num_blocks=3, dropout_rate=0.5),\n",
1668
+ " \"Model4\": lambda: Model4(dropout_rate=0.5)\n",
1669
+ "}\n",
1670
+ "\n",
1671
+ "# Load the ensemble\n",
1672
+ "trained_models, ensemble_weights = load_ensemble(pickle_file_path, model_classes, device=\"cuda\")\n",
1673
+ "models_ensemble = []\n",
1674
+ "for model_list in trained_models.values():\n",
1675
+ " models_ensemble.extend(model_list)\n",
1676
+ "\n",
1677
+ "# ensemble model\n",
1678
+ "ensemble_model = EnsembleModel(models=models_ensemble, num_models=len(models_ensemble))\n",
1679
+ "ensemble_model.weights.data = torch.tensor(ensemble_weights, dtype=torch.float32, device=\"cuda\")\n",
1680
+ "ensemble_model = ensemble_model.to(\"cuda\")"
1681
+ ],
1682
+ "metadata": {
1683
+ "id": "WpGJ4SIrZ9G2"
1684
+ },
1685
+ "execution_count": 44,
1686
+ "outputs": []
1687
+ },
1688
+ {
1689
+ "cell_type": "markdown",
1690
+ "source": [
1691
+ "# Evaluation"
1692
+ ],
1693
+ "metadata": {
1694
+ "id": "PN94YVq0dMX1"
1695
+ }
1696
+ },
1697
+ {
1698
+ "cell_type": "code",
1699
+ "source": [
1700
+ "def evaluate_final_rmse(ensemble_model, data_loader, lat_mean, lon_mean, lat_std, lon_std):\n",
1701
+ " \"\"\"\n",
1702
+ " Evaluate the ensemble model on a given dataset and compute final RMSE in meters.\n",
1703
+ " \"\"\"\n",
1704
+ " ensemble_model.eval()\n",
1705
+ " total_loss = 0.0\n",
1706
+ " total_samples = 0\n",
1707
+ "\n",
1708
+ " with torch.no_grad():\n",
1709
+ " for images, targets in data_loader:\n",
1710
+ " images, targets = images.to(device), targets.to(device)\n",
1711
+ " outputs = ensemble_model(images)\n",
1712
+ " preds_denorm = outputs.cpu().numpy() * np.array([lat_std, lon_std]) + np.array([lat_mean, lon_mean])\n",
1713
+ " actuals_denorm = targets.cpu().numpy() * np.array([lat_std, lon_std]) + np.array([lat_mean, lon_mean])\n",
1714
+ "\n",
1715
+ " for pred, actual in zip(preds_denorm, actuals_denorm):\n",
1716
+ " distance = geodesic((actual[0], actual[1]), (pred[0], pred[1])).meters\n",
1717
+ " total_loss += distance ** 2\n",
1718
+ " total_samples += targets.size(0)\n",
1719
+ "\n",
1720
+ " final_loss = total_loss / total_samples\n",
1721
+ " final_rmse = np.sqrt(final_loss)\n",
1722
+ "\n",
1723
+ " return final_loss, final_rmse"
1724
+ ],
1725
+ "metadata": {
1726
+ "id": "zUhrqOv5cNag"
1727
+ },
1728
+ "execution_count": 47,
1729
+ "outputs": []
1730
+ },
1731
+ {
1732
+ "cell_type": "code",
1733
+ "source": [
1734
+ "final_test_loss, final_test_rmse = evaluate_final_rmse(\n",
1735
+ " ensemble_model=ensemble_model,\n",
1736
+ " data_loader=test_dataloader,\n",
1737
+ " lat_mean=lat_mean,\n",
1738
+ " lon_mean=lon_mean,\n",
1739
+ " lat_std=lat_std,\n",
1740
+ " lon_std=lon_std\n",
1741
+ ")\n",
1742
+ "\n",
1743
+ "print(f\"Test Loss (meters^2): {final_test_loss:.2f}\")\n",
1744
+ "print(f\"Test RMSE (meters): {final_test_rmse:.2f}\")"
1745
+ ],
1746
+ "metadata": {
1747
+ "colab": {
1748
+ "base_uri": "https://localhost:8080/"
1749
+ },
1750
+ "id": "-UZcLgmBcM-q",
1751
+ "outputId": "5ed71053-5017-48e5-d9ec-825ca01a8124"
1752
+ },
1753
+ "execution_count": 48,
1754
+ "outputs": [
1755
+ {
1756
+ "output_type": "stream",
1757
+ "name": "stdout",
1758
+ "text": [
1759
+ "Test Loss (meters^2): 8089.13\n",
1760
+ "Test RMSE (meters): 89.94\n"
1761
+ ]
1762
+ }
1763
+ ]
1764
+ },
1765
+ {
1766
+ "cell_type": "markdown",
1767
+ "source": [
1768
+ "# Visualizatoin"
1769
+ ],
1770
+ "metadata": {
1771
+ "id": "C-7gft4ddTzo"
1772
+ }
1773
+ },
1774
+ {
1775
+ "cell_type": "code",
1776
+ "source": [
1777
+ "def visualize_predictions(all_preds, all_actuals, lat_mean, lon_mean, lat_std, lon_std):\n",
1778
+ " \"\"\"\n",
1779
+ " Visualizes actual and predicted GPS coordinates on a scatter plot,\n",
1780
+ " including error lines connecting each prediction to its corresponding actual point.\n",
1781
+ " \"\"\"\n",
1782
+ "\n",
1783
+ " all_preds_denorm = all_preds * np.array([lat_std, lon_std]) + np.array([lat_mean, lon_mean])\n",
1784
+ " all_actuals_denorm = all_actuals * np.array([lat_std, lon_std]) + np.array([lat_mean, lon_mean])\n",
1785
+ "\n",
1786
+ " plt.figure(figsize=(10, 5))\n",
1787
+ "\n",
1788
+ " plt.scatter(all_actuals_denorm[:, 1], all_actuals_denorm[:, 0], label='Actual', color='blue', alpha=0.6)\n",
1789
+ " plt.scatter(all_preds_denorm[:, 1], all_preds_denorm[:, 0], label='Predicted', color='red', alpha=0.6)\n",
1790
+ " for i in range(len(all_actuals_denorm)):\n",
1791
+ " plt.plot(\n",
1792
+ " [all_actuals_denorm[i, 1], all_preds_denorm[i, 1]],\n",
1793
+ " [all_actuals_denorm[i, 0], all_preds_denorm[i, 0]],\n",
1794
+ " color='gray', linewidth=0.5\n",
1795
+ " )\n",
1796
+ "\n",
1797
+ " plt.legend()\n",
1798
+ " plt.xlabel('Longitude')\n",
1799
+ " plt.ylabel('Latitude')\n",
1800
+ " plt.title('Actual vs. Predicted GPS Coordinates with Error Lines')\n",
1801
+ " plt.grid(True)\n",
1802
+ " plt.show()"
1803
+ ],
1804
+ "metadata": {
1805
+ "id": "W1O4anKmd1o7"
1806
+ },
1807
+ "execution_count": 49,
1808
+ "outputs": []
1809
+ },
1810
+ {
1811
+ "cell_type": "code",
1812
+ "source": [
1813
+ "ensemble_model.eval()\n",
1814
+ "\n",
1815
+ "all_preds = []\n",
1816
+ "all_actuals = []\n",
1817
+ "\n",
1818
+ "with torch.no_grad():\n",
1819
+ " for images, targets in test_dataloader:\n",
1820
+ " images = images.to(\"cuda\")\n",
1821
+ " targets = targets.to(\"cuda\")\n",
1822
+ "\n",
1823
+ " preds = ensemble_model(images)\n",
1824
+ "\n",
1825
+ " all_preds.append(preds.cpu().numpy())\n",
1826
+ " all_actuals.append(targets.cpu().numpy())\n",
1827
+ "\n",
1828
+ "all_preds = np.concatenate(all_preds, axis=0)\n",
1829
+ "all_actuals = np.concatenate(all_actuals, axis=0)\n",
1830
+ "\n",
1831
+ "visualize_predictions(\n",
1832
+ " all_preds=all_preds,\n",
1833
+ " all_actuals=all_actuals,\n",
1834
+ " lat_mean=lat_mean,\n",
1835
+ " lon_mean=lon_mean,\n",
1836
+ " lat_std=lat_std,\n",
1837
+ " lon_std=lon_std\n",
1838
+ ")"
1839
+ ],
1840
+ "metadata": {
1841
+ "id": "m8IiYdxJdYy_"
1842
+ },
1843
+ "execution_count": null,
1844
+ "outputs": []
1845
+ }
1846
+ ]
1847
+ }