fulviodeo commited on
Commit
da2daa6
·
1 Parent(s): 99f68cd

voila deployment

Browse files
.DS_Store ADDED
Binary file (6.15 kB). View file
 
.github/workflows/update-hf.yml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: Update Hugging Face repository
2
+ on:
3
+ push:
4
+ branches: [main]
5
+
6
+ workflow_dispatch:
7
+
8
+ jobs:
9
+ push-to-hf:
10
+ runs-on: ubuntu-latest
11
+ steps:
12
+ - uses: actions/checkout@v3
13
+ with:
14
+ fetch-depth: 0
15
+ lfs: true
16
+ - name: Push to Hugging Face
17
+ env:
18
+ HF_USER: ${{ secrets.HF_USER }}
19
+ HF_TOKEN: ${{ secrets.HF_TOKEN }}
20
+ run: git push https://$HF_USER:$HF_TOKEN@huggingface.co/spaces/ncdrisc/AI-Literature-Screening-for-Population-Health main --force
environment.yml ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: voila
2
+ channels:
3
+ - conda-forge
4
+ dependencies:
5
+ - python=3.10
6
+ - pip
7
+ - pip:
8
+ - anyio==4.9.0
9
+ - argon2-cffi==23.1.0
10
+ - argon2-cffi-bindings==21.2.0
11
+ - arrow==1.3.0
12
+ - attrs==25.3.0
13
+ - babel==2.17.0
14
+ - beautifulsoup4==4.13.3
15
+ - biopython==1.85
16
+ - bleach==6.2.0
17
+ - certifi==2025.1.31
18
+ - cffi==1.17.1
19
+ - charset-normalizer==3.4.1
20
+ - comm==0.2.2
21
+ - debugpy==1.8.13
22
+ - decorator==5.2.1
23
+ - defusedxml==0.7.1
24
+ - fastjsonschema==2.21.1
25
+ - filelock==3.18.0
26
+ - fsspec==2025.3.2
27
+ - huggingface-hub==0.30.1
28
+ - idna==3.10
29
+ - ipykernel==6.29.5
30
+ - ipython==8.35.0
31
+ - ipywidgets==8.1.5
32
+ - jedi==0.19.2
33
+ - Jinja2==3.1.6
34
+ - jsonschema==4.23.0
35
+ - jupyter-events==0.12.0
36
+ - jupyter_client==8.6.3
37
+ - jupyter_core==5.7.2
38
+ - jupyter_server==2.15.0
39
+ - jupyter_server_terminals==0.5.3
40
+ - jupyterlab_pygments==0.3.0
41
+ - jupyterlab_server==2.27.3
42
+ - jupyterlab_widgets==3.0.13
43
+ - MarkupSafe==3.0.2
44
+ - mistune==3.1.3
45
+ - nbclient==0.10.2
46
+ - nbconvert==7.16.6
47
+ - nbformat==5.10.4
48
+ - nest-asyncio==1.6.0
49
+ - networkx==3.4.2
50
+ - numpy==1.26.4
51
+ - packaging==24.2
52
+ - pandas==2.2.3
53
+ - platformdirs==4.3.7
54
+ - prometheus_client==0.21.1
55
+ - prompt_toolkit==3.0.50
56
+ - psutil==7.0.0
57
+ - pure_eval==0.2.3
58
+ - pycparser==2.22
59
+ - Pygments==2.19.1
60
+ - python-dateutil==2.9.0.post0
61
+ - python-json-logger==3.3.0
62
+ - pytz==2025.2
63
+ - PyYAML==6.0.2
64
+ - pyzmq==26.4.0
65
+ - referencing==0.36.2
66
+ - regex==2024.11.6
67
+ - requests==2.32.3
68
+ - rpds-py==0.24.0
69
+ - safetensors==0.5.3
70
+ - six==1.17.0
71
+ - sniffio==1.3.1
72
+ - soupsieve==2.6
73
+ - sympy==1.13.3
74
+ - terminado==0.18.1
75
+ - tinycss2==1.4.0
76
+ - tokenizers==0.15.2
77
+ - torch==2.2.2
78
+ - tornado==6.4.2
79
+ - tqdm==4.67.1
80
+ - traitlets==5.14.3
81
+ - transformers==4.38.2
82
+ - typing_extensions==4.13.1
83
+ - tzdata==2025.2
84
+ - urllib3==2.3.0
85
+ - voila==0.5.8
86
+ - wcwidth==0.2.13
87
+ - webencodings==0.5.1
88
+ - websocket-client==1.8.0
89
+ - websockets==15.0.1
90
+ - widgetsnbextension==4.0.13
notebooks/.DS_Store ADDED
Binary file (6.15 kB). View file
 
notebooks/notebook.ipynb ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "metadata": {
5
+ "ExecuteTime": {
6
+ "end_time": "2025-04-10T15:52:17.264471Z",
7
+ "start_time": "2025-04-10T15:52:17.156072Z"
8
+ }
9
+ },
10
+ "cell_type": "code",
11
+ "source": "import subprocess\nfrom datetime import date\nfrom dateutil.relativedelta import relativedelta\nimport os\nimport ipywidgets as widgets\nfrom IPython.display import HTML, clear_output\nimport warnings\nwarnings.filterwarnings(\"ignore\")\n\n# --- UI Elements ---\n\n# Dropdown for risk factor selection\nrisk_factor = widgets.Dropdown(\n options=['Anthro', 'BP', 'Lipids', 'Diabetes'],\n value='Anthro',\n description='',\n disabled=False\n)\n\nquery_method = widgets.RadioButtons(\n options=[\"default\", \"custom\"],\n value=\"default\",\n layout=widgets.Layout(width=\"180px\")\n)\n\n# Custom query input field\ncustom_query_file = widgets.Text(\n placeholder=\"Enter query filename\",\n description=\"\",\n disabled=True,\n layout=widgets.Layout(width=\"300px\")\n)\n\n# Enable text box only when 'custom' is selected\ndef toggle_custom_query(change):\n custom_query_file.disabled = (change.new != \"custom\")\n\nquery_method.observe(toggle_custom_query, names=\"value\")\n\nstart_date_value = date.today() - relativedelta(months=1)\nend_date_value = date.today()\n\nstart_date = widgets.DatePicker(\n description=\"Start date: \",\n value=start_date_value\n)\n\nend_date = widgets.DatePicker(\n description=\"End Date: \",\n value=end_date_value\n)\n\nrecall_target = widgets.Dropdown(\n options=[('95%', 95), ('90%', 90), ('80%', 80), ('70%', 70), ('60%', 60)],\n value=95,\n description='',\n layout=widgets.Layout(width=\"120px\")\n)\n\n# Buttons\n\n# Button style\ndisplay(HTML(\"\"\"\n<style>\n.widget-button {\n justify-content: flex-start !important;\n text-align: left !important;\n font-weight: bold !important;\n font-size: 15px !important;\n padding-left: 12px !important;\n}\n</style>\n\"\"\"))\n\ndownload_btn = widgets.Button(\n description=\"Download abstracts from PubMed\",\n layout=widgets.Layout(width=\"300px\"),\n style=widgets.ButtonStyle(font_weight=\"bold\")\n)\n\nload_model_btn = widgets.Button(\n description=\"Load the model\",\n layout=widgets.Layout(width=\"300px\"),\n style=widgets.ButtonStyle(font_weight=\"bold\")\n)\n\npredict_btn = widgets.Button(\n description=\"Run the model and screen articles\",\n layout=widgets.Layout(width=\"300px\"),\n style=widgets.ButtonStyle(font_weight=\"bold\")\n)\n\n# Output areas\ndownload_output = widgets.Output()\nload_model_output = widgets.Output()\npredict_output = widgets.Output()\n\n# --- Helper function for UI feedback ---\ndef mark_done(output_placeholder, header_widget):\n output_placeholder.clear_output()\n with output_placeholder:\n display(header_widget)\n display(widgets.HTML(\n \"<span style='color: green; font-weight: bold; font-size:15px; margin-left:12px;'>&#10004; Done</span>\"\n ))\n\n# --- Event Handlers ---\n\ndef toggle_query_visibility(change):\n \"\"\"Enable/Disable custom query input based on checkbox state.\"\"\"\n custom_query_file.disabled = not change.new\n\nquery_method.observe(toggle_custom_query, names=\"value\")\n\ndef run_download_citations(b):\n b.close()\n with download_output:\n clear_output(wait=True)\n header = widgets.HTML(\n \"<div style='font-size:15px; font-weight:bold; margin-left:12px;'>\"\n \"Download abstracts from PubMed\"\n \"</div>\"\n )\n display(header)\n try:\n print(f\" Downloading articles published between {start_date.value} and {end_date.value}...\")\n with open(os.path.abspath(os.path.join(os.getcwd(), \"src/download_citations.py\"))) as f:\n exec(f.read(), globals())\n mark_done(download_output, header)\n path_to_articles = os.path.join(globals().get('directory'), 'downloaded_articles.csv')\n display(widgets.HTML(\n f\"<p style='font-size:12px; color:gray; margin-top:6px; margin-left:12px; margin-bottom:0px;'>\"\n f\"Article abstracts downloaded to: {path_to_articles}</p>\"\n ))\n except Exception:\n import traceback\n traceback.print_exc()\n\ndef run_load_model(b):\n b.close()\n with load_model_output:\n clear_output(wait=True)\n header = widgets.HTML(\n \"<div style='font-size:15px; font-weight:bold; margin-left:12px;'>\"\n \"Load the model\"\n \"</div>\"\n )\n display(header)\n try:\n print(f\" Loading the model...\")\n with open(os.path.abspath(os.path.join(os.getcwd(), \"src/load_model.py\"))) as f:\n exec(f.read(), globals())\n mark_done(load_model_output, header)\n except Exception:\n import traceback\n traceback.print_exc()\n\ndef open_ris_file(b):\n path = os.path.join(globals().get('directory'), 'articles_to_review.ris')\n \"\"\"Opens the .ris file using the system's default application correctly.\"\"\"\n try:\n if os.name == \"posix\": # macOS / Linux\n subprocess.Popen([\"open\", path])\n elif os.name == \"nt\": # Windows\n subprocess.Popen([\"start\", \"\", path], shell=True)\n except Exception:\n import traceback\n traceback.print_exc()\n\ndef run_prediction(b):\n b.close()\n with predict_output:\n clear_output(wait=True)\n header = widgets.HTML(\n \"<div style='font-size:15px; font-weight:bold; margin-left:12px;'>\"\n \"Run the model and screen articles\"\n \"</div>\"\n )\n display(header)\n try:\n print(\" Running the model...\")\n with open(os.path.abspath(os.path.join(os.getcwd(), \"src/predict.py\"))) as f:\n exec(f.read(), globals())\n mark_done(predict_output, header)\n except Exception:\n import traceback\n traceback.print_exc()\n\n path_to_ris = os.path.join(globals().get('directory'), 'articles_to_review.ris')\n\n open_file_btn = widgets.Button(description=f\"📄 Open in EndNote\", layout=widgets.Layout(width=\"350px\"))\n open_file_btn.on_click(open_ris_file)\n\n path_display = widgets.HTML(\n f\"<p style='font-size:12px; color:gray; margin-top:6px; margin-left:12px;'>\"\n f\"Path to the EndNote file: {path_to_ris}</p>\"\n )\n final_message_1 = widgets.HTML(\n \"<p style='font-size:14px; color:black; margin-top:10px; margin-left:12px;'>\"\n \"Open the .ris file in EndNote by clicking on the button above or navigating to the file</p>\"\n )\n final_message_2 = widgets.HTML(\n \"<p style='font-size:14px; color:black; margin-top:10px; margin-left:12px;'>\"\n \"Select RefMan - RIS as the input file format</p>\"\n )\n\n with predict_output:\n display(widgets.VBox([\n open_file_btn,\n path_display,\n final_message_1,\n final_message_2\n ]))\n\n\n# Attach event handlers to buttons\ndownload_btn.on_click(run_download_citations)\nload_model_btn.on_click(run_load_model)\npredict_btn.on_click(run_prediction)\n\n# --- Layout ---\n\ntitle_style = \"font-size: 22px; font-weight: bold; margin-bottom: 40px;\"\nsection_style = \"font-size: 18px; font-weight: bold; margin-bottom: 15px;\"\ntext_style = \"font-size: 16px;\"\nspacing = widgets.Layout(margin=\"20px 0px\")\n\n# Titles\nheader = widgets.HTML(\n f\"<h2 style='margin-left:12px; {title_style}'>Automated screening of the literature</h2>\"\n)\n\nsection1 = widgets.VBox([\n\n widgets.HTML(\"<b style='font-size:16px;'>Choose a risk factor</b>\"),\n risk_factor,\n widgets.HTML(\"<div style='height:25px;'></div>\"),\n\n widgets.HTML(\"<b style='font-size:16px;'>Define the PubMed search</b>\"),\n widgets.HBox([\n widgets.VBox([\n start_date,\n end_date\n ], layout=widgets.Layout(margin=\"0 20px 0 0\")),\n\n widgets.VBox([\n widgets.HBox([\n widgets.Label(\"Query: \", layout=widgets.Layout(width=\"60px\")),\n query_method\n ]),\n custom_query_file\n ], layout=widgets.Layout(margin=\"0 0 0 20px\"))\n ], layout=widgets.Layout(justify_content=\"flex-start\", gap=\"20px\", margin=\"10px 0\"))\n], layout=widgets.Layout(margin=\"0 0 0 12px\"))\n\nsection2 = widgets.VBox([\n # Line 3 — vertically stacked buttons with equal width\n widgets.VBox([\n download_btn,\n download_output,\n widgets.HTML(\"<div style='height:25px;'></div>\"),\n\n load_model_btn,\n load_model_output,\n widgets.HTML(\"<div style='height:25px;'></div>\"),\n\n widgets.HTML(\"<b style='font-size:16px;'>Define how inclusive the model should be</b>\"),\n widgets.HTML(\"<span style='font-size:13px; color:gray; margin-top:2px; display:block;'>Based on the recall achieved in previous testing; the higher the recall, the more inclusive the model</span>\"),\n widgets.HTML(\"<div style='height:8px;'></div>\"),\n recall_target,\n widgets.HTML(\"<div style='height:15px;'></div>\"),\n\n predict_btn,\n predict_output\n ], layout=widgets.Layout(margin=\"30px 0\"))\n])\n\n\ndisplay(header, section1, section2)",
12
+ "id": "35bcc7331ef5d1dc",
13
+ "outputs": [],
14
+ "execution_count": null
15
+ },
16
+ {
17
+ "metadata": {},
18
+ "cell_type": "code",
19
+ "outputs": [],
20
+ "execution_count": null,
21
+ "source": "",
22
+ "id": "eeebfa287ae99109"
23
+ }
24
+ ],
25
+ "metadata": {
26
+ "kernelspec": {
27
+ "display_name": "Python 3 (ipykernel)",
28
+ "language": "python",
29
+ "name": "python3"
30
+ },
31
+ "language_info": {
32
+ "codemirror_mode": {
33
+ "name": "ipython",
34
+ "version": 3
35
+ },
36
+ "file_extension": ".py",
37
+ "mimetype": "text/x-python",
38
+ "name": "python",
39
+ "nbconvert_exporter": "python",
40
+ "pygments_lexer": "ipython3",
41
+ "version": "3.12.9"
42
+ },
43
+ "trusted": true
44
+ },
45
+ "nbformat": 4,
46
+ "nbformat_minor": 5
47
+ }
notebooks/src/.DS_Store ADDED
Binary file (6.15 kB). View file
 
notebooks/src/download_citations.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from datetime import date
3
+ from Bio import Entrez
4
+ import pandas as pd
5
+
6
+ risk_factor = globals().get('risk_factor').value
7
+
8
+ risk_factor_directory = os.path.abspath(os.path.join(os.getcwd(), risk_factor))
9
+
10
+ queries_directory = os.path.abspath(os.path.join(os.getcwd(), 'queries', risk_factor))
11
+ use_default_query = globals().get('query_method').value == "default"
12
+ if use_default_query:
13
+ with open(os.path.join(queries_directory, 'default.txt'), 'r') as file:
14
+ base_query = file.read()
15
+ query_suffix = ""
16
+ else:
17
+ custom_query_file = globals().get('custom_query_file').value
18
+ with open(os.path.join(queries_directory, f'{custom_query_file}.txt'), 'r') as file:
19
+ base_query = file.read()
20
+ query_suffix = f"_{custom_query_file}"
21
+
22
+ base_folder_name = date.today().strftime("%Y-%m-%d") + query_suffix
23
+ if not os.path.isdir(os.path.join(risk_factor_directory, base_folder_name)):
24
+ directory = os.path.join(risk_factor_directory, base_folder_name)
25
+ os.makedirs(directory, exist_ok=True)
26
+ else:
27
+ version = 2
28
+ folder_name = f"{base_folder_name}-v{version}"
29
+ directory = os.path.join(risk_factor_directory, folder_name)
30
+ while os.path.isdir(directory):
31
+ version += 1
32
+ folder_name = f"{base_folder_name}-v{version}"
33
+ directory = os.path.join(risk_factor_directory, folder_name)
34
+ os.makedirs(directory, exist_ok=True)
35
+
36
+ start_date = globals().get('start_date').value
37
+ end_date = globals().get('end_date').value
38
+ query = base_query + f'AND (("{start_date}"[Date - Publication] : "{end_date}"[Date - Publication]))'
39
+
40
+ Entrez.email = os.getenv('email')
41
+ search_handle = Entrez.esearch(db="pubmed", term=query, retmax=10000)
42
+ search_results = Entrez.read(search_handle)
43
+ search_handle.close()
44
+ id_list = search_results['IdList']
45
+ fetch_handle = Entrez.efetch(db="pubmed", id=id_list, rettype="xml")
46
+ fetch_results = Entrez.read(fetch_handle)
47
+ fetch_handle.close()
48
+
49
+ papers = []
50
+ for article in fetch_results['PubmedArticle']:
51
+ medline = article['MedlineCitation']
52
+ article_data = medline['Article']
53
+
54
+ title = str(article_data.get('ArticleTitle', ''))
55
+ abstract_text = ' '.join(article_data.get('Abstract', {}).get('AbstractText', ['']))
56
+ abstract_text = (title + ' ' + abstract_text).strip()
57
+ authors = ', '.join(['{} {}'.format(a.get('ForeName', ''), a.get('LastName', ''))
58
+ for a in article_data.get('AuthorList', []) if 'LastName' in a])
59
+ journal = article_data.get('Journal', {}).get('Title', '')
60
+ year = article_data.get('Journal', {}).get('JournalIssue', {}).get('PubDate', {}).get('Year', '')
61
+ pmid = medline.get('PMID', '')
62
+
63
+ doi = ''
64
+ for eid in article_data.get('ELocationID', []):
65
+ if eid.attributes.get('EIdType') == 'doi':
66
+ doi = str(eid)
67
+
68
+ papers.append({
69
+ 'PMID': pmid,
70
+ 'Title': title,
71
+ 'Abstract': abstract_text,
72
+ 'Authors': authors,
73
+ 'Journal': journal,
74
+ 'Year': year,
75
+ 'DOI': doi
76
+ })
77
+
78
+ df = pd.DataFrame(papers)
79
+ df.to_csv(os.path.join(directory, 'downloaded_articles.csv'), index=False)
notebooks/src/handlers/.DS_Store ADDED
Binary file (6.15 kB). View file
 
notebooks/src/handlers/hyperparameters_handler.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class Hyperparameters:
2
+
3
+ def __init__(self,
4
+ ds,
5
+ input_col,
6
+ output_col,
7
+ test_size,
8
+ seed,
9
+ pre_trained_model,
10
+ max_length=512,
11
+ freezed_layers=0,
12
+ batch_size=8,
13
+ learning_rate=0.00003,
14
+ max_epochs=None,
15
+ stop_loss_epochs=5
16
+ ):
17
+ self.ds = ds
18
+ self.input_col = input_col
19
+ self.output_col = output_col
20
+ self.test_size = test_size
21
+ self.seed = seed
22
+ self.pre_trained_model = pre_trained_model
23
+ self.max_length = max_length
24
+ self.freezed_layers = freezed_layers
25
+ self.batch_size = batch_size
26
+ self.learning_rate = learning_rate
27
+ self.max_epochs = max_epochs
28
+ if self.max_epochs is None:
29
+ self.stop_loss_epochs = stop_loss_epochs
notebooks/src/handlers/model_IO_handler.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import pandas as pd
4
+ from src.utils import file_handling
5
+
6
+
7
+ class IOHandler:
8
+
9
+ def __init__(self, directory, model_dir_path):
10
+ self.directory = directory
11
+ self.model_path = os.path.join(model_dir_path, 'model.pth')
12
+ prc_path = os.path.join(model_dir_path, 'precision_recall_curve.csv')
13
+ self._prc = pd.read_csv(prc_path)
14
+
15
+ def get_threshold(self, recall_target):
16
+ """Return the highest threshold that still achieves at least recall_target recall."""
17
+ above = self._prc[self._prc['Recall'] >= recall_target]
18
+ if above.empty:
19
+ return float(self._prc['Threshold'].iloc[-1])
20
+ return float(above['Threshold'].max())
21
+
22
+ def write_predictions(self, data, y_prob, threshold):
23
+ data = data.copy()
24
+ data['y_prob'] = y_prob
25
+ data['y_pred'] = (np.array(y_prob) >= threshold).astype(int)
26
+ file_handling.write_file(data, self.directory, 'articles_with_predictions.csv')
27
+
28
+ def write_review_file(self, data, y_prob, threshold):
29
+ data = data.loc[np.array(y_prob) >= threshold].copy()
30
+ data = data.drop(columns=['y_prob', 'y_pred'], errors='ignore')
31
+ file_handling.write_file(data, self.directory, 'articles_to_review.csv')
32
+ file_handling.write_file(data, self.directory, 'articles_to_review.ris')
notebooks/src/handlers/model_code_translator.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+ from itertools import product
3
+
4
+
5
+ MODELS = {
6
+ 'BiomedBERT_abstract': '1A2B06A00'
7
+ }
8
+
9
+
10
+ class Dataset(Enum):
11
+ _1 = 'ds1'
12
+ _2 = 'ds2'
13
+ _3 = 'ds3'
14
+
15
+
16
+ class PretrainedModel(Enum):
17
+ O = 'BERT'
18
+ A = 'bioBERT'
19
+ B = 'BiomedBERT'
20
+ L = 'Longformer'
21
+ M = 'BigBird'
22
+
23
+
24
+ class InputColumn(Enum):
25
+ A = 'Abstract'
26
+ T = 'Text'
27
+ M = 'Methods'
28
+ N = 'A+Methods'
29
+
30
+
31
+ class OutputColumn(Enum):
32
+ _1 = 'CLASS--stage_1'
33
+ _2 = 'CLASS'
34
+
35
+
36
+ class LearningRate(Enum):
37
+ A = 0.00002
38
+ B = 0.00003
39
+ C = 0.00004
40
+
41
+
42
+ TEST_SIZE = 0.1
43
+ SEED = 100
44
+
45
+
46
+ class ModelBatchSize(Enum):
47
+ O = 8
48
+ A = 8
49
+ B = 8
50
+ L = 2
51
+ M = 2
52
+
53
+
54
+ class ModelMaxLength(Enum):
55
+ O = 512
56
+ A = 512
57
+ B = 512
58
+ L = 4096
59
+ M = 4096
60
+
61
+
62
+ POSSIBLE_CODE_ELEMENTS = {
63
+ 0: [e.name[-1] for e in Dataset],
64
+ 1: [e.name for e in InputColumn],
65
+ 2: [e.name[-1] for e in OutputColumn],
66
+ 3: [e.name for e in PretrainedModel],
67
+ 4: [i for i in range(10)],
68
+ 5: [i for i in range(10)],
69
+ 6: [e.name for e in LearningRate],
70
+ 7: [i for i in range(10)],
71
+ 8: [i for i in range(10)],
72
+ }
73
+
74
+
75
+ def get_model_specs(code):
76
+ """
77
+ Generate model specifications from the provided code.
78
+ The code must not contain 'x'.
79
+ """
80
+
81
+ model_specs = {
82
+ 'ds': Dataset[f"_{code[0]}"].value,
83
+ 'test_size': TEST_SIZE,
84
+ 'seed': SEED,
85
+ 'input_col': InputColumn[code[1]].value,
86
+ 'output_col': OutputColumn[f"_{code[2]}"].value,
87
+ 'pre_trained_model': PretrainedModel[code[3]].value,
88
+ 'max_length': ModelMaxLength[code[3]].value,
89
+ 'freezed_layers': int(code[4:6]),
90
+ 'learning_rate': LearningRate[code[6]].value,
91
+ 'batch_size': ModelBatchSize[code[3]].value,
92
+ }
93
+ if code[7:] != "00":
94
+ model_specs["max_epochs"] = int(code[7:])
95
+
96
+ return model_specs
97
+
98
+
99
+ class ModelCodeTranslator:
100
+ def __init__(self, code):
101
+ if len(code) != 9:
102
+ raise Exception("Code must be of length 9")
103
+ self.code = code
104
+
105
+ if not 'x' in self.code:
106
+ self.model_specs = get_model_specs(self.code)
107
+ self.model_specs_list = None
108
+ else:
109
+ iterables = [POSSIBLE_CODE_ELEMENTS[i] if char == 'x' else [char] for i, char in enumerate(self.code)]
110
+ codes = [''.join(combination) for combination in product(*iterables)]
111
+ self.model_specs = None
112
+ self.model_specs_list = [get_model_specs(code) for code in codes]
113
+
notebooks/src/handlers/model_loader.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from enum import Enum
2
+
3
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
+ from src.handlers.hyperparameters_handler import Hyperparameters
5
+ from src.nn import pytorch_models
6
+
7
+
8
+ class Library(Enum):
9
+ PYTORCH = 'pytorch'
10
+
11
+
12
+ class ModelUrl(Enum):
13
+ BERT = 'bert-base-uncased'
14
+ bioBERT = 'dmis-lab/biobert-v1.1'
15
+ BiomedBERT = 'microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext'
16
+ Longformer = 'allenai/longformer-base-4096'
17
+
18
+
19
+ class ModelLoader:
20
+ def __init__(self, logger, model_path, specs: dict, library: Library, ):
21
+ self.logger = logger
22
+
23
+ self.target_library = library
24
+ self.hyperparameters = Hyperparameters(**specs)
25
+
26
+ self.model_url = ModelUrl[self.hyperparameters.pre_trained_model].value
27
+ self.tokenizer = AutoTokenizer.from_pretrained(self.model_url, use_fast=False)
28
+ if self.target_library == Library.PYTORCH:
29
+ self.model = AutoModelForSequenceClassification.from_pretrained(self.model_url, num_labels=1,
30
+ problem_type='multi_label_classification')
31
+ self.model_wrapper = pytorch_models.NLPClassifier(
32
+ self.logger, model_path, self.tokenizer, self.model, self.hyperparameters
33
+ )
34
+ else:
35
+ raise NotImplementedError
notebooks/src/load_model.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ from src.utils import logging
4
+ from src.handlers.model_code_translator import ModelCodeTranslator, MODELS
5
+ from src.handlers.model_loader import ModelLoader, Library
6
+ from src.handlers.model_IO_handler import IOHandler
7
+
8
+ risk_factor = globals().get("risk_factor")
9
+ directory = globals().get("directory")
10
+ model_name = 'BiomedBERT_abstract'
11
+
12
+ log_dir = os.path.join(directory, 'logs')
13
+ logger = logging.Logger(log_dir=log_dir)
14
+
15
+ logger.info(f"Creating the model")
16
+ model_dir = os.path.abspath(os.path.join(os.getcwd(), 'models', f'PopulationHealthScreener ({risk_factor})'))
17
+ model_io_handler = IOHandler(directory, model_dir)
18
+ model_specs = ModelCodeTranslator(MODELS[model_name]).model_specs
19
+ library = Library.PYTORCH
20
+ model_path = os.path.join(model_dir, 'model.pth')
21
+ model = ModelLoader(logger, model_path, model_specs, library).model_wrapper
22
+ input_col = model_specs["input_col"]
23
+
24
+ logger.info(f"Loading model weights from {model_path}")
25
+ model.load_fine_tuned_weights()
notebooks/src/nn/pytorch_models.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+ import numpy as np
4
+ import torch
5
+ from torch.utils.data import DataLoader, TensorDataset, WeightedRandomSampler
6
+
7
+
8
+ class NLPClassifier:
9
+ def __init__(self, logger, model_path, tokenizer, model, hyperparameters):
10
+ self.logger = logger
11
+ self.model_path = model_path
12
+ self.hyperparameters = hyperparameters
13
+
14
+ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
15
+
16
+ self.tokenizer = tokenizer
17
+ self.model = model
18
+
19
+ self.model_arch = self.get_model_arch()
20
+ self.freeze_layers()
21
+ self.model.to(self.device)
22
+
23
+ def get_model_arch(self):
24
+ if 'BERT' in self.hyperparameters.pre_trained_model:
25
+ return self.model.bert
26
+ elif self.hyperparameters.pre_trained_model == "Longformer":
27
+ return self.model.longformer
28
+ else:
29
+ raise ValueError("Invalid model type")
30
+
31
+ def load_training_data(self, data, labels):
32
+ labels = torch.tensor(np.array(labels).reshape(-1, 1))
33
+ classes, class_counts = torch.unique(labels, sorted=True, return_counts=True)
34
+ class_weights = 1.0 / torch.tensor(class_counts, dtype=torch.float)
35
+ weights_dict = {cls.item(): weight for cls, weight in zip(classes, class_weights)}
36
+ sample_weights = torch.tensor([weights_dict[t.item()] for t in labels])
37
+ sampler = WeightedRandomSampler(sample_weights, len(sample_weights))
38
+
39
+ encodings = self.tokenizer(data, truncation=True, padding='max_length',
40
+ max_length=self.hyperparameters.max_length, return_tensors='pt')
41
+ dataset = TensorDataset(
42
+ encodings['input_ids'], encodings['attention_mask'], labels.to(torch.float32)
43
+ )
44
+ return DataLoader(dataset, batch_size=self.hyperparameters.batch_size, sampler=sampler)
45
+
46
+ def freeze_layers(self):
47
+ for param in self.model_arch.embeddings.parameters():
48
+ param.requires_grad = False
49
+
50
+ for layer in self.model_arch.encoder.layer[:self.hyperparameters.freezed_layers]:
51
+ for param in layer.parameters():
52
+ param.requires_grad = False
53
+
54
+ def fit(self, x_train, y_train, trained_model_path):
55
+ train_data = self.load_training_data(x_train, y_train)
56
+ optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.hyperparameters.learning_rate)
57
+
58
+ for epoch in range(3):
59
+ self.logger.info(f'Epoch {epoch + 1}/3')
60
+ self.model.train()
61
+ for i, batch in enumerate(train_data):
62
+ progress = f"Batch {i+1}/{len(train_data)}"
63
+ sys.stdout.write('\r' + progress)
64
+ sys.stdout.flush()
65
+ optimizer.zero_grad()
66
+ input_ids = batch[0].to(self.device)
67
+ attention_mask = batch[1].to(self.device)
68
+ labels = batch[2].to(self.device)
69
+ outputs = self.model(input_ids, attention_mask=attention_mask, labels=labels)
70
+ loss = outputs.loss
71
+ loss.backward()
72
+ optimizer.step()
73
+
74
+ self.logger.info("Saving the model at {}".format(trained_model_path))
75
+ torch.save(self.model.state_dict(), trained_model_path)
76
+
77
+ def load_fine_tuned_weights(self):
78
+ state_dict = torch.load(self.model_path, map_location=self.device)
79
+ self.model.load_state_dict(state_dict, strict=False)
80
+
81
+ def load_test_data(self, data):
82
+ encodings = self.tokenizer(data, truncation=True, padding='max_length',
83
+ max_length=self.hyperparameters.max_length, return_tensors='pt')
84
+ dataset = TensorDataset(
85
+ encodings['input_ids'], encodings['attention_mask']
86
+ )
87
+ return DataLoader(dataset, batch_size=self.hyperparameters.batch_size, shuffle=False)
88
+
89
+ def predict(self, data):
90
+ test_data = self.load_test_data(data)
91
+
92
+ predictions = []
93
+ self.model.eval()
94
+
95
+ for i, batch in enumerate(test_data):
96
+ msg = f" Processing article {8 * (i+1)}/{8 *len(test_data)} ⏳"
97
+ sys.stdout.write('\r' + msg)
98
+ sys.stdout.flush()
99
+
100
+ input_ids = batch[0].to(self.device)
101
+ attention_mask = batch[1].to(self.device)
102
+
103
+ with torch.no_grad():
104
+ outputs = self.model(input_ids, attention_mask=attention_mask,
105
+ output_hidden_states=True)
106
+ predictions.extend(torch.sigmoid(outputs.logits).flatten().tolist())
107
+
108
+ return predictions
notebooks/src/predict.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from src.utils import file_handling
3
+
4
+ risk_factor = globals().get('risk_factor')
5
+ directory = globals().get('directory')
6
+ logger = globals().get("logger")
7
+ model_io_handler = globals().get("model_io_handler")
8
+ model = globals().get('model')
9
+ input_col = globals().get('input_col')
10
+
11
+ logger.info(f"Reading new titles and abstracts to screen")
12
+ data = file_handling.read_file(directory, file="downloaded_articles.csv").dropna(subset=[input_col])
13
+ x_test = data[input_col].astype(str).to_list()
14
+
15
+ logger.info("Truncating abstracts to max token length before prediction")
16
+ x_test = [
17
+ model.tokenizer.decode(
18
+ model.tokenizer.encode(text, max_length=model.hyperparameters.max_length, truncation=True),
19
+ skip_special_tokens=True
20
+ )
21
+ for text in x_test
22
+ ]
23
+
24
+ recall_target = globals().get('recall_target').value / 100.0
25
+ threshold = model_io_handler.get_threshold(recall_target)
26
+ logger.info(f"Using threshold {threshold:.4f} for {int(recall_target * 100)}% recall target")
27
+
28
+ logger.info("Running the model and getting predictions")
29
+ y_prob = model.predict(x_test)
30
+ model_io_handler.write_predictions(data, y_prob, threshold)
31
+ logger.info("Writing predictions")
32
+ model_io_handler.write_review_file(data, y_prob, threshold)
33
+ logger.info(f"Writing the review file -- find it at {os.path.join(directory, 'articles_to_review.ris')}")
notebooks/src/utils/file_handling.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pandas as pd
3
+ import xml.etree.ElementTree as ET
4
+
5
+
6
+ def read_file(dir: str, file: str):
7
+ os.makedirs(dir, exist_ok=True)
8
+ extension = file.split('.')[1]
9
+ func_name = f'read_{extension}'
10
+ func = globals()[func_name]
11
+ return func(dir=dir, file=file)
12
+
13
+
14
+ def write_file(data_element, dir: str, file_name: str):
15
+ os.makedirs(dir, exist_ok=True)
16
+ extension = file_name.split('.')[1]
17
+ func_name = f'write_{extension}'
18
+ func = globals()[func_name]
19
+ return func(data_element, dir=dir, file=file_name)
20
+
21
+
22
+ def read_csv(dir: str, file: str) -> pd.DataFrame:
23
+ return pd.read_csv(os.path.join(dir, file), escapechar='\\', index_col=False)
24
+
25
+
26
+ def read_xml(dir: str, file: str) -> pd.DataFrame:
27
+ tree = ET.parse(os.path.join(dir, file))
28
+ root = tree.getroot()
29
+ records = []
30
+ for rec in root.findall('.//record'):
31
+ title = rec.findtext('.//titles/title') or ""
32
+ abstract = rec.findtext('.//abstract') or ""
33
+ pmid = rec.findtext('.//electronic-resource-num')
34
+ label = rec.findtext('.//label/style')
35
+ records.append({
36
+ 'Title': title,
37
+ 'Abstract': abstract,
38
+ 'PMID': pmid,
39
+ 'Label': label
40
+ })
41
+ return pd.DataFrame(records)
42
+
43
+
44
+ def write_csv(df: pd.DataFrame, dir: str, file: str):
45
+ df.to_csv(os.path.join(dir, file), escapechar='\\', index=False)
46
+
47
+
48
+ def write_xlsx(df: pd.DataFrame, dir: str, file: str):
49
+ df.to_excel(os.path.join(dir, file), engine="openpyxl", index=False)
50
+
51
+
52
+ def write_ris(df: pd.DataFrame, dir: str, file: str):
53
+ with open(os.path.join(dir, file), 'w', encoding='utf-8') as f:
54
+ for _, row in df.iterrows():
55
+ f.write("TI - " + str(row.get('Title') or '') + "\n")
56
+ f.write("AB - " + str(row.get('Abstract') or '') + "\n")
57
+ for author in str(row.get('Authors') or '').split(','):
58
+ f.write("AU - " + author.strip() + "\n")
59
+ f.write("DO - " + str(row.get('DOI') or '') + "\n")
60
+ f.write("PM - " + str(row.get('PMID') or '') + "\n")
61
+ f.write("PY - " + str(row.get('Year') or '') + "\n")
62
+ f.write("JO - " + str(row.get('Journal') or '') + "\n")
63
+ f.write("ER - \n\n")
notebooks/src/utils/logging.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import logging
3
+ import os
4
+
5
+
6
+ class LogLevel:
7
+ INFO = logging.INFO
8
+ DEBUG = logging.DEBUG
9
+ WARN = logging.WARN
10
+ ERROR = logging.ERROR
11
+
12
+
13
+ class Logger:
14
+ def __init__(self, log_dir, log_name='log', log_level=LogLevel.INFO):
15
+ self.log_dir = log_dir
16
+ self.log_filename = self.__generate_log_filename(log_name)
17
+
18
+ if not os.path.exists(self.log_dir):
19
+ os.makedirs(self.log_dir, exist_ok=True)
20
+
21
+ try:
22
+ self.__configure_logging(_level=log_level)
23
+ except FileNotFoundError:
24
+ os.makedirs(self.log_dir, exist_ok=True)
25
+ self.__configure_logging(_level=log_level)
26
+
27
+ logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR)
28
+
29
+ @staticmethod
30
+ def __generate_log_filename(log_name):
31
+ current_datetime = datetime.datetime.now()
32
+ formatted_date_time = current_datetime.strftime('%d_%m_%Y_%H_%M_%S')
33
+ return f'{log_name}_{formatted_date_time}.log'
34
+
35
+ def __configure_logging(self, _level=LogLevel.INFO):
36
+ logging.basicConfig(level=_level, format="%(asctime)s [%(levelname)s] %(message)s",
37
+ handlers=[logging.FileHandler(os.path.join(self.log_dir, self.log_filename)),
38
+ logging.StreamHandler()])
39
+
40
+ def info(self, msg):
41
+ # logging.info(msg)
42
+ pass # TODO change back
43
+
44
+ def warn(self, msg):
45
+ logging.warning(msg)
46
+
47
+ def error(self, msg):
48
+ logging.error(msg)