Spaces:
Sleeping
Sleeping
voila deployment
Browse files- .DS_Store +0 -0
- .github/workflows/update-hf.yml +20 -0
- environment.yml +90 -0
- notebooks/.DS_Store +0 -0
- notebooks/notebook.ipynb +47 -0
- notebooks/src/.DS_Store +0 -0
- notebooks/src/download_citations.py +79 -0
- notebooks/src/handlers/.DS_Store +0 -0
- notebooks/src/handlers/hyperparameters_handler.py +29 -0
- notebooks/src/handlers/model_IO_handler.py +32 -0
- notebooks/src/handlers/model_code_translator.py +113 -0
- notebooks/src/handlers/model_loader.py +35 -0
- notebooks/src/load_model.py +25 -0
- notebooks/src/nn/pytorch_models.py +108 -0
- notebooks/src/predict.py +33 -0
- notebooks/src/utils/file_handling.py +63 -0
- notebooks/src/utils/logging.py +48 -0
.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
.github/workflows/update-hf.yml
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: Update Hugging Face repository
|
| 2 |
+
on:
|
| 3 |
+
push:
|
| 4 |
+
branches: [main]
|
| 5 |
+
|
| 6 |
+
workflow_dispatch:
|
| 7 |
+
|
| 8 |
+
jobs:
|
| 9 |
+
push-to-hf:
|
| 10 |
+
runs-on: ubuntu-latest
|
| 11 |
+
steps:
|
| 12 |
+
- uses: actions/checkout@v3
|
| 13 |
+
with:
|
| 14 |
+
fetch-depth: 0
|
| 15 |
+
lfs: true
|
| 16 |
+
- name: Push to Hugging Face
|
| 17 |
+
env:
|
| 18 |
+
HF_USER: ${{ secrets.HF_USER }}
|
| 19 |
+
HF_TOKEN: ${{ secrets.HF_TOKEN }}
|
| 20 |
+
run: git push https://$HF_USER:$HF_TOKEN@huggingface.co/spaces/ncdrisc/AI-Literature-Screening-for-Population-Health main --force
|
environment.yml
ADDED
|
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: voila
|
| 2 |
+
channels:
|
| 3 |
+
- conda-forge
|
| 4 |
+
dependencies:
|
| 5 |
+
- python=3.10
|
| 6 |
+
- pip
|
| 7 |
+
- pip:
|
| 8 |
+
- anyio==4.9.0
|
| 9 |
+
- argon2-cffi==23.1.0
|
| 10 |
+
- argon2-cffi-bindings==21.2.0
|
| 11 |
+
- arrow==1.3.0
|
| 12 |
+
- attrs==25.3.0
|
| 13 |
+
- babel==2.17.0
|
| 14 |
+
- beautifulsoup4==4.13.3
|
| 15 |
+
- biopython==1.85
|
| 16 |
+
- bleach==6.2.0
|
| 17 |
+
- certifi==2025.1.31
|
| 18 |
+
- cffi==1.17.1
|
| 19 |
+
- charset-normalizer==3.4.1
|
| 20 |
+
- comm==0.2.2
|
| 21 |
+
- debugpy==1.8.13
|
| 22 |
+
- decorator==5.2.1
|
| 23 |
+
- defusedxml==0.7.1
|
| 24 |
+
- fastjsonschema==2.21.1
|
| 25 |
+
- filelock==3.18.0
|
| 26 |
+
- fsspec==2025.3.2
|
| 27 |
+
- huggingface-hub==0.30.1
|
| 28 |
+
- idna==3.10
|
| 29 |
+
- ipykernel==6.29.5
|
| 30 |
+
- ipython==8.35.0
|
| 31 |
+
- ipywidgets==8.1.5
|
| 32 |
+
- jedi==0.19.2
|
| 33 |
+
- Jinja2==3.1.6
|
| 34 |
+
- jsonschema==4.23.0
|
| 35 |
+
- jupyter-events==0.12.0
|
| 36 |
+
- jupyter_client==8.6.3
|
| 37 |
+
- jupyter_core==5.7.2
|
| 38 |
+
- jupyter_server==2.15.0
|
| 39 |
+
- jupyter_server_terminals==0.5.3
|
| 40 |
+
- jupyterlab_pygments==0.3.0
|
| 41 |
+
- jupyterlab_server==2.27.3
|
| 42 |
+
- jupyterlab_widgets==3.0.13
|
| 43 |
+
- MarkupSafe==3.0.2
|
| 44 |
+
- mistune==3.1.3
|
| 45 |
+
- nbclient==0.10.2
|
| 46 |
+
- nbconvert==7.16.6
|
| 47 |
+
- nbformat==5.10.4
|
| 48 |
+
- nest-asyncio==1.6.0
|
| 49 |
+
- networkx==3.4.2
|
| 50 |
+
- numpy==1.26.4
|
| 51 |
+
- packaging==24.2
|
| 52 |
+
- pandas==2.2.3
|
| 53 |
+
- platformdirs==4.3.7
|
| 54 |
+
- prometheus_client==0.21.1
|
| 55 |
+
- prompt_toolkit==3.0.50
|
| 56 |
+
- psutil==7.0.0
|
| 57 |
+
- pure_eval==0.2.3
|
| 58 |
+
- pycparser==2.22
|
| 59 |
+
- Pygments==2.19.1
|
| 60 |
+
- python-dateutil==2.9.0.post0
|
| 61 |
+
- python-json-logger==3.3.0
|
| 62 |
+
- pytz==2025.2
|
| 63 |
+
- PyYAML==6.0.2
|
| 64 |
+
- pyzmq==26.4.0
|
| 65 |
+
- referencing==0.36.2
|
| 66 |
+
- regex==2024.11.6
|
| 67 |
+
- requests==2.32.3
|
| 68 |
+
- rpds-py==0.24.0
|
| 69 |
+
- safetensors==0.5.3
|
| 70 |
+
- six==1.17.0
|
| 71 |
+
- sniffio==1.3.1
|
| 72 |
+
- soupsieve==2.6
|
| 73 |
+
- sympy==1.13.3
|
| 74 |
+
- terminado==0.18.1
|
| 75 |
+
- tinycss2==1.4.0
|
| 76 |
+
- tokenizers==0.15.2
|
| 77 |
+
- torch==2.2.2
|
| 78 |
+
- tornado==6.4.2
|
| 79 |
+
- tqdm==4.67.1
|
| 80 |
+
- traitlets==5.14.3
|
| 81 |
+
- transformers==4.38.2
|
| 82 |
+
- typing_extensions==4.13.1
|
| 83 |
+
- tzdata==2025.2
|
| 84 |
+
- urllib3==2.3.0
|
| 85 |
+
- voila==0.5.8
|
| 86 |
+
- wcwidth==0.2.13
|
| 87 |
+
- webencodings==0.5.1
|
| 88 |
+
- websocket-client==1.8.0
|
| 89 |
+
- websockets==15.0.1
|
| 90 |
+
- widgetsnbextension==4.0.13
|
notebooks/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
notebooks/notebook.ipynb
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"metadata": {
|
| 5 |
+
"ExecuteTime": {
|
| 6 |
+
"end_time": "2025-04-10T15:52:17.264471Z",
|
| 7 |
+
"start_time": "2025-04-10T15:52:17.156072Z"
|
| 8 |
+
}
|
| 9 |
+
},
|
| 10 |
+
"cell_type": "code",
|
| 11 |
+
"source": "import subprocess\nfrom datetime import date\nfrom dateutil.relativedelta import relativedelta\nimport os\nimport ipywidgets as widgets\nfrom IPython.display import HTML, clear_output\nimport warnings\nwarnings.filterwarnings(\"ignore\")\n\n# --- UI Elements ---\n\n# Dropdown for risk factor selection\nrisk_factor = widgets.Dropdown(\n options=['Anthro', 'BP', 'Lipids', 'Diabetes'],\n value='Anthro',\n description='',\n disabled=False\n)\n\nquery_method = widgets.RadioButtons(\n options=[\"default\", \"custom\"],\n value=\"default\",\n layout=widgets.Layout(width=\"180px\")\n)\n\n# Custom query input field\ncustom_query_file = widgets.Text(\n placeholder=\"Enter query filename\",\n description=\"\",\n disabled=True,\n layout=widgets.Layout(width=\"300px\")\n)\n\n# Enable text box only when 'custom' is selected\ndef toggle_custom_query(change):\n custom_query_file.disabled = (change.new != \"custom\")\n\nquery_method.observe(toggle_custom_query, names=\"value\")\n\nstart_date_value = date.today() - relativedelta(months=1)\nend_date_value = date.today()\n\nstart_date = widgets.DatePicker(\n description=\"Start date: \",\n value=start_date_value\n)\n\nend_date = widgets.DatePicker(\n description=\"End Date: \",\n value=end_date_value\n)\n\nrecall_target = widgets.Dropdown(\n options=[('95%', 95), ('90%', 90), ('80%', 80), ('70%', 70), ('60%', 60)],\n value=95,\n description='',\n layout=widgets.Layout(width=\"120px\")\n)\n\n# Buttons\n\n# Button style\ndisplay(HTML(\"\"\"\n<style>\n.widget-button {\n justify-content: flex-start !important;\n text-align: left !important;\n font-weight: bold !important;\n font-size: 15px !important;\n padding-left: 12px !important;\n}\n</style>\n\"\"\"))\n\ndownload_btn = widgets.Button(\n description=\"Download abstracts from PubMed\",\n layout=widgets.Layout(width=\"300px\"),\n style=widgets.ButtonStyle(font_weight=\"bold\")\n)\n\nload_model_btn = widgets.Button(\n description=\"Load the model\",\n layout=widgets.Layout(width=\"300px\"),\n style=widgets.ButtonStyle(font_weight=\"bold\")\n)\n\npredict_btn = widgets.Button(\n description=\"Run the model and screen articles\",\n layout=widgets.Layout(width=\"300px\"),\n style=widgets.ButtonStyle(font_weight=\"bold\")\n)\n\n# Output areas\ndownload_output = widgets.Output()\nload_model_output = widgets.Output()\npredict_output = widgets.Output()\n\n# --- Helper function for UI feedback ---\ndef mark_done(output_placeholder, header_widget):\n output_placeholder.clear_output()\n with output_placeholder:\n display(header_widget)\n display(widgets.HTML(\n \"<span style='color: green; font-weight: bold; font-size:15px; margin-left:12px;'>✔ Done</span>\"\n ))\n\n# --- Event Handlers ---\n\ndef toggle_query_visibility(change):\n \"\"\"Enable/Disable custom query input based on checkbox state.\"\"\"\n custom_query_file.disabled = not change.new\n\nquery_method.observe(toggle_custom_query, names=\"value\")\n\ndef run_download_citations(b):\n b.close()\n with download_output:\n clear_output(wait=True)\n header = widgets.HTML(\n \"<div style='font-size:15px; font-weight:bold; margin-left:12px;'>\"\n \"Download abstracts from PubMed\"\n \"</div>\"\n )\n display(header)\n try:\n print(f\" Downloading articles published between {start_date.value} and {end_date.value}...\")\n with open(os.path.abspath(os.path.join(os.getcwd(), \"src/download_citations.py\"))) as f:\n exec(f.read(), globals())\n mark_done(download_output, header)\n path_to_articles = os.path.join(globals().get('directory'), 'downloaded_articles.csv')\n display(widgets.HTML(\n f\"<p style='font-size:12px; color:gray; margin-top:6px; margin-left:12px; margin-bottom:0px;'>\"\n f\"Article abstracts downloaded to: {path_to_articles}</p>\"\n ))\n except Exception:\n import traceback\n traceback.print_exc()\n\ndef run_load_model(b):\n b.close()\n with load_model_output:\n clear_output(wait=True)\n header = widgets.HTML(\n \"<div style='font-size:15px; font-weight:bold; margin-left:12px;'>\"\n \"Load the model\"\n \"</div>\"\n )\n display(header)\n try:\n print(f\" Loading the model...\")\n with open(os.path.abspath(os.path.join(os.getcwd(), \"src/load_model.py\"))) as f:\n exec(f.read(), globals())\n mark_done(load_model_output, header)\n except Exception:\n import traceback\n traceback.print_exc()\n\ndef open_ris_file(b):\n path = os.path.join(globals().get('directory'), 'articles_to_review.ris')\n \"\"\"Opens the .ris file using the system's default application correctly.\"\"\"\n try:\n if os.name == \"posix\": # macOS / Linux\n subprocess.Popen([\"open\", path])\n elif os.name == \"nt\": # Windows\n subprocess.Popen([\"start\", \"\", path], shell=True)\n except Exception:\n import traceback\n traceback.print_exc()\n\ndef run_prediction(b):\n b.close()\n with predict_output:\n clear_output(wait=True)\n header = widgets.HTML(\n \"<div style='font-size:15px; font-weight:bold; margin-left:12px;'>\"\n \"Run the model and screen articles\"\n \"</div>\"\n )\n display(header)\n try:\n print(\" Running the model...\")\n with open(os.path.abspath(os.path.join(os.getcwd(), \"src/predict.py\"))) as f:\n exec(f.read(), globals())\n mark_done(predict_output, header)\n except Exception:\n import traceback\n traceback.print_exc()\n\n path_to_ris = os.path.join(globals().get('directory'), 'articles_to_review.ris')\n\n open_file_btn = widgets.Button(description=f\"📄 Open in EndNote\", layout=widgets.Layout(width=\"350px\"))\n open_file_btn.on_click(open_ris_file)\n\n path_display = widgets.HTML(\n f\"<p style='font-size:12px; color:gray; margin-top:6px; margin-left:12px;'>\"\n f\"Path to the EndNote file: {path_to_ris}</p>\"\n )\n final_message_1 = widgets.HTML(\n \"<p style='font-size:14px; color:black; margin-top:10px; margin-left:12px;'>\"\n \"Open the .ris file in EndNote by clicking on the button above or navigating to the file</p>\"\n )\n final_message_2 = widgets.HTML(\n \"<p style='font-size:14px; color:black; margin-top:10px; margin-left:12px;'>\"\n \"Select RefMan - RIS as the input file format</p>\"\n )\n\n with predict_output:\n display(widgets.VBox([\n open_file_btn,\n path_display,\n final_message_1,\n final_message_2\n ]))\n\n\n# Attach event handlers to buttons\ndownload_btn.on_click(run_download_citations)\nload_model_btn.on_click(run_load_model)\npredict_btn.on_click(run_prediction)\n\n# --- Layout ---\n\ntitle_style = \"font-size: 22px; font-weight: bold; margin-bottom: 40px;\"\nsection_style = \"font-size: 18px; font-weight: bold; margin-bottom: 15px;\"\ntext_style = \"font-size: 16px;\"\nspacing = widgets.Layout(margin=\"20px 0px\")\n\n# Titles\nheader = widgets.HTML(\n f\"<h2 style='margin-left:12px; {title_style}'>Automated screening of the literature</h2>\"\n)\n\nsection1 = widgets.VBox([\n\n widgets.HTML(\"<b style='font-size:16px;'>Choose a risk factor</b>\"),\n risk_factor,\n widgets.HTML(\"<div style='height:25px;'></div>\"),\n\n widgets.HTML(\"<b style='font-size:16px;'>Define the PubMed search</b>\"),\n widgets.HBox([\n widgets.VBox([\n start_date,\n end_date\n ], layout=widgets.Layout(margin=\"0 20px 0 0\")),\n\n widgets.VBox([\n widgets.HBox([\n widgets.Label(\"Query: \", layout=widgets.Layout(width=\"60px\")),\n query_method\n ]),\n custom_query_file\n ], layout=widgets.Layout(margin=\"0 0 0 20px\"))\n ], layout=widgets.Layout(justify_content=\"flex-start\", gap=\"20px\", margin=\"10px 0\"))\n], layout=widgets.Layout(margin=\"0 0 0 12px\"))\n\nsection2 = widgets.VBox([\n # Line 3 — vertically stacked buttons with equal width\n widgets.VBox([\n download_btn,\n download_output,\n widgets.HTML(\"<div style='height:25px;'></div>\"),\n\n load_model_btn,\n load_model_output,\n widgets.HTML(\"<div style='height:25px;'></div>\"),\n\n widgets.HTML(\"<b style='font-size:16px;'>Define how inclusive the model should be</b>\"),\n widgets.HTML(\"<span style='font-size:13px; color:gray; margin-top:2px; display:block;'>Based on the recall achieved in previous testing; the higher the recall, the more inclusive the model</span>\"),\n widgets.HTML(\"<div style='height:8px;'></div>\"),\n recall_target,\n widgets.HTML(\"<div style='height:15px;'></div>\"),\n\n predict_btn,\n predict_output\n ], layout=widgets.Layout(margin=\"30px 0\"))\n])\n\n\ndisplay(header, section1, section2)",
|
| 12 |
+
"id": "35bcc7331ef5d1dc",
|
| 13 |
+
"outputs": [],
|
| 14 |
+
"execution_count": null
|
| 15 |
+
},
|
| 16 |
+
{
|
| 17 |
+
"metadata": {},
|
| 18 |
+
"cell_type": "code",
|
| 19 |
+
"outputs": [],
|
| 20 |
+
"execution_count": null,
|
| 21 |
+
"source": "",
|
| 22 |
+
"id": "eeebfa287ae99109"
|
| 23 |
+
}
|
| 24 |
+
],
|
| 25 |
+
"metadata": {
|
| 26 |
+
"kernelspec": {
|
| 27 |
+
"display_name": "Python 3 (ipykernel)",
|
| 28 |
+
"language": "python",
|
| 29 |
+
"name": "python3"
|
| 30 |
+
},
|
| 31 |
+
"language_info": {
|
| 32 |
+
"codemirror_mode": {
|
| 33 |
+
"name": "ipython",
|
| 34 |
+
"version": 3
|
| 35 |
+
},
|
| 36 |
+
"file_extension": ".py",
|
| 37 |
+
"mimetype": "text/x-python",
|
| 38 |
+
"name": "python",
|
| 39 |
+
"nbconvert_exporter": "python",
|
| 40 |
+
"pygments_lexer": "ipython3",
|
| 41 |
+
"version": "3.12.9"
|
| 42 |
+
},
|
| 43 |
+
"trusted": true
|
| 44 |
+
},
|
| 45 |
+
"nbformat": 4,
|
| 46 |
+
"nbformat_minor": 5
|
| 47 |
+
}
|
notebooks/src/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
notebooks/src/download_citations.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from datetime import date
|
| 3 |
+
from Bio import Entrez
|
| 4 |
+
import pandas as pd
|
| 5 |
+
|
| 6 |
+
risk_factor = globals().get('risk_factor').value
|
| 7 |
+
|
| 8 |
+
risk_factor_directory = os.path.abspath(os.path.join(os.getcwd(), risk_factor))
|
| 9 |
+
|
| 10 |
+
queries_directory = os.path.abspath(os.path.join(os.getcwd(), 'queries', risk_factor))
|
| 11 |
+
use_default_query = globals().get('query_method').value == "default"
|
| 12 |
+
if use_default_query:
|
| 13 |
+
with open(os.path.join(queries_directory, 'default.txt'), 'r') as file:
|
| 14 |
+
base_query = file.read()
|
| 15 |
+
query_suffix = ""
|
| 16 |
+
else:
|
| 17 |
+
custom_query_file = globals().get('custom_query_file').value
|
| 18 |
+
with open(os.path.join(queries_directory, f'{custom_query_file}.txt'), 'r') as file:
|
| 19 |
+
base_query = file.read()
|
| 20 |
+
query_suffix = f"_{custom_query_file}"
|
| 21 |
+
|
| 22 |
+
base_folder_name = date.today().strftime("%Y-%m-%d") + query_suffix
|
| 23 |
+
if not os.path.isdir(os.path.join(risk_factor_directory, base_folder_name)):
|
| 24 |
+
directory = os.path.join(risk_factor_directory, base_folder_name)
|
| 25 |
+
os.makedirs(directory, exist_ok=True)
|
| 26 |
+
else:
|
| 27 |
+
version = 2
|
| 28 |
+
folder_name = f"{base_folder_name}-v{version}"
|
| 29 |
+
directory = os.path.join(risk_factor_directory, folder_name)
|
| 30 |
+
while os.path.isdir(directory):
|
| 31 |
+
version += 1
|
| 32 |
+
folder_name = f"{base_folder_name}-v{version}"
|
| 33 |
+
directory = os.path.join(risk_factor_directory, folder_name)
|
| 34 |
+
os.makedirs(directory, exist_ok=True)
|
| 35 |
+
|
| 36 |
+
start_date = globals().get('start_date').value
|
| 37 |
+
end_date = globals().get('end_date').value
|
| 38 |
+
query = base_query + f'AND (("{start_date}"[Date - Publication] : "{end_date}"[Date - Publication]))'
|
| 39 |
+
|
| 40 |
+
Entrez.email = os.getenv('email')
|
| 41 |
+
search_handle = Entrez.esearch(db="pubmed", term=query, retmax=10000)
|
| 42 |
+
search_results = Entrez.read(search_handle)
|
| 43 |
+
search_handle.close()
|
| 44 |
+
id_list = search_results['IdList']
|
| 45 |
+
fetch_handle = Entrez.efetch(db="pubmed", id=id_list, rettype="xml")
|
| 46 |
+
fetch_results = Entrez.read(fetch_handle)
|
| 47 |
+
fetch_handle.close()
|
| 48 |
+
|
| 49 |
+
papers = []
|
| 50 |
+
for article in fetch_results['PubmedArticle']:
|
| 51 |
+
medline = article['MedlineCitation']
|
| 52 |
+
article_data = medline['Article']
|
| 53 |
+
|
| 54 |
+
title = str(article_data.get('ArticleTitle', ''))
|
| 55 |
+
abstract_text = ' '.join(article_data.get('Abstract', {}).get('AbstractText', ['']))
|
| 56 |
+
abstract_text = (title + ' ' + abstract_text).strip()
|
| 57 |
+
authors = ', '.join(['{} {}'.format(a.get('ForeName', ''), a.get('LastName', ''))
|
| 58 |
+
for a in article_data.get('AuthorList', []) if 'LastName' in a])
|
| 59 |
+
journal = article_data.get('Journal', {}).get('Title', '')
|
| 60 |
+
year = article_data.get('Journal', {}).get('JournalIssue', {}).get('PubDate', {}).get('Year', '')
|
| 61 |
+
pmid = medline.get('PMID', '')
|
| 62 |
+
|
| 63 |
+
doi = ''
|
| 64 |
+
for eid in article_data.get('ELocationID', []):
|
| 65 |
+
if eid.attributes.get('EIdType') == 'doi':
|
| 66 |
+
doi = str(eid)
|
| 67 |
+
|
| 68 |
+
papers.append({
|
| 69 |
+
'PMID': pmid,
|
| 70 |
+
'Title': title,
|
| 71 |
+
'Abstract': abstract_text,
|
| 72 |
+
'Authors': authors,
|
| 73 |
+
'Journal': journal,
|
| 74 |
+
'Year': year,
|
| 75 |
+
'DOI': doi
|
| 76 |
+
})
|
| 77 |
+
|
| 78 |
+
df = pd.DataFrame(papers)
|
| 79 |
+
df.to_csv(os.path.join(directory, 'downloaded_articles.csv'), index=False)
|
notebooks/src/handlers/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
notebooks/src/handlers/hyperparameters_handler.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
class Hyperparameters:
|
| 2 |
+
|
| 3 |
+
def __init__(self,
|
| 4 |
+
ds,
|
| 5 |
+
input_col,
|
| 6 |
+
output_col,
|
| 7 |
+
test_size,
|
| 8 |
+
seed,
|
| 9 |
+
pre_trained_model,
|
| 10 |
+
max_length=512,
|
| 11 |
+
freezed_layers=0,
|
| 12 |
+
batch_size=8,
|
| 13 |
+
learning_rate=0.00003,
|
| 14 |
+
max_epochs=None,
|
| 15 |
+
stop_loss_epochs=5
|
| 16 |
+
):
|
| 17 |
+
self.ds = ds
|
| 18 |
+
self.input_col = input_col
|
| 19 |
+
self.output_col = output_col
|
| 20 |
+
self.test_size = test_size
|
| 21 |
+
self.seed = seed
|
| 22 |
+
self.pre_trained_model = pre_trained_model
|
| 23 |
+
self.max_length = max_length
|
| 24 |
+
self.freezed_layers = freezed_layers
|
| 25 |
+
self.batch_size = batch_size
|
| 26 |
+
self.learning_rate = learning_rate
|
| 27 |
+
self.max_epochs = max_epochs
|
| 28 |
+
if self.max_epochs is None:
|
| 29 |
+
self.stop_loss_epochs = stop_loss_epochs
|
notebooks/src/handlers/model_IO_handler.py
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import numpy as np
|
| 3 |
+
import pandas as pd
|
| 4 |
+
from src.utils import file_handling
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class IOHandler:
|
| 8 |
+
|
| 9 |
+
def __init__(self, directory, model_dir_path):
|
| 10 |
+
self.directory = directory
|
| 11 |
+
self.model_path = os.path.join(model_dir_path, 'model.pth')
|
| 12 |
+
prc_path = os.path.join(model_dir_path, 'precision_recall_curve.csv')
|
| 13 |
+
self._prc = pd.read_csv(prc_path)
|
| 14 |
+
|
| 15 |
+
def get_threshold(self, recall_target):
|
| 16 |
+
"""Return the highest threshold that still achieves at least recall_target recall."""
|
| 17 |
+
above = self._prc[self._prc['Recall'] >= recall_target]
|
| 18 |
+
if above.empty:
|
| 19 |
+
return float(self._prc['Threshold'].iloc[-1])
|
| 20 |
+
return float(above['Threshold'].max())
|
| 21 |
+
|
| 22 |
+
def write_predictions(self, data, y_prob, threshold):
|
| 23 |
+
data = data.copy()
|
| 24 |
+
data['y_prob'] = y_prob
|
| 25 |
+
data['y_pred'] = (np.array(y_prob) >= threshold).astype(int)
|
| 26 |
+
file_handling.write_file(data, self.directory, 'articles_with_predictions.csv')
|
| 27 |
+
|
| 28 |
+
def write_review_file(self, data, y_prob, threshold):
|
| 29 |
+
data = data.loc[np.array(y_prob) >= threshold].copy()
|
| 30 |
+
data = data.drop(columns=['y_prob', 'y_pred'], errors='ignore')
|
| 31 |
+
file_handling.write_file(data, self.directory, 'articles_to_review.csv')
|
| 32 |
+
file_handling.write_file(data, self.directory, 'articles_to_review.ris')
|
notebooks/src/handlers/model_code_translator.py
ADDED
|
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from enum import Enum
|
| 2 |
+
from itertools import product
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
MODELS = {
|
| 6 |
+
'BiomedBERT_abstract': '1A2B06A00'
|
| 7 |
+
}
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class Dataset(Enum):
|
| 11 |
+
_1 = 'ds1'
|
| 12 |
+
_2 = 'ds2'
|
| 13 |
+
_3 = 'ds3'
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class PretrainedModel(Enum):
|
| 17 |
+
O = 'BERT'
|
| 18 |
+
A = 'bioBERT'
|
| 19 |
+
B = 'BiomedBERT'
|
| 20 |
+
L = 'Longformer'
|
| 21 |
+
M = 'BigBird'
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class InputColumn(Enum):
|
| 25 |
+
A = 'Abstract'
|
| 26 |
+
T = 'Text'
|
| 27 |
+
M = 'Methods'
|
| 28 |
+
N = 'A+Methods'
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class OutputColumn(Enum):
|
| 32 |
+
_1 = 'CLASS--stage_1'
|
| 33 |
+
_2 = 'CLASS'
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class LearningRate(Enum):
|
| 37 |
+
A = 0.00002
|
| 38 |
+
B = 0.00003
|
| 39 |
+
C = 0.00004
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
TEST_SIZE = 0.1
|
| 43 |
+
SEED = 100
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class ModelBatchSize(Enum):
|
| 47 |
+
O = 8
|
| 48 |
+
A = 8
|
| 49 |
+
B = 8
|
| 50 |
+
L = 2
|
| 51 |
+
M = 2
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class ModelMaxLength(Enum):
|
| 55 |
+
O = 512
|
| 56 |
+
A = 512
|
| 57 |
+
B = 512
|
| 58 |
+
L = 4096
|
| 59 |
+
M = 4096
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
POSSIBLE_CODE_ELEMENTS = {
|
| 63 |
+
0: [e.name[-1] for e in Dataset],
|
| 64 |
+
1: [e.name for e in InputColumn],
|
| 65 |
+
2: [e.name[-1] for e in OutputColumn],
|
| 66 |
+
3: [e.name for e in PretrainedModel],
|
| 67 |
+
4: [i for i in range(10)],
|
| 68 |
+
5: [i for i in range(10)],
|
| 69 |
+
6: [e.name for e in LearningRate],
|
| 70 |
+
7: [i for i in range(10)],
|
| 71 |
+
8: [i for i in range(10)],
|
| 72 |
+
}
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def get_model_specs(code):
|
| 76 |
+
"""
|
| 77 |
+
Generate model specifications from the provided code.
|
| 78 |
+
The code must not contain 'x'.
|
| 79 |
+
"""
|
| 80 |
+
|
| 81 |
+
model_specs = {
|
| 82 |
+
'ds': Dataset[f"_{code[0]}"].value,
|
| 83 |
+
'test_size': TEST_SIZE,
|
| 84 |
+
'seed': SEED,
|
| 85 |
+
'input_col': InputColumn[code[1]].value,
|
| 86 |
+
'output_col': OutputColumn[f"_{code[2]}"].value,
|
| 87 |
+
'pre_trained_model': PretrainedModel[code[3]].value,
|
| 88 |
+
'max_length': ModelMaxLength[code[3]].value,
|
| 89 |
+
'freezed_layers': int(code[4:6]),
|
| 90 |
+
'learning_rate': LearningRate[code[6]].value,
|
| 91 |
+
'batch_size': ModelBatchSize[code[3]].value,
|
| 92 |
+
}
|
| 93 |
+
if code[7:] != "00":
|
| 94 |
+
model_specs["max_epochs"] = int(code[7:])
|
| 95 |
+
|
| 96 |
+
return model_specs
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
class ModelCodeTranslator:
|
| 100 |
+
def __init__(self, code):
|
| 101 |
+
if len(code) != 9:
|
| 102 |
+
raise Exception("Code must be of length 9")
|
| 103 |
+
self.code = code
|
| 104 |
+
|
| 105 |
+
if not 'x' in self.code:
|
| 106 |
+
self.model_specs = get_model_specs(self.code)
|
| 107 |
+
self.model_specs_list = None
|
| 108 |
+
else:
|
| 109 |
+
iterables = [POSSIBLE_CODE_ELEMENTS[i] if char == 'x' else [char] for i, char in enumerate(self.code)]
|
| 110 |
+
codes = [''.join(combination) for combination in product(*iterables)]
|
| 111 |
+
self.model_specs = None
|
| 112 |
+
self.model_specs_list = [get_model_specs(code) for code in codes]
|
| 113 |
+
|
notebooks/src/handlers/model_loader.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from enum import Enum
|
| 2 |
+
|
| 3 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
| 4 |
+
from src.handlers.hyperparameters_handler import Hyperparameters
|
| 5 |
+
from src.nn import pytorch_models
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class Library(Enum):
|
| 9 |
+
PYTORCH = 'pytorch'
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class ModelUrl(Enum):
|
| 13 |
+
BERT = 'bert-base-uncased'
|
| 14 |
+
bioBERT = 'dmis-lab/biobert-v1.1'
|
| 15 |
+
BiomedBERT = 'microsoft/BiomedNLP-BiomedBERT-base-uncased-abstract-fulltext'
|
| 16 |
+
Longformer = 'allenai/longformer-base-4096'
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class ModelLoader:
|
| 20 |
+
def __init__(self, logger, model_path, specs: dict, library: Library, ):
|
| 21 |
+
self.logger = logger
|
| 22 |
+
|
| 23 |
+
self.target_library = library
|
| 24 |
+
self.hyperparameters = Hyperparameters(**specs)
|
| 25 |
+
|
| 26 |
+
self.model_url = ModelUrl[self.hyperparameters.pre_trained_model].value
|
| 27 |
+
self.tokenizer = AutoTokenizer.from_pretrained(self.model_url, use_fast=False)
|
| 28 |
+
if self.target_library == Library.PYTORCH:
|
| 29 |
+
self.model = AutoModelForSequenceClassification.from_pretrained(self.model_url, num_labels=1,
|
| 30 |
+
problem_type='multi_label_classification')
|
| 31 |
+
self.model_wrapper = pytorch_models.NLPClassifier(
|
| 32 |
+
self.logger, model_path, self.tokenizer, self.model, self.hyperparameters
|
| 33 |
+
)
|
| 34 |
+
else:
|
| 35 |
+
raise NotImplementedError
|
notebooks/src/load_model.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
from src.utils import logging
|
| 4 |
+
from src.handlers.model_code_translator import ModelCodeTranslator, MODELS
|
| 5 |
+
from src.handlers.model_loader import ModelLoader, Library
|
| 6 |
+
from src.handlers.model_IO_handler import IOHandler
|
| 7 |
+
|
| 8 |
+
risk_factor = globals().get("risk_factor")
|
| 9 |
+
directory = globals().get("directory")
|
| 10 |
+
model_name = 'BiomedBERT_abstract'
|
| 11 |
+
|
| 12 |
+
log_dir = os.path.join(directory, 'logs')
|
| 13 |
+
logger = logging.Logger(log_dir=log_dir)
|
| 14 |
+
|
| 15 |
+
logger.info(f"Creating the model")
|
| 16 |
+
model_dir = os.path.abspath(os.path.join(os.getcwd(), 'models', f'PopulationHealthScreener ({risk_factor})'))
|
| 17 |
+
model_io_handler = IOHandler(directory, model_dir)
|
| 18 |
+
model_specs = ModelCodeTranslator(MODELS[model_name]).model_specs
|
| 19 |
+
library = Library.PYTORCH
|
| 20 |
+
model_path = os.path.join(model_dir, 'model.pth')
|
| 21 |
+
model = ModelLoader(logger, model_path, model_specs, library).model_wrapper
|
| 22 |
+
input_col = model_specs["input_col"]
|
| 23 |
+
|
| 24 |
+
logger.info(f"Loading model weights from {model_path}")
|
| 25 |
+
model.load_fine_tuned_weights()
|
notebooks/src/nn/pytorch_models.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
from torch.utils.data import DataLoader, TensorDataset, WeightedRandomSampler
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class NLPClassifier:
|
| 9 |
+
def __init__(self, logger, model_path, tokenizer, model, hyperparameters):
|
| 10 |
+
self.logger = logger
|
| 11 |
+
self.model_path = model_path
|
| 12 |
+
self.hyperparameters = hyperparameters
|
| 13 |
+
|
| 14 |
+
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 15 |
+
|
| 16 |
+
self.tokenizer = tokenizer
|
| 17 |
+
self.model = model
|
| 18 |
+
|
| 19 |
+
self.model_arch = self.get_model_arch()
|
| 20 |
+
self.freeze_layers()
|
| 21 |
+
self.model.to(self.device)
|
| 22 |
+
|
| 23 |
+
def get_model_arch(self):
|
| 24 |
+
if 'BERT' in self.hyperparameters.pre_trained_model:
|
| 25 |
+
return self.model.bert
|
| 26 |
+
elif self.hyperparameters.pre_trained_model == "Longformer":
|
| 27 |
+
return self.model.longformer
|
| 28 |
+
else:
|
| 29 |
+
raise ValueError("Invalid model type")
|
| 30 |
+
|
| 31 |
+
def load_training_data(self, data, labels):
|
| 32 |
+
labels = torch.tensor(np.array(labels).reshape(-1, 1))
|
| 33 |
+
classes, class_counts = torch.unique(labels, sorted=True, return_counts=True)
|
| 34 |
+
class_weights = 1.0 / torch.tensor(class_counts, dtype=torch.float)
|
| 35 |
+
weights_dict = {cls.item(): weight for cls, weight in zip(classes, class_weights)}
|
| 36 |
+
sample_weights = torch.tensor([weights_dict[t.item()] for t in labels])
|
| 37 |
+
sampler = WeightedRandomSampler(sample_weights, len(sample_weights))
|
| 38 |
+
|
| 39 |
+
encodings = self.tokenizer(data, truncation=True, padding='max_length',
|
| 40 |
+
max_length=self.hyperparameters.max_length, return_tensors='pt')
|
| 41 |
+
dataset = TensorDataset(
|
| 42 |
+
encodings['input_ids'], encodings['attention_mask'], labels.to(torch.float32)
|
| 43 |
+
)
|
| 44 |
+
return DataLoader(dataset, batch_size=self.hyperparameters.batch_size, sampler=sampler)
|
| 45 |
+
|
| 46 |
+
def freeze_layers(self):
|
| 47 |
+
for param in self.model_arch.embeddings.parameters():
|
| 48 |
+
param.requires_grad = False
|
| 49 |
+
|
| 50 |
+
for layer in self.model_arch.encoder.layer[:self.hyperparameters.freezed_layers]:
|
| 51 |
+
for param in layer.parameters():
|
| 52 |
+
param.requires_grad = False
|
| 53 |
+
|
| 54 |
+
def fit(self, x_train, y_train, trained_model_path):
|
| 55 |
+
train_data = self.load_training_data(x_train, y_train)
|
| 56 |
+
optimizer = torch.optim.AdamW(self.model.parameters(), lr=self.hyperparameters.learning_rate)
|
| 57 |
+
|
| 58 |
+
for epoch in range(3):
|
| 59 |
+
self.logger.info(f'Epoch {epoch + 1}/3')
|
| 60 |
+
self.model.train()
|
| 61 |
+
for i, batch in enumerate(train_data):
|
| 62 |
+
progress = f"Batch {i+1}/{len(train_data)}"
|
| 63 |
+
sys.stdout.write('\r' + progress)
|
| 64 |
+
sys.stdout.flush()
|
| 65 |
+
optimizer.zero_grad()
|
| 66 |
+
input_ids = batch[0].to(self.device)
|
| 67 |
+
attention_mask = batch[1].to(self.device)
|
| 68 |
+
labels = batch[2].to(self.device)
|
| 69 |
+
outputs = self.model(input_ids, attention_mask=attention_mask, labels=labels)
|
| 70 |
+
loss = outputs.loss
|
| 71 |
+
loss.backward()
|
| 72 |
+
optimizer.step()
|
| 73 |
+
|
| 74 |
+
self.logger.info("Saving the model at {}".format(trained_model_path))
|
| 75 |
+
torch.save(self.model.state_dict(), trained_model_path)
|
| 76 |
+
|
| 77 |
+
def load_fine_tuned_weights(self):
|
| 78 |
+
state_dict = torch.load(self.model_path, map_location=self.device)
|
| 79 |
+
self.model.load_state_dict(state_dict, strict=False)
|
| 80 |
+
|
| 81 |
+
def load_test_data(self, data):
|
| 82 |
+
encodings = self.tokenizer(data, truncation=True, padding='max_length',
|
| 83 |
+
max_length=self.hyperparameters.max_length, return_tensors='pt')
|
| 84 |
+
dataset = TensorDataset(
|
| 85 |
+
encodings['input_ids'], encodings['attention_mask']
|
| 86 |
+
)
|
| 87 |
+
return DataLoader(dataset, batch_size=self.hyperparameters.batch_size, shuffle=False)
|
| 88 |
+
|
| 89 |
+
def predict(self, data):
|
| 90 |
+
test_data = self.load_test_data(data)
|
| 91 |
+
|
| 92 |
+
predictions = []
|
| 93 |
+
self.model.eval()
|
| 94 |
+
|
| 95 |
+
for i, batch in enumerate(test_data):
|
| 96 |
+
msg = f" Processing article {8 * (i+1)}/{8 *len(test_data)} ⏳"
|
| 97 |
+
sys.stdout.write('\r' + msg)
|
| 98 |
+
sys.stdout.flush()
|
| 99 |
+
|
| 100 |
+
input_ids = batch[0].to(self.device)
|
| 101 |
+
attention_mask = batch[1].to(self.device)
|
| 102 |
+
|
| 103 |
+
with torch.no_grad():
|
| 104 |
+
outputs = self.model(input_ids, attention_mask=attention_mask,
|
| 105 |
+
output_hidden_states=True)
|
| 106 |
+
predictions.extend(torch.sigmoid(outputs.logits).flatten().tolist())
|
| 107 |
+
|
| 108 |
+
return predictions
|
notebooks/src/predict.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from src.utils import file_handling
|
| 3 |
+
|
| 4 |
+
risk_factor = globals().get('risk_factor')
|
| 5 |
+
directory = globals().get('directory')
|
| 6 |
+
logger = globals().get("logger")
|
| 7 |
+
model_io_handler = globals().get("model_io_handler")
|
| 8 |
+
model = globals().get('model')
|
| 9 |
+
input_col = globals().get('input_col')
|
| 10 |
+
|
| 11 |
+
logger.info(f"Reading new titles and abstracts to screen")
|
| 12 |
+
data = file_handling.read_file(directory, file="downloaded_articles.csv").dropna(subset=[input_col])
|
| 13 |
+
x_test = data[input_col].astype(str).to_list()
|
| 14 |
+
|
| 15 |
+
logger.info("Truncating abstracts to max token length before prediction")
|
| 16 |
+
x_test = [
|
| 17 |
+
model.tokenizer.decode(
|
| 18 |
+
model.tokenizer.encode(text, max_length=model.hyperparameters.max_length, truncation=True),
|
| 19 |
+
skip_special_tokens=True
|
| 20 |
+
)
|
| 21 |
+
for text in x_test
|
| 22 |
+
]
|
| 23 |
+
|
| 24 |
+
recall_target = globals().get('recall_target').value / 100.0
|
| 25 |
+
threshold = model_io_handler.get_threshold(recall_target)
|
| 26 |
+
logger.info(f"Using threshold {threshold:.4f} for {int(recall_target * 100)}% recall target")
|
| 27 |
+
|
| 28 |
+
logger.info("Running the model and getting predictions")
|
| 29 |
+
y_prob = model.predict(x_test)
|
| 30 |
+
model_io_handler.write_predictions(data, y_prob, threshold)
|
| 31 |
+
logger.info("Writing predictions")
|
| 32 |
+
model_io_handler.write_review_file(data, y_prob, threshold)
|
| 33 |
+
logger.info(f"Writing the review file -- find it at {os.path.join(directory, 'articles_to_review.ris')}")
|
notebooks/src/utils/file_handling.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import pandas as pd
|
| 3 |
+
import xml.etree.ElementTree as ET
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def read_file(dir: str, file: str):
|
| 7 |
+
os.makedirs(dir, exist_ok=True)
|
| 8 |
+
extension = file.split('.')[1]
|
| 9 |
+
func_name = f'read_{extension}'
|
| 10 |
+
func = globals()[func_name]
|
| 11 |
+
return func(dir=dir, file=file)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def write_file(data_element, dir: str, file_name: str):
|
| 15 |
+
os.makedirs(dir, exist_ok=True)
|
| 16 |
+
extension = file_name.split('.')[1]
|
| 17 |
+
func_name = f'write_{extension}'
|
| 18 |
+
func = globals()[func_name]
|
| 19 |
+
return func(data_element, dir=dir, file=file_name)
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def read_csv(dir: str, file: str) -> pd.DataFrame:
|
| 23 |
+
return pd.read_csv(os.path.join(dir, file), escapechar='\\', index_col=False)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def read_xml(dir: str, file: str) -> pd.DataFrame:
|
| 27 |
+
tree = ET.parse(os.path.join(dir, file))
|
| 28 |
+
root = tree.getroot()
|
| 29 |
+
records = []
|
| 30 |
+
for rec in root.findall('.//record'):
|
| 31 |
+
title = rec.findtext('.//titles/title') or ""
|
| 32 |
+
abstract = rec.findtext('.//abstract') or ""
|
| 33 |
+
pmid = rec.findtext('.//electronic-resource-num')
|
| 34 |
+
label = rec.findtext('.//label/style')
|
| 35 |
+
records.append({
|
| 36 |
+
'Title': title,
|
| 37 |
+
'Abstract': abstract,
|
| 38 |
+
'PMID': pmid,
|
| 39 |
+
'Label': label
|
| 40 |
+
})
|
| 41 |
+
return pd.DataFrame(records)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def write_csv(df: pd.DataFrame, dir: str, file: str):
|
| 45 |
+
df.to_csv(os.path.join(dir, file), escapechar='\\', index=False)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def write_xlsx(df: pd.DataFrame, dir: str, file: str):
|
| 49 |
+
df.to_excel(os.path.join(dir, file), engine="openpyxl", index=False)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def write_ris(df: pd.DataFrame, dir: str, file: str):
|
| 53 |
+
with open(os.path.join(dir, file), 'w', encoding='utf-8') as f:
|
| 54 |
+
for _, row in df.iterrows():
|
| 55 |
+
f.write("TI - " + str(row.get('Title') or '') + "\n")
|
| 56 |
+
f.write("AB - " + str(row.get('Abstract') or '') + "\n")
|
| 57 |
+
for author in str(row.get('Authors') or '').split(','):
|
| 58 |
+
f.write("AU - " + author.strip() + "\n")
|
| 59 |
+
f.write("DO - " + str(row.get('DOI') or '') + "\n")
|
| 60 |
+
f.write("PM - " + str(row.get('PMID') or '') + "\n")
|
| 61 |
+
f.write("PY - " + str(row.get('Year') or '') + "\n")
|
| 62 |
+
f.write("JO - " + str(row.get('Journal') or '') + "\n")
|
| 63 |
+
f.write("ER - \n\n")
|
notebooks/src/utils/logging.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import datetime
|
| 2 |
+
import logging
|
| 3 |
+
import os
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class LogLevel:
|
| 7 |
+
INFO = logging.INFO
|
| 8 |
+
DEBUG = logging.DEBUG
|
| 9 |
+
WARN = logging.WARN
|
| 10 |
+
ERROR = logging.ERROR
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class Logger:
|
| 14 |
+
def __init__(self, log_dir, log_name='log', log_level=LogLevel.INFO):
|
| 15 |
+
self.log_dir = log_dir
|
| 16 |
+
self.log_filename = self.__generate_log_filename(log_name)
|
| 17 |
+
|
| 18 |
+
if not os.path.exists(self.log_dir):
|
| 19 |
+
os.makedirs(self.log_dir, exist_ok=True)
|
| 20 |
+
|
| 21 |
+
try:
|
| 22 |
+
self.__configure_logging(_level=log_level)
|
| 23 |
+
except FileNotFoundError:
|
| 24 |
+
os.makedirs(self.log_dir, exist_ok=True)
|
| 25 |
+
self.__configure_logging(_level=log_level)
|
| 26 |
+
|
| 27 |
+
logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR)
|
| 28 |
+
|
| 29 |
+
@staticmethod
|
| 30 |
+
def __generate_log_filename(log_name):
|
| 31 |
+
current_datetime = datetime.datetime.now()
|
| 32 |
+
formatted_date_time = current_datetime.strftime('%d_%m_%Y_%H_%M_%S')
|
| 33 |
+
return f'{log_name}_{formatted_date_time}.log'
|
| 34 |
+
|
| 35 |
+
def __configure_logging(self, _level=LogLevel.INFO):
|
| 36 |
+
logging.basicConfig(level=_level, format="%(asctime)s [%(levelname)s] %(message)s",
|
| 37 |
+
handlers=[logging.FileHandler(os.path.join(self.log_dir, self.log_filename)),
|
| 38 |
+
logging.StreamHandler()])
|
| 39 |
+
|
| 40 |
+
def info(self, msg):
|
| 41 |
+
# logging.info(msg)
|
| 42 |
+
pass # TODO change back
|
| 43 |
+
|
| 44 |
+
def warn(self, msg):
|
| 45 |
+
logging.warning(msg)
|
| 46 |
+
|
| 47 |
+
def error(self, msg):
|
| 48 |
+
logging.error(msg)
|