AshleyBanksNIHR commited on
Commit
58b7cb1
·
verified ·
1 Parent(s): 8c64f5f

Delete inference/inference_script.ipynb

Browse files
Files changed (1) hide show
  1. inference/inference_script.ipynb +0 -1490
inference/inference_script.ipynb DELETED
@@ -1,1490 +0,0 @@
1
- {
2
- "nbformat": 4,
3
- "nbformat_minor": 0,
4
- "metadata": {
5
- "colab": {
6
- "provenance": []
7
- },
8
- "kernelspec": {
9
- "name": "python3",
10
- "display_name": "Python 3"
11
- },
12
- "language_info": {
13
- "name": "python"
14
- },
15
- "widgets": {
16
- "application/vnd.jupyter.widget-state+json": {
17
- "d6d7368c1c7b4a549d9a706ee1a62785": {
18
- "model_module": "@jupyter-widgets/controls",
19
- "model_name": "HBoxModel",
20
- "model_module_version": "1.5.0",
21
- "state": {
22
- "_dom_classes": [],
23
- "_model_module": "@jupyter-widgets/controls",
24
- "_model_module_version": "1.5.0",
25
- "_model_name": "HBoxModel",
26
- "_view_count": null,
27
- "_view_module": "@jupyter-widgets/controls",
28
- "_view_module_version": "1.5.0",
29
- "_view_name": "HBoxView",
30
- "box_style": "",
31
- "children": [
32
- "IPY_MODEL_0e22f97db509402a9b4d4a2fa8c80fe0",
33
- "IPY_MODEL_475728820f8e4b6fa0c33a1e9989179a",
34
- "IPY_MODEL_cf5e926c8c004525b5aad36d821ef928"
35
- ],
36
- "layout": "IPY_MODEL_9275c5cf7b314dc69419611128cb0c09"
37
- }
38
- },
39
- "0e22f97db509402a9b4d4a2fa8c80fe0": {
40
- "model_module": "@jupyter-widgets/controls",
41
- "model_name": "HTMLModel",
42
- "model_module_version": "1.5.0",
43
- "state": {
44
- "_dom_classes": [],
45
- "_model_module": "@jupyter-widgets/controls",
46
- "_model_module_version": "1.5.0",
47
- "_model_name": "HTMLModel",
48
- "_view_count": null,
49
- "_view_module": "@jupyter-widgets/controls",
50
- "_view_module_version": "1.5.0",
51
- "_view_name": "HTMLView",
52
- "description": "",
53
- "description_tooltip": null,
54
- "layout": "IPY_MODEL_926c9a0562f3455b958c71bd0a138558",
55
- "placeholder": "​",
56
- "style": "IPY_MODEL_ef276c365d604dcc83b4a0133c6b7652",
57
- "value": "Loading weights: 100%"
58
- }
59
- },
60
- "475728820f8e4b6fa0c33a1e9989179a": {
61
- "model_module": "@jupyter-widgets/controls",
62
- "model_name": "FloatProgressModel",
63
- "model_module_version": "1.5.0",
64
- "state": {
65
- "_dom_classes": [],
66
- "_model_module": "@jupyter-widgets/controls",
67
- "_model_module_version": "1.5.0",
68
- "_model_name": "FloatProgressModel",
69
- "_view_count": null,
70
- "_view_module": "@jupyter-widgets/controls",
71
- "_view_module_version": "1.5.0",
72
- "_view_name": "ProgressView",
73
- "bar_style": "success",
74
- "description": "",
75
- "description_tooltip": null,
76
- "layout": "IPY_MODEL_3dc315e087f44330badbb6749b0611a9",
77
- "max": 393,
78
- "min": 0,
79
- "orientation": "horizontal",
80
- "style": "IPY_MODEL_bc6c346b95424eb881b793bf6fd2f53e",
81
- "value": 393
82
- }
83
- },
84
- "cf5e926c8c004525b5aad36d821ef928": {
85
- "model_module": "@jupyter-widgets/controls",
86
- "model_name": "HTMLModel",
87
- "model_module_version": "1.5.0",
88
- "state": {
89
- "_dom_classes": [],
90
- "_model_module": "@jupyter-widgets/controls",
91
- "_model_module_version": "1.5.0",
92
- "_model_name": "HTMLModel",
93
- "_view_count": null,
94
- "_view_module": "@jupyter-widgets/controls",
95
- "_view_module_version": "1.5.0",
96
- "_view_name": "HTMLView",
97
- "description": "",
98
- "description_tooltip": null,
99
- "layout": "IPY_MODEL_c26db3dd5a86405b81f510fd759c3af7",
100
- "placeholder": "​",
101
- "style": "IPY_MODEL_dae2b5afeee243e682322c1ecb8238e6",
102
- "value": " 393/393 [00:01<00:00, 353.83it/s, Materializing param=classifier.weight]"
103
- }
104
- },
105
- "9275c5cf7b314dc69419611128cb0c09": {
106
- "model_module": "@jupyter-widgets/base",
107
- "model_name": "LayoutModel",
108
- "model_module_version": "1.2.0",
109
- "state": {
110
- "_model_module": "@jupyter-widgets/base",
111
- "_model_module_version": "1.2.0",
112
- "_model_name": "LayoutModel",
113
- "_view_count": null,
114
- "_view_module": "@jupyter-widgets/base",
115
- "_view_module_version": "1.2.0",
116
- "_view_name": "LayoutView",
117
- "align_content": null,
118
- "align_items": null,
119
- "align_self": null,
120
- "border": null,
121
- "bottom": null,
122
- "display": null,
123
- "flex": null,
124
- "flex_flow": null,
125
- "grid_area": null,
126
- "grid_auto_columns": null,
127
- "grid_auto_flow": null,
128
- "grid_auto_rows": null,
129
- "grid_column": null,
130
- "grid_gap": null,
131
- "grid_row": null,
132
- "grid_template_areas": null,
133
- "grid_template_columns": null,
134
- "grid_template_rows": null,
135
- "height": null,
136
- "justify_content": null,
137
- "justify_items": null,
138
- "left": null,
139
- "margin": null,
140
- "max_height": null,
141
- "max_width": null,
142
- "min_height": null,
143
- "min_width": null,
144
- "object_fit": null,
145
- "object_position": null,
146
- "order": null,
147
- "overflow": null,
148
- "overflow_x": null,
149
- "overflow_y": null,
150
- "padding": null,
151
- "right": null,
152
- "top": null,
153
- "visibility": null,
154
- "width": null
155
- }
156
- },
157
- "926c9a0562f3455b958c71bd0a138558": {
158
- "model_module": "@jupyter-widgets/base",
159
- "model_name": "LayoutModel",
160
- "model_module_version": "1.2.0",
161
- "state": {
162
- "_model_module": "@jupyter-widgets/base",
163
- "_model_module_version": "1.2.0",
164
- "_model_name": "LayoutModel",
165
- "_view_count": null,
166
- "_view_module": "@jupyter-widgets/base",
167
- "_view_module_version": "1.2.0",
168
- "_view_name": "LayoutView",
169
- "align_content": null,
170
- "align_items": null,
171
- "align_self": null,
172
- "border": null,
173
- "bottom": null,
174
- "display": null,
175
- "flex": null,
176
- "flex_flow": null,
177
- "grid_area": null,
178
- "grid_auto_columns": null,
179
- "grid_auto_flow": null,
180
- "grid_auto_rows": null,
181
- "grid_column": null,
182
- "grid_gap": null,
183
- "grid_row": null,
184
- "grid_template_areas": null,
185
- "grid_template_columns": null,
186
- "grid_template_rows": null,
187
- "height": null,
188
- "justify_content": null,
189
- "justify_items": null,
190
- "left": null,
191
- "margin": null,
192
- "max_height": null,
193
- "max_width": null,
194
- "min_height": null,
195
- "min_width": null,
196
- "object_fit": null,
197
- "object_position": null,
198
- "order": null,
199
- "overflow": null,
200
- "overflow_x": null,
201
- "overflow_y": null,
202
- "padding": null,
203
- "right": null,
204
- "top": null,
205
- "visibility": null,
206
- "width": null
207
- }
208
- },
209
- "ef276c365d604dcc83b4a0133c6b7652": {
210
- "model_module": "@jupyter-widgets/controls",
211
- "model_name": "DescriptionStyleModel",
212
- "model_module_version": "1.5.0",
213
- "state": {
214
- "_model_module": "@jupyter-widgets/controls",
215
- "_model_module_version": "1.5.0",
216
- "_model_name": "DescriptionStyleModel",
217
- "_view_count": null,
218
- "_view_module": "@jupyter-widgets/base",
219
- "_view_module_version": "1.2.0",
220
- "_view_name": "StyleView",
221
- "description_width": ""
222
- }
223
- },
224
- "3dc315e087f44330badbb6749b0611a9": {
225
- "model_module": "@jupyter-widgets/base",
226
- "model_name": "LayoutModel",
227
- "model_module_version": "1.2.0",
228
- "state": {
229
- "_model_module": "@jupyter-widgets/base",
230
- "_model_module_version": "1.2.0",
231
- "_model_name": "LayoutModel",
232
- "_view_count": null,
233
- "_view_module": "@jupyter-widgets/base",
234
- "_view_module_version": "1.2.0",
235
- "_view_name": "LayoutView",
236
- "align_content": null,
237
- "align_items": null,
238
- "align_self": null,
239
- "border": null,
240
- "bottom": null,
241
- "display": null,
242
- "flex": null,
243
- "flex_flow": null,
244
- "grid_area": null,
245
- "grid_auto_columns": null,
246
- "grid_auto_flow": null,
247
- "grid_auto_rows": null,
248
- "grid_column": null,
249
- "grid_gap": null,
250
- "grid_row": null,
251
- "grid_template_areas": null,
252
- "grid_template_columns": null,
253
- "grid_template_rows": null,
254
- "height": null,
255
- "justify_content": null,
256
- "justify_items": null,
257
- "left": null,
258
- "margin": null,
259
- "max_height": null,
260
- "max_width": null,
261
- "min_height": null,
262
- "min_width": null,
263
- "object_fit": null,
264
- "object_position": null,
265
- "order": null,
266
- "overflow": null,
267
- "overflow_x": null,
268
- "overflow_y": null,
269
- "padding": null,
270
- "right": null,
271
- "top": null,
272
- "visibility": null,
273
- "width": null
274
- }
275
- },
276
- "bc6c346b95424eb881b793bf6fd2f53e": {
277
- "model_module": "@jupyter-widgets/controls",
278
- "model_name": "ProgressStyleModel",
279
- "model_module_version": "1.5.0",
280
- "state": {
281
- "_model_module": "@jupyter-widgets/controls",
282
- "_model_module_version": "1.5.0",
283
- "_model_name": "ProgressStyleModel",
284
- "_view_count": null,
285
- "_view_module": "@jupyter-widgets/base",
286
- "_view_module_version": "1.2.0",
287
- "_view_name": "StyleView",
288
- "bar_color": null,
289
- "description_width": ""
290
- }
291
- },
292
- "c26db3dd5a86405b81f510fd759c3af7": {
293
- "model_module": "@jupyter-widgets/base",
294
- "model_name": "LayoutModel",
295
- "model_module_version": "1.2.0",
296
- "state": {
297
- "_model_module": "@jupyter-widgets/base",
298
- "_model_module_version": "1.2.0",
299
- "_model_name": "LayoutModel",
300
- "_view_count": null,
301
- "_view_module": "@jupyter-widgets/base",
302
- "_view_module_version": "1.2.0",
303
- "_view_name": "LayoutView",
304
- "align_content": null,
305
- "align_items": null,
306
- "align_self": null,
307
- "border": null,
308
- "bottom": null,
309
- "display": null,
310
- "flex": null,
311
- "flex_flow": null,
312
- "grid_area": null,
313
- "grid_auto_columns": null,
314
- "grid_auto_flow": null,
315
- "grid_auto_rows": null,
316
- "grid_column": null,
317
- "grid_gap": null,
318
- "grid_row": null,
319
- "grid_template_areas": null,
320
- "grid_template_columns": null,
321
- "grid_template_rows": null,
322
- "height": null,
323
- "justify_content": null,
324
- "justify_items": null,
325
- "left": null,
326
- "margin": null,
327
- "max_height": null,
328
- "max_width": null,
329
- "min_height": null,
330
- "min_width": null,
331
- "object_fit": null,
332
- "object_position": null,
333
- "order": null,
334
- "overflow": null,
335
- "overflow_x": null,
336
- "overflow_y": null,
337
- "padding": null,
338
- "right": null,
339
- "top": null,
340
- "visibility": null,
341
- "width": null
342
- }
343
- },
344
- "dae2b5afeee243e682322c1ecb8238e6": {
345
- "model_module": "@jupyter-widgets/controls",
346
- "model_name": "DescriptionStyleModel",
347
- "model_module_version": "1.5.0",
348
- "state": {
349
- "_model_module": "@jupyter-widgets/controls",
350
- "_model_module_version": "1.5.0",
351
- "_model_name": "DescriptionStyleModel",
352
- "_view_count": null,
353
- "_view_module": "@jupyter-widgets/base",
354
- "_view_module_version": "1.2.0",
355
- "_view_name": "StyleView",
356
- "description_width": ""
357
- }
358
- },
359
- "f66e121a24e441c2a4cb32fb87d22632": {
360
- "model_module": "@jupyter-widgets/controls",
361
- "model_name": "HBoxModel",
362
- "model_module_version": "1.5.0",
363
- "state": {
364
- "_dom_classes": [],
365
- "_model_module": "@jupyter-widgets/controls",
366
- "_model_module_version": "1.5.0",
367
- "_model_name": "HBoxModel",
368
- "_view_count": null,
369
- "_view_module": "@jupyter-widgets/controls",
370
- "_view_module_version": "1.5.0",
371
- "_view_name": "HBoxView",
372
- "box_style": "",
373
- "children": [
374
- "IPY_MODEL_1776c0895e7f400f85d3a2123cb573f1",
375
- "IPY_MODEL_e7fd82b589eb49aaba73aeafcc340421",
376
- "IPY_MODEL_210ec65a874a4f51ab9f059d3de804a9"
377
- ],
378
- "layout": "IPY_MODEL_c91d89b903ce4f02b635454d3bd87877"
379
- }
380
- },
381
- "1776c0895e7f400f85d3a2123cb573f1": {
382
- "model_module": "@jupyter-widgets/controls",
383
- "model_name": "HTMLModel",
384
- "model_module_version": "1.5.0",
385
- "state": {
386
- "_dom_classes": [],
387
- "_model_module": "@jupyter-widgets/controls",
388
- "_model_module_version": "1.5.0",
389
- "_model_name": "HTMLModel",
390
- "_view_count": null,
391
- "_view_module": "@jupyter-widgets/controls",
392
- "_view_module_version": "1.5.0",
393
- "_view_name": "HTMLView",
394
- "description": "",
395
- "description_tooltip": null,
396
- "layout": "IPY_MODEL_9cba2384346645499516a7923ad6fcef",
397
- "placeholder": "​",
398
- "style": "IPY_MODEL_491f71a8889c4c7aa0042e99f374be79",
399
- "value": "Loading weights: 100%"
400
- }
401
- },
402
- "e7fd82b589eb49aaba73aeafcc340421": {
403
- "model_module": "@jupyter-widgets/controls",
404
- "model_name": "FloatProgressModel",
405
- "model_module_version": "1.5.0",
406
- "state": {
407
- "_dom_classes": [],
408
- "_model_module": "@jupyter-widgets/controls",
409
- "_model_module_version": "1.5.0",
410
- "_model_name": "FloatProgressModel",
411
- "_view_count": null,
412
- "_view_module": "@jupyter-widgets/controls",
413
- "_view_module_version": "1.5.0",
414
- "_view_name": "ProgressView",
415
- "bar_style": "success",
416
- "description": "",
417
- "description_tooltip": null,
418
- "layout": "IPY_MODEL_9f60072613d84754a0efe7813d853565",
419
- "max": 393,
420
- "min": 0,
421
- "orientation": "horizontal",
422
- "style": "IPY_MODEL_bf48a3e692ae4e6da5379152f51baa09",
423
- "value": 393
424
- }
425
- },
426
- "210ec65a874a4f51ab9f059d3de804a9": {
427
- "model_module": "@jupyter-widgets/controls",
428
- "model_name": "HTMLModel",
429
- "model_module_version": "1.5.0",
430
- "state": {
431
- "_dom_classes": [],
432
- "_model_module": "@jupyter-widgets/controls",
433
- "_model_module_version": "1.5.0",
434
- "_model_name": "HTMLModel",
435
- "_view_count": null,
436
- "_view_module": "@jupyter-widgets/controls",
437
- "_view_module_version": "1.5.0",
438
- "_view_name": "HTMLView",
439
- "description": "",
440
- "description_tooltip": null,
441
- "layout": "IPY_MODEL_25529d785e9047e1b9bd197a7ec56773",
442
- "placeholder": "​",
443
- "style": "IPY_MODEL_9dda1a77133c4247bac106a9c4699153",
444
- "value": " 393/393 [00:00<00:00, 581.79it/s, Materializing param=classifier.weight]"
445
- }
446
- },
447
- "c91d89b903ce4f02b635454d3bd87877": {
448
- "model_module": "@jupyter-widgets/base",
449
- "model_name": "LayoutModel",
450
- "model_module_version": "1.2.0",
451
- "state": {
452
- "_model_module": "@jupyter-widgets/base",
453
- "_model_module_version": "1.2.0",
454
- "_model_name": "LayoutModel",
455
- "_view_count": null,
456
- "_view_module": "@jupyter-widgets/base",
457
- "_view_module_version": "1.2.0",
458
- "_view_name": "LayoutView",
459
- "align_content": null,
460
- "align_items": null,
461
- "align_self": null,
462
- "border": null,
463
- "bottom": null,
464
- "display": null,
465
- "flex": null,
466
- "flex_flow": null,
467
- "grid_area": null,
468
- "grid_auto_columns": null,
469
- "grid_auto_flow": null,
470
- "grid_auto_rows": null,
471
- "grid_column": null,
472
- "grid_gap": null,
473
- "grid_row": null,
474
- "grid_template_areas": null,
475
- "grid_template_columns": null,
476
- "grid_template_rows": null,
477
- "height": null,
478
- "justify_content": null,
479
- "justify_items": null,
480
- "left": null,
481
- "margin": null,
482
- "max_height": null,
483
- "max_width": null,
484
- "min_height": null,
485
- "min_width": null,
486
- "object_fit": null,
487
- "object_position": null,
488
- "order": null,
489
- "overflow": null,
490
- "overflow_x": null,
491
- "overflow_y": null,
492
- "padding": null,
493
- "right": null,
494
- "top": null,
495
- "visibility": null,
496
- "width": null
497
- }
498
- },
499
- "9cba2384346645499516a7923ad6fcef": {
500
- "model_module": "@jupyter-widgets/base",
501
- "model_name": "LayoutModel",
502
- "model_module_version": "1.2.0",
503
- "state": {
504
- "_model_module": "@jupyter-widgets/base",
505
- "_model_module_version": "1.2.0",
506
- "_model_name": "LayoutModel",
507
- "_view_count": null,
508
- "_view_module": "@jupyter-widgets/base",
509
- "_view_module_version": "1.2.0",
510
- "_view_name": "LayoutView",
511
- "align_content": null,
512
- "align_items": null,
513
- "align_self": null,
514
- "border": null,
515
- "bottom": null,
516
- "display": null,
517
- "flex": null,
518
- "flex_flow": null,
519
- "grid_area": null,
520
- "grid_auto_columns": null,
521
- "grid_auto_flow": null,
522
- "grid_auto_rows": null,
523
- "grid_column": null,
524
- "grid_gap": null,
525
- "grid_row": null,
526
- "grid_template_areas": null,
527
- "grid_template_columns": null,
528
- "grid_template_rows": null,
529
- "height": null,
530
- "justify_content": null,
531
- "justify_items": null,
532
- "left": null,
533
- "margin": null,
534
- "max_height": null,
535
- "max_width": null,
536
- "min_height": null,
537
- "min_width": null,
538
- "object_fit": null,
539
- "object_position": null,
540
- "order": null,
541
- "overflow": null,
542
- "overflow_x": null,
543
- "overflow_y": null,
544
- "padding": null,
545
- "right": null,
546
- "top": null,
547
- "visibility": null,
548
- "width": null
549
- }
550
- },
551
- "491f71a8889c4c7aa0042e99f374be79": {
552
- "model_module": "@jupyter-widgets/controls",
553
- "model_name": "DescriptionStyleModel",
554
- "model_module_version": "1.5.0",
555
- "state": {
556
- "_model_module": "@jupyter-widgets/controls",
557
- "_model_module_version": "1.5.0",
558
- "_model_name": "DescriptionStyleModel",
559
- "_view_count": null,
560
- "_view_module": "@jupyter-widgets/base",
561
- "_view_module_version": "1.2.0",
562
- "_view_name": "StyleView",
563
- "description_width": ""
564
- }
565
- },
566
- "9f60072613d84754a0efe7813d853565": {
567
- "model_module": "@jupyter-widgets/base",
568
- "model_name": "LayoutModel",
569
- "model_module_version": "1.2.0",
570
- "state": {
571
- "_model_module": "@jupyter-widgets/base",
572
- "_model_module_version": "1.2.0",
573
- "_model_name": "LayoutModel",
574
- "_view_count": null,
575
- "_view_module": "@jupyter-widgets/base",
576
- "_view_module_version": "1.2.0",
577
- "_view_name": "LayoutView",
578
- "align_content": null,
579
- "align_items": null,
580
- "align_self": null,
581
- "border": null,
582
- "bottom": null,
583
- "display": null,
584
- "flex": null,
585
- "flex_flow": null,
586
- "grid_area": null,
587
- "grid_auto_columns": null,
588
- "grid_auto_flow": null,
589
- "grid_auto_rows": null,
590
- "grid_column": null,
591
- "grid_gap": null,
592
- "grid_row": null,
593
- "grid_template_areas": null,
594
- "grid_template_columns": null,
595
- "grid_template_rows": null,
596
- "height": null,
597
- "justify_content": null,
598
- "justify_items": null,
599
- "left": null,
600
- "margin": null,
601
- "max_height": null,
602
- "max_width": null,
603
- "min_height": null,
604
- "min_width": null,
605
- "object_fit": null,
606
- "object_position": null,
607
- "order": null,
608
- "overflow": null,
609
- "overflow_x": null,
610
- "overflow_y": null,
611
- "padding": null,
612
- "right": null,
613
- "top": null,
614
- "visibility": null,
615
- "width": null
616
- }
617
- },
618
- "bf48a3e692ae4e6da5379152f51baa09": {
619
- "model_module": "@jupyter-widgets/controls",
620
- "model_name": "ProgressStyleModel",
621
- "model_module_version": "1.5.0",
622
- "state": {
623
- "_model_module": "@jupyter-widgets/controls",
624
- "_model_module_version": "1.5.0",
625
- "_model_name": "ProgressStyleModel",
626
- "_view_count": null,
627
- "_view_module": "@jupyter-widgets/base",
628
- "_view_module_version": "1.2.0",
629
- "_view_name": "StyleView",
630
- "bar_color": null,
631
- "description_width": ""
632
- }
633
- },
634
- "25529d785e9047e1b9bd197a7ec56773": {
635
- "model_module": "@jupyter-widgets/base",
636
- "model_name": "LayoutModel",
637
- "model_module_version": "1.2.0",
638
- "state": {
639
- "_model_module": "@jupyter-widgets/base",
640
- "_model_module_version": "1.2.0",
641
- "_model_name": "LayoutModel",
642
- "_view_count": null,
643
- "_view_module": "@jupyter-widgets/base",
644
- "_view_module_version": "1.2.0",
645
- "_view_name": "LayoutView",
646
- "align_content": null,
647
- "align_items": null,
648
- "align_self": null,
649
- "border": null,
650
- "bottom": null,
651
- "display": null,
652
- "flex": null,
653
- "flex_flow": null,
654
- "grid_area": null,
655
- "grid_auto_columns": null,
656
- "grid_auto_flow": null,
657
- "grid_auto_rows": null,
658
- "grid_column": null,
659
- "grid_gap": null,
660
- "grid_row": null,
661
- "grid_template_areas": null,
662
- "grid_template_columns": null,
663
- "grid_template_rows": null,
664
- "height": null,
665
- "justify_content": null,
666
- "justify_items": null,
667
- "left": null,
668
- "margin": null,
669
- "max_height": null,
670
- "max_width": null,
671
- "min_height": null,
672
- "min_width": null,
673
- "object_fit": null,
674
- "object_position": null,
675
- "order": null,
676
- "overflow": null,
677
- "overflow_x": null,
678
- "overflow_y": null,
679
- "padding": null,
680
- "right": null,
681
- "top": null,
682
- "visibility": null,
683
- "width": null
684
- }
685
- },
686
- "9dda1a77133c4247bac106a9c4699153": {
687
- "model_module": "@jupyter-widgets/controls",
688
- "model_name": "DescriptionStyleModel",
689
- "model_module_version": "1.5.0",
690
- "state": {
691
- "_model_module": "@jupyter-widgets/controls",
692
- "_model_module_version": "1.5.0",
693
- "_model_name": "DescriptionStyleModel",
694
- "_view_count": null,
695
- "_view_module": "@jupyter-widgets/base",
696
- "_view_module_version": "1.2.0",
697
- "_view_name": "StyleView",
698
- "description_width": ""
699
- }
700
- }
701
- }
702
- }
703
- },
704
- "cells": [
705
- {
706
- "cell_type": "code",
707
- "execution_count": 1,
708
- "metadata": {
709
- "colab": {
710
- "base_uri": "https://localhost:8080/",
711
- "height": 262
712
- },
713
- "collapsed": true,
714
- "id": "PsYjjGH4bfrR",
715
- "outputId": "a5976404-da8d-4d09-ed23-fa176ae51cc1"
716
- },
717
- "outputs": [
718
- {
719
- "output_type": "display_data",
720
- "data": {
721
- "text/plain": [
722
- "<IPython.core.display.HTML object>"
723
- ],
724
- "text/html": [
725
- "\n",
726
- " <input type=\"file\" id=\"files-a67e1396-3138-4501-bfed-a8cfae954d7f\" name=\"files[]\" multiple disabled\n",
727
- " style=\"border:none\" />\n",
728
- " <output id=\"result-a67e1396-3138-4501-bfed-a8cfae954d7f\">\n",
729
- " Upload widget is only available when the cell has been executed in the\n",
730
- " current browser session. Please rerun this cell to enable.\n",
731
- " </output>\n",
732
- " <script>// Copyright 2017 Google LLC\n",
733
- "//\n",
734
- "// Licensed under the Apache License, Version 2.0 (the \"License\");\n",
735
- "// you may not use this file except in compliance with the License.\n",
736
- "// You may obtain a copy of the License at\n",
737
- "//\n",
738
- "// http://www.apache.org/licenses/LICENSE-2.0\n",
739
- "//\n",
740
- "// Unless required by applicable law or agreed to in writing, software\n",
741
- "// distributed under the License is distributed on an \"AS IS\" BASIS,\n",
742
- "// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
743
- "// See the License for the specific language governing permissions and\n",
744
- "// limitations under the License.\n",
745
- "\n",
746
- "/**\n",
747
- " * @fileoverview Helpers for google.colab Python module.\n",
748
- " */\n",
749
- "(function(scope) {\n",
750
- "function span(text, styleAttributes = {}) {\n",
751
- " const element = document.createElement('span');\n",
752
- " element.textContent = text;\n",
753
- " for (const key of Object.keys(styleAttributes)) {\n",
754
- " element.style[key] = styleAttributes[key];\n",
755
- " }\n",
756
- " return element;\n",
757
- "}\n",
758
- "\n",
759
- "// Max number of bytes which will be uploaded at a time.\n",
760
- "const MAX_PAYLOAD_SIZE = 100 * 1024;\n",
761
- "\n",
762
- "function _uploadFiles(inputId, outputId) {\n",
763
- " const steps = uploadFilesStep(inputId, outputId);\n",
764
- " const outputElement = document.getElementById(outputId);\n",
765
- " // Cache steps on the outputElement to make it available for the next call\n",
766
- " // to uploadFilesContinue from Python.\n",
767
- " outputElement.steps = steps;\n",
768
- "\n",
769
- " return _uploadFilesContinue(outputId);\n",
770
- "}\n",
771
- "\n",
772
- "// This is roughly an async generator (not supported in the browser yet),\n",
773
- "// where there are multiple asynchronous steps and the Python side is going\n",
774
- "// to poll for completion of each step.\n",
775
- "// This uses a Promise to block the python side on completion of each step,\n",
776
- "// then passes the result of the previous step as the input to the next step.\n",
777
- "function _uploadFilesContinue(outputId) {\n",
778
- " const outputElement = document.getElementById(outputId);\n",
779
- " const steps = outputElement.steps;\n",
780
- "\n",
781
- " const next = steps.next(outputElement.lastPromiseValue);\n",
782
- " return Promise.resolve(next.value.promise).then((value) => {\n",
783
- " // Cache the last promise value to make it available to the next\n",
784
- " // step of the generator.\n",
785
- " outputElement.lastPromiseValue = value;\n",
786
- " return next.value.response;\n",
787
- " });\n",
788
- "}\n",
789
- "\n",
790
- "/**\n",
791
- " * Generator function which is called between each async step of the upload\n",
792
- " * process.\n",
793
- " * @param {string} inputId Element ID of the input file picker element.\n",
794
- " * @param {string} outputId Element ID of the output display.\n",
795
- " * @return {!Iterable<!Object>} Iterable of next steps.\n",
796
- " */\n",
797
- "function* uploadFilesStep(inputId, outputId) {\n",
798
- " const inputElement = document.getElementById(inputId);\n",
799
- " inputElement.disabled = false;\n",
800
- "\n",
801
- " const outputElement = document.getElementById(outputId);\n",
802
- " outputElement.innerHTML = '';\n",
803
- "\n",
804
- " const pickedPromise = new Promise((resolve) => {\n",
805
- " inputElement.addEventListener('change', (e) => {\n",
806
- " resolve(e.target.files);\n",
807
- " });\n",
808
- " });\n",
809
- "\n",
810
- " const cancel = document.createElement('button');\n",
811
- " inputElement.parentElement.appendChild(cancel);\n",
812
- " cancel.textContent = 'Cancel upload';\n",
813
- " const cancelPromise = new Promise((resolve) => {\n",
814
- " cancel.onclick = () => {\n",
815
- " resolve(null);\n",
816
- " };\n",
817
- " });\n",
818
- "\n",
819
- " // Wait for the user to pick the files.\n",
820
- " const files = yield {\n",
821
- " promise: Promise.race([pickedPromise, cancelPromise]),\n",
822
- " response: {\n",
823
- " action: 'starting',\n",
824
- " }\n",
825
- " };\n",
826
- "\n",
827
- " cancel.remove();\n",
828
- "\n",
829
- " // Disable the input element since further picks are not allowed.\n",
830
- " inputElement.disabled = true;\n",
831
- "\n",
832
- " if (!files) {\n",
833
- " return {\n",
834
- " response: {\n",
835
- " action: 'complete',\n",
836
- " }\n",
837
- " };\n",
838
- " }\n",
839
- "\n",
840
- " for (const file of files) {\n",
841
- " const li = document.createElement('li');\n",
842
- " li.append(span(file.name, {fontWeight: 'bold'}));\n",
843
- " li.append(span(\n",
844
- " `(${file.type || 'n/a'}) - ${file.size} bytes, ` +\n",
845
- " `last modified: ${\n",
846
- " file.lastModifiedDate ? file.lastModifiedDate.toLocaleDateString() :\n",
847
- " 'n/a'} - `));\n",
848
- " const percent = span('0% done');\n",
849
- " li.appendChild(percent);\n",
850
- "\n",
851
- " outputElement.appendChild(li);\n",
852
- "\n",
853
- " const fileDataPromise = new Promise((resolve) => {\n",
854
- " const reader = new FileReader();\n",
855
- " reader.onload = (e) => {\n",
856
- " resolve(e.target.result);\n",
857
- " };\n",
858
- " reader.readAsArrayBuffer(file);\n",
859
- " });\n",
860
- " // Wait for the data to be ready.\n",
861
- " let fileData = yield {\n",
862
- " promise: fileDataPromise,\n",
863
- " response: {\n",
864
- " action: 'continue',\n",
865
- " }\n",
866
- " };\n",
867
- "\n",
868
- " // Use a chunked sending to avoid message size limits. See b/62115660.\n",
869
- " let position = 0;\n",
870
- " do {\n",
871
- " const length = Math.min(fileData.byteLength - position, MAX_PAYLOAD_SIZE);\n",
872
- " const chunk = new Uint8Array(fileData, position, length);\n",
873
- " position += length;\n",
874
- "\n",
875
- " const base64 = btoa(String.fromCharCode.apply(null, chunk));\n",
876
- " yield {\n",
877
- " response: {\n",
878
- " action: 'append',\n",
879
- " file: file.name,\n",
880
- " data: base64,\n",
881
- " },\n",
882
- " };\n",
883
- "\n",
884
- " let percentDone = fileData.byteLength === 0 ?\n",
885
- " 100 :\n",
886
- " Math.round((position / fileData.byteLength) * 100);\n",
887
- " percent.textContent = `${percentDone}% done`;\n",
888
- "\n",
889
- " } while (position < fileData.byteLength);\n",
890
- " }\n",
891
- "\n",
892
- " // All done.\n",
893
- " yield {\n",
894
- " response: {\n",
895
- " action: 'complete',\n",
896
- " }\n",
897
- " };\n",
898
- "}\n",
899
- "\n",
900
- "scope.google = scope.google || {};\n",
901
- "scope.google.colab = scope.google.colab || {};\n",
902
- "scope.google.colab._files = {\n",
903
- " _uploadFiles,\n",
904
- " _uploadFilesContinue,\n",
905
- "};\n",
906
- "})(self);\n",
907
- "</script> "
908
- ]
909
- },
910
- "metadata": {}
911
- },
912
- {
913
- "output_type": "stream",
914
- "name": "stdout",
915
- "text": [
916
- "Saving test_data.csv to test_data.csv\n"
917
- ]
918
- },
919
- {
920
- "output_type": "execute_result",
921
- "data": {
922
- "text/plain": [
923
- " ID AwardTitle \\\n",
924
- "0 synthetic1 Phase III trial of novel immunotherapy for adv... \n",
925
- "1 synthetic2 Genetic and environmental risk factors in earl... \n",
926
- "2 synthetic3 Community-based dietary interventions to reduc... \n",
927
- "3 synthetic4 Structural analysis of the viral envelope protein \n",
928
- "4 synthetic5 Improving palliative care pathways for advance... \n",
929
- "\n",
930
- " AwardAbstract \n",
931
- "0 This clinical trial will evaluate the efficacy... \n",
932
- "1 We aim to identify the underlying genetic vari... \n",
933
- "2 This project implements a public health initia... \n",
934
- "3 Basic biological research into the molecular s... \n",
935
- "4 A study to assess the implementation of new ca... "
936
- ],
937
- "text/html": [
938
- "\n",
939
- " <div id=\"df-ed7f79b6-7871-48f0-b98c-dd5bca38f647\" class=\"colab-df-container\">\n",
940
- " <div>\n",
941
- "<style scoped>\n",
942
- " .dataframe tbody tr th:only-of-type {\n",
943
- " vertical-align: middle;\n",
944
- " }\n",
945
- "\n",
946
- " .dataframe tbody tr th {\n",
947
- " vertical-align: top;\n",
948
- " }\n",
949
- "\n",
950
- " .dataframe thead th {\n",
951
- " text-align: right;\n",
952
- " }\n",
953
- "</style>\n",
954
- "<table border=\"1\" class=\"dataframe\">\n",
955
- " <thead>\n",
956
- " <tr style=\"text-align: right;\">\n",
957
- " <th></th>\n",
958
- " <th>ID</th>\n",
959
- " <th>AwardTitle</th>\n",
960
- " <th>AwardAbstract</th>\n",
961
- " </tr>\n",
962
- " </thead>\n",
963
- " <tbody>\n",
964
- " <tr>\n",
965
- " <th>0</th>\n",
966
- " <td>synthetic1</td>\n",
967
- " <td>Phase III trial of novel immunotherapy for adv...</td>\n",
968
- " <td>This clinical trial will evaluate the efficacy...</td>\n",
969
- " </tr>\n",
970
- " <tr>\n",
971
- " <th>1</th>\n",
972
- " <td>synthetic2</td>\n",
973
- " <td>Genetic and environmental risk factors in earl...</td>\n",
974
- " <td>We aim to identify the underlying genetic vari...</td>\n",
975
- " </tr>\n",
976
- " <tr>\n",
977
- " <th>2</th>\n",
978
- " <td>synthetic3</td>\n",
979
- " <td>Community-based dietary interventions to reduc...</td>\n",
980
- " <td>This project implements a public health initia...</td>\n",
981
- " </tr>\n",
982
- " <tr>\n",
983
- " <th>3</th>\n",
984
- " <td>synthetic4</td>\n",
985
- " <td>Structural analysis of the viral envelope protein</td>\n",
986
- " <td>Basic biological research into the molecular s...</td>\n",
987
- " </tr>\n",
988
- " <tr>\n",
989
- " <th>4</th>\n",
990
- " <td>synthetic5</td>\n",
991
- " <td>Improving palliative care pathways for advance...</td>\n",
992
- " <td>A study to assess the implementation of new ca...</td>\n",
993
- " </tr>\n",
994
- " </tbody>\n",
995
- "</table>\n",
996
- "</div>\n",
997
- " <div class=\"colab-df-buttons\">\n",
998
- "\n",
999
- " <div class=\"colab-df-container\">\n",
1000
- " <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-ed7f79b6-7871-48f0-b98c-dd5bca38f647')\"\n",
1001
- " title=\"Convert this dataframe to an interactive table.\"\n",
1002
- " style=\"display:none;\">\n",
1003
- "\n",
1004
- " <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\" viewBox=\"0 -960 960 960\">\n",
1005
- " <path d=\"M120-120v-720h720v720H120Zm60-500h600v-160H180v160Zm220 220h160v-160H400v160Zm0 220h160v-160H400v160ZM180-400h160v-160H180v160Zm440 0h160v-160H620v160ZM180-180h160v-160H180v160Zm440 0h160v-160H620v160Z\"/>\n",
1006
- " </svg>\n",
1007
- " </button>\n",
1008
- "\n",
1009
- " <style>\n",
1010
- " .colab-df-container {\n",
1011
- " display:flex;\n",
1012
- " gap: 12px;\n",
1013
- " }\n",
1014
- "\n",
1015
- " .colab-df-convert {\n",
1016
- " background-color: #E8F0FE;\n",
1017
- " border: none;\n",
1018
- " border-radius: 50%;\n",
1019
- " cursor: pointer;\n",
1020
- " display: none;\n",
1021
- " fill: #1967D2;\n",
1022
- " height: 32px;\n",
1023
- " padding: 0 0 0 0;\n",
1024
- " width: 32px;\n",
1025
- " }\n",
1026
- "\n",
1027
- " .colab-df-convert:hover {\n",
1028
- " background-color: #E2EBFA;\n",
1029
- " box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
1030
- " fill: #174EA6;\n",
1031
- " }\n",
1032
- "\n",
1033
- " .colab-df-buttons div {\n",
1034
- " margin-bottom: 4px;\n",
1035
- " }\n",
1036
- "\n",
1037
- " [theme=dark] .colab-df-convert {\n",
1038
- " background-color: #3B4455;\n",
1039
- " fill: #D2E3FC;\n",
1040
- " }\n",
1041
- "\n",
1042
- " [theme=dark] .colab-df-convert:hover {\n",
1043
- " background-color: #434B5C;\n",
1044
- " box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
1045
- " filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
1046
- " fill: #FFFFFF;\n",
1047
- " }\n",
1048
- " </style>\n",
1049
- "\n",
1050
- " <script>\n",
1051
- " const buttonEl =\n",
1052
- " document.querySelector('#df-ed7f79b6-7871-48f0-b98c-dd5bca38f647 button.colab-df-convert');\n",
1053
- " buttonEl.style.display =\n",
1054
- " google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
1055
- "\n",
1056
- " async function convertToInteractive(key) {\n",
1057
- " const element = document.querySelector('#df-ed7f79b6-7871-48f0-b98c-dd5bca38f647');\n",
1058
- " const dataTable =\n",
1059
- " await google.colab.kernel.invokeFunction('convertToInteractive',\n",
1060
- " [key], {});\n",
1061
- " if (!dataTable) return;\n",
1062
- "\n",
1063
- " const docLinkHtml = 'Like what you see? Visit the ' +\n",
1064
- " '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
1065
- " + ' to learn more about interactive tables.';\n",
1066
- " element.innerHTML = '';\n",
1067
- " dataTable['output_type'] = 'display_data';\n",
1068
- " await google.colab.output.renderOutput(dataTable, element);\n",
1069
- " const docLink = document.createElement('div');\n",
1070
- " docLink.innerHTML = docLinkHtml;\n",
1071
- " element.appendChild(docLink);\n",
1072
- " }\n",
1073
- " </script>\n",
1074
- " </div>\n",
1075
- "\n",
1076
- "\n",
1077
- " </div>\n",
1078
- " </div>\n"
1079
- ],
1080
- "application/vnd.google.colaboratory.intrinsic+json": {
1081
- "type": "dataframe",
1082
- "variable_name": "test_df",
1083
- "summary": "{\n \"name\": \"test_df\",\n \"rows\": 5,\n \"fields\": [\n {\n \"column\": \"ID\",\n \"properties\": {\n \"dtype\": \"string\",\n \"num_unique_values\": 5,\n \"samples\": [\n \"synthetic2\",\n \"synthetic5\",\n \"synthetic3\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"AwardTitle\",\n \"properties\": {\n \"dtype\": \"string\",\n \"num_unique_values\": 5,\n \"samples\": [\n \"Genetic and environmental risk factors in early-onset schizophrenia\",\n \"Improving palliative care pathways for advanced dementia patients\",\n \"Community-based dietary interventions to reduce hypertension\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n },\n {\n \"column\": \"AwardAbstract\",\n \"properties\": {\n \"dtype\": \"string\",\n \"num_unique_values\": 5,\n \"samples\": [\n \"We aim to identify the underlying genetic variants and social determinants contributing to the development of schizophrenia in adolescents using a large-scale longitudinal cohort.\",\n \"A study to assess the implementation of new care guidelines in residential homes, aiming to improve the quality of life and symptom management for patients with severe dementia.\",\n \"This project implements a public health initiative focused on sodium reduction and diet modifications in urban populations to prevent cardiovascular diseases.\"\n ],\n \"semantic_type\": \"\",\n \"description\": \"\"\n }\n }\n ]\n}"
1084
- }
1085
- },
1086
- "metadata": {},
1087
- "execution_count": 1
1088
- }
1089
- ],
1090
- "source": [
1091
- "import pandas as pd\n",
1092
- "import os\n",
1093
- "from google.colab import files\n",
1094
- "\n",
1095
- "\n",
1096
- "uploaded = files.upload()\n",
1097
- "test_data = list(uploaded.keys())[0]\n",
1098
- "test_df = pd.read_csv(test_data)\n",
1099
- "test_df.head()\n"
1100
- ]
1101
- },
1102
- {
1103
- "cell_type": "code",
1104
- "source": [
1105
- "TEST_FILENAME = test_data\n",
1106
- "OUTPUT_FILENAME = r\"C:\\Users\\Nicoy.Downes\\OneDrive - LGC Group\\_User Profile\\Documents\\VSCode\\hc_rac_predictions.csv\""
1107
- ],
1108
- "metadata": {
1109
- "id": "XCQOEYvzcNgW"
1110
- },
1111
- "execution_count": 3,
1112
- "outputs": []
1113
- },
1114
- {
1115
- "cell_type": "code",
1116
- "source": [
1117
- "\"\"\"\n",
1118
- "HRCS Health Category & Research Activity Code - Inference Script\n",
1119
- "Developed by National Institute for Health and Care Research (NIHR)\n",
1120
- "\n",
1121
- "This script downloads the trained models from Hugging Face and runs\n",
1122
- "predictions locally on a provided CSV of research awards.\n",
1123
- "\n",
1124
- "Dependencies required:\n",
1125
- "pip install torch pandas numpy tqdm transformers huggingface_hub\n",
1126
- "\"\"\"\n",
1127
- "\n",
1128
- "import os, json\n",
1129
- "from typing import List\n",
1130
- "import numpy as np\n",
1131
- "import pandas as pd\n",
1132
- "from tqdm import tqdm\n",
1133
- "\n",
1134
- "import torch\n",
1135
- "from torch.utils.data import Dataset as TorchDataset, DataLoader\n",
1136
- "from torch.nn.functional import sigmoid\n",
1137
- "\n",
1138
- "from transformers import AutoTokenizer, AutoModelForSequenceClassification\n",
1139
- "from huggingface_hub import hf_hub_download\n",
1140
- "\n",
1141
- "# --- USER SETTINGS ---\n",
1142
- "# Set the path to the folder containing your CSV and the desired filenames\n",
1143
- "DATA_FOLDERS = [\"./\"] # can input more than one folder if needed\n",
1144
- "#TEST_FILENAME = r\"C:\\Users\\Nicoy.Downes\\OneDrive - LGC Group\\_User Profile\\Documents\\VSCode\\test_data.csv\"\n",
1145
- "OUTPUT_FILENAME = \"hc_rac_predictions.csv\"\n",
1146
- "\n",
1147
- "# --- HUGGING FACE REPO SETTINGS ---\n",
1148
- "HC_REPO_ID = \"NIHRDataInsights/HRCSHealthCategories\"\n",
1149
- "RAC_REPO_ID = \"NIHRDataInsights/HRCSResearchActivityCodes\"\n",
1150
- "\n",
1151
- "BATCH_SIZE = 16 # Adjust based on your hardware (use 64+ for GPUs, 2-16 for CPUs)\n",
1152
- "MAX_LEN = 512\n",
1153
- "USE_GPU = torch.cuda.is_available()\n",
1154
- "\n",
1155
- "# -------------------- DATASET --------------------\n",
1156
- "class TextDataset(TorchDataset):\n",
1157
- " \"\"\"\n",
1158
- " This section takes your spreadsheet and packages it so the AI can read it.\n",
1159
- " It combines the 'AwardTitle' and 'AwardAbstract' together into 'text' (to be categorised).\n",
1160
- " \"\"\"\n",
1161
- " def __init__(self, df: pd.DataFrame, tokenizer, max_len: int = 512):\n",
1162
- " df = df.copy()\n",
1163
- " df[\"text\"] = (\n",
1164
- " df[\"AwardTitle\"].fillna(\"\").astype(str) + \"\\n\" +\n",
1165
- " df[\"AwardAbstract\"].fillna(\"\").astype(str)\n",
1166
- " )\n",
1167
- " self.texts = df[\"text\"].tolist()\n",
1168
- " self.tokenizer = tokenizer\n",
1169
- " self.max_len = max_len\n",
1170
- "\n",
1171
- " def __len__(self):\n",
1172
- " return len(self.texts)\n",
1173
- "\n",
1174
- " def __getitem__(self, idx):\n",
1175
- " enc = self.tokenizer(\n",
1176
- " self.texts[idx],\n",
1177
- " truncation=True, # cuts if too long\n",
1178
- " padding=\"max_length\", # adds spaces if under limit\n",
1179
- " max_length=self.max_len,\n",
1180
- " return_tensors=\"pt\"\n",
1181
- " )\n",
1182
- " return {k: v.squeeze(0) for k, v in enc.items()}\n",
1183
- "\n",
1184
- "# -------------------- MODEL LOADING --------------------\n",
1185
- "def load_hf_model(repo_id: str):\n",
1186
- " \"\"\"\n",
1187
- " This section downloads the models from hugging face and the optimal thresholds.\n",
1188
- " \"\"\"\n",
1189
- " print(f\"Loading model and tokenizer from: {repo_id}...\")\n",
1190
- "\n",
1191
- " tok = AutoTokenizer.from_pretrained(repo_id, use_fast=True)\n",
1192
- " model = AutoModelForSequenceClassification.from_pretrained(repo_id)\n",
1193
- "\n",
1194
- " model.eval()\n",
1195
- " device = torch.device(\"cuda\" if USE_GPU else \"cpu\")\n",
1196
- " model.to(device)\n",
1197
- "\n",
1198
- " # Download custom metadata.json from the repo for specific threshold values\n",
1199
- " try:\n",
1200
- " meta_path = hf_hub_download(repo_id=repo_id, filename=\"metadata.json\")\n",
1201
- " with open(meta_path, \"r\") as f:\n",
1202
- " meta = json.load(f)\n",
1203
- " labels = meta.get(\"labels\")\n",
1204
- " thresholds = np.array(meta.get(\"optimal_thresholds\", []), dtype=float)\n",
1205
- " except Exception as e:\n",
1206
- " print(f\"Warning: Could not load metadata.json from {repo_id}. {e}\")\n",
1207
- " labels, thresholds = None, None\n",
1208
- "\n",
1209
- " if labels is None:\n",
1210
- " num_labels = int(model.config.num_labels)\n",
1211
- " labels = [model.config.id2label.get(i, f\"L{i}\") for i in range(num_labels)]\n",
1212
- "\n",
1213
- " if thresholds is None or len(thresholds) != len(labels):\n",
1214
- " print(f\"Using default 0.5 threshold for {repo_id}\")\n",
1215
- " thresholds = np.full(len(labels), 0.5, dtype=float)\n",
1216
- "\n",
1217
- " return tok, model, labels, thresholds\n",
1218
- "\n",
1219
- "# -------------------- INFERENCE & BUILD --------------------\n",
1220
- "def predict(model, dataloader: DataLoader, thresholds: np.ndarray):\n",
1221
- " device = next(model.parameters()).device\n",
1222
- " all_logits = []\n",
1223
- "\n",
1224
- " with torch.no_grad():\n",
1225
- " for batch in tqdm(dataloader, desc=\"Inferencing\"):\n",
1226
- " inputs = {k: v.to(device) for k, v in batch.items()}\n",
1227
- " logits = model(**inputs).logits\n",
1228
- " all_logits.append(logits.cpu())\n",
1229
- "\n",
1230
- " logits = torch.cat(all_logits, dim=0).numpy()\n",
1231
- " probs = sigmoid(torch.tensor(logits)).numpy()\n",
1232
- " preds = (probs > thresholds.reshape(1, -1)).astype(int)\n",
1233
- " return logits, probs, preds\n",
1234
- "\n",
1235
- "def build_outputs(df, logits, probs, preds, label_list, thresholds, prefix: str):\n",
1236
- " \"\"\"\n",
1237
- " This section builds the final columns for your spreadsheet.\n",
1238
- " Crucially, it calculates the 'Smallest Logit Diff'.\n",
1239
- " This can be considered an 'AI Certainty Score'. The closer this number is to 0,\n",
1240
- " the more confused the AI was. Please see the Hugging Face READMEs for each model which contains details on the impact on performance.\n",
1241
- " \"\"\"\n",
1242
- " # Clip thresholds to avoid log(0) warnings\n",
1243
- " safe_thresholds = np.clip(thresholds, 1e-7, 1 - 1e-7)\n",
1244
- " thr_logits = np.log(safe_thresholds / (1 - safe_thresholds))\n",
1245
- "\n",
1246
- " rows = {}\n",
1247
- " for i in range(len(df)):\n",
1248
- " logit_row = logits[i]\n",
1249
- " pred_row = preds[i]\n",
1250
- "\n",
1251
- " # gather categories labelled as 'yes'\n",
1252
- " chosen = [label_list[j] for j, val in enumerate(pred_row) if val == 1]\n",
1253
- "\n",
1254
- " # Calculate the uncertainty score\n",
1255
- " smallest_diff = float(np.min(np.abs(logit_row - thr_logits)))\n",
1256
- " rows[i] = {\n",
1257
- " f\"{prefix}_PredictedLabels\": \";\".join(chosen),\n",
1258
- " f\"{prefix}_SmallestLogitDiff\": smallest_diff\n",
1259
- " }\n",
1260
- " return rows\n",
1261
- "\n",
1262
- "# -------------------- PROCESS FOLDER --------------------\n",
1263
- "def run_folder(data_folder: str, hc_assets, rac_assets):\n",
1264
- " \"\"\"\n",
1265
- " This function opens your specific folder (selected in the first section), reads your CSV, passes it through\n",
1266
- " both AI models (Health Categories and Research Activity Codes), and saves the result.\n",
1267
- " \"\"\"\n",
1268
- " test_file = os.path.join(data_folder, TEST_FILENAME)\n",
1269
- " output_file = os.path.join(data_folder, OUTPUT_FILENAME)\n",
1270
- "\n",
1271
- " if not os.path.exists(test_file):\n",
1272
- " print(f\"Skipping: {test_file} not found in {data_folder}.\")\n",
1273
- " return\n",
1274
- "\n",
1275
- " print(f\"\\n=== Processing folder: {data_folder} ===\")\n",
1276
- " df = pd.read_csv(test_data).reset_index(drop=True)\n",
1277
- "\n",
1278
- " # Validate Input Data contains title and abstract\n",
1279
- " if \"AwardTitle\" not in df.columns or \"AwardAbstract\" not in df.columns:\n",
1280
- " print(f\"ERROR: '{TEST_FILENAME}' must contain 'AwardTitle' and 'AwardAbstract' columns. Skipping.\")\n",
1281
- " return\n",
1282
- "\n",
1283
- " # HC Inference\n",
1284
- " hc_tok, hc_model, hc_labels, hc_thr = hc_assets\n",
1285
- " hc_loader = DataLoader(TextDataset(df, hc_tok), batch_size=BATCH_SIZE)\n",
1286
- " hc_logits, _, hc_preds = predict(hc_model, hc_loader, hc_thr)\n",
1287
- " hc_rows = build_outputs(df, hc_logits, None, hc_preds, hc_labels, hc_thr, \"HC\")\n",
1288
- "\n",
1289
- " # RAC Inference\n",
1290
- " rac_tok, rac_model, rac_labels, rac_thr = rac_assets\n",
1291
- " rac_loader = DataLoader(TextDataset(df, rac_tok), batch_size=BATCH_SIZE)\n",
1292
- " rac_logits, _, rac_preds = predict(rac_model, rac_loader, rac_thr)\n",
1293
- " rac_rows = build_outputs(df, rac_logits, None, rac_preds, rac_labels, rac_thr, \"RAC\")\n",
1294
- "\n",
1295
- " # Construct final dataframe\n",
1296
- " output_data = []\n",
1297
- " for i in range(len(df)):\n",
1298
- " row = {\n",
1299
- " \"ID\": df.loc[i, \"ID\"] if \"ID\" in df.columns else i,\n",
1300
- " \"FunderAcronym\": df.loc[i, \"FunderAcronym\"] if \"FunderAcronym\" in df.columns else \"\",\n",
1301
- " \"AwardTitle\": df.loc[i, \"AwardTitle\"],\n",
1302
- " \"AwardAbstract\": df.loc[i, \"AwardAbstract\"],\n",
1303
- " **hc_rows[i],\n",
1304
- " **rac_rows[i]\n",
1305
- " }\n",
1306
- " output_data.append(row)\n",
1307
- "\n",
1308
- " pd.DataFrame(output_data).to_csv(OUTPUT_FILENAME, index=False)\n",
1309
- " files.download(OUTPUT_FILENAME)\n",
1310
- " print(f\"Success! Predictions saved to: {output_file}\")\n",
1311
- "\n",
1312
- "# -------------------- MAIN --------------------\n",
1313
- "if __name__ == \"__main__\":\n",
1314
- " print(\"Initialising HRCS Classifier Pipeline...\")\n",
1315
- "\n",
1316
- " # Load models once outside the loop to save memory/time\n",
1317
- " hc_assets = load_hf_model(HC_REPO_ID)\n",
1318
- " rac_assets = load_hf_model(RAC_REPO_ID)\n",
1319
- "\n",
1320
- " for folder in DATA_FOLDERS:\n",
1321
- " run_folder(folder, hc_assets, rac_assets)"
1322
- ],
1323
- "metadata": {
1324
- "colab": {
1325
- "base_uri": "https://localhost:8080/",
1326
- "height": 220,
1327
- "referenced_widgets": [
1328
- "d6d7368c1c7b4a549d9a706ee1a62785",
1329
- "0e22f97db509402a9b4d4a2fa8c80fe0",
1330
- "475728820f8e4b6fa0c33a1e9989179a",
1331
- "cf5e926c8c004525b5aad36d821ef928",
1332
- "9275c5cf7b314dc69419611128cb0c09",
1333
- "926c9a0562f3455b958c71bd0a138558",
1334
- "ef276c365d604dcc83b4a0133c6b7652",
1335
- "3dc315e087f44330badbb6749b0611a9",
1336
- "bc6c346b95424eb881b793bf6fd2f53e",
1337
- "c26db3dd5a86405b81f510fd759c3af7",
1338
- "dae2b5afeee243e682322c1ecb8238e6",
1339
- "f66e121a24e441c2a4cb32fb87d22632",
1340
- "1776c0895e7f400f85d3a2123cb573f1",
1341
- "e7fd82b589eb49aaba73aeafcc340421",
1342
- "210ec65a874a4f51ab9f059d3de804a9",
1343
- "c91d89b903ce4f02b635454d3bd87877",
1344
- "9cba2384346645499516a7923ad6fcef",
1345
- "491f71a8889c4c7aa0042e99f374be79",
1346
- "9f60072613d84754a0efe7813d853565",
1347
- "bf48a3e692ae4e6da5379152f51baa09",
1348
- "25529d785e9047e1b9bd197a7ec56773",
1349
- "9dda1a77133c4247bac106a9c4699153"
1350
- ]
1351
- },
1352
- "id": "Sm-TvQo5chAc",
1353
- "outputId": "031ba45c-fc02-4ee6-ddd4-eeca67c6f92e"
1354
- },
1355
- "execution_count": 6,
1356
- "outputs": [
1357
- {
1358
- "output_type": "stream",
1359
- "name": "stdout",
1360
- "text": [
1361
- "Initialising HRCS Classifier Pipeline...\n",
1362
- "Loading model and tokenizer from: NIHRDataInsights/HRCSHealthCategories...\n"
1363
- ]
1364
- },
1365
- {
1366
- "output_type": "display_data",
1367
- "data": {
1368
- "text/plain": [
1369
- "Loading weights: 0%| | 0/393 [00:00<?, ?it/s]"
1370
- ],
1371
- "application/vnd.jupyter.widget-view+json": {
1372
- "version_major": 2,
1373
- "version_minor": 0,
1374
- "model_id": "d6d7368c1c7b4a549d9a706ee1a62785"
1375
- }
1376
- },
1377
- "metadata": {}
1378
- },
1379
- {
1380
- "output_type": "stream",
1381
- "name": "stdout",
1382
- "text": [
1383
- "Loading model and tokenizer from: NIHRDataInsights/HRCSResearchActivityCodes...\n"
1384
- ]
1385
- },
1386
- {
1387
- "output_type": "display_data",
1388
- "data": {
1389
- "text/plain": [
1390
- "Loading weights: 0%| | 0/393 [00:00<?, ?it/s]"
1391
- ],
1392
- "application/vnd.jupyter.widget-view+json": {
1393
- "version_major": 2,
1394
- "version_minor": 0,
1395
- "model_id": "f66e121a24e441c2a4cb32fb87d22632"
1396
- }
1397
- },
1398
- "metadata": {}
1399
- },
1400
- {
1401
- "output_type": "stream",
1402
- "name": "stdout",
1403
- "text": [
1404
- "\n",
1405
- "=== Processing folder: ./ ===\n"
1406
- ]
1407
- },
1408
- {
1409
- "output_type": "stream",
1410
- "name": "stderr",
1411
- "text": [
1412
- "Inferencing: 100%|██████████| 1/1 [00:30<00:00, 30.99s/it]\n",
1413
- "Inferencing: 100%|██████████| 1/1 [00:29<00:00, 29.16s/it]\n"
1414
- ]
1415
- },
1416
- {
1417
- "output_type": "display_data",
1418
- "data": {
1419
- "text/plain": [
1420
- "<IPython.core.display.Javascript object>"
1421
- ],
1422
- "application/javascript": [
1423
- "\n",
1424
- " async function download(id, filename, size) {\n",
1425
- " if (!google.colab.kernel.accessAllowed) {\n",
1426
- " return;\n",
1427
- " }\n",
1428
- " const div = document.createElement('div');\n",
1429
- " const label = document.createElement('label');\n",
1430
- " label.textContent = `Downloading \"${filename}\": `;\n",
1431
- " div.appendChild(label);\n",
1432
- " const progress = document.createElement('progress');\n",
1433
- " progress.max = size;\n",
1434
- " div.appendChild(progress);\n",
1435
- " document.body.appendChild(div);\n",
1436
- "\n",
1437
- " const buffers = [];\n",
1438
- " let downloaded = 0;\n",
1439
- "\n",
1440
- " const channel = await google.colab.kernel.comms.open(id);\n",
1441
- " // Send a message to notify the kernel that we're ready.\n",
1442
- " channel.send({})\n",
1443
- "\n",
1444
- " for await (const message of channel.messages) {\n",
1445
- " // Send a message to notify the kernel that we're ready.\n",
1446
- " channel.send({})\n",
1447
- " if (message.buffers) {\n",
1448
- " for (const buffer of message.buffers) {\n",
1449
- " buffers.push(buffer);\n",
1450
- " downloaded += buffer.byteLength;\n",
1451
- " progress.value = downloaded;\n",
1452
- " }\n",
1453
- " }\n",
1454
- " }\n",
1455
- " const blob = new Blob(buffers, {type: 'application/binary'});\n",
1456
- " const a = document.createElement('a');\n",
1457
- " a.href = window.URL.createObjectURL(blob);\n",
1458
- " a.download = filename;\n",
1459
- " div.appendChild(a);\n",
1460
- " a.click();\n",
1461
- " div.remove();\n",
1462
- " }\n",
1463
- " "
1464
- ]
1465
- },
1466
- "metadata": {}
1467
- },
1468
- {
1469
- "output_type": "display_data",
1470
- "data": {
1471
- "text/plain": [
1472
- "<IPython.core.display.Javascript object>"
1473
- ],
1474
- "application/javascript": [
1475
- "download(\"download_d84f50b6-a57a-4840-866b-dc789f372d10\", \"hc_rac_predictions.csv\", 1874)"
1476
- ]
1477
- },
1478
- "metadata": {}
1479
- },
1480
- {
1481
- "output_type": "stream",
1482
- "name": "stdout",
1483
- "text": [
1484
- "Success! Predictions saved to: ./hc_rac_predictions.csv\n"
1485
- ]
1486
- }
1487
- ]
1488
- }
1489
- ]
1490
- }